# Directory Structure ``` ├── .cursor │ └── rules │ ├── modelcontextprotocol.mdc │ └── python.mdc ├── .github │ ├── CODEOWNERS │ └── workflows │ ├── lint.yml │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .tool-versions ├── Dockerfile ├── LICENSE ├── Makefile ├── pyproject.toml ├── pytest.ini ├── README.md ├── smithery.yaml ├── src │ ├── __init__.py │ ├── __main__.py │ ├── airflow │ │ ├── __init__.py │ │ ├── airflow_client.py │ │ ├── config.py │ │ ├── connection.py │ │ ├── dag.py │ │ ├── dagrun.py │ │ ├── dagstats.py │ │ ├── dataset.py │ │ ├── eventlog.py │ │ ├── importerror.py │ │ ├── monitoring.py │ │ ├── plugin.py │ │ ├── pool.py │ │ ├── provider.py │ │ ├── taskinstance.py │ │ ├── variable.py │ │ └── xcom.py │ ├── enums.py │ ├── envs.py │ ├── main.py │ └── server.py ├── test │ ├── __init__.py │ ├── airflow │ │ ├── test_dag.py │ │ └── test_taskinstance.py │ ├── conftest.py │ ├── test_airflow_client.py │ ├── test_main.py │ └── test_server.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /.tool-versions: -------------------------------------------------------------------------------- ``` 1 | python 3.12.6 2 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | .pypirc 23 | .ruff_cache/ 24 | 25 | # Virtual Environment 26 | .env 27 | .venv 28 | env/ 29 | venv/ 30 | ENV/ 31 | 32 | # IDE 33 | .idea/ 34 | .vscode/ 35 | *.swp 36 | *.swo 37 | 38 | # OS 39 | .DS_Store 40 | Thumbs.db 41 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- ```markdown 1 | [](https://mseep.ai/app/yangkyeongmo-mcp-server-apache-airflow) 2 | 3 | # mcp-server-apache-airflow 4 | 5 | [](https://smithery.ai/server/@yangkyeongmo/mcp-server-apache-airflow) 6 |  7 | 8 | A Model Context Protocol (MCP) server implementation for Apache Airflow, enabling seamless integration with MCP clients. This project provides a standardized way to interact with Apache Airflow through the Model Context Protocol. 9 | 10 | <a href="https://glama.ai/mcp/servers/e99b6vx9lw"> 11 | <img width="380" height="200" src="https://glama.ai/mcp/servers/e99b6vx9lw/badge" alt="Server for Apache Airflow MCP server" /> 12 | </a> 13 | 14 | ## About 15 | 16 | This project implements a [Model Context Protocol](https://modelcontextprotocol.io/introduction) server that wraps Apache Airflow's REST API, allowing MCP clients to interact with Airflow in a standardized way. It uses the official Apache Airflow client library to ensure compatibility and maintainability. 17 | 18 | ## Feature Implementation Status 19 | 20 | | Feature | API Path | Status | 21 | | -------------------------------- | --------------------------------------------------------------------------------------------- | ------ | 22 | | **DAG Management** | | | 23 | | List DAGs | `/api/v1/dags` | ✅ | 24 | | Get DAG Details | `/api/v1/dags/{dag_id}` | ✅ | 25 | | Pause DAG | `/api/v1/dags/{dag_id}` | ✅ | 26 | | Unpause DAG | `/api/v1/dags/{dag_id}` | ✅ | 27 | | Update DAG | `/api/v1/dags/{dag_id}` | ✅ | 28 | | Delete DAG | `/api/v1/dags/{dag_id}` | ✅ | 29 | | Get DAG Source | `/api/v1/dagSources/{file_token}` | ✅ | 30 | | Patch Multiple DAGs | `/api/v1/dags` | ✅ | 31 | | Reparse DAG File | `/api/v1/dagSources/{file_token}/reparse` | ✅ | 32 | | **DAG Runs** | | | 33 | | List DAG Runs | `/api/v1/dags/{dag_id}/dagRuns` | ✅ | 34 | | Create DAG Run | `/api/v1/dags/{dag_id}/dagRuns` | ✅ | 35 | | Get DAG Run Details | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}` | ✅ | 36 | | Update DAG Run | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}` | ✅ | 37 | | Delete DAG Run | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}` | ✅ | 38 | | Get DAG Runs Batch | `/api/v1/dags/~/dagRuns/list` | ✅ | 39 | | Clear DAG Run | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear` | ✅ | 40 | | Set DAG Run Note | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/setNote` | ✅ | 41 | | Get Upstream Dataset Events | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/upstreamDatasetEvents` | ✅ | 42 | | **Tasks** | | | 43 | | List DAG Tasks | `/api/v1/dags/{dag_id}/tasks` | ✅ | 44 | | Get Task Details | `/api/v1/dags/{dag_id}/tasks/{task_id}` | ✅ | 45 | | Get Task Instance | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}` | ✅ | 46 | | List Task Instances | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances` | ✅ | 47 | | Update Task Instance | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}` | ✅ | 48 | | Get Task Instance Log | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/logs/{task_try_number}` | ✅ | 49 | | Clear Task Instances | `/api/v1/dags/{dag_id}/clearTaskInstances` | ✅ | 50 | | Set Task Instances State | `/api/v1/dags/{dag_id}/updateTaskInstancesState` | ✅ | 51 | | List Task Instance Tries | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/tries` | ✅ | 52 | | **Variables** | | | 53 | | List Variables | `/api/v1/variables` | ✅ | 54 | | Create Variable | `/api/v1/variables` | ✅ | 55 | | Get Variable | `/api/v1/variables/{variable_key}` | ✅ | 56 | | Update Variable | `/api/v1/variables/{variable_key}` | ✅ | 57 | | Delete Variable | `/api/v1/variables/{variable_key}` | ✅ | 58 | | **Connections** | | | 59 | | List Connections | `/api/v1/connections` | ✅ | 60 | | Create Connection | `/api/v1/connections` | ✅ | 61 | | Get Connection | `/api/v1/connections/{connection_id}` | ✅ | 62 | | Update Connection | `/api/v1/connections/{connection_id}` | ✅ | 63 | | Delete Connection | `/api/v1/connections/{connection_id}` | ✅ | 64 | | Test Connection | `/api/v1/connections/test` | ✅ | 65 | | **Pools** | | | 66 | | List Pools | `/api/v1/pools` | ✅ | 67 | | Create Pool | `/api/v1/pools` | ✅ | 68 | | Get Pool | `/api/v1/pools/{pool_name}` | ✅ | 69 | | Update Pool | `/api/v1/pools/{pool_name}` | ✅ | 70 | | Delete Pool | `/api/v1/pools/{pool_name}` | ✅ | 71 | | **XComs** | | | 72 | | List XComs | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries` | ✅ | 73 | | Get XCom Entry | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}` | ✅ | 74 | | **Datasets** | | | 75 | | List Datasets | `/api/v1/datasets` | ✅ | 76 | | Get Dataset | `/api/v1/datasets/{uri}` | ✅ | 77 | | Get Dataset Events | `/api/v1/datasetEvents` | ✅ | 78 | | Create Dataset Event | `/api/v1/datasetEvents` | ✅ | 79 | | Get DAG Dataset Queued Event | `/api/v1/dags/{dag_id}/dagRuns/queued/datasetEvents/{uri}` | ✅ | 80 | | Get DAG Dataset Queued Events | `/api/v1/dags/{dag_id}/dagRuns/queued/datasetEvents` | ✅ | 81 | | Delete DAG Dataset Queued Event | `/api/v1/dags/{dag_id}/dagRuns/queued/datasetEvents/{uri}` | ✅ | 82 | | Delete DAG Dataset Queued Events | `/api/v1/dags/{dag_id}/dagRuns/queued/datasetEvents` | ✅ | 83 | | Get Dataset Queued Events | `/api/v1/datasets/{uri}/dagRuns/queued/datasetEvents` | ✅ | 84 | | Delete Dataset Queued Events | `/api/v1/datasets/{uri}/dagRuns/queued/datasetEvents` | ✅ | 85 | | **Monitoring** | | | 86 | | Get Health | `/api/v1/health` | ✅ | 87 | | **DAG Stats** | | | 88 | | Get DAG Stats | `/api/v1/dags/statistics` | ✅ | 89 | | **Config** | | | 90 | | Get Config | `/api/v1/config` | ✅ | 91 | | **Plugins** | | | 92 | | Get Plugins | `/api/v1/plugins` | ✅ | 93 | | **Providers** | | | 94 | | List Providers | `/api/v1/providers` | ✅ | 95 | | **Event Logs** | | | 96 | | List Event Logs | `/api/v1/eventLogs` | ✅ | 97 | | Get Event Log | `/api/v1/eventLogs/{event_log_id}` | ✅ | 98 | | **System** | | | 99 | | Get Import Errors | `/api/v1/importErrors` | ✅ | 100 | | Get Import Error Details | `/api/v1/importErrors/{import_error_id}` | ✅ | 101 | | Get Health Status | `/api/v1/health` | ✅ | 102 | | Get Version | `/api/v1/version` | ✅ | 103 | 104 | ## Setup 105 | 106 | ### Dependencies 107 | 108 | This project depends on the official Apache Airflow client library (`apache-airflow-client`). It will be automatically installed when you install this package. 109 | 110 | ### Environment Variables 111 | 112 | Set the following environment variables: 113 | 114 | ``` 115 | AIRFLOW_HOST=<your-airflow-host> # Optional, defaults to http://localhost:8080 116 | AIRFLOW_API_VERSION=v1 # Optional, defaults to v1 117 | READ_ONLY=true # Optional, enables read-only mode (true/false, defaults to false) 118 | ``` 119 | 120 | #### Authentication 121 | 122 | Choose one of the following authentication methods: 123 | 124 | **Basic Authentication (default):** 125 | ``` 126 | AIRFLOW_USERNAME=<your-airflow-username> 127 | AIRFLOW_PASSWORD=<your-airflow-password> 128 | ``` 129 | 130 | **JWT Token Authentication:** 131 | ``` 132 | AIRFLOW_JWT_TOKEN=<your-jwt-token> 133 | ``` 134 | 135 | To obtain a JWT token, you can use Airflow's authentication endpoint: 136 | 137 | ```bash 138 | ENDPOINT_URL="http://localhost:8080" # Replace with your Airflow endpoint 139 | curl -X 'POST' \ 140 | "${ENDPOINT_URL}/auth/token" \ 141 | -H 'Content-Type: application/json' \ 142 | -d '{ "username": "<your-username>", "password": "<your-password>" }' 143 | ``` 144 | 145 | > **Note**: If both JWT token and basic authentication credentials are provided, JWT token takes precedence. 146 | 147 | ### Usage with Claude Desktop 148 | 149 | Add to your `claude_desktop_config.json`: 150 | 151 | **Basic Authentication:** 152 | ```json 153 | { 154 | "mcpServers": { 155 | "mcp-server-apache-airflow": { 156 | "command": "uvx", 157 | "args": ["mcp-server-apache-airflow"], 158 | "env": { 159 | "AIRFLOW_HOST": "https://your-airflow-host", 160 | "AIRFLOW_USERNAME": "your-username", 161 | "AIRFLOW_PASSWORD": "your-password" 162 | } 163 | } 164 | } 165 | } 166 | ``` 167 | 168 | **JWT Token Authentication:** 169 | ```json 170 | { 171 | "mcpServers": { 172 | "mcp-server-apache-airflow": { 173 | "command": "uvx", 174 | "args": ["mcp-server-apache-airflow"], 175 | "env": { 176 | "AIRFLOW_HOST": "https://your-airflow-host", 177 | "AIRFLOW_JWT_TOKEN": "your-jwt-token" 178 | } 179 | } 180 | } 181 | } 182 | ``` 183 | 184 | For read-only mode (recommended for safety): 185 | 186 | **Basic Authentication:** 187 | ```json 188 | { 189 | "mcpServers": { 190 | "mcp-server-apache-airflow": { 191 | "command": "uvx", 192 | "args": ["mcp-server-apache-airflow"], 193 | "env": { 194 | "AIRFLOW_HOST": "https://your-airflow-host", 195 | "AIRFLOW_USERNAME": "your-username", 196 | "AIRFLOW_PASSWORD": "your-password", 197 | "READ_ONLY": "true" 198 | } 199 | } 200 | } 201 | } 202 | ``` 203 | 204 | **JWT Token Authentication:** 205 | ```json 206 | { 207 | "mcpServers": { 208 | "mcp-server-apache-airflow": { 209 | "command": "uvx", 210 | "args": ["mcp-server-apache-airflow", "--read-only"], 211 | "env": { 212 | "AIRFLOW_HOST": "https://your-airflow-host", 213 | "AIRFLOW_JWT_TOKEN": "your-jwt-token" 214 | } 215 | } 216 | } 217 | } 218 | ``` 219 | 220 | Alternative configuration using `uv`: 221 | 222 | **Basic Authentication:** 223 | ```json 224 | { 225 | "mcpServers": { 226 | "mcp-server-apache-airflow": { 227 | "command": "uv", 228 | "args": [ 229 | "--directory", 230 | "/path/to/mcp-server-apache-airflow", 231 | "run", 232 | "mcp-server-apache-airflow" 233 | ], 234 | "env": { 235 | "AIRFLOW_HOST": "https://your-airflow-host", 236 | "AIRFLOW_USERNAME": "your-username", 237 | "AIRFLOW_PASSWORD": "your-password" 238 | } 239 | } 240 | } 241 | } 242 | ``` 243 | 244 | **JWT Token Authentication:** 245 | ```json 246 | { 247 | "mcpServers": { 248 | "mcp-server-apache-airflow": { 249 | "command": "uv", 250 | "args": [ 251 | "--directory", 252 | "/path/to/mcp-server-apache-airflow", 253 | "run", 254 | "mcp-server-apache-airflow" 255 | ], 256 | "env": { 257 | "AIRFLOW_HOST": "https://your-airflow-host", 258 | "AIRFLOW_JWT_TOKEN": "your-jwt-token" 259 | } 260 | } 261 | } 262 | } 263 | ``` 264 | 265 | Replace `/path/to/mcp-server-apache-airflow` with the actual path where you've cloned the repository. 266 | 267 | ### Selecting the API groups 268 | 269 | You can select the API groups you want to use by setting the `--apis` flag. 270 | 271 | ```bash 272 | uv run mcp-server-apache-airflow --apis dag --apis dagrun 273 | ``` 274 | 275 | The default is to use all APIs. 276 | 277 | Allowed values are: 278 | 279 | - config 280 | - connections 281 | - dag 282 | - dagrun 283 | - dagstats 284 | - dataset 285 | - eventlog 286 | - importerror 287 | - monitoring 288 | - plugin 289 | - pool 290 | - provider 291 | - taskinstance 292 | - variable 293 | - xcom 294 | 295 | ### Read-Only Mode 296 | 297 | You can run the server in read-only mode by using the `--read-only` flag or by setting the `READ_ONLY=true` environment variable. This will only expose tools that perform read operations (GET requests) and exclude any tools that create, update, or delete resources. 298 | 299 | Using the command-line flag: 300 | ```bash 301 | uv run mcp-server-apache-airflow --read-only 302 | ``` 303 | 304 | Using the environment variable: 305 | ```bash 306 | READ_ONLY=true uv run mcp-server-apache-airflow 307 | ``` 308 | 309 | In read-only mode, the server will only expose tools like: 310 | - Listing DAGs, DAG runs, tasks, variables, connections, etc. 311 | - Getting details of specific resources 312 | - Reading configurations and monitoring information 313 | - Testing connections (non-destructive) 314 | 315 | Write operations like creating, updating, deleting DAGs, variables, connections, triggering DAG runs, etc. will not be available in read-only mode. 316 | 317 | You can combine read-only mode with API group selection: 318 | 319 | ```bash 320 | uv run mcp-server-apache-airflow --read-only --apis dag --apis variable 321 | ``` 322 | 323 | ### Manual Execution 324 | 325 | You can also run the server manually: 326 | 327 | ```bash 328 | make run 329 | ``` 330 | 331 | `make run` accepts following options: 332 | 333 | Options: 334 | 335 | - `--port`: Port to listen on for SSE (default: 8000) 336 | - `--transport`: Transport type (stdio/sse/http, default: stdio) 337 | 338 | Or, you could run the sse server directly, which accepts same parameters: 339 | 340 | ```bash 341 | make run-sse 342 | ``` 343 | 344 | Also, you could start service directly using `uv` like in the following command: 345 | 346 | ```bash 347 | uv run src --transport http --port 8080 348 | ``` 349 | 350 | ### Installing via Smithery 351 | 352 | To install Apache Airflow MCP Server for Claude Desktop automatically via [Smithery](https://smithery.ai/server/@yangkyeongmo/mcp-server-apache-airflow): 353 | 354 | ```bash 355 | npx -y @smithery/cli install @yangkyeongmo/mcp-server-apache-airflow --client claude 356 | ``` 357 | 358 | ## Development 359 | 360 | ### Setting up Development Environment 361 | 362 | 1. Clone the repository: 363 | ```bash 364 | git clone https://github.com/yangkyeongmo/mcp-server-apache-airflow.git 365 | cd mcp-server-apache-airflow 366 | ``` 367 | 368 | 2. Install development dependencies: 369 | ```bash 370 | uv sync --dev 371 | ``` 372 | 373 | 3. Create a `.env` file for environment variables (optional for development): 374 | ```bash 375 | touch .env 376 | ``` 377 | 378 | > **Note**: No environment variables are required for running tests. The `AIRFLOW_HOST` defaults to `http://localhost:8080` for development and testing purposes. 379 | 380 | ### Running Tests 381 | 382 | The project uses pytest for testing with the following commands available: 383 | 384 | ```bash 385 | # Run all tests 386 | make test 387 | ``` 388 | 389 | ### Code Quality 390 | 391 | ```bash 392 | # Run linting 393 | make lint 394 | 395 | # Run code formatting 396 | make format 397 | ``` 398 | 399 | ### Continuous Integration 400 | 401 | The project includes a GitHub Actions workflow (`.github/workflows/test.yml`) that automatically: 402 | 403 | - Runs tests on Python 3.10, 3.11, and 3.12 404 | - Executes linting checks using ruff 405 | - Runs on every push and pull request to `main` branch 406 | 407 | The CI pipeline ensures code quality and compatibility across supported Python versions before any changes are merged. 408 | 409 | ## Contributing 410 | 411 | Contributions are welcome! Please feel free to submit a Pull Request. 412 | 413 | The package is deployed automatically to PyPI when project.version is updated in `pyproject.toml`. 414 | Follow semver for versioning. 415 | 416 | Please include version update in the PR in order to apply the changes to core logic. 417 | 418 | ## License 419 | 420 | [MIT License](LICENSE) 421 | ``` -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /src/airflow/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- ```python 1 | # Test package initialization 2 | ``` -------------------------------------------------------------------------------- /src/__main__.py: -------------------------------------------------------------------------------- ```python 1 | import sys 2 | 3 | from src.main import main 4 | 5 | sys.exit(main()) 6 | ``` -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- ```python 1 | from fastmcp import FastMCP 2 | 3 | app = FastMCP("mcp-apache-airflow") 4 | ``` -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- ``` 1 | [pytest] 2 | minversion = 6.0 3 | addopts = -ra -q --strict-markers --strict-config 4 | testpaths = test 5 | python_files = test_*.py 6 | python_classes = Test* 7 | python_functions = test_* 8 | asyncio_mode = auto 9 | markers = 10 | integration: marks tests as integration tests (deselect with '-m "not integration"') 11 | slow: marks tests as slow (deselect with '-m "not slow"') 12 | unit: marks tests as unit tests 13 | ``` -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Pytest configuration and shared fixtures for the test suite. 3 | 4 | This file contains shared test fixtures, configurations, and utilities 5 | that can be used across all test modules. 6 | """ 7 | 8 | import sys 9 | from pathlib import Path 10 | 11 | # Add the src directory to the Python path for imports during testing 12 | src_path = Path(__file__).parent.parent / "src" 13 | if str(src_path) not in sys.path: 14 | sys.path.insert(0, str(src_path)) 15 | ``` -------------------------------------------------------------------------------- /src/enums.py: -------------------------------------------------------------------------------- ```python 1 | from enum import Enum 2 | 3 | 4 | class APIType(str, Enum): 5 | CONFIG = "config" 6 | CONNECTION = "connection" 7 | DAG = "dag" 8 | DAGRUN = "dagrun" 9 | DAGSTATS = "dagstats" 10 | DATASET = "dataset" 11 | EVENTLOG = "eventlog" 12 | IMPORTERROR = "importerror" 13 | MONITORING = "monitoring" 14 | PLUGIN = "plugin" 15 | POOL = "pool" 16 | PROVIDER = "provider" 17 | TASKINSTANCE = "taskinstance" 18 | VARIABLE = "variable" 19 | XCOM = "xcom" 20 | ``` -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- ```dockerfile 1 | # Generated by https://smithery.ai. See: https://smithery.ai/docs/config#dockerfile 2 | # Use a Python base image 3 | FROM python:3.10-slim 4 | 5 | # Set the working directory 6 | WORKDIR /app 7 | 8 | # Copy the contents of the repository to the working directory 9 | COPY . . 10 | 11 | # Install the project dependencies 12 | RUN pip install uv 13 | RUN uv sync 14 | 15 | # Expose the port that the server will run on 16 | EXPOSE 8000 17 | 18 | # Command to run the server 19 | CMD ["uv", "run", "src", "--transport", "sse"] 20 | ``` -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- ```yaml 1 | name: Lint and Format Check 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.12" 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install ruff 24 | 25 | - name: Check formatting 26 | run: ruff format --check . 27 | 28 | - name: Run linting 29 | run: ruff check . 30 | ``` -------------------------------------------------------------------------------- /src/envs.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | from urllib.parse import urlparse 3 | 4 | # Environment variables for Airflow connection 5 | # AIRFLOW_HOST defaults to localhost for development/testing if not provided 6 | _airflow_host_raw = os.getenv("AIRFLOW_HOST", "http://localhost:8080") 7 | AIRFLOW_HOST = urlparse(_airflow_host_raw)._replace(path="").geturl().rstrip("/") 8 | 9 | # Authentication - supports both basic auth and JWT token auth 10 | AIRFLOW_USERNAME = os.getenv("AIRFLOW_USERNAME") 11 | AIRFLOW_PASSWORD = os.getenv("AIRFLOW_PASSWORD") 12 | AIRFLOW_JWT_TOKEN = os.getenv("AIRFLOW_JWT_TOKEN") 13 | AIRFLOW_API_VERSION = os.getenv("AIRFLOW_API_VERSION", "v1") 14 | 15 | # Environment variable for read-only mode 16 | READ_ONLY = os.getenv("READ_ONLY", "false").lower() in ("true", "1", "yes", "on") 17 | ``` -------------------------------------------------------------------------------- /src/airflow/airflow_client.py: -------------------------------------------------------------------------------- ```python 1 | from urllib.parse import urljoin 2 | 3 | from airflow_client.client import ApiClient, Configuration 4 | 5 | from src.envs import ( 6 | AIRFLOW_API_VERSION, 7 | AIRFLOW_HOST, 8 | AIRFLOW_JWT_TOKEN, 9 | AIRFLOW_PASSWORD, 10 | AIRFLOW_USERNAME, 11 | ) 12 | 13 | # Create a configuration and API client 14 | configuration = Configuration( 15 | host=urljoin(AIRFLOW_HOST, f"/api/{AIRFLOW_API_VERSION}"), 16 | ) 17 | 18 | # Set up authentication - prefer JWT token if available, fallback to basic auth 19 | if AIRFLOW_JWT_TOKEN: 20 | configuration.api_key = {"Authorization": f"Bearer {AIRFLOW_JWT_TOKEN}"} 21 | configuration.api_key_prefix = {"Authorization": ""} 22 | elif AIRFLOW_USERNAME and AIRFLOW_PASSWORD: 23 | configuration.username = AIRFLOW_USERNAME 24 | configuration.password = AIRFLOW_PASSWORD 25 | 26 | api_client = ApiClient(configuration) 27 | ``` -------------------------------------------------------------------------------- /src/airflow/dagstats.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.dag_stats_api import DagStatsApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | dag_stats_api = DagStatsApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (get_dag_stats, "get_dag_stats", "Get DAG stats", True), 15 | ] 16 | 17 | 18 | async def get_dag_stats( 19 | dag_ids: Optional[List[str]] = None, 20 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 21 | # Build parameters dictionary 22 | kwargs: Dict[str, Any] = {} 23 | if dag_ids is not None: 24 | kwargs["dag_ids"] = dag_ids 25 | 26 | response = dag_stats_api.get_dag_stats(**kwargs) 27 | return [types.TextContent(type="text", text=str(response.to_dict()))] 28 | ``` -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- ```yaml 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | contents: read 12 | id-token: write # Required for trusted publishing 13 | steps: 14 | - name: Checkout repository 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.x' 21 | 22 | - name: Install dependencies 23 | run: pip install build 24 | 25 | - name: Build package 26 | run: python -m build 27 | 28 | - name: Publish package to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | with: 31 | user: __token__ 32 | password: ${{ secrets.PYPI_API_TOKEN }} 33 | # Alternatively, if using trusted publishing (recommended): 34 | # See: https://docs.pypi.org/trusted-publishers/ 35 | # attestation-check-repository: yangkyeongmo/mcp-server-apache-airflow 36 | # attestation-check-workflow: publish.yml # Optional: if your workflow file is named differently 37 | ``` -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- ```yaml 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12"] 15 | 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install uv 26 | uses: astral-sh/setup-uv@v3 27 | with: 28 | version: "latest" 29 | 30 | - name: Create .env file 31 | run: | 32 | touch .env 33 | echo "# Environment variables for testing" > .env 34 | 35 | - name: Install dependencies 36 | run: | 37 | uv sync --dev 38 | 39 | - name: Run linting 40 | run: | 41 | make lint 42 | 43 | - name: Run tests 44 | run: | 45 | make test 46 | 47 | - name: Upload coverage reports 48 | if: matrix.python-version == '3.11' 49 | uses: codecov/codecov-action@v4 50 | with: 51 | fail_ci_if_error: false 52 | ``` -------------------------------------------------------------------------------- /smithery.yaml: -------------------------------------------------------------------------------- ```yaml 1 | # Smithery configuration file: https://smithery.ai/docs/config#smitheryyaml 2 | 3 | startCommand: 4 | type: stdio 5 | configSchema: 6 | # JSON Schema defining the configuration options for the MCP. 7 | type: object 8 | required: 9 | - airflowHost 10 | - airflowUsername 11 | - airflowPassword 12 | properties: 13 | airflowHost: 14 | type: string 15 | description: The host URL for the Airflow instance. 16 | airflowUsername: 17 | type: string 18 | description: The username for Airflow authentication. 19 | airflowPassword: 20 | type: string 21 | description: The password for Airflow authentication. 22 | airflowApiVersion: 23 | type: string 24 | description: The Airflow API version to use (defaults to v1). 25 | commandFunction: 26 | # A function that produces the CLI command to start the MCP on stdio. 27 | |- 28 | (config) => ({ command: 'python', args: ['src/server.py'], env: { AIRFLOW_HOST: config.airflowHost, AIRFLOW_USERNAME: config.airflowUsername, AIRFLOW_PASSWORD: config.airflowPassword, AIRFLOW_API_VERSION: config.airflowApiVersion || 'v1' } }) 29 | ``` -------------------------------------------------------------------------------- /src/airflow/plugin.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.plugin_api import PluginApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | plugin_api = PluginApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (get_plugins, "get_plugins", "Get a list of loaded plugins", True), 15 | ] 16 | 17 | 18 | async def get_plugins( 19 | limit: Optional[int] = None, 20 | offset: Optional[int] = None, 21 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 22 | """ 23 | Get a list of loaded plugins. 24 | 25 | Args: 26 | limit: The numbers of items to return. 27 | offset: The number of items to skip before starting to collect the result set. 28 | 29 | Returns: 30 | A list of loaded plugins. 31 | """ 32 | # Build parameters dictionary 33 | kwargs: Dict[str, Any] = {} 34 | if limit is not None: 35 | kwargs["limit"] = limit 36 | if offset is not None: 37 | kwargs["offset"] = offset 38 | 39 | response = plugin_api.get_plugins(**kwargs) 40 | return [types.TextContent(type="text", text=str(response.to_dict()))] 41 | ``` -------------------------------------------------------------------------------- /src/airflow/monitoring.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Callable, List, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.monitoring_api import MonitoringApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | monitoring_api = MonitoringApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (get_health, "get_health", "Get instance status", True), 15 | (get_version, "get_version", "Get version information", True), 16 | ] 17 | 18 | 19 | async def get_health() -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 20 | """ 21 | Get the status of Airflow's metadatabase, triggerer and scheduler. 22 | It includes info about metadatabase and last heartbeat of scheduler and triggerer. 23 | """ 24 | response = monitoring_api.get_health() 25 | return [types.TextContent(type="text", text=str(response.to_dict()))] 26 | 27 | 28 | async def get_version() -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 29 | """ 30 | Get version information about Airflow. 31 | """ 32 | response = monitoring_api.get_version() 33 | return [types.TextContent(type="text", text=str(response.to_dict()))] 34 | ``` -------------------------------------------------------------------------------- /src/airflow/config.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.config_api import ConfigApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | config_api = ConfigApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (get_config, "get_config", "Get current configuration", True), 15 | (get_value, "get_value", "Get a specific option from configuration", True), 16 | ] 17 | 18 | 19 | async def get_config( 20 | section: Optional[str] = None, 21 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 22 | # Build parameters dictionary 23 | kwargs: Dict[str, Any] = {} 24 | if section is not None: 25 | kwargs["section"] = section 26 | 27 | response = config_api.get_config(**kwargs) 28 | return [types.TextContent(type="text", text=str(response.to_dict()))] 29 | 30 | 31 | async def get_value( 32 | section: str, option: str 33 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 34 | response = config_api.get_value(section=section, option=option) 35 | return [types.TextContent(type="text", text=str(response.to_dict()))] 36 | ``` -------------------------------------------------------------------------------- /src/airflow/provider.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.provider_api import ProviderApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | provider_api = ProviderApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (get_providers, "get_providers", "Get a list of loaded providers", True), 15 | ] 16 | 17 | 18 | async def get_providers( 19 | limit: Optional[int] = None, 20 | offset: Optional[int] = None, 21 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 22 | """ 23 | Get a list of providers. 24 | 25 | Args: 26 | limit: The numbers of items to return. 27 | offset: The number of items to skip before starting to collect the result set. 28 | 29 | Returns: 30 | A list of providers with their details. 31 | """ 32 | # Build parameters dictionary 33 | kwargs: Dict[str, Any] = {} 34 | if limit is not None: 35 | kwargs["limit"] = limit 36 | if offset is not None: 37 | kwargs["offset"] = offset 38 | 39 | response = provider_api.get_providers(**kwargs) 40 | return [types.TextContent(type="text", text=str(response.to_dict()))] 41 | ``` -------------------------------------------------------------------------------- /src/airflow/importerror.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.import_error_api import ImportErrorApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | import_error_api = ImportErrorApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (get_import_errors, "get_import_errors", "List import errors", True), 15 | (get_import_error, "get_import_error", "Get a specific import error by ID", True), 16 | ] 17 | 18 | 19 | async def get_import_errors( 20 | limit: Optional[int] = None, 21 | offset: Optional[int] = None, 22 | order_by: Optional[str] = None, 23 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 24 | # Build parameters dictionary 25 | kwargs: Dict[str, Any] = {} 26 | if limit is not None: 27 | kwargs["limit"] = limit 28 | if offset is not None: 29 | kwargs["offset"] = offset 30 | if order_by is not None: 31 | kwargs["order_by"] = order_by 32 | 33 | response = import_error_api.get_import_errors(**kwargs) 34 | return [types.TextContent(type="text", text=str(response.to_dict()))] 35 | 36 | 37 | async def get_import_error( 38 | import_error_id: int, 39 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 40 | response = import_error_api.get_import_error(import_error_id=import_error_id) 41 | return [types.TextContent(type="text", text=str(response.to_dict()))] 42 | ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- ```toml 1 | [project] 2 | name = "mcp-server-apache-airflow" 3 | version = "0.2.9" 4 | description = "Model Context Protocol (MCP) server for Apache Airflow" 5 | authors = [ 6 | { name = "Gyeongmo Yang", email = "[email protected]" } 7 | ] 8 | dependencies = [ 9 | "httpx>=0.24.1", 10 | "click>=8.1.7", 11 | "mcp>=0.1.0", 12 | "apache-airflow-client>=2.7.0,<3", 13 | "fastmcp>=2.11.3", 14 | "PyJWT>=2.8.0", 15 | ] 16 | requires-python = ">=3.10" 17 | readme = "README.md" 18 | license = { text = "MIT" } 19 | classifiers = [ 20 | "Development Status :: 3 - Alpha", 21 | "Intended Audience :: Developers", 22 | "License :: OSI Approved :: MIT License", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: 3.12", 27 | "Topic :: Software Development :: Libraries :: Python Modules", 28 | ] 29 | keywords = ["mcp", "airflow", "apache-airflow", "model-context-protocol"] 30 | 31 | [project.optional-dependencies] 32 | dev = [ 33 | "build>=1.2.2.post1", 34 | "twine>=6.1.0", 35 | ] 36 | 37 | [project.urls] 38 | Homepage = "https://github.com/yangkyeongmo/mcp-server-apache-airflow" 39 | Repository = "https://github.com/yangkyeongmo/mcp-server-apache-airflow.git" 40 | "Bug Tracker" = "https://github.com/yangkyeongmo/mcp-server-apache-airflow/issues" 41 | 42 | [build-system] 43 | requires = ["hatchling"] 44 | build-backend = "hatchling.build" 45 | 46 | [project.scripts] 47 | mcp-server-apache-airflow = "src.main:main" 48 | 49 | [tool.hatch.build.targets.wheel] 50 | packages = ["src"] 51 | 52 | [tool.hatch.build] 53 | include = [ 54 | "src/**/*.py", 55 | "README.md", 56 | "LICENSE", 57 | ] 58 | 59 | [tool.ruff] 60 | line-length = 120 61 | 62 | [tool.ruff.lint] 63 | select = ["E", "W", "F", "B", "I"] 64 | 65 | [dependency-groups] 66 | dev = [ 67 | "ruff>=0.11.0", 68 | "pytest>=7.0.0", 69 | "pytest-cov>=4.0.0", 70 | "pytest-asyncio>=0.21.0", 71 | ] 72 | ``` -------------------------------------------------------------------------------- /src/airflow/xcom.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.x_com_api import XComApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | xcom_api = XComApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (get_xcom_entries, "get_xcom_entries", "Get all XCom entries", True), 15 | (get_xcom_entry, "get_xcom_entry", "Get an XCom entry", True), 16 | ] 17 | 18 | 19 | async def get_xcom_entries( 20 | dag_id: str, 21 | dag_run_id: str, 22 | task_id: str, 23 | map_index: Optional[int] = None, 24 | xcom_key: Optional[str] = None, 25 | limit: Optional[int] = None, 26 | offset: Optional[int] = None, 27 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 28 | # Build parameters dictionary 29 | kwargs: Dict[str, Any] = {} 30 | if map_index is not None: 31 | kwargs["map_index"] = map_index 32 | if xcom_key is not None: 33 | kwargs["xcom_key"] = xcom_key 34 | if limit is not None: 35 | kwargs["limit"] = limit 36 | if offset is not None: 37 | kwargs["offset"] = offset 38 | 39 | response = xcom_api.get_xcom_entries(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, **kwargs) 40 | return [types.TextContent(type="text", text=str(response.to_dict()))] 41 | 42 | 43 | async def get_xcom_entry( 44 | dag_id: str, 45 | dag_run_id: str, 46 | task_id: str, 47 | xcom_key: str, 48 | map_index: Optional[int] = None, 49 | deserialize: Optional[bool] = None, 50 | stringify: Optional[bool] = None, 51 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 52 | # Build parameters dictionary 53 | kwargs: Dict[str, Any] = {} 54 | if map_index is not None: 55 | kwargs["map_index"] = map_index 56 | if deserialize is not None: 57 | kwargs["deserialize"] = deserialize 58 | if stringify is not None: 59 | kwargs["stringify"] = stringify 60 | 61 | response = xcom_api.get_xcom_entry( 62 | dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, xcom_key=xcom_key, **kwargs 63 | ) 64 | return [types.TextContent(type="text", text=str(response.to_dict()))] 65 | ``` -------------------------------------------------------------------------------- /src/airflow/eventlog.py: -------------------------------------------------------------------------------- ```python 1 | from datetime import datetime 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | 4 | import mcp.types as types 5 | from airflow_client.client.api.event_log_api import EventLogApi 6 | 7 | from src.airflow.airflow_client import api_client 8 | 9 | event_log_api = EventLogApi(api_client) 10 | 11 | 12 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 13 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 14 | return [ 15 | (get_event_logs, "get_event_logs", "List log entries from event log", True), 16 | (get_event_log, "get_event_log", "Get a specific log entry by ID", True), 17 | ] 18 | 19 | 20 | async def get_event_logs( 21 | limit: Optional[int] = None, 22 | offset: Optional[int] = None, 23 | order_by: Optional[str] = None, 24 | dag_id: Optional[str] = None, 25 | task_id: Optional[str] = None, 26 | run_id: Optional[str] = None, 27 | map_index: Optional[int] = None, 28 | try_number: Optional[int] = None, 29 | event: Optional[str] = None, 30 | owner: Optional[str] = None, 31 | before: Optional[datetime] = None, 32 | after: Optional[datetime] = None, 33 | included_events: Optional[str] = None, 34 | excluded_events: Optional[str] = None, 35 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 36 | # Build parameters dictionary 37 | kwargs: Dict[str, Any] = {} 38 | if limit is not None: 39 | kwargs["limit"] = limit 40 | if offset is not None: 41 | kwargs["offset"] = offset 42 | if order_by is not None: 43 | kwargs["order_by"] = order_by 44 | if dag_id is not None: 45 | kwargs["dag_id"] = dag_id 46 | if task_id is not None: 47 | kwargs["task_id"] = task_id 48 | if run_id is not None: 49 | kwargs["run_id"] = run_id 50 | if map_index is not None: 51 | kwargs["map_index"] = map_index 52 | if try_number is not None: 53 | kwargs["try_number"] = try_number 54 | if event is not None: 55 | kwargs["event"] = event 56 | if owner is not None: 57 | kwargs["owner"] = owner 58 | if before is not None: 59 | kwargs["before"] = before 60 | if after is not None: 61 | kwargs["after"] = after 62 | if included_events is not None: 63 | kwargs["included_events"] = included_events 64 | if excluded_events is not None: 65 | kwargs["excluded_events"] = excluded_events 66 | 67 | response = event_log_api.get_event_logs(**kwargs) 68 | return [types.TextContent(type="text", text=str(response.to_dict()))] 69 | 70 | 71 | async def get_event_log( 72 | event_log_id: int, 73 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 74 | response = event_log_api.get_event_log(event_log_id=event_log_id) 75 | return [types.TextContent(type="text", text=str(response.to_dict()))] 76 | ``` -------------------------------------------------------------------------------- /src/airflow/variable.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.variable_api import VariableApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | variable_api = VariableApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (list_variables, "list_variables", "List all variables", True), 15 | (create_variable, "create_variable", "Create a variable", False), 16 | (get_variable, "get_variable", "Get a variable by key", True), 17 | (update_variable, "update_variable", "Update a variable by key", False), 18 | (delete_variable, "delete_variable", "Delete a variable by key", False), 19 | ] 20 | 21 | 22 | async def list_variables( 23 | limit: Optional[int] = None, 24 | offset: Optional[int] = None, 25 | order_by: Optional[str] = None, 26 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 27 | # Build parameters dictionary 28 | kwargs: Dict[str, Any] = {} 29 | if limit is not None: 30 | kwargs["limit"] = limit 31 | if offset is not None: 32 | kwargs["offset"] = offset 33 | if order_by is not None: 34 | kwargs["order_by"] = order_by 35 | 36 | response = variable_api.get_variables(**kwargs) 37 | return [types.TextContent(type="text", text=str(response.to_dict()))] 38 | 39 | 40 | async def create_variable( 41 | key: str, value: str, description: Optional[str] = None 42 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 43 | variable_request = { 44 | "key": key, 45 | "value": value, 46 | } 47 | if description is not None: 48 | variable_request["description"] = description 49 | 50 | response = variable_api.post_variables(variable_request=variable_request) 51 | return [types.TextContent(type="text", text=str(response.to_dict()))] 52 | 53 | 54 | async def get_variable(key: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 55 | response = variable_api.get_variable(variable_key=key) 56 | return [types.TextContent(type="text", text=str(response.to_dict()))] 57 | 58 | 59 | async def update_variable( 60 | key: str, value: Optional[str] = None, description: Optional[str] = None 61 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 62 | update_request = {} 63 | if value is not None: 64 | update_request["value"] = value 65 | if description is not None: 66 | update_request["description"] = description 67 | 68 | response = variable_api.patch_variable( 69 | variable_key=key, update_mask=list(update_request.keys()), variable_request=update_request 70 | ) 71 | return [types.TextContent(type="text", text=str(response.to_dict()))] 72 | 73 | 74 | async def delete_variable(key: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 75 | response = variable_api.delete_variable(variable_key=key) 76 | return [types.TextContent(type="text", text=str(response.to_dict()))] 77 | ``` -------------------------------------------------------------------------------- /test/test_server.py: -------------------------------------------------------------------------------- ```python 1 | """Tests for the server module using pytest framework.""" 2 | 3 | import pytest 4 | from fastmcp import FastMCP 5 | from fastmcp.tools import Tool 6 | 7 | 8 | class TestServer: 9 | """Test cases for the server module.""" 10 | 11 | def test_app_instance_type(self): 12 | """Test that app instance is of correct type.""" 13 | from src.server import app 14 | 15 | # Verify app is an instance of FastMCP 16 | assert isinstance(app, FastMCP) 17 | 18 | def test_app_instance_name(self): 19 | """Test that app instance has the correct name.""" 20 | from src.server import app 21 | 22 | # Verify the app name is set correctly 23 | assert app.name == "mcp-apache-airflow" 24 | 25 | def test_app_instance_is_singleton(self): 26 | """Test that importing the app multiple times returns the same instance.""" 27 | from src.server import app as app1 28 | from src.server import app as app2 29 | 30 | # Verify same instance is returned 31 | assert app1 is app2 32 | 33 | def test_app_has_required_methods(self): 34 | """Test that app instance has required FastMCP methods.""" 35 | from src.server import app 36 | 37 | # Verify essential methods exist 38 | assert hasattr(app, "add_tool") 39 | assert hasattr(app, "run") 40 | assert callable(app.add_tool) 41 | assert callable(app.run) 42 | 43 | def test_app_initialization_attributes(self): 44 | """Test that app is properly initialized with default attributes.""" 45 | from src.server import app 46 | 47 | # Verify basic FastMCP attributes 48 | assert app.name is not None 49 | assert app.name == "mcp-apache-airflow" 50 | 51 | # Verify app can be used (doesn't raise exceptions on basic operations) 52 | try: 53 | # These should not raise exceptions 54 | str(app) 55 | repr(app) 56 | except Exception as e: 57 | pytest.fail(f"Basic app operations failed: {e}") 58 | 59 | def test_app_name_format(self): 60 | """Test that app name follows expected format.""" 61 | from src.server import app 62 | 63 | # Verify name format 64 | assert isinstance(app.name, str) 65 | assert app.name.startswith("mcp-") 66 | assert "airflow" in app.name 67 | 68 | @pytest.mark.integration 69 | def test_app_tool_registration_capability(self): 70 | """Test that app can register tools without errors.""" 71 | from src.server import app 72 | 73 | # Mock function to register 74 | def test_tool(): 75 | return "test result" 76 | 77 | # This should not raise an exception 78 | try: 79 | app.add_tool(Tool.from_function(test_tool, name="test_tool", description="Test tool")) 80 | except Exception as e: 81 | pytest.fail(f"Tool registration failed: {e}") 82 | 83 | def test_app_module_level_initialization(self): 84 | """Test that app is initialized at module level.""" 85 | # Import should work without any setup 86 | from src.server import app 87 | 88 | # App should be ready to use immediately 89 | assert app is not None 90 | assert hasattr(app, "name") 91 | assert app.name == "mcp-apache-airflow" 92 | ``` -------------------------------------------------------------------------------- /src/airflow/pool.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.pool_api import PoolApi 5 | from airflow_client.client.model.pool import Pool 6 | 7 | from src.airflow.airflow_client import api_client 8 | 9 | pool_api = PoolApi(api_client) 10 | 11 | 12 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 13 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 14 | return [ 15 | (get_pools, "get_pools", "List pools", True), 16 | (get_pool, "get_pool", "Get a pool by name", True), 17 | (delete_pool, "delete_pool", "Delete a pool", False), 18 | (post_pool, "post_pool", "Create a pool", False), 19 | (patch_pool, "patch_pool", "Update a pool", False), 20 | ] 21 | 22 | 23 | async def get_pools( 24 | limit: Optional[int] = None, 25 | offset: Optional[int] = None, 26 | order_by: Optional[str] = None, 27 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 28 | """ 29 | List pools. 30 | 31 | Args: 32 | limit: The numbers of items to return. 33 | offset: The number of items to skip before starting to collect the result set. 34 | order_by: The name of the field to order the results by. Prefix a field name with `-` to reverse the sort order. 35 | 36 | Returns: 37 | A list of pools. 38 | """ 39 | # Build parameters dictionary 40 | kwargs: Dict[str, Any] = {} 41 | if limit is not None: 42 | kwargs["limit"] = limit 43 | if offset is not None: 44 | kwargs["offset"] = offset 45 | if order_by is not None: 46 | kwargs["order_by"] = order_by 47 | 48 | response = pool_api.get_pools(**kwargs) 49 | return [types.TextContent(type="text", text=str(response.to_dict()))] 50 | 51 | 52 | async def get_pool( 53 | pool_name: str, 54 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 55 | """ 56 | Get a pool by name. 57 | 58 | Args: 59 | pool_name: The pool name. 60 | 61 | Returns: 62 | The pool details. 63 | """ 64 | response = pool_api.get_pool(pool_name=pool_name) 65 | return [types.TextContent(type="text", text=str(response.to_dict()))] 66 | 67 | 68 | async def delete_pool( 69 | pool_name: str, 70 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 71 | """ 72 | Delete a pool. 73 | 74 | Args: 75 | pool_name: The pool name. 76 | 77 | Returns: 78 | A confirmation message. 79 | """ 80 | pool_api.delete_pool(pool_name=pool_name) 81 | return [types.TextContent(type="text", text=f"Pool '{pool_name}' deleted successfully.")] 82 | 83 | 84 | async def post_pool( 85 | name: str, 86 | slots: int, 87 | description: Optional[str] = None, 88 | include_deferred: Optional[bool] = None, 89 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 90 | """ 91 | Create a pool. 92 | 93 | Args: 94 | name: The pool name. 95 | slots: The number of slots. 96 | description: The pool description. 97 | include_deferred: Whether to include deferred tasks in slot calculations. 98 | 99 | Returns: 100 | The created pool details. 101 | """ 102 | pool = Pool( 103 | name=name, 104 | slots=slots, 105 | ) 106 | 107 | if description is not None: 108 | pool.description = description 109 | 110 | if include_deferred is not None: 111 | pool.include_deferred = include_deferred 112 | 113 | response = pool_api.post_pool(pool=pool) 114 | return [types.TextContent(type="text", text=str(response.to_dict()))] 115 | 116 | 117 | async def patch_pool( 118 | pool_name: str, 119 | slots: Optional[int] = None, 120 | description: Optional[str] = None, 121 | include_deferred: Optional[bool] = None, 122 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 123 | """ 124 | Update a pool. 125 | 126 | Args: 127 | pool_name: The pool name. 128 | slots: The number of slots. 129 | description: The pool description. 130 | include_deferred: Whether to include deferred tasks in slot calculations. 131 | 132 | Returns: 133 | The updated pool details. 134 | """ 135 | pool = Pool() 136 | 137 | if slots is not None: 138 | pool.slots = slots 139 | 140 | if description is not None: 141 | pool.description = description 142 | 143 | if include_deferred is not None: 144 | pool.include_deferred = include_deferred 145 | 146 | response = pool_api.patch_pool(pool_name=pool_name, pool=pool) 147 | return [types.TextContent(type="text", text=str(response.to_dict()))] 148 | ``` -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | 3 | import click 4 | from fastmcp.tools import Tool 5 | 6 | from src.airflow.config import get_all_functions as get_config_functions 7 | from src.airflow.connection import get_all_functions as get_connection_functions 8 | from src.airflow.dag import get_all_functions as get_dag_functions 9 | from src.airflow.dagrun import get_all_functions as get_dagrun_functions 10 | from src.airflow.dagstats import get_all_functions as get_dagstats_functions 11 | from src.airflow.dataset import get_all_functions as get_dataset_functions 12 | from src.airflow.eventlog import get_all_functions as get_eventlog_functions 13 | from src.airflow.importerror import get_all_functions as get_importerror_functions 14 | from src.airflow.monitoring import get_all_functions as get_monitoring_functions 15 | from src.airflow.plugin import get_all_functions as get_plugin_functions 16 | from src.airflow.pool import get_all_functions as get_pool_functions 17 | from src.airflow.provider import get_all_functions as get_provider_functions 18 | from src.airflow.taskinstance import get_all_functions as get_taskinstance_functions 19 | from src.airflow.variable import get_all_functions as get_variable_functions 20 | from src.airflow.xcom import get_all_functions as get_xcom_functions 21 | from src.enums import APIType 22 | from src.envs import READ_ONLY 23 | 24 | APITYPE_TO_FUNCTIONS = { 25 | APIType.CONFIG: get_config_functions, 26 | APIType.CONNECTION: get_connection_functions, 27 | APIType.DAG: get_dag_functions, 28 | APIType.DAGRUN: get_dagrun_functions, 29 | APIType.DAGSTATS: get_dagstats_functions, 30 | APIType.DATASET: get_dataset_functions, 31 | APIType.EVENTLOG: get_eventlog_functions, 32 | APIType.IMPORTERROR: get_importerror_functions, 33 | APIType.MONITORING: get_monitoring_functions, 34 | APIType.PLUGIN: get_plugin_functions, 35 | APIType.POOL: get_pool_functions, 36 | APIType.PROVIDER: get_provider_functions, 37 | APIType.TASKINSTANCE: get_taskinstance_functions, 38 | APIType.VARIABLE: get_variable_functions, 39 | APIType.XCOM: get_xcom_functions, 40 | } 41 | 42 | 43 | def filter_functions_for_read_only(functions: list[tuple]) -> list[tuple]: 44 | """ 45 | Filter functions to only include read-only operations. 46 | 47 | Args: 48 | functions: List of (func, name, description, is_read_only) tuples 49 | 50 | Returns: 51 | List of (func, name, description, is_read_only) tuples with only read-only functions 52 | """ 53 | return [ 54 | (func, name, description, is_read_only) for func, name, description, is_read_only in functions if is_read_only 55 | ] 56 | 57 | 58 | @click.command() 59 | @click.option( 60 | "--transport", 61 | type=click.Choice(["stdio", "sse", "http"]), 62 | default="stdio", 63 | help="Transport type", 64 | ) 65 | @click.option("--mcp-port", default=8000, help="Port to run MCP service in case of SSE or HTTP transports.") 66 | @click.option("--mcp-host", default="0.0.0.0", help="Host to rum MCP srvice in case of SSE or HTTP transports.") 67 | @click.option( 68 | "--apis", 69 | type=click.Choice([api.value for api in APIType]), 70 | default=[api.value for api in APIType], 71 | multiple=True, 72 | help="APIs to run, default is all", 73 | ) 74 | @click.option( 75 | "--read-only", 76 | is_flag=True, 77 | default=READ_ONLY, 78 | help="Only expose read-only tools (GET operations, no CREATE/UPDATE/DELETE)", 79 | ) 80 | def main(transport: str, mcp_host: str, mcp_port: int, apis: list[str], read_only: bool) -> None: 81 | from src.server import app 82 | 83 | for api in apis: 84 | logging.debug(f"Adding API: {api}") 85 | get_function = APITYPE_TO_FUNCTIONS[APIType(api)] 86 | try: 87 | functions = get_function() 88 | except NotImplementedError: 89 | continue 90 | 91 | # Filter functions for read-only mode if requested 92 | if read_only: 93 | functions = filter_functions_for_read_only(functions) 94 | 95 | for func, name, description, *_ in functions: 96 | app.add_tool(Tool.from_function(func, name=name, description=description)) 97 | 98 | logging.debug(f"Starting MCP server for Apache Airflow with {transport} transport") 99 | params_to_run = {} 100 | 101 | if transport in {"sse", "http"}: 102 | if transport == "sse": 103 | logging.warning("NOTE: the SSE transport is going to be deprecated.") 104 | 105 | params_to_run = {"port": int(mcp_port), "host": mcp_host} 106 | 107 | app.run(transport=transport, **params_to_run) 108 | ``` -------------------------------------------------------------------------------- /src/airflow/connection.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.connection_api import ConnectionApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | connection_api = ConnectionApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (list_connections, "list_connections", "List all connections", True), 15 | (create_connection, "create_connection", "Create a connection", False), 16 | (get_connection, "get_connection", "Get a connection by ID", True), 17 | (update_connection, "update_connection", "Update a connection by ID", False), 18 | (delete_connection, "delete_connection", "Delete a connection by ID", False), 19 | (test_connection, "test_connection", "Test a connection", True), 20 | ] 21 | 22 | 23 | async def list_connections( 24 | limit: Optional[int] = None, 25 | offset: Optional[int] = None, 26 | order_by: Optional[str] = None, 27 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 28 | # Build parameters dictionary 29 | kwargs: Dict[str, Any] = {} 30 | if limit is not None: 31 | kwargs["limit"] = limit 32 | if offset is not None: 33 | kwargs["offset"] = offset 34 | if order_by is not None: 35 | kwargs["order_by"] = order_by 36 | 37 | response = connection_api.get_connections(**kwargs) 38 | return [types.TextContent(type="text", text=str(response.to_dict()))] 39 | 40 | 41 | async def create_connection( 42 | conn_id: str, 43 | conn_type: str, 44 | host: Optional[str] = None, 45 | port: Optional[int] = None, 46 | login: Optional[str] = None, 47 | password: Optional[str] = None, 48 | schema: Optional[str] = None, 49 | extra: Optional[str] = None, 50 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 51 | connection_request = { 52 | "connection_id": conn_id, 53 | "conn_type": conn_type, 54 | } 55 | if host is not None: 56 | connection_request["host"] = host 57 | if port is not None: 58 | connection_request["port"] = port 59 | if login is not None: 60 | connection_request["login"] = login 61 | if password is not None: 62 | connection_request["password"] = password 63 | if schema is not None: 64 | connection_request["schema"] = schema 65 | if extra is not None: 66 | connection_request["extra"] = extra 67 | 68 | response = connection_api.post_connection(connection_request=connection_request) 69 | return [types.TextContent(type="text", text=str(response.to_dict()))] 70 | 71 | 72 | async def get_connection(conn_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 73 | response = connection_api.get_connection(connection_id=conn_id) 74 | return [types.TextContent(type="text", text=str(response.to_dict()))] 75 | 76 | 77 | async def update_connection( 78 | conn_id: str, 79 | conn_type: Optional[str] = None, 80 | host: Optional[str] = None, 81 | port: Optional[int] = None, 82 | login: Optional[str] = None, 83 | password: Optional[str] = None, 84 | schema: Optional[str] = None, 85 | extra: Optional[str] = None, 86 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 87 | update_request = {} 88 | if conn_type is not None: 89 | update_request["conn_type"] = conn_type 90 | if host is not None: 91 | update_request["host"] = host 92 | if port is not None: 93 | update_request["port"] = port 94 | if login is not None: 95 | update_request["login"] = login 96 | if password is not None: 97 | update_request["password"] = password 98 | if schema is not None: 99 | update_request["schema"] = schema 100 | if extra is not None: 101 | update_request["extra"] = extra 102 | 103 | response = connection_api.patch_connection( 104 | connection_id=conn_id, update_mask=list(update_request.keys()), connection_request=update_request 105 | ) 106 | return [types.TextContent(type="text", text=str(response.to_dict()))] 107 | 108 | 109 | async def delete_connection(conn_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 110 | response = connection_api.delete_connection(connection_id=conn_id) 111 | return [types.TextContent(type="text", text=str(response.to_dict()))] 112 | 113 | 114 | async def test_connection( 115 | conn_type: str, 116 | host: Optional[str] = None, 117 | port: Optional[int] = None, 118 | login: Optional[str] = None, 119 | password: Optional[str] = None, 120 | schema: Optional[str] = None, 121 | extra: Optional[str] = None, 122 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 123 | connection_request = { 124 | "conn_type": conn_type, 125 | } 126 | if host is not None: 127 | connection_request["host"] = host 128 | if port is not None: 129 | connection_request["port"] = port 130 | if login is not None: 131 | connection_request["login"] = login 132 | if password is not None: 133 | connection_request["password"] = password 134 | if schema is not None: 135 | connection_request["schema"] = schema 136 | if extra is not None: 137 | connection_request["extra"] = extra 138 | 139 | response = connection_api.test_connection(connection_request=connection_request) 140 | return [types.TextContent(type="text", text=str(response.to_dict()))] 141 | ``` -------------------------------------------------------------------------------- /src/airflow/taskinstance.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.task_instance_api import TaskInstanceApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | task_instance_api = TaskInstanceApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (get_task_instance, "get_task_instance", "Get a task instance by DAG ID, task ID, and DAG run ID", True), 15 | (list_task_instances, "list_task_instances", "List task instances by DAG ID and DAG run ID", True), 16 | ( 17 | update_task_instance, 18 | "update_task_instance", 19 | "Update a task instance by DAG ID, DAG run ID, and task ID", 20 | False, 21 | ), 22 | ( 23 | get_log, 24 | "get_log", 25 | "Get the log from a task instance by DAG ID, task ID, DAG run ID and task try number", 26 | True, 27 | ), 28 | ( 29 | list_task_instance_tries, 30 | "list_task_instance_tries", 31 | "List task instance tries by DAG ID, DAG run ID, and task ID", 32 | True, 33 | ), 34 | ] 35 | 36 | 37 | async def get_task_instance( 38 | dag_id: str, task_id: str, dag_run_id: str 39 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 40 | response = task_instance_api.get_task_instance(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id) 41 | return [types.TextContent(type="text", text=str(response.to_dict()))] 42 | 43 | 44 | async def list_task_instances( 45 | dag_id: str, 46 | dag_run_id: str, 47 | execution_date_gte: Optional[str] = None, 48 | execution_date_lte: Optional[str] = None, 49 | start_date_gte: Optional[str] = None, 50 | start_date_lte: Optional[str] = None, 51 | end_date_gte: Optional[str] = None, 52 | end_date_lte: Optional[str] = None, 53 | updated_at_gte: Optional[str] = None, 54 | updated_at_lte: Optional[str] = None, 55 | duration_gte: Optional[float] = None, 56 | duration_lte: Optional[float] = None, 57 | state: Optional[List[str]] = None, 58 | pool: Optional[List[str]] = None, 59 | queue: Optional[List[str]] = None, 60 | limit: Optional[int] = None, 61 | offset: Optional[int] = None, 62 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 63 | # Build parameters dictionary 64 | kwargs: Dict[str, Any] = {} 65 | if execution_date_gte is not None: 66 | kwargs["execution_date_gte"] = execution_date_gte 67 | if execution_date_lte is not None: 68 | kwargs["execution_date_lte"] = execution_date_lte 69 | if start_date_gte is not None: 70 | kwargs["start_date_gte"] = start_date_gte 71 | if start_date_lte is not None: 72 | kwargs["start_date_lte"] = start_date_lte 73 | if end_date_gte is not None: 74 | kwargs["end_date_gte"] = end_date_gte 75 | if end_date_lte is not None: 76 | kwargs["end_date_lte"] = end_date_lte 77 | if updated_at_gte is not None: 78 | kwargs["updated_at_gte"] = updated_at_gte 79 | if updated_at_lte is not None: 80 | kwargs["updated_at_lte"] = updated_at_lte 81 | if duration_gte is not None: 82 | kwargs["duration_gte"] = duration_gte 83 | if duration_lte is not None: 84 | kwargs["duration_lte"] = duration_lte 85 | if state is not None: 86 | kwargs["state"] = state 87 | if pool is not None: 88 | kwargs["pool"] = pool 89 | if queue is not None: 90 | kwargs["queue"] = queue 91 | if limit is not None: 92 | kwargs["limit"] = limit 93 | if offset is not None: 94 | kwargs["offset"] = offset 95 | 96 | response = task_instance_api.get_task_instances(dag_id=dag_id, dag_run_id=dag_run_id, **kwargs) 97 | return [types.TextContent(type="text", text=str(response.to_dict()))] 98 | 99 | 100 | async def update_task_instance( 101 | dag_id: str, dag_run_id: str, task_id: str, state: Optional[str] = None 102 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 103 | update_request = {} 104 | if state is not None: 105 | update_request["state"] = state 106 | 107 | response = task_instance_api.patch_task_instance( 108 | dag_id=dag_id, 109 | dag_run_id=dag_run_id, 110 | task_id=task_id, 111 | update_mask=list(update_request.keys()), 112 | task_instance_request=update_request, 113 | ) 114 | return [types.TextContent(type="text", text=str(response.to_dict()))] 115 | 116 | 117 | async def get_log( 118 | dag_id: str, task_id: str, dag_run_id: str, task_try_number: int 119 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 120 | response = task_instance_api.get_log( 121 | dag_id=dag_id, 122 | dag_run_id=dag_run_id, 123 | task_id=task_id, 124 | task_try_number=task_try_number, 125 | ) 126 | return [types.TextContent(type="text", text=str(response.to_dict()))] 127 | 128 | 129 | async def list_task_instance_tries( 130 | dag_id: str, 131 | dag_run_id: str, 132 | task_id: str, 133 | limit: Optional[int] = None, 134 | offset: Optional[int] = None, 135 | order_by: Optional[str] = None, 136 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 137 | # Build parameters dictionary 138 | kwargs: Dict[str, Any] = {} 139 | if limit is not None: 140 | kwargs["limit"] = limit 141 | if offset is not None: 142 | kwargs["offset"] = offset 143 | if order_by is not None: 144 | kwargs["order_by"] = order_by 145 | 146 | response = task_instance_api.get_task_instance_tries( 147 | dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, **kwargs 148 | ) 149 | return [types.TextContent(type="text", text=str(response.to_dict()))] 150 | ``` -------------------------------------------------------------------------------- /src/airflow/dataset.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.dataset_api import DatasetApi 5 | 6 | from src.airflow.airflow_client import api_client 7 | 8 | dataset_api = DatasetApi(api_client) 9 | 10 | 11 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 12 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 13 | return [ 14 | (get_datasets, "get_datasets", "List datasets", True), 15 | (get_dataset, "get_dataset", "Get a dataset by URI", True), 16 | (get_dataset_events, "get_dataset_events", "Get dataset events", True), 17 | (create_dataset_event, "create_dataset_event", "Create dataset event", False), 18 | (get_dag_dataset_queued_event, "get_dag_dataset_queued_event", "Get a queued Dataset event for a DAG", True), 19 | (get_dag_dataset_queued_events, "get_dag_dataset_queued_events", "Get queued Dataset events for a DAG", True), 20 | ( 21 | delete_dag_dataset_queued_event, 22 | "delete_dag_dataset_queued_event", 23 | "Delete a queued Dataset event for a DAG", 24 | False, 25 | ), 26 | ( 27 | delete_dag_dataset_queued_events, 28 | "delete_dag_dataset_queued_events", 29 | "Delete queued Dataset events for a DAG", 30 | False, 31 | ), 32 | (get_dataset_queued_events, "get_dataset_queued_events", "Get queued Dataset events for a Dataset", True), 33 | ( 34 | delete_dataset_queued_events, 35 | "delete_dataset_queued_events", 36 | "Delete queued Dataset events for a Dataset", 37 | False, 38 | ), 39 | ] 40 | 41 | 42 | async def get_datasets( 43 | limit: Optional[int] = None, 44 | offset: Optional[int] = None, 45 | order_by: Optional[str] = None, 46 | uri_pattern: Optional[str] = None, 47 | dag_ids: Optional[str] = None, 48 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 49 | # Build parameters dictionary 50 | kwargs: Dict[str, Any] = {} 51 | if limit is not None: 52 | kwargs["limit"] = limit 53 | if offset is not None: 54 | kwargs["offset"] = offset 55 | if order_by is not None: 56 | kwargs["order_by"] = order_by 57 | if uri_pattern is not None: 58 | kwargs["uri_pattern"] = uri_pattern 59 | if dag_ids is not None: 60 | kwargs["dag_ids"] = dag_ids 61 | 62 | response = dataset_api.get_datasets(**kwargs) 63 | return [types.TextContent(type="text", text=str(response.to_dict()))] 64 | 65 | 66 | async def get_dataset( 67 | uri: str, 68 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 69 | response = dataset_api.get_dataset(uri=uri) 70 | return [types.TextContent(type="text", text=str(response.to_dict()))] 71 | 72 | 73 | async def get_dataset_events( 74 | limit: Optional[int] = None, 75 | offset: Optional[int] = None, 76 | order_by: Optional[str] = None, 77 | dataset_id: Optional[int] = None, 78 | source_dag_id: Optional[str] = None, 79 | source_task_id: Optional[str] = None, 80 | source_run_id: Optional[str] = None, 81 | source_map_index: Optional[int] = None, 82 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 83 | # Build parameters dictionary 84 | kwargs: Dict[str, Any] = {} 85 | if limit is not None: 86 | kwargs["limit"] = limit 87 | if offset is not None: 88 | kwargs["offset"] = offset 89 | if order_by is not None: 90 | kwargs["order_by"] = order_by 91 | if dataset_id is not None: 92 | kwargs["dataset_id"] = dataset_id 93 | if source_dag_id is not None: 94 | kwargs["source_dag_id"] = source_dag_id 95 | if source_task_id is not None: 96 | kwargs["source_task_id"] = source_task_id 97 | if source_run_id is not None: 98 | kwargs["source_run_id"] = source_run_id 99 | if source_map_index is not None: 100 | kwargs["source_map_index"] = source_map_index 101 | 102 | response = dataset_api.get_dataset_events(**kwargs) 103 | return [types.TextContent(type="text", text=str(response.to_dict()))] 104 | 105 | 106 | async def create_dataset_event( 107 | dataset_uri: str, 108 | extra: Optional[Dict[str, Any]] = None, 109 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 110 | event_request = { 111 | "dataset_uri": dataset_uri, 112 | } 113 | if extra is not None: 114 | event_request["extra"] = extra 115 | 116 | response = dataset_api.create_dataset_event(create_dataset_event=event_request) 117 | return [types.TextContent(type="text", text=str(response.to_dict()))] 118 | 119 | 120 | async def get_dag_dataset_queued_event( 121 | dag_id: str, 122 | uri: str, 123 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 124 | response = dataset_api.get_dag_dataset_queued_event(dag_id=dag_id, uri=uri) 125 | return [types.TextContent(type="text", text=str(response.to_dict()))] 126 | 127 | 128 | async def get_dag_dataset_queued_events( 129 | dag_id: str, 130 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 131 | response = dataset_api.get_dag_dataset_queued_events(dag_id=dag_id) 132 | return [types.TextContent(type="text", text=str(response.to_dict()))] 133 | 134 | 135 | async def delete_dag_dataset_queued_event( 136 | dag_id: str, 137 | uri: str, 138 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 139 | response = dataset_api.delete_dag_dataset_queued_event(dag_id=dag_id, uri=uri) 140 | return [types.TextContent(type="text", text=str(response.to_dict()))] 141 | 142 | 143 | async def delete_dag_dataset_queued_events( 144 | dag_id: str, 145 | before: Optional[str] = None, 146 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 147 | kwargs: Dict[str, Any] = {} 148 | if before is not None: 149 | kwargs["before"] = before 150 | 151 | response = dataset_api.delete_dag_dataset_queued_events(dag_id=dag_id, **kwargs) 152 | return [types.TextContent(type="text", text=str(response.to_dict()))] 153 | 154 | 155 | async def get_dataset_queued_events( 156 | uri: str, 157 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 158 | response = dataset_api.get_dataset_queued_events(uri=uri) 159 | return [types.TextContent(type="text", text=str(response.to_dict()))] 160 | 161 | 162 | async def delete_dataset_queued_events( 163 | uri: str, 164 | before: Optional[str] = None, 165 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 166 | kwargs: Dict[str, Any] = {} 167 | if before is not None: 168 | kwargs["before"] = before 169 | 170 | response = dataset_api.delete_dataset_queued_events(uri=uri, **kwargs) 171 | return [types.TextContent(type="text", text=str(response.to_dict()))] 172 | ``` -------------------------------------------------------------------------------- /test/test_airflow_client.py: -------------------------------------------------------------------------------- ```python 1 | """Tests for the airflow client authentication module.""" 2 | 3 | import os 4 | import sys 5 | from unittest.mock import patch 6 | 7 | from airflow_client.client import ApiClient 8 | 9 | 10 | class TestAirflowClientAuthentication: 11 | """Test cases for airflow client authentication configuration.""" 12 | 13 | def test_basic_auth_configuration(self): 14 | """Test that basic authentication is configured correctly.""" 15 | with patch.dict( 16 | os.environ, 17 | { 18 | "AIRFLOW_HOST": "http://localhost:8080", 19 | "AIRFLOW_USERNAME": "testuser", 20 | "AIRFLOW_PASSWORD": "testpass", 21 | "AIRFLOW_API_VERSION": "v1", 22 | }, 23 | clear=True, 24 | ): 25 | # Clear any cached modules 26 | modules_to_clear = ["src.envs", "src.airflow.airflow_client"] 27 | for module in modules_to_clear: 28 | if module in sys.modules: 29 | del sys.modules[module] 30 | 31 | # Re-import after setting environment 32 | from src.airflow.airflow_client import api_client, configuration 33 | 34 | # Verify configuration 35 | assert configuration.host == "http://localhost:8080/api/v1" 36 | assert configuration.username == "testuser" 37 | assert configuration.password == "testpass" 38 | assert isinstance(api_client, ApiClient) 39 | 40 | def test_jwt_token_auth_configuration(self): 41 | """Test that JWT token authentication is configured correctly.""" 42 | with patch.dict( 43 | os.environ, 44 | { 45 | "AIRFLOW_HOST": "http://localhost:8080", 46 | "AIRFLOW_JWT_TOKEN": "test.jwt.token", 47 | "AIRFLOW_API_VERSION": "v1", 48 | }, 49 | clear=True, 50 | ): 51 | # Clear any cached modules 52 | modules_to_clear = ["src.envs", "src.airflow.airflow_client"] 53 | for module in modules_to_clear: 54 | if module in sys.modules: 55 | del sys.modules[module] 56 | 57 | # Re-import after setting environment 58 | from src.airflow.airflow_client import api_client, configuration 59 | 60 | # Verify configuration 61 | assert configuration.host == "http://localhost:8080/api/v1" 62 | assert configuration.api_key == {"Authorization": "Bearer test.jwt.token"} 63 | assert configuration.api_key_prefix == {"Authorization": ""} 64 | assert isinstance(api_client, ApiClient) 65 | 66 | def test_jwt_token_takes_precedence_over_basic_auth(self): 67 | """Test that JWT token takes precedence when both auth methods are provided.""" 68 | with patch.dict( 69 | os.environ, 70 | { 71 | "AIRFLOW_HOST": "http://localhost:8080", 72 | "AIRFLOW_USERNAME": "testuser", 73 | "AIRFLOW_PASSWORD": "testpass", 74 | "AIRFLOW_JWT_TOKEN": "test.jwt.token", 75 | "AIRFLOW_API_VERSION": "v1", 76 | }, 77 | clear=True, 78 | ): 79 | # Clear any cached modules 80 | modules_to_clear = ["src.envs", "src.airflow.airflow_client"] 81 | for module in modules_to_clear: 82 | if module in sys.modules: 83 | del sys.modules[module] 84 | 85 | # Re-import after setting environment 86 | from src.airflow.airflow_client import api_client, configuration 87 | 88 | # Verify JWT token is used (not basic auth) 89 | assert configuration.host == "http://localhost:8080/api/v1" 90 | assert configuration.api_key == {"Authorization": "Bearer test.jwt.token"} 91 | assert configuration.api_key_prefix == {"Authorization": ""} 92 | # Basic auth should not be set when JWT is present 93 | assert not hasattr(configuration, "username") or configuration.username is None 94 | assert not hasattr(configuration, "password") or configuration.password is None 95 | assert isinstance(api_client, ApiClient) 96 | 97 | def test_no_auth_configuration(self): 98 | """Test that configuration works with no authentication (for testing/development).""" 99 | with patch.dict(os.environ, {"AIRFLOW_HOST": "http://localhost:8080", "AIRFLOW_API_VERSION": "v1"}, clear=True): 100 | # Clear any cached modules 101 | modules_to_clear = ["src.envs", "src.airflow.airflow_client"] 102 | for module in modules_to_clear: 103 | if module in sys.modules: 104 | del sys.modules[module] 105 | 106 | # Re-import after setting environment 107 | from src.airflow.airflow_client import api_client, configuration 108 | 109 | # Verify configuration 110 | assert configuration.host == "http://localhost:8080/api/v1" 111 | # No auth should be set 112 | assert not hasattr(configuration, "username") or configuration.username is None 113 | assert not hasattr(configuration, "password") or configuration.password is None 114 | # api_key might be an empty dict by default, but should not have Authorization 115 | assert "Authorization" not in getattr(configuration, "api_key", {}) 116 | assert isinstance(api_client, ApiClient) 117 | 118 | def test_environment_variable_parsing(self): 119 | """Test that environment variables are parsed correctly.""" 120 | with patch.dict( 121 | os.environ, 122 | { 123 | "AIRFLOW_HOST": "https://airflow.example.com:8080/custom", 124 | "AIRFLOW_JWT_TOKEN": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...", 125 | "AIRFLOW_API_VERSION": "v2", 126 | }, 127 | clear=True, 128 | ): 129 | # Clear any cached modules 130 | modules_to_clear = ["src.envs", "src.airflow.airflow_client"] 131 | for module in modules_to_clear: 132 | if module in sys.modules: 133 | del sys.modules[module] 134 | 135 | # Re-import after setting environment 136 | from src.airflow.airflow_client import configuration 137 | from src.envs import AIRFLOW_API_VERSION, AIRFLOW_HOST, AIRFLOW_JWT_TOKEN 138 | 139 | # Verify environment variables are parsed correctly 140 | assert AIRFLOW_HOST == "https://airflow.example.com:8080" 141 | assert AIRFLOW_JWT_TOKEN == "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9..." 142 | assert AIRFLOW_API_VERSION == "v2" 143 | 144 | # Verify configuration uses parsed values 145 | assert configuration.host == "https://airflow.example.com:8080/api/v2" 146 | assert configuration.api_key == {"Authorization": "Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9..."} 147 | assert configuration.api_key_prefix == {"Authorization": ""} 148 | ``` -------------------------------------------------------------------------------- /src/airflow/dagrun.py: -------------------------------------------------------------------------------- ```python 1 | from datetime import datetime 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | 4 | import mcp.types as types 5 | from airflow_client.client.api.dag_run_api import DAGRunApi 6 | from airflow_client.client.model.clear_dag_run import ClearDagRun 7 | from airflow_client.client.model.dag_run import DAGRun 8 | from airflow_client.client.model.set_dag_run_note import SetDagRunNote 9 | from airflow_client.client.model.update_dag_run_state import UpdateDagRunState 10 | 11 | from src.airflow.airflow_client import api_client 12 | from src.envs import AIRFLOW_HOST 13 | 14 | dag_run_api = DAGRunApi(api_client) 15 | 16 | 17 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 18 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 19 | return [ 20 | (post_dag_run, "post_dag_run", "Trigger a DAG by ID", False), 21 | (get_dag_runs, "get_dag_runs", "Get DAG runs by ID", True), 22 | (get_dag_runs_batch, "get_dag_runs_batch", "List DAG runs (batch)", True), 23 | (get_dag_run, "get_dag_run", "Get a DAG run by DAG ID and DAG run ID", True), 24 | (update_dag_run_state, "update_dag_run_state", "Update a DAG run state by DAG ID and DAG run ID", False), 25 | (delete_dag_run, "delete_dag_run", "Delete a DAG run by DAG ID and DAG run ID", False), 26 | (clear_dag_run, "clear_dag_run", "Clear a DAG run", False), 27 | (set_dag_run_note, "set_dag_run_note", "Update the DagRun note", False), 28 | (get_upstream_dataset_events, "get_upstream_dataset_events", "Get dataset events for a DAG run", True), 29 | ] 30 | 31 | 32 | def get_dag_run_url(dag_id: str, dag_run_id: str) -> str: 33 | return f"{AIRFLOW_HOST}/dags/{dag_id}/grid?dag_run_id={dag_run_id}" 34 | 35 | 36 | async def post_dag_run( 37 | dag_id: str, 38 | dag_run_id: Optional[str] = None, 39 | data_interval_end: Optional[datetime] = None, 40 | data_interval_start: Optional[datetime] = None, 41 | execution_date: Optional[datetime] = None, 42 | logical_date: Optional[datetime] = None, 43 | note: Optional[str] = None, 44 | # state: Optional[str] = None, # TODO: add state 45 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 46 | # Build kwargs dictionary with only non-None values 47 | kwargs = {} 48 | 49 | # Add non-read-only fields that can be set during creation 50 | if dag_run_id is not None: 51 | kwargs["dag_run_id"] = dag_run_id 52 | if data_interval_end is not None: 53 | kwargs["data_interval_end"] = data_interval_end 54 | if data_interval_start is not None: 55 | kwargs["data_interval_start"] = data_interval_start 56 | if execution_date is not None: 57 | kwargs["execution_date"] = execution_date 58 | if logical_date is not None: 59 | kwargs["logical_date"] = logical_date 60 | if note is not None: 61 | kwargs["note"] = note 62 | 63 | # Create DAGRun without read-only fields 64 | dag_run = DAGRun(**kwargs) 65 | 66 | response = dag_run_api.post_dag_run(dag_id=dag_id, dag_run=dag_run) 67 | return [types.TextContent(type="text", text=str(response.to_dict()))] 68 | 69 | 70 | async def get_dag_runs( 71 | dag_id: str, 72 | limit: Optional[int] = None, 73 | offset: Optional[int] = None, 74 | execution_date_gte: Optional[str] = None, 75 | execution_date_lte: Optional[str] = None, 76 | start_date_gte: Optional[str] = None, 77 | start_date_lte: Optional[str] = None, 78 | end_date_gte: Optional[str] = None, 79 | end_date_lte: Optional[str] = None, 80 | updated_at_gte: Optional[str] = None, 81 | updated_at_lte: Optional[str] = None, 82 | state: Optional[List[str]] = None, 83 | order_by: Optional[str] = None, 84 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 85 | # Build parameters dictionary 86 | kwargs: Dict[str, Any] = {} 87 | if limit is not None: 88 | kwargs["limit"] = limit 89 | if offset is not None: 90 | kwargs["offset"] = offset 91 | if execution_date_gte is not None: 92 | kwargs["execution_date_gte"] = execution_date_gte 93 | if execution_date_lte is not None: 94 | kwargs["execution_date_lte"] = execution_date_lte 95 | if start_date_gte is not None: 96 | kwargs["start_date_gte"] = start_date_gte 97 | if start_date_lte is not None: 98 | kwargs["start_date_lte"] = start_date_lte 99 | if end_date_gte is not None: 100 | kwargs["end_date_gte"] = end_date_gte 101 | if end_date_lte is not None: 102 | kwargs["end_date_lte"] = end_date_lte 103 | if updated_at_gte is not None: 104 | kwargs["updated_at_gte"] = updated_at_gte 105 | if updated_at_lte is not None: 106 | kwargs["updated_at_lte"] = updated_at_lte 107 | if state is not None: 108 | kwargs["state"] = state 109 | if order_by is not None: 110 | kwargs["order_by"] = order_by 111 | 112 | response = dag_run_api.get_dag_runs(dag_id=dag_id, **kwargs) 113 | 114 | # Convert response to dictionary for easier manipulation 115 | response_dict = response.to_dict() 116 | 117 | # Add UI links to each DAG run 118 | for dag_run in response_dict.get("dag_runs", []): 119 | dag_run["ui_url"] = get_dag_run_url(dag_id, dag_run["dag_run_id"]) 120 | 121 | return [types.TextContent(type="text", text=str(response_dict))] 122 | 123 | 124 | async def get_dag_runs_batch( 125 | dag_ids: Optional[List[str]] = None, 126 | execution_date_gte: Optional[str] = None, 127 | execution_date_lte: Optional[str] = None, 128 | start_date_gte: Optional[str] = None, 129 | start_date_lte: Optional[str] = None, 130 | end_date_gte: Optional[str] = None, 131 | end_date_lte: Optional[str] = None, 132 | state: Optional[List[str]] = None, 133 | order_by: Optional[str] = None, 134 | page_offset: Optional[int] = None, 135 | page_limit: Optional[int] = None, 136 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 137 | # Build request dictionary 138 | request: Dict[str, Any] = {} 139 | if dag_ids is not None: 140 | request["dag_ids"] = dag_ids 141 | if execution_date_gte is not None: 142 | request["execution_date_gte"] = execution_date_gte 143 | if execution_date_lte is not None: 144 | request["execution_date_lte"] = execution_date_lte 145 | if start_date_gte is not None: 146 | request["start_date_gte"] = start_date_gte 147 | if start_date_lte is not None: 148 | request["start_date_lte"] = start_date_lte 149 | if end_date_gte is not None: 150 | request["end_date_gte"] = end_date_gte 151 | if end_date_lte is not None: 152 | request["end_date_lte"] = end_date_lte 153 | if state is not None: 154 | request["state"] = state 155 | if order_by is not None: 156 | request["order_by"] = order_by 157 | if page_offset is not None: 158 | request["page_offset"] = page_offset 159 | if page_limit is not None: 160 | request["page_limit"] = page_limit 161 | 162 | response = dag_run_api.get_dag_runs_batch(list_dag_runs_form=request) 163 | 164 | # Convert response to dictionary for easier manipulation 165 | response_dict = response.to_dict() 166 | 167 | # Add UI links to each DAG run 168 | for dag_run in response_dict.get("dag_runs", []): 169 | dag_run["ui_url"] = get_dag_run_url(dag_run["dag_id"], dag_run["dag_run_id"]) 170 | 171 | return [types.TextContent(type="text", text=str(response_dict))] 172 | 173 | 174 | async def get_dag_run( 175 | dag_id: str, dag_run_id: str 176 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 177 | response = dag_run_api.get_dag_run(dag_id=dag_id, dag_run_id=dag_run_id) 178 | 179 | # Convert response to dictionary for easier manipulation 180 | response_dict = response.to_dict() 181 | 182 | # Add UI link to DAG run 183 | response_dict["ui_url"] = get_dag_run_url(dag_id, dag_run_id) 184 | 185 | return [types.TextContent(type="text", text=str(response_dict))] 186 | 187 | 188 | async def update_dag_run_state( 189 | dag_id: str, dag_run_id: str, state: Optional[str] = None 190 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 191 | update_dag_run_state = UpdateDagRunState(state=state) 192 | response = dag_run_api.update_dag_run_state( 193 | dag_id=dag_id, 194 | dag_run_id=dag_run_id, 195 | update_dag_run_state=update_dag_run_state, 196 | ) 197 | return [types.TextContent(type="text", text=str(response.to_dict()))] 198 | 199 | 200 | async def delete_dag_run( 201 | dag_id: str, dag_run_id: str 202 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 203 | response = dag_run_api.delete_dag_run(dag_id=dag_id, dag_run_id=dag_run_id) 204 | return [types.TextContent(type="text", text=str(response.to_dict()))] 205 | 206 | 207 | async def clear_dag_run( 208 | dag_id: str, dag_run_id: str, dry_run: Optional[bool] = None 209 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 210 | clear_dag_run = ClearDagRun(dry_run=dry_run) 211 | response = dag_run_api.clear_dag_run(dag_id=dag_id, dag_run_id=dag_run_id, clear_dag_run=clear_dag_run) 212 | return [types.TextContent(type="text", text=str(response.to_dict()))] 213 | 214 | 215 | async def set_dag_run_note( 216 | dag_id: str, dag_run_id: str, note: str 217 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 218 | set_dag_run_note = SetDagRunNote(note=note) 219 | response = dag_run_api.set_dag_run_note(dag_id=dag_id, dag_run_id=dag_run_id, set_dag_run_note=set_dag_run_note) 220 | return [types.TextContent(type="text", text=str(response.to_dict()))] 221 | 222 | 223 | async def get_upstream_dataset_events( 224 | dag_id: str, dag_run_id: str 225 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 226 | response = dag_run_api.get_upstream_dataset_events(dag_id=dag_id, dag_run_id=dag_run_id) 227 | return [types.TextContent(type="text", text=str(response.to_dict()))] 228 | ``` -------------------------------------------------------------------------------- /test/airflow/test_taskinstance.py: -------------------------------------------------------------------------------- ```python 1 | """Unit tests for taskinstance module using pytest framework.""" 2 | 3 | from unittest.mock import MagicMock, patch 4 | 5 | import mcp.types as types 6 | import pytest 7 | 8 | from src.airflow.taskinstance import ( 9 | get_task_instance, 10 | list_task_instance_tries, 11 | list_task_instances, 12 | update_task_instance, 13 | ) 14 | 15 | 16 | class TestTaskInstanceModule: 17 | """ 18 | Test suite for verifying the behavior of taskinstance module's functions. 19 | 20 | Covers: 21 | - get_task_instance 22 | - list_task_instances 23 | - update_task_instance 24 | - list_task_instance_tries 25 | 26 | Each test uses parameterization to exercise a range of valid inputs and asserts: 27 | - Correct structure and content of the returned TextContent 28 | - Proper use of optional parameters 29 | - That the underlying API client methods are invoked with the right arguments 30 | """ 31 | 32 | @pytest.mark.asyncio 33 | @pytest.mark.parametrize( 34 | "dag_id, task_id, dag_run_id, expected_state", 35 | [ 36 | ("dag_1", "task_a", "run_001", "success"), 37 | ("dag_2", "task_b", "run_002", "failed"), 38 | ("dag_3", "task_c", "run_003", "running"), 39 | ], 40 | ids=[ 41 | "success-task-dag_1", 42 | "failed-task-dag_2", 43 | "running-task-dag_3", 44 | ], 45 | ) 46 | async def test_get_task_instance(self, dag_id, task_id, dag_run_id, expected_state): 47 | """ 48 | Test `get_task_instance` returns correct TextContent output and calls API once 49 | for different task states. 50 | """ 51 | mock_response = MagicMock() 52 | mock_response.to_dict.return_value = { 53 | "dag_id": dag_id, 54 | "task_id": task_id, 55 | "dag_run_id": dag_run_id, 56 | "state": expected_state, 57 | } 58 | 59 | with patch( 60 | "src.airflow.taskinstance.task_instance_api.get_task_instance", 61 | return_value=mock_response, 62 | ) as mock_get: 63 | result = await get_task_instance(dag_id=dag_id, task_id=task_id, dag_run_id=dag_run_id) 64 | 65 | assert isinstance(result, list) 66 | assert len(result) == 1 67 | content = result[0] 68 | assert isinstance(content, types.TextContent) 69 | assert content.type == "text" 70 | assert dag_id in content.text 71 | assert task_id in content.text 72 | assert dag_run_id in content.text 73 | assert expected_state in content.text 74 | 75 | mock_get.assert_called_once_with(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id) 76 | 77 | @pytest.mark.asyncio 78 | @pytest.mark.parametrize( 79 | "params", 80 | [ 81 | {"dag_id": "dag_basic", "dag_run_id": "run_basic"}, 82 | { 83 | "dag_id": "dag_with_state", 84 | "dag_run_id": "run_with_state", 85 | "state": ["success", "failed"], 86 | }, 87 | { 88 | "dag_id": "dag_with_dates", 89 | "dag_run_id": "run_with_dates", 90 | "start_date_gte": "2024-01-01T00:00:00Z", 91 | "end_date_lte": "2024-01-10T23:59:59Z", 92 | }, 93 | { 94 | "dag_id": "dag_with_filters", 95 | "dag_run_id": "run_filters", 96 | "pool": ["default_pool"], 97 | "queue": ["default"], 98 | "limit": 5, 99 | "offset": 10, 100 | "duration_gte": 5.0, 101 | "duration_lte": 100.5, 102 | }, 103 | { 104 | "dag_id": "dag_with_all", 105 | "dag_run_id": "run_all", 106 | "execution_date_gte": "2024-01-01T00:00:00Z", 107 | "execution_date_lte": "2024-01-02T00:00:00Z", 108 | "start_date_gte": "2024-01-01T01:00:00Z", 109 | "start_date_lte": "2024-01-01T23:00:00Z", 110 | "end_date_gte": "2024-01-01T02:00:00Z", 111 | "end_date_lte": "2024-01-01T23:59:00Z", 112 | "updated_at_gte": "2024-01-01T03:00:00Z", 113 | "updated_at_lte": "2024-01-01T04:00:00Z", 114 | "duration_gte": 1.0, 115 | "duration_lte": 99.9, 116 | "state": ["queued"], 117 | "pool": ["my_pool"], 118 | "queue": ["my_queue"], 119 | "limit": 50, 120 | "offset": 0, 121 | }, 122 | { 123 | "dag_id": "dag_with_empty_lists", 124 | "dag_run_id": "run_empty_lists", 125 | "state": [], 126 | "pool": [], 127 | "queue": [], 128 | }, 129 | ], 130 | ids=[ 131 | "basic-required-params", 132 | "with-state-filter", 133 | "with-date-range", 134 | "with-resource-filters", 135 | "all-filters-included", 136 | "empty-lists-filter", 137 | ], 138 | ) 139 | async def test_list_task_instances(self, params): 140 | """ 141 | Test `list_task_instances` with various combinations of filters. 142 | Validates output content and verifies API call arguments. 143 | """ 144 | mock_response = MagicMock() 145 | mock_response.to_dict.return_value = { 146 | "dag_id": params["dag_id"], 147 | "dag_run_id": params["dag_run_id"], 148 | "instances": [ 149 | {"task_id": "task_1", "state": "success"}, 150 | {"task_id": "task_2", "state": "running"}, 151 | ], 152 | } 153 | 154 | with patch( 155 | "src.airflow.taskinstance.task_instance_api.get_task_instances", 156 | return_value=mock_response, 157 | ) as mock_get: 158 | result = await list_task_instances(**params) 159 | 160 | assert isinstance(result, list) 161 | assert len(result) == 1 162 | assert isinstance(result[0], types.TextContent) 163 | assert result[0].type == "text" 164 | assert params["dag_id"] in result[0].text 165 | assert params["dag_run_id"] in result[0].text 166 | 167 | mock_get.assert_called_once_with( 168 | dag_id=params["dag_id"], 169 | dag_run_id=params["dag_run_id"], 170 | **{k: v for k, v in params.items() if k not in {"dag_id", "dag_run_id"} and v is not None}, 171 | ) 172 | 173 | @pytest.mark.asyncio 174 | @pytest.mark.parametrize( 175 | "dag_id, dag_run_id, task_id, state", 176 | [ 177 | ("dag_1", "run_001", "task_a", "success"), 178 | ("dag_2", "run_002", "task_b", "failed"), 179 | ("dag_3", "run_003", "task_c", None), 180 | ], 181 | ids=["set-success-state", "set-failed-state", "no-state-update"], 182 | ) 183 | async def test_update_task_instance(self, dag_id, dag_run_id, task_id, state): 184 | """ 185 | Test `update_task_instance` for updating state and validating request payload. 186 | Also verifies that patch API is called with the correct update mask. 187 | """ 188 | mock_response = MagicMock() 189 | mock_response.to_dict.return_value = { 190 | "dag_id": dag_id, 191 | "dag_run_id": dag_run_id, 192 | "task_id": task_id, 193 | "state": state, 194 | } 195 | 196 | with patch( 197 | "src.airflow.taskinstance.task_instance_api.patch_task_instance", 198 | return_value=mock_response, 199 | ) as mock_patch: 200 | result = await update_task_instance(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, state=state) 201 | 202 | assert isinstance(result, list) 203 | assert len(result) == 1 204 | content = result[0] 205 | assert isinstance(content, types.TextContent) 206 | assert content.type == "text" 207 | assert dag_id in content.text 208 | assert dag_run_id in content.text 209 | assert task_id in content.text 210 | if state is not None: 211 | assert state in content.text 212 | 213 | expected_mask = ["state"] if state is not None else [] 214 | expected_request = {"state": state} if state is not None else {} 215 | mock_patch.assert_called_once_with( 216 | dag_id=dag_id, 217 | dag_run_id=dag_run_id, 218 | task_id=task_id, 219 | update_mask=expected_mask, 220 | task_instance_request=expected_request, 221 | ) 222 | 223 | @pytest.mark.asyncio 224 | @pytest.mark.parametrize( 225 | "dag_id, dag_run_id, task_id, limit, offset, order_by", 226 | [ 227 | ("dag_basic", "run_001", "task_a", None, None, None), 228 | ("dag_with_limit", "run_002", "task_b", 5, None, None), 229 | ("dag_with_offset", "run_003", "task_c", None, 10, None), 230 | ("dag_with_order_by", "run_004", "task_d", None, None, "-start_date"), 231 | ("dag_all_params", "run_005", "task_e", 10, 0, "end_date"), 232 | ("dag_zero_limit", "run_006", "task_f", 0, None, None), 233 | ("dag_zero_offset", "run_007", "task_g", None, 0, None), 234 | ("dag_empty_order", "run_008", "task_h", None, None, ""), 235 | ], 236 | ids=[ 237 | "basic-required-only", 238 | "with-limit", 239 | "with-offset", 240 | "with-order_by-desc", 241 | "with-all-filters", 242 | "limit-zero", 243 | "offset-zero", 244 | "order_by-empty-string", 245 | ], 246 | ) 247 | async def test_list_task_instance_tries(self, dag_id, dag_run_id, task_id, limit, offset, order_by): 248 | """ 249 | Test `list_task_instance_tries` across various filter combinations, 250 | validating correct API call and response formatting. 251 | """ 252 | mock_response = MagicMock() 253 | mock_response.to_dict.return_value = { 254 | "dag_id": dag_id, 255 | "dag_run_id": dag_run_id, 256 | "task_id": task_id, 257 | "tries": [ 258 | {"try_number": 1, "state": "queued"}, 259 | {"try_number": 2, "state": "success"}, 260 | ], 261 | } 262 | 263 | with patch( 264 | "src.airflow.taskinstance.task_instance_api.get_task_instance_tries", 265 | return_value=mock_response, 266 | ) as mock_get: 267 | result = await list_task_instance_tries( 268 | dag_id=dag_id, 269 | dag_run_id=dag_run_id, 270 | task_id=task_id, 271 | limit=limit, 272 | offset=offset, 273 | order_by=order_by, 274 | ) 275 | 276 | assert isinstance(result, list) 277 | assert len(result) == 1 278 | content = result[0] 279 | assert isinstance(content, types.TextContent) 280 | assert content.type == "text" 281 | assert dag_id in content.text 282 | assert dag_run_id in content.text 283 | assert task_id in content.text 284 | assert "tries" in content.text 285 | 286 | mock_get.assert_called_once_with( 287 | dag_id=dag_id, 288 | dag_run_id=dag_run_id, 289 | task_id=task_id, 290 | **{ 291 | k: v 292 | for k, v in { 293 | "limit": limit, 294 | "offset": offset, 295 | "order_by": order_by, 296 | }.items() 297 | if v is not None 298 | }, 299 | ) 300 | ``` -------------------------------------------------------------------------------- /src/airflow/dag.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import mcp.types as types 4 | from airflow_client.client.api.dag_api import DAGApi 5 | from airflow_client.client.model.clear_task_instances import ClearTaskInstances 6 | from airflow_client.client.model.dag import DAG 7 | from airflow_client.client.model.update_task_instances_state import UpdateTaskInstancesState 8 | 9 | from src.airflow.airflow_client import api_client 10 | from src.envs import AIRFLOW_HOST 11 | 12 | dag_api = DAGApi(api_client) 13 | 14 | 15 | def get_all_functions() -> list[tuple[Callable, str, str, bool]]: 16 | """Return list of (function, name, description, is_read_only) tuples for registration.""" 17 | return [ 18 | (get_dags, "fetch_dags", "Fetch all DAGs", True), 19 | (get_dag, "get_dag", "Get a DAG by ID", True), 20 | (get_dag_details, "get_dag_details", "Get a simplified representation of DAG", True), 21 | (get_dag_source, "get_dag_source", "Get a source code", True), 22 | (pause_dag, "pause_dag", "Pause a DAG by ID", False), 23 | (unpause_dag, "unpause_dag", "Unpause a DAG by ID", False), 24 | (get_dag_tasks, "get_dag_tasks", "Get tasks for DAG", True), 25 | (get_task, "get_task", "Get a task by ID", True), 26 | (get_tasks, "get_tasks", "Get tasks for DAG", True), 27 | (patch_dag, "patch_dag", "Update a DAG", False), 28 | (patch_dags, "patch_dags", "Update multiple DAGs", False), 29 | (delete_dag, "delete_dag", "Delete a DAG", False), 30 | (clear_task_instances, "clear_task_instances", "Clear a set of task instances", False), 31 | (set_task_instances_state, "set_task_instances_state", "Set a state of task instances", False), 32 | (reparse_dag_file, "reparse_dag_file", "Request re-parsing of a DAG file", False), 33 | ] 34 | 35 | 36 | def get_dag_url(dag_id: str) -> str: 37 | return f"{AIRFLOW_HOST}/dags/{dag_id}/grid" 38 | 39 | 40 | async def get_dags( 41 | limit: Optional[int] = None, 42 | offset: Optional[int] = None, 43 | order_by: Optional[str] = None, 44 | tags: Optional[List[str]] = None, 45 | only_active: Optional[bool] = None, 46 | paused: Optional[bool] = None, 47 | dag_id_pattern: Optional[str] = None, 48 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 49 | # Build parameters dictionary 50 | kwargs: Dict[str, Any] = {} 51 | if limit is not None: 52 | kwargs["limit"] = limit 53 | if offset is not None: 54 | kwargs["offset"] = offset 55 | if order_by is not None: 56 | kwargs["order_by"] = order_by 57 | if tags is not None: 58 | kwargs["tags"] = tags 59 | if only_active is not None: 60 | kwargs["only_active"] = only_active 61 | if paused is not None: 62 | kwargs["paused"] = paused 63 | if dag_id_pattern is not None: 64 | kwargs["dag_id_pattern"] = dag_id_pattern 65 | 66 | # Use the client to fetch DAGs 67 | response = dag_api.get_dags(**kwargs) 68 | 69 | # Convert response to dictionary for easier manipulation 70 | response_dict = response.to_dict() 71 | 72 | # Add UI links to each DAG 73 | for dag in response_dict.get("dags", []): 74 | dag["ui_url"] = get_dag_url(dag["dag_id"]) 75 | 76 | return [types.TextContent(type="text", text=str(response_dict))] 77 | 78 | 79 | async def get_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 80 | response = dag_api.get_dag(dag_id=dag_id) 81 | 82 | # Convert response to dictionary for easier manipulation 83 | response_dict = response.to_dict() 84 | 85 | # Add UI link to DAG 86 | response_dict["ui_url"] = get_dag_url(dag_id) 87 | 88 | return [types.TextContent(type="text", text=str(response_dict))] 89 | 90 | 91 | async def get_dag_details( 92 | dag_id: str, fields: Optional[List[str]] = None 93 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 94 | # Build parameters dictionary 95 | kwargs: Dict[str, Any] = {} 96 | if fields is not None: 97 | kwargs["fields"] = fields 98 | 99 | response = dag_api.get_dag_details(dag_id=dag_id, **kwargs) 100 | return [types.TextContent(type="text", text=str(response.to_dict()))] 101 | 102 | 103 | async def get_dag_source(file_token: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 104 | response = dag_api.get_dag_source(file_token=file_token) 105 | return [types.TextContent(type="text", text=str(response.to_dict()))] 106 | 107 | 108 | async def pause_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 109 | dag = DAG(is_paused=True) 110 | response = dag_api.patch_dag(dag_id=dag_id, dag=dag, update_mask=["is_paused"]) 111 | return [types.TextContent(type="text", text=str(response.to_dict()))] 112 | 113 | 114 | async def unpause_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 115 | dag = DAG(is_paused=False) 116 | response = dag_api.patch_dag(dag_id=dag_id, dag=dag, update_mask=["is_paused"]) 117 | return [types.TextContent(type="text", text=str(response.to_dict()))] 118 | 119 | 120 | async def get_dag_tasks(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 121 | response = dag_api.get_tasks(dag_id=dag_id) 122 | return [types.TextContent(type="text", text=str(response.to_dict()))] 123 | 124 | 125 | async def patch_dag( 126 | dag_id: str, is_paused: Optional[bool] = None, tags: Optional[List[str]] = None 127 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 128 | update_request = {} 129 | update_mask = [] 130 | 131 | if is_paused is not None: 132 | update_request["is_paused"] = is_paused 133 | update_mask.append("is_paused") 134 | if tags is not None: 135 | update_request["tags"] = tags 136 | update_mask.append("tags") 137 | 138 | dag = DAG(**update_request) 139 | 140 | response = dag_api.patch_dag(dag_id=dag_id, dag=dag, update_mask=update_mask) 141 | return [types.TextContent(type="text", text=str(response.to_dict()))] 142 | 143 | 144 | async def patch_dags( 145 | dag_id_pattern: Optional[str] = None, 146 | is_paused: Optional[bool] = None, 147 | tags: Optional[List[str]] = None, 148 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 149 | update_request = {} 150 | update_mask = [] 151 | 152 | if is_paused is not None: 153 | update_request["is_paused"] = is_paused 154 | update_mask.append("is_paused") 155 | if tags is not None: 156 | update_request["tags"] = tags 157 | update_mask.append("tags") 158 | 159 | dag = DAG(**update_request) 160 | 161 | kwargs = {} 162 | if dag_id_pattern is not None: 163 | kwargs["dag_id_pattern"] = dag_id_pattern 164 | 165 | response = dag_api.patch_dags(dag_id_pattern=dag_id_pattern, dag=dag, update_mask=update_mask, **kwargs) 166 | return [types.TextContent(type="text", text=str(response.to_dict()))] 167 | 168 | 169 | async def delete_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 170 | response = dag_api.delete_dag(dag_id=dag_id) 171 | return [types.TextContent(type="text", text=str(response.to_dict()))] 172 | 173 | 174 | async def get_task( 175 | dag_id: str, task_id: str 176 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 177 | response = dag_api.get_task(dag_id=dag_id, task_id=task_id) 178 | return [types.TextContent(type="text", text=str(response.to_dict()))] 179 | 180 | 181 | async def get_tasks( 182 | dag_id: str, order_by: Optional[str] = None 183 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 184 | kwargs = {} 185 | if order_by is not None: 186 | kwargs["order_by"] = order_by 187 | 188 | response = dag_api.get_tasks(dag_id=dag_id, **kwargs) 189 | return [types.TextContent(type="text", text=str(response.to_dict()))] 190 | 191 | 192 | async def clear_task_instances( 193 | dag_id: str, 194 | task_ids: Optional[List[str]] = None, 195 | start_date: Optional[str] = None, 196 | end_date: Optional[str] = None, 197 | include_subdags: Optional[bool] = None, 198 | include_parentdag: Optional[bool] = None, 199 | include_upstream: Optional[bool] = None, 200 | include_downstream: Optional[bool] = None, 201 | include_future: Optional[bool] = None, 202 | include_past: Optional[bool] = None, 203 | dry_run: Optional[bool] = None, 204 | reset_dag_runs: Optional[bool] = None, 205 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 206 | clear_request = {} 207 | if task_ids is not None: 208 | clear_request["task_ids"] = task_ids 209 | if start_date is not None: 210 | clear_request["start_date"] = start_date 211 | if end_date is not None: 212 | clear_request["end_date"] = end_date 213 | if include_subdags is not None: 214 | clear_request["include_subdags"] = include_subdags 215 | if include_parentdag is not None: 216 | clear_request["include_parentdag"] = include_parentdag 217 | if include_upstream is not None: 218 | clear_request["include_upstream"] = include_upstream 219 | if include_downstream is not None: 220 | clear_request["include_downstream"] = include_downstream 221 | if include_future is not None: 222 | clear_request["include_future"] = include_future 223 | if include_past is not None: 224 | clear_request["include_past"] = include_past 225 | if dry_run is not None: 226 | clear_request["dry_run"] = dry_run 227 | if reset_dag_runs is not None: 228 | clear_request["reset_dag_runs"] = reset_dag_runs 229 | 230 | clear_task_instances = ClearTaskInstances(**clear_request) 231 | 232 | response = dag_api.post_clear_task_instances(dag_id=dag_id, clear_task_instances=clear_task_instances) 233 | return [types.TextContent(type="text", text=str(response.to_dict()))] 234 | 235 | 236 | async def set_task_instances_state( 237 | dag_id: str, 238 | state: str, 239 | task_ids: Optional[List[str]] = None, 240 | execution_date: Optional[str] = None, 241 | include_upstream: Optional[bool] = None, 242 | include_downstream: Optional[bool] = None, 243 | include_future: Optional[bool] = None, 244 | include_past: Optional[bool] = None, 245 | dry_run: Optional[bool] = None, 246 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 247 | state_request = {"state": state} 248 | if task_ids is not None: 249 | state_request["task_ids"] = task_ids 250 | if execution_date is not None: 251 | state_request["execution_date"] = execution_date 252 | if include_upstream is not None: 253 | state_request["include_upstream"] = include_upstream 254 | if include_downstream is not None: 255 | state_request["include_downstream"] = include_downstream 256 | if include_future is not None: 257 | state_request["include_future"] = include_future 258 | if include_past is not None: 259 | state_request["include_past"] = include_past 260 | if dry_run is not None: 261 | state_request["dry_run"] = dry_run 262 | 263 | update_task_instances_state = UpdateTaskInstancesState(**state_request) 264 | 265 | response = dag_api.post_set_task_instances_state( 266 | dag_id=dag_id, 267 | update_task_instances_state=update_task_instances_state, 268 | ) 269 | return [types.TextContent(type="text", text=str(response.to_dict()))] 270 | 271 | 272 | async def reparse_dag_file( 273 | file_token: str, 274 | ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 275 | response = dag_api.reparse_dag_file(file_token=file_token) 276 | return [types.TextContent(type="text", text=str(response.to_dict()))] 277 | ``` -------------------------------------------------------------------------------- /test/test_main.py: -------------------------------------------------------------------------------- ```python 1 | """Tests for the main module using pytest framework.""" 2 | 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | from click.testing import CliRunner 7 | 8 | from src.enums import APIType 9 | from src.main import APITYPE_TO_FUNCTIONS, Tool, main 10 | 11 | 12 | class TestMain: 13 | """Test cases for the main module.""" 14 | 15 | @pytest.fixture 16 | def runner(self): 17 | """Set up CLI test runner.""" 18 | return CliRunner() 19 | 20 | def test_apitype_to_functions_mapping(self): 21 | """Test that all API types are mapped to functions.""" 22 | # Verify all APIType enum values have corresponding functions 23 | for api_type in APIType: 24 | assert api_type in APITYPE_TO_FUNCTIONS 25 | assert APITYPE_TO_FUNCTIONS[api_type] is not None 26 | 27 | def test_apitype_to_functions_completeness(self): 28 | """Test that the function mapping is complete and contains only valid APITypes.""" 29 | # Verify mapping keys match APIType enum values 30 | expected_keys = set(APIType) 31 | actual_keys = set(APITYPE_TO_FUNCTIONS.keys()) 32 | assert expected_keys == actual_keys 33 | 34 | @patch("src.server.app") 35 | def test_main_default_options(self, mock_app, runner): 36 | """Test main function with default options.""" 37 | # Mock get_function to return valid functions 38 | mock_functions = [(lambda: None, "test_function", "Test description")] 39 | 40 | with patch.dict(APITYPE_TO_FUNCTIONS, {api: lambda: mock_functions for api in APIType}): 41 | result = runner.invoke(main, []) 42 | 43 | assert result.exit_code == 0 44 | # Verify app.add_tool was called for each API type 45 | expected_calls = len(APIType) # One call per API type 46 | assert mock_app.add_tool.call_count == expected_calls 47 | # Verify app.run was called with stdio transport 48 | mock_app.run.assert_called_once_with(transport="stdio") 49 | 50 | @patch("src.server.app") 51 | def test_main_sse_transport(self, mock_app, runner): 52 | """Test main function with SSE transport.""" 53 | mock_functions = [(lambda: None, "test_function", "Test description")] 54 | 55 | with patch.dict(APITYPE_TO_FUNCTIONS, {api: lambda: mock_functions for api in APIType}): 56 | result = runner.invoke(main, ["--transport", "sse"]) 57 | 58 | assert result.exit_code == 0 59 | mock_app.run.assert_called_once_with(transport="sse", port=8000, host="0.0.0.0") 60 | 61 | @patch("src.server.app") 62 | def test_main_specific_apis(self, mock_app, runner): 63 | """Test main function with specific APIs selected.""" 64 | mock_functions = [(lambda: None, "test_function", "Test description")] 65 | 66 | selected_apis = ["config", "connection"] 67 | with patch.dict(APITYPE_TO_FUNCTIONS, {api: lambda: mock_functions for api in APIType}): 68 | result = runner.invoke(main, ["--apis", "config", "--apis", "connection"]) 69 | 70 | assert result.exit_code == 0 71 | # Should only add tools for selected APIs 72 | assert mock_app.add_tool.call_count == len(selected_apis) 73 | 74 | @patch("src.server.app") 75 | def test_main_not_implemented_error_handling(self, mock_app, runner): 76 | """Test main function handles NotImplementedError gracefully.""" 77 | 78 | def raise_not_implemented(): 79 | raise NotImplementedError("Not implemented") 80 | 81 | # Mock one API to raise NotImplementedError 82 | modified_mapping = APITYPE_TO_FUNCTIONS.copy() 83 | modified_mapping[APIType.CONFIG] = raise_not_implemented 84 | 85 | mock_functions = [(lambda: None, "test_function", "Test description")] 86 | 87 | # Other APIs should still work 88 | for api in APIType: 89 | if api != APIType.CONFIG: 90 | modified_mapping[api] = lambda: mock_functions 91 | 92 | with patch.dict(APITYPE_TO_FUNCTIONS, modified_mapping, clear=True): 93 | result = runner.invoke(main, []) 94 | 95 | assert result.exit_code == 0 96 | # Should add tools for all APIs except the one that raised NotImplementedError 97 | expected_calls = len(APIType) - 1 98 | assert mock_app.add_tool.call_count == expected_calls 99 | 100 | def test_cli_transport_choices(self, runner): 101 | """Test CLI transport option only accepts valid choices.""" 102 | result = runner.invoke(main, ["--transport", "invalid"]) 103 | assert result.exit_code != 0 104 | assert "Invalid value for '--transport'" in result.output 105 | 106 | def test_cli_apis_choices(self, runner): 107 | """Test CLI apis option only accepts valid choices.""" 108 | result = runner.invoke(main, ["--apis", "invalid"]) 109 | assert result.exit_code != 0 110 | assert "Invalid value for '--apis'" in result.output 111 | 112 | @patch("src.server.app") 113 | def test_function_registration_flow(self, mock_app, runner): 114 | """Test the complete function registration flow.""" 115 | 116 | def mock_function(): 117 | # .add_tools in FastMCP does not allow adding functions with *args 118 | # it limits to use Mock and MagicMock 119 | pass 120 | 121 | mock_functions = [(mock_function, "test_name", "test_description")] 122 | 123 | with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): 124 | result = runner.invoke(main, ["--apis", "config"]) 125 | 126 | assert result.exit_code == 0 127 | mock_app.add_tool.assert_called_once_with( 128 | Tool.from_function(mock_function, name="test_name", description="test_description") 129 | ) 130 | 131 | @patch("src.server.app") 132 | def test_multiple_functions_per_api(self, mock_app, runner): 133 | """Test handling multiple functions per API.""" 134 | mock_functions = [ 135 | (lambda: "func1", "func1_name", "func1_desc"), 136 | (lambda: "func2", "func2_name", "func2_desc"), 137 | (lambda: "func3", "func3_name", "func3_desc"), 138 | ] 139 | 140 | with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): 141 | result = runner.invoke(main, ["--apis", "config"]) 142 | 143 | assert result.exit_code == 0 144 | # Should register all functions 145 | assert mock_app.add_tool.call_count == 3 146 | 147 | def test_help_option(self, runner): 148 | """Test CLI help option.""" 149 | result = runner.invoke(main, ["--help"]) 150 | assert result.exit_code == 0 151 | assert "Transport type" in result.output 152 | assert "APIs to run" in result.output 153 | 154 | @pytest.mark.parametrize("transport", ["stdio", "sse", "http"]) 155 | @patch("src.server.app") 156 | def test_main_transport_options(self, mock_app, transport, runner): 157 | """Test main function with different transport options.""" 158 | mock_functions = [(lambda: None, "test_function", "Test description")] 159 | 160 | with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): 161 | result = runner.invoke(main, ["--transport", transport, "--apis", "config"]) 162 | 163 | assert result.exit_code == 0 164 | if transport == "stdio": 165 | mock_app.run.assert_called_once_with(transport=transport) 166 | else: 167 | mock_app.run.assert_called_once_with(transport=transport, port=8000, host="0.0.0.0") 168 | 169 | @pytest.mark.parametrize("transport", ["sse", "http"]) 170 | @pytest.mark.parametrize("port", [None, "12345"]) 171 | @pytest.mark.parametrize("host", [None, "127.0.0.1"]) 172 | @patch("src.server.app") 173 | def test_port_and_host_options(self, mock_app, transport, port, host, runner): 174 | """Test that port and host are set for SSE and HTTP transports""" 175 | mock_functions = [(lambda: None, "test_function", "Test description")] 176 | 177 | with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): 178 | ext_params = [] 179 | if port: 180 | ext_params += ["--mcp-port", port] 181 | if host: 182 | ext_params += ["--mcp-host", host] 183 | runner.invoke(main, ["--transport", transport, "--apis", "config"] + ext_params) 184 | 185 | expected_params = {} 186 | expected_params["port"] = int(port) if port else 8000 187 | expected_params["host"] = host if host else "0.0.0.0" 188 | mock_app.run.assert_called_once_with(transport=transport, **expected_params) 189 | 190 | @pytest.mark.parametrize("api_name", [api.value for api in APIType]) 191 | @patch("src.server.app") 192 | def test_individual_api_selection(self, mock_app, api_name, runner): 193 | """Test selecting individual APIs.""" 194 | mock_functions = [(lambda: None, "test_function", "Test description")] 195 | 196 | with patch.dict(APITYPE_TO_FUNCTIONS, {APIType(api_name): lambda: mock_functions}, clear=True): 197 | result = runner.invoke(main, ["--apis", api_name]) 198 | 199 | assert result.exit_code == 0 200 | assert mock_app.add_tool.call_count == 1 201 | 202 | def test_filter_functions_for_read_only(self): 203 | """Test that filter_functions_for_read_only correctly filters functions.""" 204 | from src.main import filter_functions_for_read_only 205 | 206 | # Mock function objects 207 | def mock_read_func(): 208 | pass 209 | 210 | def mock_write_func(): 211 | pass 212 | 213 | # Test functions with mixed read/write status 214 | functions = [ 215 | (mock_read_func, "get_something", "Get something", True), 216 | (mock_write_func, "create_something", "Create something", False), 217 | (mock_read_func, "list_something", "List something", True), 218 | (mock_write_func, "delete_something", "Delete something", False), 219 | ] 220 | 221 | filtered = filter_functions_for_read_only(functions) 222 | 223 | # Should only have the read-only functions 224 | assert len(filtered) == 2 225 | assert filtered[0][1] == "get_something" 226 | assert filtered[1][1] == "list_something" 227 | 228 | # Verify all returned functions are read-only 229 | for _, _, _, is_read_only in filtered: 230 | assert is_read_only is True 231 | 232 | def test_connection_functions_have_correct_read_only_status(self): 233 | """Test that connection functions are correctly marked as read-only or write.""" 234 | from src.airflow.connection import get_all_functions 235 | 236 | functions = get_all_functions() 237 | function_names = {name: is_read_only for _, name, _, is_read_only in functions} 238 | 239 | # Verify read-only functions 240 | assert function_names["list_connections"] is True 241 | assert function_names["get_connection"] is True 242 | assert function_names["test_connection"] is True 243 | 244 | # Verify write functions 245 | assert function_names["create_connection"] is False 246 | assert function_names["update_connection"] is False 247 | assert function_names["delete_connection"] is False 248 | 249 | def test_dag_functions_have_correct_read_only_status(self): 250 | """Test that DAG functions are correctly marked as read-only or write.""" 251 | from src.airflow.dag import get_all_functions 252 | 253 | functions = get_all_functions() 254 | function_names = {name: is_read_only for _, name, _, is_read_only in functions} 255 | 256 | # Verify read-only functions 257 | assert function_names["fetch_dags"] is True 258 | assert function_names["get_dag"] is True 259 | assert function_names["get_dag_details"] is True 260 | assert function_names["get_dag_source"] is True 261 | assert function_names["get_dag_tasks"] is True 262 | assert function_names["get_task"] is True 263 | assert function_names["get_tasks"] is True 264 | 265 | # Verify write functions 266 | assert function_names["pause_dag"] is False 267 | assert function_names["unpause_dag"] is False 268 | assert function_names["patch_dag"] is False 269 | assert function_names["patch_dags"] is False 270 | assert function_names["delete_dag"] is False 271 | assert function_names["clear_task_instances"] is False 272 | assert function_names["set_task_instances_state"] is False 273 | assert function_names["reparse_dag_file"] is False 274 | 275 | @patch("src.server.app") 276 | def test_main_read_only_mode(self, mock_app, runner): 277 | """Test main function with read-only flag.""" 278 | # Create mock functions with mixed read/write status 279 | mock_functions = [ 280 | (lambda: None, "read_function", "Read function", True), 281 | (lambda: None, "write_function", "Write function", False), 282 | (lambda: None, "another_read_function", "Another read function", True), 283 | ] 284 | 285 | with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): 286 | result = runner.invoke(main, ["--read-only", "--apis", "config"]) 287 | 288 | assert result.exit_code == 0 289 | # Should only register read-only functions (2 out of 3) 290 | assert mock_app.add_tool.call_count == 2 291 | 292 | # Verify the correct functions were registered 293 | call_args_list = mock_app.add_tool.call_args_list 294 | 295 | registered_names = [call.args[0].name for call in call_args_list] 296 | assert "read_function" in registered_names 297 | assert "another_read_function" in registered_names 298 | assert "write_function" not in registered_names 299 | 300 | @patch("src.server.app") 301 | def test_main_read_only_mode_with_no_read_functions(self, mock_app, runner): 302 | """Test main function with read-only flag when API has no read-only functions.""" 303 | # Create mock functions with only write operations 304 | mock_functions = [ 305 | (lambda: None, "write_function1", "Write function 1", False), 306 | (lambda: None, "write_function2", "Write function 2", False), 307 | ] 308 | 309 | with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): 310 | result = runner.invoke(main, ["--read-only", "--apis", "config"]) 311 | 312 | assert result.exit_code == 0 313 | # Should not register any functions 314 | assert mock_app.add_tool.call_count == 0 315 | 316 | def test_cli_read_only_flag_in_help(self, runner): 317 | """Test that read-only flag appears in help.""" 318 | result = runner.invoke(main, ["--help"]) 319 | assert result.exit_code == 0 320 | assert "--read-only" in result.output 321 | assert "Only expose read-only tools" in result.output 322 | ``` -------------------------------------------------------------------------------- /test/airflow/test_dag.py: -------------------------------------------------------------------------------- ```python 1 | """Table-driven tests for the dag module using pytest framework.""" 2 | 3 | from unittest.mock import ANY, MagicMock, patch 4 | 5 | import mcp.types as types 6 | import pytest 7 | 8 | from src.airflow.dag import ( 9 | clear_task_instances, 10 | delete_dag, 11 | get_dag, 12 | get_dag_details, 13 | get_dag_source, 14 | get_dag_tasks, 15 | get_dag_url, 16 | get_dags, 17 | get_task, 18 | get_tasks, 19 | patch_dag, 20 | pause_dag, 21 | reparse_dag_file, 22 | set_task_instances_state, 23 | unpause_dag, 24 | ) 25 | 26 | 27 | class TestDagModule: 28 | """Table-driven test cases for the dag module.""" 29 | 30 | @pytest.fixture 31 | def mock_dag_api(self): 32 | """Create a mock DAG API instance.""" 33 | with patch("src.airflow.dag.dag_api") as mock_api: 34 | yield mock_api 35 | 36 | def test_get_dag_url(self): 37 | """Test DAG URL generation.""" 38 | test_cases = [ 39 | # (dag_id, expected_url) 40 | ("test_dag", "http://localhost:8080/dags/test_dag/grid"), 41 | ("my-complex_dag.v2", "http://localhost:8080/dags/my-complex_dag.v2/grid"), 42 | ("", "http://localhost:8080/dags//grid"), 43 | ] 44 | 45 | for dag_id, expected_url in test_cases: 46 | with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): 47 | result = get_dag_url(dag_id) 48 | assert result == expected_url 49 | 50 | @pytest.mark.parametrize( 51 | "test_case", 52 | [ 53 | # Test case structure: (input_params, mock_response_dict, expected_result_partial) 54 | { 55 | "name": "get_dags_no_params", 56 | "input": {}, 57 | "mock_response": {"dags": [{"dag_id": "test_dag", "description": "Test"}], "total_entries": 1}, 58 | "expected_call_kwargs": {}, 59 | "expected_ui_urls": True, 60 | }, 61 | { 62 | "name": "get_dags_with_limit_offset", 63 | "input": {"limit": 10, "offset": 5}, 64 | "mock_response": {"dags": [{"dag_id": "dag1"}, {"dag_id": "dag2"}], "total_entries": 2}, 65 | "expected_call_kwargs": {"limit": 10, "offset": 5}, 66 | "expected_ui_urls": True, 67 | }, 68 | { 69 | "name": "get_dags_with_filters", 70 | "input": {"tags": ["prod", "daily"], "only_active": True, "paused": False, "dag_id_pattern": "prod_*"}, 71 | "mock_response": {"dags": [{"dag_id": "prod_dag1"}], "total_entries": 1}, 72 | "expected_call_kwargs": { 73 | "tags": ["prod", "daily"], 74 | "only_active": True, 75 | "paused": False, 76 | "dag_id_pattern": "prod_*", 77 | }, 78 | "expected_ui_urls": True, 79 | }, 80 | ], 81 | ) 82 | async def test_get_dags_table_driven(self, test_case, mock_dag_api): 83 | """Table-driven test for get_dags function.""" 84 | # Setup mock response 85 | mock_response = MagicMock() 86 | mock_response.to_dict.return_value = test_case["mock_response"] 87 | mock_dag_api.get_dags.return_value = mock_response 88 | 89 | # Execute function 90 | with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): 91 | result = await get_dags(**test_case["input"]) 92 | 93 | # Verify API call 94 | mock_dag_api.get_dags.assert_called_once_with(**test_case["expected_call_kwargs"]) 95 | 96 | # Verify result structure 97 | assert len(result) == 1 98 | assert isinstance(result[0], types.TextContent) 99 | 100 | # Parse result and verify UI URLs were added if expected 101 | if test_case["expected_ui_urls"]: 102 | result_text = result[0].text 103 | assert "ui_url" in result_text 104 | 105 | @pytest.mark.parametrize( 106 | "test_case", 107 | [ 108 | { 109 | "name": "get_dag_basic", 110 | "input": {"dag_id": "test_dag"}, 111 | "mock_response": {"dag_id": "test_dag", "description": "Test DAG", "is_paused": False}, 112 | "expected_call_kwargs": {"dag_id": "test_dag"}, 113 | }, 114 | { 115 | "name": "get_dag_complex_id", 116 | "input": {"dag_id": "complex-dag_name.v2"}, 117 | "mock_response": {"dag_id": "complex-dag_name.v2", "description": "Complex DAG", "is_paused": True}, 118 | "expected_call_kwargs": {"dag_id": "complex-dag_name.v2"}, 119 | }, 120 | ], 121 | ) 122 | async def test_get_dag_table_driven(self, test_case, mock_dag_api): 123 | """Table-driven test for get_dag function.""" 124 | # Setup mock response 125 | mock_response = MagicMock() 126 | mock_response.to_dict.return_value = test_case["mock_response"] 127 | mock_dag_api.get_dag.return_value = mock_response 128 | 129 | # Execute function 130 | with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): 131 | result = await get_dag(**test_case["input"]) 132 | 133 | # Verify API call 134 | mock_dag_api.get_dag.assert_called_once_with(**test_case["expected_call_kwargs"]) 135 | 136 | # Verify result structure and UI URL addition 137 | assert len(result) == 1 138 | assert isinstance(result[0], types.TextContent) 139 | assert "ui_url" in result[0].text 140 | 141 | @pytest.mark.parametrize( 142 | "test_case", 143 | [ 144 | { 145 | "name": "get_dag_details_no_fields", 146 | "input": {"dag_id": "test_dag"}, 147 | "mock_response": {"dag_id": "test_dag", "file_path": "/path/to/dag.py"}, 148 | "expected_call_kwargs": {"dag_id": "test_dag"}, 149 | }, 150 | { 151 | "name": "get_dag_details_with_fields", 152 | "input": {"dag_id": "test_dag", "fields": ["dag_id", "description"]}, 153 | "mock_response": {"dag_id": "test_dag", "description": "Test"}, 154 | "expected_call_kwargs": {"dag_id": "test_dag", "fields": ["dag_id", "description"]}, 155 | }, 156 | ], 157 | ) 158 | async def test_get_dag_details_table_driven(self, test_case, mock_dag_api): 159 | """Table-driven test for get_dag_details function.""" 160 | # Setup mock response 161 | mock_response = MagicMock() 162 | mock_response.to_dict.return_value = test_case["mock_response"] 163 | mock_dag_api.get_dag_details.return_value = mock_response 164 | 165 | # Execute function 166 | result = await get_dag_details(**test_case["input"]) 167 | 168 | # Verify API call and result 169 | mock_dag_api.get_dag_details.assert_called_once_with(**test_case["expected_call_kwargs"]) 170 | assert len(result) == 1 171 | assert isinstance(result[0], types.TextContent) 172 | 173 | @pytest.mark.parametrize( 174 | "test_case", 175 | [ 176 | { 177 | "name": "pause_dag", 178 | "function": pause_dag, 179 | "input": {"dag_id": "test_dag"}, 180 | "mock_response": {"dag_id": "test_dag", "is_paused": True}, 181 | "expected_call_kwargs": {"dag_id": "test_dag", "dag": ANY, "update_mask": ["is_paused"]}, 182 | "expected_dag_is_paused": True, 183 | }, 184 | { 185 | "name": "unpause_dag", 186 | "function": unpause_dag, 187 | "input": {"dag_id": "test_dag"}, 188 | "mock_response": {"dag_id": "test_dag", "is_paused": False}, 189 | "expected_call_kwargs": {"dag_id": "test_dag", "dag": ANY, "update_mask": ["is_paused"]}, 190 | "expected_dag_is_paused": False, 191 | }, 192 | ], 193 | ) 194 | async def test_pause_unpause_dag_table_driven(self, test_case, mock_dag_api): 195 | """Table-driven test for pause_dag and unpause_dag functions.""" 196 | # Setup mock response 197 | mock_response = MagicMock() 198 | mock_response.to_dict.return_value = test_case["mock_response"] 199 | mock_dag_api.patch_dag.return_value = mock_response 200 | 201 | # Execute function 202 | result = await test_case["function"](**test_case["input"]) 203 | 204 | # Verify API call and result 205 | mock_dag_api.patch_dag.assert_called_once_with(**test_case["expected_call_kwargs"]) 206 | 207 | # Verify the DAG object has correct is_paused value 208 | actual_call_args = mock_dag_api.patch_dag.call_args 209 | actual_dag = actual_call_args.kwargs["dag"] 210 | assert actual_dag["is_paused"] == test_case["expected_dag_is_paused"] 211 | 212 | assert len(result) == 1 213 | assert isinstance(result[0], types.TextContent) 214 | 215 | @pytest.mark.parametrize( 216 | "test_case", 217 | [ 218 | { 219 | "name": "get_tasks_no_order", 220 | "input": {"dag_id": "test_dag"}, 221 | "mock_response": { 222 | "tasks": [ 223 | {"task_id": "task1", "operator": "BashOperator"}, 224 | {"task_id": "task2", "operator": "PythonOperator"}, 225 | ] 226 | }, 227 | "expected_call_kwargs": {"dag_id": "test_dag"}, 228 | }, 229 | { 230 | "name": "get_tasks_with_order", 231 | "input": {"dag_id": "test_dag", "order_by": "task_id"}, 232 | "mock_response": { 233 | "tasks": [ 234 | {"task_id": "task1", "operator": "BashOperator"}, 235 | {"task_id": "task2", "operator": "PythonOperator"}, 236 | ] 237 | }, 238 | "expected_call_kwargs": {"dag_id": "test_dag", "order_by": "task_id"}, 239 | }, 240 | ], 241 | ) 242 | async def test_get_tasks_table_driven(self, test_case, mock_dag_api): 243 | """Table-driven test for get_tasks function.""" 244 | # Setup mock response 245 | mock_response = MagicMock() 246 | mock_response.to_dict.return_value = test_case["mock_response"] 247 | mock_dag_api.get_tasks.return_value = mock_response 248 | 249 | # Execute function 250 | result = await get_tasks(**test_case["input"]) 251 | 252 | # Verify API call and result 253 | mock_dag_api.get_tasks.assert_called_once_with(**test_case["expected_call_kwargs"]) 254 | assert len(result) == 1 255 | assert isinstance(result[0], types.TextContent) 256 | 257 | @pytest.mark.parametrize( 258 | "test_case", 259 | [ 260 | { 261 | "name": "patch_dag_pause_only", 262 | "input": {"dag_id": "test_dag", "is_paused": True}, 263 | "mock_response": {"dag_id": "test_dag", "is_paused": True}, 264 | "expected_update_mask": ["is_paused"], 265 | }, 266 | { 267 | "name": "patch_dag_tags_only", 268 | "input": {"dag_id": "test_dag", "tags": ["prod", "daily"]}, 269 | "mock_response": {"dag_id": "test_dag", "tags": ["prod", "daily"]}, 270 | "expected_update_mask": ["tags"], 271 | }, 272 | { 273 | "name": "patch_dag_both_fields", 274 | "input": {"dag_id": "test_dag", "is_paused": False, "tags": ["dev"]}, 275 | "mock_response": {"dag_id": "test_dag", "is_paused": False, "tags": ["dev"]}, 276 | "expected_update_mask": ["is_paused", "tags"], 277 | }, 278 | ], 279 | ) 280 | async def test_patch_dag_table_driven(self, test_case, mock_dag_api): 281 | """Table-driven test for patch_dag function.""" 282 | # Setup mock response 283 | mock_response = MagicMock() 284 | mock_response.to_dict.return_value = test_case["mock_response"] 285 | mock_dag_api.patch_dag.return_value = mock_response 286 | 287 | # Execute function 288 | with patch("src.airflow.dag.DAG") as mock_dag_class: 289 | mock_dag_instance = MagicMock() 290 | mock_dag_class.return_value = mock_dag_instance 291 | 292 | result = await patch_dag(**test_case["input"]) 293 | 294 | # Verify DAG instance creation and API call 295 | expected_update_request = {k: v for k, v in test_case["input"].items() if k != "dag_id"} 296 | mock_dag_class.assert_called_once_with(**expected_update_request) 297 | 298 | mock_dag_api.patch_dag.assert_called_once_with( 299 | dag_id=test_case["input"]["dag_id"], 300 | dag=mock_dag_instance, 301 | update_mask=test_case["expected_update_mask"], 302 | ) 303 | 304 | assert len(result) == 1 305 | assert isinstance(result[0], types.TextContent) 306 | 307 | @pytest.mark.parametrize( 308 | "test_case", 309 | [ 310 | { 311 | "name": "clear_task_instances_minimal", 312 | "input": {"dag_id": "test_dag"}, 313 | "mock_response": {"message": "Task instances cleared"}, 314 | "expected_clear_request": {}, 315 | }, 316 | { 317 | "name": "clear_task_instances_full", 318 | "input": { 319 | "dag_id": "test_dag", 320 | "task_ids": ["task1", "task2"], 321 | "start_date": "2023-01-01", 322 | "end_date": "2023-01-31", 323 | "include_subdags": True, 324 | "include_upstream": True, 325 | "dry_run": True, 326 | }, 327 | "mock_response": {"message": "Dry run completed"}, 328 | "expected_clear_request": { 329 | "task_ids": ["task1", "task2"], 330 | "start_date": "2023-01-01", 331 | "end_date": "2023-01-31", 332 | "include_subdags": True, 333 | "include_upstream": True, 334 | "dry_run": True, 335 | }, 336 | }, 337 | ], 338 | ) 339 | async def test_clear_task_instances_table_driven(self, test_case, mock_dag_api): 340 | """Table-driven test for clear_task_instances function.""" 341 | # Setup mock response 342 | mock_response = MagicMock() 343 | mock_response.to_dict.return_value = test_case["mock_response"] 344 | mock_dag_api.post_clear_task_instances.return_value = mock_response 345 | 346 | # Execute function 347 | with patch("src.airflow.dag.ClearTaskInstances") as mock_clear_class: 348 | mock_clear_instance = MagicMock() 349 | mock_clear_class.return_value = mock_clear_instance 350 | 351 | result = await clear_task_instances(**test_case["input"]) 352 | 353 | # Verify ClearTaskInstances creation and API call 354 | mock_clear_class.assert_called_once_with(**test_case["expected_clear_request"]) 355 | mock_dag_api.post_clear_task_instances.assert_called_once_with( 356 | dag_id=test_case["input"]["dag_id"], clear_task_instances=mock_clear_instance 357 | ) 358 | 359 | assert len(result) == 1 360 | assert isinstance(result[0], types.TextContent) 361 | 362 | @pytest.mark.parametrize( 363 | "test_case", 364 | [ 365 | { 366 | "name": "set_task_state_minimal", 367 | "input": {"dag_id": "test_dag", "state": "success"}, 368 | "mock_response": {"message": "Task state updated"}, 369 | "expected_state_request": {"state": "success"}, 370 | }, 371 | { 372 | "name": "set_task_state_full", 373 | "input": { 374 | "dag_id": "test_dag", 375 | "state": "failed", 376 | "task_ids": ["task1"], 377 | "execution_date": "2023-01-01T00:00:00Z", 378 | "include_upstream": True, 379 | "include_downstream": False, 380 | "dry_run": True, 381 | }, 382 | "mock_response": {"message": "Dry run state update"}, 383 | "expected_state_request": { 384 | "state": "failed", 385 | "task_ids": ["task1"], 386 | "execution_date": "2023-01-01T00:00:00Z", 387 | "include_upstream": True, 388 | "include_downstream": False, 389 | "dry_run": True, 390 | }, 391 | }, 392 | ], 393 | ) 394 | async def test_set_task_instances_state_table_driven(self, test_case, mock_dag_api): 395 | """Table-driven test for set_task_instances_state function.""" 396 | # Setup mock response 397 | mock_response = MagicMock() 398 | mock_response.to_dict.return_value = test_case["mock_response"] 399 | mock_dag_api.post_set_task_instances_state.return_value = mock_response 400 | 401 | # Execute function 402 | with patch("src.airflow.dag.UpdateTaskInstancesState") as mock_state_class: 403 | mock_state_instance = MagicMock() 404 | mock_state_class.return_value = mock_state_instance 405 | 406 | result = await set_task_instances_state(**test_case["input"]) 407 | 408 | # Verify UpdateTaskInstancesState creation and API call 409 | mock_state_class.assert_called_once_with(**test_case["expected_state_request"]) 410 | mock_dag_api.post_set_task_instances_state.assert_called_once_with( 411 | dag_id=test_case["input"]["dag_id"], update_task_instances_state=mock_state_instance 412 | ) 413 | 414 | assert len(result) == 1 415 | assert isinstance(result[0], types.TextContent) 416 | 417 | @pytest.mark.parametrize( 418 | "test_case", 419 | [ 420 | { 421 | "name": "simple_functions_get_dag_source", 422 | "function": get_dag_source, 423 | "api_method": "get_dag_source", 424 | "input": {"file_token": "test_token"}, 425 | "mock_response": {"content": "DAG source code"}, 426 | "expected_call_kwargs": {"file_token": "test_token"}, 427 | }, 428 | { 429 | "name": "simple_functions_get_dag_tasks", 430 | "function": get_dag_tasks, 431 | "api_method": "get_tasks", 432 | "input": {"dag_id": "test_dag"}, 433 | "mock_response": {"tasks": []}, 434 | "expected_call_kwargs": {"dag_id": "test_dag"}, 435 | }, 436 | { 437 | "name": "simple_functions_get_task", 438 | "function": get_task, 439 | "api_method": "get_task", 440 | "input": {"dag_id": "test_dag", "task_id": "test_task"}, 441 | "mock_response": {"task_id": "test_task", "operator": "BashOperator"}, 442 | "expected_call_kwargs": {"dag_id": "test_dag", "task_id": "test_task"}, 443 | }, 444 | { 445 | "name": "simple_functions_delete_dag", 446 | "function": delete_dag, 447 | "api_method": "delete_dag", 448 | "input": {"dag_id": "test_dag"}, 449 | "mock_response": {"message": "DAG deleted"}, 450 | "expected_call_kwargs": {"dag_id": "test_dag"}, 451 | }, 452 | { 453 | "name": "simple_functions_reparse_dag_file", 454 | "function": reparse_dag_file, 455 | "api_method": "reparse_dag_file", 456 | "input": {"file_token": "test_token"}, 457 | "mock_response": {"message": "DAG file reparsed"}, 458 | "expected_call_kwargs": {"file_token": "test_token"}, 459 | }, 460 | ], 461 | ) 462 | async def test_simple_functions_table_driven(self, test_case, mock_dag_api): 463 | """Table-driven test for simple functions that directly call API methods.""" 464 | # Setup mock response 465 | mock_response = MagicMock() 466 | mock_response.to_dict.return_value = test_case["mock_response"] 467 | getattr(mock_dag_api, test_case["api_method"]).return_value = mock_response 468 | 469 | # Execute function 470 | result = await test_case["function"](**test_case["input"]) 471 | 472 | # Verify API call and result 473 | getattr(mock_dag_api, test_case["api_method"]).assert_called_once_with(**test_case["expected_call_kwargs"]) 474 | assert len(result) == 1 475 | assert isinstance(result[0], types.TextContent) 476 | assert str(test_case["mock_response"]) in result[0].text 477 | 478 | @pytest.mark.integration 479 | async def test_dag_functions_integration_flow(self, mock_dag_api): 480 | """Integration test showing typical DAG management workflow.""" 481 | # Test data for a complete workflow 482 | dag_id = "integration_test_dag" 483 | 484 | # Mock responses for each step 485 | mock_responses = { 486 | "get_dag": {"dag_id": dag_id, "is_paused": True}, 487 | "patch_dag": {"dag_id": dag_id, "is_paused": False}, 488 | "get_tasks": {"tasks": [{"task_id": "task1"}, {"task_id": "task2"}]}, 489 | "delete_dag": {"message": "DAG deleted successfully"}, 490 | } 491 | 492 | # Setup mock responses 493 | for method, response in mock_responses.items(): 494 | mock_response = MagicMock() 495 | mock_response.to_dict.return_value = response 496 | getattr(mock_dag_api, method).return_value = mock_response 497 | 498 | # Execute workflow steps 499 | with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): 500 | # 1. Get DAG info 501 | dag_info = await get_dag(dag_id) 502 | assert len(dag_info) == 1 503 | 504 | # 2. Unpause DAG 505 | with patch("src.airflow.dag.DAG") as mock_dag_class: 506 | mock_dag_class.return_value = MagicMock() 507 | unpause_result = await patch_dag(dag_id, is_paused=False) 508 | assert len(unpause_result) == 1 509 | 510 | # 3. Get tasks 511 | tasks_result = await get_tasks(dag_id) 512 | assert len(tasks_result) == 1 513 | 514 | # 4. Delete DAG 515 | delete_result = await delete_dag(dag_id) 516 | assert len(delete_result) == 1 517 | 518 | # Verify all API calls were made 519 | mock_dag_api.get_dag.assert_called_once_with(dag_id=dag_id) 520 | mock_dag_api.patch_dag.assert_called_once() 521 | mock_dag_api.get_tasks.assert_called_once_with(dag_id=dag_id) 522 | mock_dag_api.delete_dag.assert_called_once_with(dag_id=dag_id) 523 | ```