# Directory Structure ``` ├── .cursor │ └── rules │ ├── modelcontextprotocol.mdc │ └── python.mdc ├── .github │ ├── CODEOWNERS │ └── workflows │ ├── lint.yml │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .tool-versions ├── Dockerfile ├── LICENSE ├── Makefile ├── pyproject.toml ├── pytest.ini ├── README.md ├── smithery.yaml ├── src │ ├── __init__.py │ ├── __main__.py │ ├── airflow │ │ ├── __init__.py │ │ ├── airflow_client.py │ │ ├── config.py │ │ ├── connection.py │ │ ├── dag.py │ │ ├── dagrun.py │ │ ├── dagstats.py │ │ ├── dataset.py │ │ ├── eventlog.py │ │ ├── importerror.py │ │ ├── monitoring.py │ │ ├── plugin.py │ │ ├── pool.py │ │ ├── provider.py │ │ ├── taskinstance.py │ │ ├── variable.py │ │ └── xcom.py │ ├── enums.py │ ├── envs.py │ ├── main.py │ └── server.py ├── test │ ├── __init__.py │ ├── airflow │ │ ├── test_dag.py │ │ └── test_taskinstance.py │ ├── conftest.py │ ├── test_airflow_client.py │ ├── test_main.py │ └── test_server.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /.tool-versions: -------------------------------------------------------------------------------- ``` python 3.12.6 ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` # Python __pycache__/ *.py[cod] *$py.class *.so .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg .pypirc .ruff_cache/ # Virtual Environment .env .venv env/ venv/ ENV/ # IDE .idea/ .vscode/ *.swp *.swo # OS .DS_Store Thumbs.db ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- ```markdown [](https://mseep.ai/app/yangkyeongmo-mcp-server-apache-airflow) # mcp-server-apache-airflow [](https://smithery.ai/server/@yangkyeongmo/mcp-server-apache-airflow)  A Model Context Protocol (MCP) server implementation for Apache Airflow, enabling seamless integration with MCP clients. This project provides a standardized way to interact with Apache Airflow through the Model Context Protocol. <a href="https://glama.ai/mcp/servers/e99b6vx9lw"> <img width="380" height="200" src="https://glama.ai/mcp/servers/e99b6vx9lw/badge" alt="Server for Apache Airflow MCP server" /> </a> ## About This project implements a [Model Context Protocol](https://modelcontextprotocol.io/introduction) server that wraps Apache Airflow's REST API, allowing MCP clients to interact with Airflow in a standardized way. It uses the official Apache Airflow client library to ensure compatibility and maintainability. ## Feature Implementation Status | Feature | API Path | Status | | -------------------------------- | --------------------------------------------------------------------------------------------- | ------ | | **DAG Management** | | | | List DAGs | `/api/v1/dags` | ✅ | | Get DAG Details | `/api/v1/dags/{dag_id}` | ✅ | | Pause DAG | `/api/v1/dags/{dag_id}` | ✅ | | Unpause DAG | `/api/v1/dags/{dag_id}` | ✅ | | Update DAG | `/api/v1/dags/{dag_id}` | ✅ | | Delete DAG | `/api/v1/dags/{dag_id}` | ✅ | | Get DAG Source | `/api/v1/dagSources/{file_token}` | ✅ | | Patch Multiple DAGs | `/api/v1/dags` | ✅ | | Reparse DAG File | `/api/v1/dagSources/{file_token}/reparse` | ✅ | | **DAG Runs** | | | | List DAG Runs | `/api/v1/dags/{dag_id}/dagRuns` | ✅ | | Create DAG Run | `/api/v1/dags/{dag_id}/dagRuns` | ✅ | | Get DAG Run Details | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}` | ✅ | | Update DAG Run | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}` | ✅ | | Delete DAG Run | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}` | ✅ | | Get DAG Runs Batch | `/api/v1/dags/~/dagRuns/list` | ✅ | | Clear DAG Run | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear` | ✅ | | Set DAG Run Note | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/setNote` | ✅ | | Get Upstream Dataset Events | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/upstreamDatasetEvents` | ✅ | | **Tasks** | | | | List DAG Tasks | `/api/v1/dags/{dag_id}/tasks` | ✅ | | Get Task Details | `/api/v1/dags/{dag_id}/tasks/{task_id}` | ✅ | | Get Task Instance | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}` | ✅ | | List Task Instances | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances` | ✅ | | Update Task Instance | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}` | ✅ | | Get Task Instance Log | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/logs/{task_try_number}` | ✅ | | Clear Task Instances | `/api/v1/dags/{dag_id}/clearTaskInstances` | ✅ | | Set Task Instances State | `/api/v1/dags/{dag_id}/updateTaskInstancesState` | ✅ | | List Task Instance Tries | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/tries` | ✅ | | **Variables** | | | | List Variables | `/api/v1/variables` | ✅ | | Create Variable | `/api/v1/variables` | ✅ | | Get Variable | `/api/v1/variables/{variable_key}` | ✅ | | Update Variable | `/api/v1/variables/{variable_key}` | ✅ | | Delete Variable | `/api/v1/variables/{variable_key}` | ✅ | | **Connections** | | | | List Connections | `/api/v1/connections` | ✅ | | Create Connection | `/api/v1/connections` | ✅ | | Get Connection | `/api/v1/connections/{connection_id}` | ✅ | | Update Connection | `/api/v1/connections/{connection_id}` | ✅ | | Delete Connection | `/api/v1/connections/{connection_id}` | ✅ | | Test Connection | `/api/v1/connections/test` | ✅ | | **Pools** | | | | List Pools | `/api/v1/pools` | ✅ | | Create Pool | `/api/v1/pools` | ✅ | | Get Pool | `/api/v1/pools/{pool_name}` | ✅ | | Update Pool | `/api/v1/pools/{pool_name}` | ✅ | | Delete Pool | `/api/v1/pools/{pool_name}` | ✅ | | **XComs** | | | | List XComs | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries` | ✅ | | Get XCom Entry | `/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}` | ✅ | | **Datasets** | | | | List Datasets | `/api/v1/datasets` | ✅ | | Get Dataset | `/api/v1/datasets/{uri}` | ✅ | | Get Dataset Events | `/api/v1/datasetEvents` | ✅ | | Create Dataset Event | `/api/v1/datasetEvents` | ✅ | | Get DAG Dataset Queued Event | `/api/v1/dags/{dag_id}/dagRuns/queued/datasetEvents/{uri}` | ✅ | | Get DAG Dataset Queued Events | `/api/v1/dags/{dag_id}/dagRuns/queued/datasetEvents` | ✅ | | Delete DAG Dataset Queued Event | `/api/v1/dags/{dag_id}/dagRuns/queued/datasetEvents/{uri}` | ✅ | | Delete DAG Dataset Queued Events | `/api/v1/dags/{dag_id}/dagRuns/queued/datasetEvents` | ✅ | | Get Dataset Queued Events | `/api/v1/datasets/{uri}/dagRuns/queued/datasetEvents` | ✅ | | Delete Dataset Queued Events | `/api/v1/datasets/{uri}/dagRuns/queued/datasetEvents` | ✅ | | **Monitoring** | | | | Get Health | `/api/v1/health` | ✅ | | **DAG Stats** | | | | Get DAG Stats | `/api/v1/dags/statistics` | ✅ | | **Config** | | | | Get Config | `/api/v1/config` | ✅ | | **Plugins** | | | | Get Plugins | `/api/v1/plugins` | ✅ | | **Providers** | | | | List Providers | `/api/v1/providers` | ✅ | | **Event Logs** | | | | List Event Logs | `/api/v1/eventLogs` | ✅ | | Get Event Log | `/api/v1/eventLogs/{event_log_id}` | ✅ | | **System** | | | | Get Import Errors | `/api/v1/importErrors` | ✅ | | Get Import Error Details | `/api/v1/importErrors/{import_error_id}` | ✅ | | Get Health Status | `/api/v1/health` | ✅ | | Get Version | `/api/v1/version` | ✅ | ## Setup ### Dependencies This project depends on the official Apache Airflow client library (`apache-airflow-client`). It will be automatically installed when you install this package. ### Environment Variables Set the following environment variables: ``` AIRFLOW_HOST=<your-airflow-host> # Optional, defaults to http://localhost:8080 AIRFLOW_API_VERSION=v1 # Optional, defaults to v1 READ_ONLY=true # Optional, enables read-only mode (true/false, defaults to false) ``` #### Authentication Choose one of the following authentication methods: **Basic Authentication (default):** ``` AIRFLOW_USERNAME=<your-airflow-username> AIRFLOW_PASSWORD=<your-airflow-password> ``` **JWT Token Authentication:** ``` AIRFLOW_JWT_TOKEN=<your-jwt-token> ``` To obtain a JWT token, you can use Airflow's authentication endpoint: ```bash ENDPOINT_URL="http://localhost:8080" # Replace with your Airflow endpoint curl -X 'POST' \ "${ENDPOINT_URL}/auth/token" \ -H 'Content-Type: application/json' \ -d '{ "username": "<your-username>", "password": "<your-password>" }' ``` > **Note**: If both JWT token and basic authentication credentials are provided, JWT token takes precedence. ### Usage with Claude Desktop Add to your `claude_desktop_config.json`: **Basic Authentication:** ```json { "mcpServers": { "mcp-server-apache-airflow": { "command": "uvx", "args": ["mcp-server-apache-airflow"], "env": { "AIRFLOW_HOST": "https://your-airflow-host", "AIRFLOW_USERNAME": "your-username", "AIRFLOW_PASSWORD": "your-password" } } } } ``` **JWT Token Authentication:** ```json { "mcpServers": { "mcp-server-apache-airflow": { "command": "uvx", "args": ["mcp-server-apache-airflow"], "env": { "AIRFLOW_HOST": "https://your-airflow-host", "AIRFLOW_JWT_TOKEN": "your-jwt-token" } } } } ``` For read-only mode (recommended for safety): **Basic Authentication:** ```json { "mcpServers": { "mcp-server-apache-airflow": { "command": "uvx", "args": ["mcp-server-apache-airflow"], "env": { "AIRFLOW_HOST": "https://your-airflow-host", "AIRFLOW_USERNAME": "your-username", "AIRFLOW_PASSWORD": "your-password", "READ_ONLY": "true" } } } } ``` **JWT Token Authentication:** ```json { "mcpServers": { "mcp-server-apache-airflow": { "command": "uvx", "args": ["mcp-server-apache-airflow", "--read-only"], "env": { "AIRFLOW_HOST": "https://your-airflow-host", "AIRFLOW_JWT_TOKEN": "your-jwt-token" } } } } ``` Alternative configuration using `uv`: **Basic Authentication:** ```json { "mcpServers": { "mcp-server-apache-airflow": { "command": "uv", "args": [ "--directory", "/path/to/mcp-server-apache-airflow", "run", "mcp-server-apache-airflow" ], "env": { "AIRFLOW_HOST": "https://your-airflow-host", "AIRFLOW_USERNAME": "your-username", "AIRFLOW_PASSWORD": "your-password" } } } } ``` **JWT Token Authentication:** ```json { "mcpServers": { "mcp-server-apache-airflow": { "command": "uv", "args": [ "--directory", "/path/to/mcp-server-apache-airflow", "run", "mcp-server-apache-airflow" ], "env": { "AIRFLOW_HOST": "https://your-airflow-host", "AIRFLOW_JWT_TOKEN": "your-jwt-token" } } } } ``` Replace `/path/to/mcp-server-apache-airflow` with the actual path where you've cloned the repository. ### Selecting the API groups You can select the API groups you want to use by setting the `--apis` flag. ```bash uv run mcp-server-apache-airflow --apis dag --apis dagrun ``` The default is to use all APIs. Allowed values are: - config - connections - dag - dagrun - dagstats - dataset - eventlog - importerror - monitoring - plugin - pool - provider - taskinstance - variable - xcom ### Read-Only Mode You can run the server in read-only mode by using the `--read-only` flag or by setting the `READ_ONLY=true` environment variable. This will only expose tools that perform read operations (GET requests) and exclude any tools that create, update, or delete resources. Using the command-line flag: ```bash uv run mcp-server-apache-airflow --read-only ``` Using the environment variable: ```bash READ_ONLY=true uv run mcp-server-apache-airflow ``` In read-only mode, the server will only expose tools like: - Listing DAGs, DAG runs, tasks, variables, connections, etc. - Getting details of specific resources - Reading configurations and monitoring information - Testing connections (non-destructive) Write operations like creating, updating, deleting DAGs, variables, connections, triggering DAG runs, etc. will not be available in read-only mode. You can combine read-only mode with API group selection: ```bash uv run mcp-server-apache-airflow --read-only --apis dag --apis variable ``` ### Manual Execution You can also run the server manually: ```bash make run ``` `make run` accepts following options: Options: - `--port`: Port to listen on for SSE (default: 8000) - `--transport`: Transport type (stdio/sse/http, default: stdio) Or, you could run the sse server directly, which accepts same parameters: ```bash make run-sse ``` Also, you could start service directly using `uv` like in the following command: ```bash uv run src --transport http --port 8080 ``` ### Installing via Smithery To install Apache Airflow MCP Server for Claude Desktop automatically via [Smithery](https://smithery.ai/server/@yangkyeongmo/mcp-server-apache-airflow): ```bash npx -y @smithery/cli install @yangkyeongmo/mcp-server-apache-airflow --client claude ``` ## Development ### Setting up Development Environment 1. Clone the repository: ```bash git clone https://github.com/yangkyeongmo/mcp-server-apache-airflow.git cd mcp-server-apache-airflow ``` 2. Install development dependencies: ```bash uv sync --dev ``` 3. Create a `.env` file for environment variables (optional for development): ```bash touch .env ``` > **Note**: No environment variables are required for running tests. The `AIRFLOW_HOST` defaults to `http://localhost:8080` for development and testing purposes. ### Running Tests The project uses pytest for testing with the following commands available: ```bash # Run all tests make test ``` ### Code Quality ```bash # Run linting make lint # Run code formatting make format ``` ### Continuous Integration The project includes a GitHub Actions workflow (`.github/workflows/test.yml`) that automatically: - Runs tests on Python 3.10, 3.11, and 3.12 - Executes linting checks using ruff - Runs on every push and pull request to `main` branch The CI pipeline ensures code quality and compatibility across supported Python versions before any changes are merged. ## Contributing Contributions are welcome! Please feel free to submit a Pull Request. The package is deployed automatically to PyPI when project.version is updated in `pyproject.toml`. Follow semver for versioning. Please include version update in the PR in order to apply the changes to core logic. ## License [MIT License](LICENSE) ``` -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /src/airflow/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- ```python # Test package initialization ``` -------------------------------------------------------------------------------- /src/__main__.py: -------------------------------------------------------------------------------- ```python import sys from src.main import main sys.exit(main()) ``` -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- ```python from fastmcp import FastMCP app = FastMCP("mcp-apache-airflow") ``` -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- ``` [pytest] minversion = 6.0 addopts = -ra -q --strict-markers --strict-config testpaths = test python_files = test_*.py python_classes = Test* python_functions = test_* asyncio_mode = auto markers = integration: marks tests as integration tests (deselect with '-m "not integration"') slow: marks tests as slow (deselect with '-m "not slow"') unit: marks tests as unit tests ``` -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- ```python """ Pytest configuration and shared fixtures for the test suite. This file contains shared test fixtures, configurations, and utilities that can be used across all test modules. """ import sys from pathlib import Path # Add the src directory to the Python path for imports during testing src_path = Path(__file__).parent.parent / "src" if str(src_path) not in sys.path: sys.path.insert(0, str(src_path)) ``` -------------------------------------------------------------------------------- /src/enums.py: -------------------------------------------------------------------------------- ```python from enum import Enum class APIType(str, Enum): CONFIG = "config" CONNECTION = "connection" DAG = "dag" DAGRUN = "dagrun" DAGSTATS = "dagstats" DATASET = "dataset" EVENTLOG = "eventlog" IMPORTERROR = "importerror" MONITORING = "monitoring" PLUGIN = "plugin" POOL = "pool" PROVIDER = "provider" TASKINSTANCE = "taskinstance" VARIABLE = "variable" XCOM = "xcom" ``` -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- ```dockerfile # Generated by https://smithery.ai. See: https://smithery.ai/docs/config#dockerfile # Use a Python base image FROM python:3.10-slim # Set the working directory WORKDIR /app # Copy the contents of the repository to the working directory COPY . . # Install the project dependencies RUN pip install uv RUN uv sync # Expose the port that the server will run on EXPOSE 8000 # Command to run the server CMD ["uv", "run", "src", "--transport", "sse"] ``` -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- ```yaml name: Lint and Format Check on: push: branches: [ main ] pull_request: branches: [ main ] jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip pip install ruff - name: Check formatting run: ruff format --check . - name: Run linting run: ruff check . ``` -------------------------------------------------------------------------------- /src/envs.py: -------------------------------------------------------------------------------- ```python import os from urllib.parse import urlparse # Environment variables for Airflow connection # AIRFLOW_HOST defaults to localhost for development/testing if not provided _airflow_host_raw = os.getenv("AIRFLOW_HOST", "http://localhost:8080") AIRFLOW_HOST = urlparse(_airflow_host_raw)._replace(path="").geturl().rstrip("/") # Authentication - supports both basic auth and JWT token auth AIRFLOW_USERNAME = os.getenv("AIRFLOW_USERNAME") AIRFLOW_PASSWORD = os.getenv("AIRFLOW_PASSWORD") AIRFLOW_JWT_TOKEN = os.getenv("AIRFLOW_JWT_TOKEN") AIRFLOW_API_VERSION = os.getenv("AIRFLOW_API_VERSION", "v1") # Environment variable for read-only mode READ_ONLY = os.getenv("READ_ONLY", "false").lower() in ("true", "1", "yes", "on") ``` -------------------------------------------------------------------------------- /src/airflow/airflow_client.py: -------------------------------------------------------------------------------- ```python from urllib.parse import urljoin from airflow_client.client import ApiClient, Configuration from src.envs import ( AIRFLOW_API_VERSION, AIRFLOW_HOST, AIRFLOW_JWT_TOKEN, AIRFLOW_PASSWORD, AIRFLOW_USERNAME, ) # Create a configuration and API client configuration = Configuration( host=urljoin(AIRFLOW_HOST, f"/api/{AIRFLOW_API_VERSION}"), ) # Set up authentication - prefer JWT token if available, fallback to basic auth if AIRFLOW_JWT_TOKEN: configuration.api_key = {"Authorization": f"Bearer {AIRFLOW_JWT_TOKEN}"} configuration.api_key_prefix = {"Authorization": ""} elif AIRFLOW_USERNAME and AIRFLOW_PASSWORD: configuration.username = AIRFLOW_USERNAME configuration.password = AIRFLOW_PASSWORD api_client = ApiClient(configuration) ``` -------------------------------------------------------------------------------- /src/airflow/dagstats.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.dag_stats_api import DagStatsApi from src.airflow.airflow_client import api_client dag_stats_api = DagStatsApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_dag_stats, "get_dag_stats", "Get DAG stats", True), ] async def get_dag_stats( dag_ids: Optional[List[str]] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if dag_ids is not None: kwargs["dag_ids"] = dag_ids response = dag_stats_api.get_dag_stats(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- ```yaml name: Publish Python Package on: release: types: [created] jobs: publish: runs-on: ubuntu-latest permissions: contents: read id-token: write # Required for trusted publishing steps: - name: Checkout repository uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.x' - name: Install dependencies run: pip install build - name: Build package run: python -m build - name: Publish package to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} # Alternatively, if using trusted publishing (recommended): # See: https://docs.pypi.org/trusted-publishers/ # attestation-check-repository: yangkyeongmo/mcp-server-apache-airflow # attestation-check-workflow: publish.yml # Optional: if your workflow file is named differently ``` -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- ```yaml name: Run Tests on: push: branches: [ main ] pull_request: branches: [ main ] jobs: test: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.10", "3.11", "3.12"] steps: - name: Checkout code uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install uv uses: astral-sh/setup-uv@v3 with: version: "latest" - name: Create .env file run: | touch .env echo "# Environment variables for testing" > .env - name: Install dependencies run: | uv sync --dev - name: Run linting run: | make lint - name: Run tests run: | make test - name: Upload coverage reports if: matrix.python-version == '3.11' uses: codecov/codecov-action@v4 with: fail_ci_if_error: false ``` -------------------------------------------------------------------------------- /smithery.yaml: -------------------------------------------------------------------------------- ```yaml # Smithery configuration file: https://smithery.ai/docs/config#smitheryyaml startCommand: type: stdio configSchema: # JSON Schema defining the configuration options for the MCP. type: object required: - airflowHost - airflowUsername - airflowPassword properties: airflowHost: type: string description: The host URL for the Airflow instance. airflowUsername: type: string description: The username for Airflow authentication. airflowPassword: type: string description: The password for Airflow authentication. airflowApiVersion: type: string description: The Airflow API version to use (defaults to v1). commandFunction: # A function that produces the CLI command to start the MCP on stdio. |- (config) => ({ command: 'python', args: ['src/server.py'], env: { AIRFLOW_HOST: config.airflowHost, AIRFLOW_USERNAME: config.airflowUsername, AIRFLOW_PASSWORD: config.airflowPassword, AIRFLOW_API_VERSION: config.airflowApiVersion || 'v1' } }) ``` -------------------------------------------------------------------------------- /src/airflow/plugin.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.plugin_api import PluginApi from src.airflow.airflow_client import api_client plugin_api = PluginApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_plugins, "get_plugins", "Get a list of loaded plugins", True), ] async def get_plugins( limit: Optional[int] = None, offset: Optional[int] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: """ Get a list of loaded plugins. Args: limit: The numbers of items to return. offset: The number of items to skip before starting to collect the result set. Returns: A list of loaded plugins. """ # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset response = plugin_api.get_plugins(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /src/airflow/monitoring.py: -------------------------------------------------------------------------------- ```python from typing import Callable, List, Union import mcp.types as types from airflow_client.client.api.monitoring_api import MonitoringApi from src.airflow.airflow_client import api_client monitoring_api = MonitoringApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_health, "get_health", "Get instance status", True), (get_version, "get_version", "Get version information", True), ] async def get_health() -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: """ Get the status of Airflow's metadatabase, triggerer and scheduler. It includes info about metadatabase and last heartbeat of scheduler and triggerer. """ response = monitoring_api.get_health() return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_version() -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: """ Get version information about Airflow. """ response = monitoring_api.get_version() return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /src/airflow/config.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.config_api import ConfigApi from src.airflow.airflow_client import api_client config_api = ConfigApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_config, "get_config", "Get current configuration", True), (get_value, "get_value", "Get a specific option from configuration", True), ] async def get_config( section: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if section is not None: kwargs["section"] = section response = config_api.get_config(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_value( section: str, option: str ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = config_api.get_value(section=section, option=option) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /src/airflow/provider.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.provider_api import ProviderApi from src.airflow.airflow_client import api_client provider_api = ProviderApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_providers, "get_providers", "Get a list of loaded providers", True), ] async def get_providers( limit: Optional[int] = None, offset: Optional[int] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: """ Get a list of providers. Args: limit: The numbers of items to return. offset: The number of items to skip before starting to collect the result set. Returns: A list of providers with their details. """ # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset response = provider_api.get_providers(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /src/airflow/importerror.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.import_error_api import ImportErrorApi from src.airflow.airflow_client import api_client import_error_api = ImportErrorApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_import_errors, "get_import_errors", "List import errors", True), (get_import_error, "get_import_error", "Get a specific import error by ID", True), ] async def get_import_errors( limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if order_by is not None: kwargs["order_by"] = order_by response = import_error_api.get_import_errors(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_import_error( import_error_id: int, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = import_error_api.get_import_error(import_error_id=import_error_id) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- ```toml [project] name = "mcp-server-apache-airflow" version = "0.2.9" description = "Model Context Protocol (MCP) server for Apache Airflow" authors = [ { name = "Gyeongmo Yang", email = "[email protected]" } ] dependencies = [ "httpx>=0.24.1", "click>=8.1.7", "mcp>=0.1.0", "apache-airflow-client>=2.7.0,<3", "fastmcp>=2.11.3", "PyJWT>=2.8.0", ] requires-python = ">=3.10" readme = "README.md" license = { text = "MIT" } classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Software Development :: Libraries :: Python Modules", ] keywords = ["mcp", "airflow", "apache-airflow", "model-context-protocol"] [project.optional-dependencies] dev = [ "build>=1.2.2.post1", "twine>=6.1.0", ] [project.urls] Homepage = "https://github.com/yangkyeongmo/mcp-server-apache-airflow" Repository = "https://github.com/yangkyeongmo/mcp-server-apache-airflow.git" "Bug Tracker" = "https://github.com/yangkyeongmo/mcp-server-apache-airflow/issues" [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [project.scripts] mcp-server-apache-airflow = "src.main:main" [tool.hatch.build.targets.wheel] packages = ["src"] [tool.hatch.build] include = [ "src/**/*.py", "README.md", "LICENSE", ] [tool.ruff] line-length = 120 [tool.ruff.lint] select = ["E", "W", "F", "B", "I"] [dependency-groups] dev = [ "ruff>=0.11.0", "pytest>=7.0.0", "pytest-cov>=4.0.0", "pytest-asyncio>=0.21.0", ] ``` -------------------------------------------------------------------------------- /src/airflow/xcom.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.x_com_api import XComApi from src.airflow.airflow_client import api_client xcom_api = XComApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_xcom_entries, "get_xcom_entries", "Get all XCom entries", True), (get_xcom_entry, "get_xcom_entry", "Get an XCom entry", True), ] async def get_xcom_entries( dag_id: str, dag_run_id: str, task_id: str, map_index: Optional[int] = None, xcom_key: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if map_index is not None: kwargs["map_index"] = map_index if xcom_key is not None: kwargs["xcom_key"] = xcom_key if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset response = xcom_api.get_xcom_entries(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, **kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_xcom_entry( dag_id: str, dag_run_id: str, task_id: str, xcom_key: str, map_index: Optional[int] = None, deserialize: Optional[bool] = None, stringify: Optional[bool] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if map_index is not None: kwargs["map_index"] = map_index if deserialize is not None: kwargs["deserialize"] = deserialize if stringify is not None: kwargs["stringify"] = stringify response = xcom_api.get_xcom_entry( dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, xcom_key=xcom_key, **kwargs ) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /src/airflow/eventlog.py: -------------------------------------------------------------------------------- ```python from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.event_log_api import EventLogApi from src.airflow.airflow_client import api_client event_log_api = EventLogApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_event_logs, "get_event_logs", "List log entries from event log", True), (get_event_log, "get_event_log", "Get a specific log entry by ID", True), ] async def get_event_logs( limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, dag_id: Optional[str] = None, task_id: Optional[str] = None, run_id: Optional[str] = None, map_index: Optional[int] = None, try_number: Optional[int] = None, event: Optional[str] = None, owner: Optional[str] = None, before: Optional[datetime] = None, after: Optional[datetime] = None, included_events: Optional[str] = None, excluded_events: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if order_by is not None: kwargs["order_by"] = order_by if dag_id is not None: kwargs["dag_id"] = dag_id if task_id is not None: kwargs["task_id"] = task_id if run_id is not None: kwargs["run_id"] = run_id if map_index is not None: kwargs["map_index"] = map_index if try_number is not None: kwargs["try_number"] = try_number if event is not None: kwargs["event"] = event if owner is not None: kwargs["owner"] = owner if before is not None: kwargs["before"] = before if after is not None: kwargs["after"] = after if included_events is not None: kwargs["included_events"] = included_events if excluded_events is not None: kwargs["excluded_events"] = excluded_events response = event_log_api.get_event_logs(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_event_log( event_log_id: int, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = event_log_api.get_event_log(event_log_id=event_log_id) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /src/airflow/variable.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.variable_api import VariableApi from src.airflow.airflow_client import api_client variable_api = VariableApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (list_variables, "list_variables", "List all variables", True), (create_variable, "create_variable", "Create a variable", False), (get_variable, "get_variable", "Get a variable by key", True), (update_variable, "update_variable", "Update a variable by key", False), (delete_variable, "delete_variable", "Delete a variable by key", False), ] async def list_variables( limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if order_by is not None: kwargs["order_by"] = order_by response = variable_api.get_variables(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def create_variable( key: str, value: str, description: Optional[str] = None ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: variable_request = { "key": key, "value": value, } if description is not None: variable_request["description"] = description response = variable_api.post_variables(variable_request=variable_request) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_variable(key: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = variable_api.get_variable(variable_key=key) return [types.TextContent(type="text", text=str(response.to_dict()))] async def update_variable( key: str, value: Optional[str] = None, description: Optional[str] = None ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: update_request = {} if value is not None: update_request["value"] = value if description is not None: update_request["description"] = description response = variable_api.patch_variable( variable_key=key, update_mask=list(update_request.keys()), variable_request=update_request ) return [types.TextContent(type="text", text=str(response.to_dict()))] async def delete_variable(key: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = variable_api.delete_variable(variable_key=key) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /test/test_server.py: -------------------------------------------------------------------------------- ```python """Tests for the server module using pytest framework.""" import pytest from fastmcp import FastMCP from fastmcp.tools import Tool class TestServer: """Test cases for the server module.""" def test_app_instance_type(self): """Test that app instance is of correct type.""" from src.server import app # Verify app is an instance of FastMCP assert isinstance(app, FastMCP) def test_app_instance_name(self): """Test that app instance has the correct name.""" from src.server import app # Verify the app name is set correctly assert app.name == "mcp-apache-airflow" def test_app_instance_is_singleton(self): """Test that importing the app multiple times returns the same instance.""" from src.server import app as app1 from src.server import app as app2 # Verify same instance is returned assert app1 is app2 def test_app_has_required_methods(self): """Test that app instance has required FastMCP methods.""" from src.server import app # Verify essential methods exist assert hasattr(app, "add_tool") assert hasattr(app, "run") assert callable(app.add_tool) assert callable(app.run) def test_app_initialization_attributes(self): """Test that app is properly initialized with default attributes.""" from src.server import app # Verify basic FastMCP attributes assert app.name is not None assert app.name == "mcp-apache-airflow" # Verify app can be used (doesn't raise exceptions on basic operations) try: # These should not raise exceptions str(app) repr(app) except Exception as e: pytest.fail(f"Basic app operations failed: {e}") def test_app_name_format(self): """Test that app name follows expected format.""" from src.server import app # Verify name format assert isinstance(app.name, str) assert app.name.startswith("mcp-") assert "airflow" in app.name @pytest.mark.integration def test_app_tool_registration_capability(self): """Test that app can register tools without errors.""" from src.server import app # Mock function to register def test_tool(): return "test result" # This should not raise an exception try: app.add_tool(Tool.from_function(test_tool, name="test_tool", description="Test tool")) except Exception as e: pytest.fail(f"Tool registration failed: {e}") def test_app_module_level_initialization(self): """Test that app is initialized at module level.""" # Import should work without any setup from src.server import app # App should be ready to use immediately assert app is not None assert hasattr(app, "name") assert app.name == "mcp-apache-airflow" ``` -------------------------------------------------------------------------------- /src/airflow/pool.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.pool_api import PoolApi from airflow_client.client.model.pool import Pool from src.airflow.airflow_client import api_client pool_api = PoolApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_pools, "get_pools", "List pools", True), (get_pool, "get_pool", "Get a pool by name", True), (delete_pool, "delete_pool", "Delete a pool", False), (post_pool, "post_pool", "Create a pool", False), (patch_pool, "patch_pool", "Update a pool", False), ] async def get_pools( limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: """ List pools. Args: limit: The numbers of items to return. offset: The number of items to skip before starting to collect the result set. order_by: The name of the field to order the results by. Prefix a field name with `-` to reverse the sort order. Returns: A list of pools. """ # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if order_by is not None: kwargs["order_by"] = order_by response = pool_api.get_pools(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_pool( pool_name: str, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: """ Get a pool by name. Args: pool_name: The pool name. Returns: The pool details. """ response = pool_api.get_pool(pool_name=pool_name) return [types.TextContent(type="text", text=str(response.to_dict()))] async def delete_pool( pool_name: str, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: """ Delete a pool. Args: pool_name: The pool name. Returns: A confirmation message. """ pool_api.delete_pool(pool_name=pool_name) return [types.TextContent(type="text", text=f"Pool '{pool_name}' deleted successfully.")] async def post_pool( name: str, slots: int, description: Optional[str] = None, include_deferred: Optional[bool] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: """ Create a pool. Args: name: The pool name. slots: The number of slots. description: The pool description. include_deferred: Whether to include deferred tasks in slot calculations. Returns: The created pool details. """ pool = Pool( name=name, slots=slots, ) if description is not None: pool.description = description if include_deferred is not None: pool.include_deferred = include_deferred response = pool_api.post_pool(pool=pool) return [types.TextContent(type="text", text=str(response.to_dict()))] async def patch_pool( pool_name: str, slots: Optional[int] = None, description: Optional[str] = None, include_deferred: Optional[bool] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: """ Update a pool. Args: pool_name: The pool name. slots: The number of slots. description: The pool description. include_deferred: Whether to include deferred tasks in slot calculations. Returns: The updated pool details. """ pool = Pool() if slots is not None: pool.slots = slots if description is not None: pool.description = description if include_deferred is not None: pool.include_deferred = include_deferred response = pool_api.patch_pool(pool_name=pool_name, pool=pool) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- ```python import logging import click from fastmcp.tools import Tool from src.airflow.config import get_all_functions as get_config_functions from src.airflow.connection import get_all_functions as get_connection_functions from src.airflow.dag import get_all_functions as get_dag_functions from src.airflow.dagrun import get_all_functions as get_dagrun_functions from src.airflow.dagstats import get_all_functions as get_dagstats_functions from src.airflow.dataset import get_all_functions as get_dataset_functions from src.airflow.eventlog import get_all_functions as get_eventlog_functions from src.airflow.importerror import get_all_functions as get_importerror_functions from src.airflow.monitoring import get_all_functions as get_monitoring_functions from src.airflow.plugin import get_all_functions as get_plugin_functions from src.airflow.pool import get_all_functions as get_pool_functions from src.airflow.provider import get_all_functions as get_provider_functions from src.airflow.taskinstance import get_all_functions as get_taskinstance_functions from src.airflow.variable import get_all_functions as get_variable_functions from src.airflow.xcom import get_all_functions as get_xcom_functions from src.enums import APIType from src.envs import READ_ONLY APITYPE_TO_FUNCTIONS = { APIType.CONFIG: get_config_functions, APIType.CONNECTION: get_connection_functions, APIType.DAG: get_dag_functions, APIType.DAGRUN: get_dagrun_functions, APIType.DAGSTATS: get_dagstats_functions, APIType.DATASET: get_dataset_functions, APIType.EVENTLOG: get_eventlog_functions, APIType.IMPORTERROR: get_importerror_functions, APIType.MONITORING: get_monitoring_functions, APIType.PLUGIN: get_plugin_functions, APIType.POOL: get_pool_functions, APIType.PROVIDER: get_provider_functions, APIType.TASKINSTANCE: get_taskinstance_functions, APIType.VARIABLE: get_variable_functions, APIType.XCOM: get_xcom_functions, } def filter_functions_for_read_only(functions: list[tuple]) -> list[tuple]: """ Filter functions to only include read-only operations. Args: functions: List of (func, name, description, is_read_only) tuples Returns: List of (func, name, description, is_read_only) tuples with only read-only functions """ return [ (func, name, description, is_read_only) for func, name, description, is_read_only in functions if is_read_only ] @click.command() @click.option( "--transport", type=click.Choice(["stdio", "sse", "http"]), default="stdio", help="Transport type", ) @click.option("--mcp-port", default=8000, help="Port to run MCP service in case of SSE or HTTP transports.") @click.option("--mcp-host", default="0.0.0.0", help="Host to rum MCP srvice in case of SSE or HTTP transports.") @click.option( "--apis", type=click.Choice([api.value for api in APIType]), default=[api.value for api in APIType], multiple=True, help="APIs to run, default is all", ) @click.option( "--read-only", is_flag=True, default=READ_ONLY, help="Only expose read-only tools (GET operations, no CREATE/UPDATE/DELETE)", ) def main(transport: str, mcp_host: str, mcp_port: int, apis: list[str], read_only: bool) -> None: from src.server import app for api in apis: logging.debug(f"Adding API: {api}") get_function = APITYPE_TO_FUNCTIONS[APIType(api)] try: functions = get_function() except NotImplementedError: continue # Filter functions for read-only mode if requested if read_only: functions = filter_functions_for_read_only(functions) for func, name, description, *_ in functions: app.add_tool(Tool.from_function(func, name=name, description=description)) logging.debug(f"Starting MCP server for Apache Airflow with {transport} transport") params_to_run = {} if transport in {"sse", "http"}: if transport == "sse": logging.warning("NOTE: the SSE transport is going to be deprecated.") params_to_run = {"port": int(mcp_port), "host": mcp_host} app.run(transport=transport, **params_to_run) ``` -------------------------------------------------------------------------------- /src/airflow/connection.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.connection_api import ConnectionApi from src.airflow.airflow_client import api_client connection_api = ConnectionApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (list_connections, "list_connections", "List all connections", True), (create_connection, "create_connection", "Create a connection", False), (get_connection, "get_connection", "Get a connection by ID", True), (update_connection, "update_connection", "Update a connection by ID", False), (delete_connection, "delete_connection", "Delete a connection by ID", False), (test_connection, "test_connection", "Test a connection", True), ] async def list_connections( limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if order_by is not None: kwargs["order_by"] = order_by response = connection_api.get_connections(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def create_connection( conn_id: str, conn_type: str, host: Optional[str] = None, port: Optional[int] = None, login: Optional[str] = None, password: Optional[str] = None, schema: Optional[str] = None, extra: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: connection_request = { "connection_id": conn_id, "conn_type": conn_type, } if host is not None: connection_request["host"] = host if port is not None: connection_request["port"] = port if login is not None: connection_request["login"] = login if password is not None: connection_request["password"] = password if schema is not None: connection_request["schema"] = schema if extra is not None: connection_request["extra"] = extra response = connection_api.post_connection(connection_request=connection_request) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_connection(conn_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = connection_api.get_connection(connection_id=conn_id) return [types.TextContent(type="text", text=str(response.to_dict()))] async def update_connection( conn_id: str, conn_type: Optional[str] = None, host: Optional[str] = None, port: Optional[int] = None, login: Optional[str] = None, password: Optional[str] = None, schema: Optional[str] = None, extra: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: update_request = {} if conn_type is not None: update_request["conn_type"] = conn_type if host is not None: update_request["host"] = host if port is not None: update_request["port"] = port if login is not None: update_request["login"] = login if password is not None: update_request["password"] = password if schema is not None: update_request["schema"] = schema if extra is not None: update_request["extra"] = extra response = connection_api.patch_connection( connection_id=conn_id, update_mask=list(update_request.keys()), connection_request=update_request ) return [types.TextContent(type="text", text=str(response.to_dict()))] async def delete_connection(conn_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = connection_api.delete_connection(connection_id=conn_id) return [types.TextContent(type="text", text=str(response.to_dict()))] async def test_connection( conn_type: str, host: Optional[str] = None, port: Optional[int] = None, login: Optional[str] = None, password: Optional[str] = None, schema: Optional[str] = None, extra: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: connection_request = { "conn_type": conn_type, } if host is not None: connection_request["host"] = host if port is not None: connection_request["port"] = port if login is not None: connection_request["login"] = login if password is not None: connection_request["password"] = password if schema is not None: connection_request["schema"] = schema if extra is not None: connection_request["extra"] = extra response = connection_api.test_connection(connection_request=connection_request) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /src/airflow/taskinstance.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.task_instance_api import TaskInstanceApi from src.airflow.airflow_client import api_client task_instance_api = TaskInstanceApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_task_instance, "get_task_instance", "Get a task instance by DAG ID, task ID, and DAG run ID", True), (list_task_instances, "list_task_instances", "List task instances by DAG ID and DAG run ID", True), ( update_task_instance, "update_task_instance", "Update a task instance by DAG ID, DAG run ID, and task ID", False, ), ( get_log, "get_log", "Get the log from a task instance by DAG ID, task ID, DAG run ID and task try number", True, ), ( list_task_instance_tries, "list_task_instance_tries", "List task instance tries by DAG ID, DAG run ID, and task ID", True, ), ] async def get_task_instance( dag_id: str, task_id: str, dag_run_id: str ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = task_instance_api.get_task_instance(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id) return [types.TextContent(type="text", text=str(response.to_dict()))] async def list_task_instances( dag_id: str, dag_run_id: str, execution_date_gte: Optional[str] = None, execution_date_lte: Optional[str] = None, start_date_gte: Optional[str] = None, start_date_lte: Optional[str] = None, end_date_gte: Optional[str] = None, end_date_lte: Optional[str] = None, updated_at_gte: Optional[str] = None, updated_at_lte: Optional[str] = None, duration_gte: Optional[float] = None, duration_lte: Optional[float] = None, state: Optional[List[str]] = None, pool: Optional[List[str]] = None, queue: Optional[List[str]] = None, limit: Optional[int] = None, offset: Optional[int] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if execution_date_gte is not None: kwargs["execution_date_gte"] = execution_date_gte if execution_date_lte is not None: kwargs["execution_date_lte"] = execution_date_lte if start_date_gte is not None: kwargs["start_date_gte"] = start_date_gte if start_date_lte is not None: kwargs["start_date_lte"] = start_date_lte if end_date_gte is not None: kwargs["end_date_gte"] = end_date_gte if end_date_lte is not None: kwargs["end_date_lte"] = end_date_lte if updated_at_gte is not None: kwargs["updated_at_gte"] = updated_at_gte if updated_at_lte is not None: kwargs["updated_at_lte"] = updated_at_lte if duration_gte is not None: kwargs["duration_gte"] = duration_gte if duration_lte is not None: kwargs["duration_lte"] = duration_lte if state is not None: kwargs["state"] = state if pool is not None: kwargs["pool"] = pool if queue is not None: kwargs["queue"] = queue if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset response = task_instance_api.get_task_instances(dag_id=dag_id, dag_run_id=dag_run_id, **kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def update_task_instance( dag_id: str, dag_run_id: str, task_id: str, state: Optional[str] = None ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: update_request = {} if state is not None: update_request["state"] = state response = task_instance_api.patch_task_instance( dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, update_mask=list(update_request.keys()), task_instance_request=update_request, ) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_log( dag_id: str, task_id: str, dag_run_id: str, task_try_number: int ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = task_instance_api.get_log( dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, task_try_number=task_try_number, ) return [types.TextContent(type="text", text=str(response.to_dict()))] async def list_task_instance_tries( dag_id: str, dag_run_id: str, task_id: str, limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if order_by is not None: kwargs["order_by"] = order_by response = task_instance_api.get_task_instance_tries( dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, **kwargs ) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /src/airflow/dataset.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.dataset_api import DatasetApi from src.airflow.airflow_client import api_client dataset_api = DatasetApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_datasets, "get_datasets", "List datasets", True), (get_dataset, "get_dataset", "Get a dataset by URI", True), (get_dataset_events, "get_dataset_events", "Get dataset events", True), (create_dataset_event, "create_dataset_event", "Create dataset event", False), (get_dag_dataset_queued_event, "get_dag_dataset_queued_event", "Get a queued Dataset event for a DAG", True), (get_dag_dataset_queued_events, "get_dag_dataset_queued_events", "Get queued Dataset events for a DAG", True), ( delete_dag_dataset_queued_event, "delete_dag_dataset_queued_event", "Delete a queued Dataset event for a DAG", False, ), ( delete_dag_dataset_queued_events, "delete_dag_dataset_queued_events", "Delete queued Dataset events for a DAG", False, ), (get_dataset_queued_events, "get_dataset_queued_events", "Get queued Dataset events for a Dataset", True), ( delete_dataset_queued_events, "delete_dataset_queued_events", "Delete queued Dataset events for a Dataset", False, ), ] async def get_datasets( limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, uri_pattern: Optional[str] = None, dag_ids: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if order_by is not None: kwargs["order_by"] = order_by if uri_pattern is not None: kwargs["uri_pattern"] = uri_pattern if dag_ids is not None: kwargs["dag_ids"] = dag_ids response = dataset_api.get_datasets(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_dataset( uri: str, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dataset_api.get_dataset(uri=uri) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_dataset_events( limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, dataset_id: Optional[int] = None, source_dag_id: Optional[str] = None, source_task_id: Optional[str] = None, source_run_id: Optional[str] = None, source_map_index: Optional[int] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if order_by is not None: kwargs["order_by"] = order_by if dataset_id is not None: kwargs["dataset_id"] = dataset_id if source_dag_id is not None: kwargs["source_dag_id"] = source_dag_id if source_task_id is not None: kwargs["source_task_id"] = source_task_id if source_run_id is not None: kwargs["source_run_id"] = source_run_id if source_map_index is not None: kwargs["source_map_index"] = source_map_index response = dataset_api.get_dataset_events(**kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def create_dataset_event( dataset_uri: str, extra: Optional[Dict[str, Any]] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: event_request = { "dataset_uri": dataset_uri, } if extra is not None: event_request["extra"] = extra response = dataset_api.create_dataset_event(create_dataset_event=event_request) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_dag_dataset_queued_event( dag_id: str, uri: str, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dataset_api.get_dag_dataset_queued_event(dag_id=dag_id, uri=uri) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_dag_dataset_queued_events( dag_id: str, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dataset_api.get_dag_dataset_queued_events(dag_id=dag_id) return [types.TextContent(type="text", text=str(response.to_dict()))] async def delete_dag_dataset_queued_event( dag_id: str, uri: str, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dataset_api.delete_dag_dataset_queued_event(dag_id=dag_id, uri=uri) return [types.TextContent(type="text", text=str(response.to_dict()))] async def delete_dag_dataset_queued_events( dag_id: str, before: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: kwargs: Dict[str, Any] = {} if before is not None: kwargs["before"] = before response = dataset_api.delete_dag_dataset_queued_events(dag_id=dag_id, **kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_dataset_queued_events( uri: str, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dataset_api.get_dataset_queued_events(uri=uri) return [types.TextContent(type="text", text=str(response.to_dict()))] async def delete_dataset_queued_events( uri: str, before: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: kwargs: Dict[str, Any] = {} if before is not None: kwargs["before"] = before response = dataset_api.delete_dataset_queued_events(uri=uri, **kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /test/test_airflow_client.py: -------------------------------------------------------------------------------- ```python """Tests for the airflow client authentication module.""" import os import sys from unittest.mock import patch from airflow_client.client import ApiClient class TestAirflowClientAuthentication: """Test cases for airflow client authentication configuration.""" def test_basic_auth_configuration(self): """Test that basic authentication is configured correctly.""" with patch.dict( os.environ, { "AIRFLOW_HOST": "http://localhost:8080", "AIRFLOW_USERNAME": "testuser", "AIRFLOW_PASSWORD": "testpass", "AIRFLOW_API_VERSION": "v1", }, clear=True, ): # Clear any cached modules modules_to_clear = ["src.envs", "src.airflow.airflow_client"] for module in modules_to_clear: if module in sys.modules: del sys.modules[module] # Re-import after setting environment from src.airflow.airflow_client import api_client, configuration # Verify configuration assert configuration.host == "http://localhost:8080/api/v1" assert configuration.username == "testuser" assert configuration.password == "testpass" assert isinstance(api_client, ApiClient) def test_jwt_token_auth_configuration(self): """Test that JWT token authentication is configured correctly.""" with patch.dict( os.environ, { "AIRFLOW_HOST": "http://localhost:8080", "AIRFLOW_JWT_TOKEN": "test.jwt.token", "AIRFLOW_API_VERSION": "v1", }, clear=True, ): # Clear any cached modules modules_to_clear = ["src.envs", "src.airflow.airflow_client"] for module in modules_to_clear: if module in sys.modules: del sys.modules[module] # Re-import after setting environment from src.airflow.airflow_client import api_client, configuration # Verify configuration assert configuration.host == "http://localhost:8080/api/v1" assert configuration.api_key == {"Authorization": "Bearer test.jwt.token"} assert configuration.api_key_prefix == {"Authorization": ""} assert isinstance(api_client, ApiClient) def test_jwt_token_takes_precedence_over_basic_auth(self): """Test that JWT token takes precedence when both auth methods are provided.""" with patch.dict( os.environ, { "AIRFLOW_HOST": "http://localhost:8080", "AIRFLOW_USERNAME": "testuser", "AIRFLOW_PASSWORD": "testpass", "AIRFLOW_JWT_TOKEN": "test.jwt.token", "AIRFLOW_API_VERSION": "v1", }, clear=True, ): # Clear any cached modules modules_to_clear = ["src.envs", "src.airflow.airflow_client"] for module in modules_to_clear: if module in sys.modules: del sys.modules[module] # Re-import after setting environment from src.airflow.airflow_client import api_client, configuration # Verify JWT token is used (not basic auth) assert configuration.host == "http://localhost:8080/api/v1" assert configuration.api_key == {"Authorization": "Bearer test.jwt.token"} assert configuration.api_key_prefix == {"Authorization": ""} # Basic auth should not be set when JWT is present assert not hasattr(configuration, "username") or configuration.username is None assert not hasattr(configuration, "password") or configuration.password is None assert isinstance(api_client, ApiClient) def test_no_auth_configuration(self): """Test that configuration works with no authentication (for testing/development).""" with patch.dict(os.environ, {"AIRFLOW_HOST": "http://localhost:8080", "AIRFLOW_API_VERSION": "v1"}, clear=True): # Clear any cached modules modules_to_clear = ["src.envs", "src.airflow.airflow_client"] for module in modules_to_clear: if module in sys.modules: del sys.modules[module] # Re-import after setting environment from src.airflow.airflow_client import api_client, configuration # Verify configuration assert configuration.host == "http://localhost:8080/api/v1" # No auth should be set assert not hasattr(configuration, "username") or configuration.username is None assert not hasattr(configuration, "password") or configuration.password is None # api_key might be an empty dict by default, but should not have Authorization assert "Authorization" not in getattr(configuration, "api_key", {}) assert isinstance(api_client, ApiClient) def test_environment_variable_parsing(self): """Test that environment variables are parsed correctly.""" with patch.dict( os.environ, { "AIRFLOW_HOST": "https://airflow.example.com:8080/custom", "AIRFLOW_JWT_TOKEN": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...", "AIRFLOW_API_VERSION": "v2", }, clear=True, ): # Clear any cached modules modules_to_clear = ["src.envs", "src.airflow.airflow_client"] for module in modules_to_clear: if module in sys.modules: del sys.modules[module] # Re-import after setting environment from src.airflow.airflow_client import configuration from src.envs import AIRFLOW_API_VERSION, AIRFLOW_HOST, AIRFLOW_JWT_TOKEN # Verify environment variables are parsed correctly assert AIRFLOW_HOST == "https://airflow.example.com:8080" assert AIRFLOW_JWT_TOKEN == "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9..." assert AIRFLOW_API_VERSION == "v2" # Verify configuration uses parsed values assert configuration.host == "https://airflow.example.com:8080/api/v2" assert configuration.api_key == {"Authorization": "Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9..."} assert configuration.api_key_prefix == {"Authorization": ""} ``` -------------------------------------------------------------------------------- /src/airflow/dagrun.py: -------------------------------------------------------------------------------- ```python from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.dag_run_api import DAGRunApi from airflow_client.client.model.clear_dag_run import ClearDagRun from airflow_client.client.model.dag_run import DAGRun from airflow_client.client.model.set_dag_run_note import SetDagRunNote from airflow_client.client.model.update_dag_run_state import UpdateDagRunState from src.airflow.airflow_client import api_client from src.envs import AIRFLOW_HOST dag_run_api = DAGRunApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (post_dag_run, "post_dag_run", "Trigger a DAG by ID", False), (get_dag_runs, "get_dag_runs", "Get DAG runs by ID", True), (get_dag_runs_batch, "get_dag_runs_batch", "List DAG runs (batch)", True), (get_dag_run, "get_dag_run", "Get a DAG run by DAG ID and DAG run ID", True), (update_dag_run_state, "update_dag_run_state", "Update a DAG run state by DAG ID and DAG run ID", False), (delete_dag_run, "delete_dag_run", "Delete a DAG run by DAG ID and DAG run ID", False), (clear_dag_run, "clear_dag_run", "Clear a DAG run", False), (set_dag_run_note, "set_dag_run_note", "Update the DagRun note", False), (get_upstream_dataset_events, "get_upstream_dataset_events", "Get dataset events for a DAG run", True), ] def get_dag_run_url(dag_id: str, dag_run_id: str) -> str: return f"{AIRFLOW_HOST}/dags/{dag_id}/grid?dag_run_id={dag_run_id}" async def post_dag_run( dag_id: str, dag_run_id: Optional[str] = None, data_interval_end: Optional[datetime] = None, data_interval_start: Optional[datetime] = None, execution_date: Optional[datetime] = None, logical_date: Optional[datetime] = None, note: Optional[str] = None, # state: Optional[str] = None, # TODO: add state ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build kwargs dictionary with only non-None values kwargs = {} # Add non-read-only fields that can be set during creation if dag_run_id is not None: kwargs["dag_run_id"] = dag_run_id if data_interval_end is not None: kwargs["data_interval_end"] = data_interval_end if data_interval_start is not None: kwargs["data_interval_start"] = data_interval_start if execution_date is not None: kwargs["execution_date"] = execution_date if logical_date is not None: kwargs["logical_date"] = logical_date if note is not None: kwargs["note"] = note # Create DAGRun without read-only fields dag_run = DAGRun(**kwargs) response = dag_run_api.post_dag_run(dag_id=dag_id, dag_run=dag_run) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_dag_runs( dag_id: str, limit: Optional[int] = None, offset: Optional[int] = None, execution_date_gte: Optional[str] = None, execution_date_lte: Optional[str] = None, start_date_gte: Optional[str] = None, start_date_lte: Optional[str] = None, end_date_gte: Optional[str] = None, end_date_lte: Optional[str] = None, updated_at_gte: Optional[str] = None, updated_at_lte: Optional[str] = None, state: Optional[List[str]] = None, order_by: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if execution_date_gte is not None: kwargs["execution_date_gte"] = execution_date_gte if execution_date_lte is not None: kwargs["execution_date_lte"] = execution_date_lte if start_date_gte is not None: kwargs["start_date_gte"] = start_date_gte if start_date_lte is not None: kwargs["start_date_lte"] = start_date_lte if end_date_gte is not None: kwargs["end_date_gte"] = end_date_gte if end_date_lte is not None: kwargs["end_date_lte"] = end_date_lte if updated_at_gte is not None: kwargs["updated_at_gte"] = updated_at_gte if updated_at_lte is not None: kwargs["updated_at_lte"] = updated_at_lte if state is not None: kwargs["state"] = state if order_by is not None: kwargs["order_by"] = order_by response = dag_run_api.get_dag_runs(dag_id=dag_id, **kwargs) # Convert response to dictionary for easier manipulation response_dict = response.to_dict() # Add UI links to each DAG run for dag_run in response_dict.get("dag_runs", []): dag_run["ui_url"] = get_dag_run_url(dag_id, dag_run["dag_run_id"]) return [types.TextContent(type="text", text=str(response_dict))] async def get_dag_runs_batch( dag_ids: Optional[List[str]] = None, execution_date_gte: Optional[str] = None, execution_date_lte: Optional[str] = None, start_date_gte: Optional[str] = None, start_date_lte: Optional[str] = None, end_date_gte: Optional[str] = None, end_date_lte: Optional[str] = None, state: Optional[List[str]] = None, order_by: Optional[str] = None, page_offset: Optional[int] = None, page_limit: Optional[int] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build request dictionary request: Dict[str, Any] = {} if dag_ids is not None: request["dag_ids"] = dag_ids if execution_date_gte is not None: request["execution_date_gte"] = execution_date_gte if execution_date_lte is not None: request["execution_date_lte"] = execution_date_lte if start_date_gte is not None: request["start_date_gte"] = start_date_gte if start_date_lte is not None: request["start_date_lte"] = start_date_lte if end_date_gte is not None: request["end_date_gte"] = end_date_gte if end_date_lte is not None: request["end_date_lte"] = end_date_lte if state is not None: request["state"] = state if order_by is not None: request["order_by"] = order_by if page_offset is not None: request["page_offset"] = page_offset if page_limit is not None: request["page_limit"] = page_limit response = dag_run_api.get_dag_runs_batch(list_dag_runs_form=request) # Convert response to dictionary for easier manipulation response_dict = response.to_dict() # Add UI links to each DAG run for dag_run in response_dict.get("dag_runs", []): dag_run["ui_url"] = get_dag_run_url(dag_run["dag_id"], dag_run["dag_run_id"]) return [types.TextContent(type="text", text=str(response_dict))] async def get_dag_run( dag_id: str, dag_run_id: str ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dag_run_api.get_dag_run(dag_id=dag_id, dag_run_id=dag_run_id) # Convert response to dictionary for easier manipulation response_dict = response.to_dict() # Add UI link to DAG run response_dict["ui_url"] = get_dag_run_url(dag_id, dag_run_id) return [types.TextContent(type="text", text=str(response_dict))] async def update_dag_run_state( dag_id: str, dag_run_id: str, state: Optional[str] = None ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: update_dag_run_state = UpdateDagRunState(state=state) response = dag_run_api.update_dag_run_state( dag_id=dag_id, dag_run_id=dag_run_id, update_dag_run_state=update_dag_run_state, ) return [types.TextContent(type="text", text=str(response.to_dict()))] async def delete_dag_run( dag_id: str, dag_run_id: str ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dag_run_api.delete_dag_run(dag_id=dag_id, dag_run_id=dag_run_id) return [types.TextContent(type="text", text=str(response.to_dict()))] async def clear_dag_run( dag_id: str, dag_run_id: str, dry_run: Optional[bool] = None ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: clear_dag_run = ClearDagRun(dry_run=dry_run) response = dag_run_api.clear_dag_run(dag_id=dag_id, dag_run_id=dag_run_id, clear_dag_run=clear_dag_run) return [types.TextContent(type="text", text=str(response.to_dict()))] async def set_dag_run_note( dag_id: str, dag_run_id: str, note: str ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: set_dag_run_note = SetDagRunNote(note=note) response = dag_run_api.set_dag_run_note(dag_id=dag_id, dag_run_id=dag_run_id, set_dag_run_note=set_dag_run_note) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_upstream_dataset_events( dag_id: str, dag_run_id: str ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dag_run_api.get_upstream_dataset_events(dag_id=dag_id, dag_run_id=dag_run_id) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /test/airflow/test_taskinstance.py: -------------------------------------------------------------------------------- ```python """Unit tests for taskinstance module using pytest framework.""" from unittest.mock import MagicMock, patch import mcp.types as types import pytest from src.airflow.taskinstance import ( get_task_instance, list_task_instance_tries, list_task_instances, update_task_instance, ) class TestTaskInstanceModule: """ Test suite for verifying the behavior of taskinstance module's functions. Covers: - get_task_instance - list_task_instances - update_task_instance - list_task_instance_tries Each test uses parameterization to exercise a range of valid inputs and asserts: - Correct structure and content of the returned TextContent - Proper use of optional parameters - That the underlying API client methods are invoked with the right arguments """ @pytest.mark.asyncio @pytest.mark.parametrize( "dag_id, task_id, dag_run_id, expected_state", [ ("dag_1", "task_a", "run_001", "success"), ("dag_2", "task_b", "run_002", "failed"), ("dag_3", "task_c", "run_003", "running"), ], ids=[ "success-task-dag_1", "failed-task-dag_2", "running-task-dag_3", ], ) async def test_get_task_instance(self, dag_id, task_id, dag_run_id, expected_state): """ Test `get_task_instance` returns correct TextContent output and calls API once for different task states. """ mock_response = MagicMock() mock_response.to_dict.return_value = { "dag_id": dag_id, "task_id": task_id, "dag_run_id": dag_run_id, "state": expected_state, } with patch( "src.airflow.taskinstance.task_instance_api.get_task_instance", return_value=mock_response, ) as mock_get: result = await get_task_instance(dag_id=dag_id, task_id=task_id, dag_run_id=dag_run_id) assert isinstance(result, list) assert len(result) == 1 content = result[0] assert isinstance(content, types.TextContent) assert content.type == "text" assert dag_id in content.text assert task_id in content.text assert dag_run_id in content.text assert expected_state in content.text mock_get.assert_called_once_with(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id) @pytest.mark.asyncio @pytest.mark.parametrize( "params", [ {"dag_id": "dag_basic", "dag_run_id": "run_basic"}, { "dag_id": "dag_with_state", "dag_run_id": "run_with_state", "state": ["success", "failed"], }, { "dag_id": "dag_with_dates", "dag_run_id": "run_with_dates", "start_date_gte": "2024-01-01T00:00:00Z", "end_date_lte": "2024-01-10T23:59:59Z", }, { "dag_id": "dag_with_filters", "dag_run_id": "run_filters", "pool": ["default_pool"], "queue": ["default"], "limit": 5, "offset": 10, "duration_gte": 5.0, "duration_lte": 100.5, }, { "dag_id": "dag_with_all", "dag_run_id": "run_all", "execution_date_gte": "2024-01-01T00:00:00Z", "execution_date_lte": "2024-01-02T00:00:00Z", "start_date_gte": "2024-01-01T01:00:00Z", "start_date_lte": "2024-01-01T23:00:00Z", "end_date_gte": "2024-01-01T02:00:00Z", "end_date_lte": "2024-01-01T23:59:00Z", "updated_at_gte": "2024-01-01T03:00:00Z", "updated_at_lte": "2024-01-01T04:00:00Z", "duration_gte": 1.0, "duration_lte": 99.9, "state": ["queued"], "pool": ["my_pool"], "queue": ["my_queue"], "limit": 50, "offset": 0, }, { "dag_id": "dag_with_empty_lists", "dag_run_id": "run_empty_lists", "state": [], "pool": [], "queue": [], }, ], ids=[ "basic-required-params", "with-state-filter", "with-date-range", "with-resource-filters", "all-filters-included", "empty-lists-filter", ], ) async def test_list_task_instances(self, params): """ Test `list_task_instances` with various combinations of filters. Validates output content and verifies API call arguments. """ mock_response = MagicMock() mock_response.to_dict.return_value = { "dag_id": params["dag_id"], "dag_run_id": params["dag_run_id"], "instances": [ {"task_id": "task_1", "state": "success"}, {"task_id": "task_2", "state": "running"}, ], } with patch( "src.airflow.taskinstance.task_instance_api.get_task_instances", return_value=mock_response, ) as mock_get: result = await list_task_instances(**params) assert isinstance(result, list) assert len(result) == 1 assert isinstance(result[0], types.TextContent) assert result[0].type == "text" assert params["dag_id"] in result[0].text assert params["dag_run_id"] in result[0].text mock_get.assert_called_once_with( dag_id=params["dag_id"], dag_run_id=params["dag_run_id"], **{k: v for k, v in params.items() if k not in {"dag_id", "dag_run_id"} and v is not None}, ) @pytest.mark.asyncio @pytest.mark.parametrize( "dag_id, dag_run_id, task_id, state", [ ("dag_1", "run_001", "task_a", "success"), ("dag_2", "run_002", "task_b", "failed"), ("dag_3", "run_003", "task_c", None), ], ids=["set-success-state", "set-failed-state", "no-state-update"], ) async def test_update_task_instance(self, dag_id, dag_run_id, task_id, state): """ Test `update_task_instance` for updating state and validating request payload. Also verifies that patch API is called with the correct update mask. """ mock_response = MagicMock() mock_response.to_dict.return_value = { "dag_id": dag_id, "dag_run_id": dag_run_id, "task_id": task_id, "state": state, } with patch( "src.airflow.taskinstance.task_instance_api.patch_task_instance", return_value=mock_response, ) as mock_patch: result = await update_task_instance(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, state=state) assert isinstance(result, list) assert len(result) == 1 content = result[0] assert isinstance(content, types.TextContent) assert content.type == "text" assert dag_id in content.text assert dag_run_id in content.text assert task_id in content.text if state is not None: assert state in content.text expected_mask = ["state"] if state is not None else [] expected_request = {"state": state} if state is not None else {} mock_patch.assert_called_once_with( dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, update_mask=expected_mask, task_instance_request=expected_request, ) @pytest.mark.asyncio @pytest.mark.parametrize( "dag_id, dag_run_id, task_id, limit, offset, order_by", [ ("dag_basic", "run_001", "task_a", None, None, None), ("dag_with_limit", "run_002", "task_b", 5, None, None), ("dag_with_offset", "run_003", "task_c", None, 10, None), ("dag_with_order_by", "run_004", "task_d", None, None, "-start_date"), ("dag_all_params", "run_005", "task_e", 10, 0, "end_date"), ("dag_zero_limit", "run_006", "task_f", 0, None, None), ("dag_zero_offset", "run_007", "task_g", None, 0, None), ("dag_empty_order", "run_008", "task_h", None, None, ""), ], ids=[ "basic-required-only", "with-limit", "with-offset", "with-order_by-desc", "with-all-filters", "limit-zero", "offset-zero", "order_by-empty-string", ], ) async def test_list_task_instance_tries(self, dag_id, dag_run_id, task_id, limit, offset, order_by): """ Test `list_task_instance_tries` across various filter combinations, validating correct API call and response formatting. """ mock_response = MagicMock() mock_response.to_dict.return_value = { "dag_id": dag_id, "dag_run_id": dag_run_id, "task_id": task_id, "tries": [ {"try_number": 1, "state": "queued"}, {"try_number": 2, "state": "success"}, ], } with patch( "src.airflow.taskinstance.task_instance_api.get_task_instance_tries", return_value=mock_response, ) as mock_get: result = await list_task_instance_tries( dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, limit=limit, offset=offset, order_by=order_by, ) assert isinstance(result, list) assert len(result) == 1 content = result[0] assert isinstance(content, types.TextContent) assert content.type == "text" assert dag_id in content.text assert dag_run_id in content.text assert task_id in content.text assert "tries" in content.text mock_get.assert_called_once_with( dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, **{ k: v for k, v in { "limit": limit, "offset": offset, "order_by": order_by, }.items() if v is not None }, ) ``` -------------------------------------------------------------------------------- /src/airflow/dag.py: -------------------------------------------------------------------------------- ```python from typing import Any, Callable, Dict, List, Optional, Union import mcp.types as types from airflow_client.client.api.dag_api import DAGApi from airflow_client.client.model.clear_task_instances import ClearTaskInstances from airflow_client.client.model.dag import DAG from airflow_client.client.model.update_task_instances_state import UpdateTaskInstancesState from src.airflow.airflow_client import api_client from src.envs import AIRFLOW_HOST dag_api = DAGApi(api_client) def get_all_functions() -> list[tuple[Callable, str, str, bool]]: """Return list of (function, name, description, is_read_only) tuples for registration.""" return [ (get_dags, "fetch_dags", "Fetch all DAGs", True), (get_dag, "get_dag", "Get a DAG by ID", True), (get_dag_details, "get_dag_details", "Get a simplified representation of DAG", True), (get_dag_source, "get_dag_source", "Get a source code", True), (pause_dag, "pause_dag", "Pause a DAG by ID", False), (unpause_dag, "unpause_dag", "Unpause a DAG by ID", False), (get_dag_tasks, "get_dag_tasks", "Get tasks for DAG", True), (get_task, "get_task", "Get a task by ID", True), (get_tasks, "get_tasks", "Get tasks for DAG", True), (patch_dag, "patch_dag", "Update a DAG", False), (patch_dags, "patch_dags", "Update multiple DAGs", False), (delete_dag, "delete_dag", "Delete a DAG", False), (clear_task_instances, "clear_task_instances", "Clear a set of task instances", False), (set_task_instances_state, "set_task_instances_state", "Set a state of task instances", False), (reparse_dag_file, "reparse_dag_file", "Request re-parsing of a DAG file", False), ] def get_dag_url(dag_id: str) -> str: return f"{AIRFLOW_HOST}/dags/{dag_id}/grid" async def get_dags( limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, tags: Optional[List[str]] = None, only_active: Optional[bool] = None, paused: Optional[bool] = None, dag_id_pattern: Optional[str] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if limit is not None: kwargs["limit"] = limit if offset is not None: kwargs["offset"] = offset if order_by is not None: kwargs["order_by"] = order_by if tags is not None: kwargs["tags"] = tags if only_active is not None: kwargs["only_active"] = only_active if paused is not None: kwargs["paused"] = paused if dag_id_pattern is not None: kwargs["dag_id_pattern"] = dag_id_pattern # Use the client to fetch DAGs response = dag_api.get_dags(**kwargs) # Convert response to dictionary for easier manipulation response_dict = response.to_dict() # Add UI links to each DAG for dag in response_dict.get("dags", []): dag["ui_url"] = get_dag_url(dag["dag_id"]) return [types.TextContent(type="text", text=str(response_dict))] async def get_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dag_api.get_dag(dag_id=dag_id) # Convert response to dictionary for easier manipulation response_dict = response.to_dict() # Add UI link to DAG response_dict["ui_url"] = get_dag_url(dag_id) return [types.TextContent(type="text", text=str(response_dict))] async def get_dag_details( dag_id: str, fields: Optional[List[str]] = None ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: # Build parameters dictionary kwargs: Dict[str, Any] = {} if fields is not None: kwargs["fields"] = fields response = dag_api.get_dag_details(dag_id=dag_id, **kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_dag_source(file_token: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dag_api.get_dag_source(file_token=file_token) return [types.TextContent(type="text", text=str(response.to_dict()))] async def pause_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: dag = DAG(is_paused=True) response = dag_api.patch_dag(dag_id=dag_id, dag=dag, update_mask=["is_paused"]) return [types.TextContent(type="text", text=str(response.to_dict()))] async def unpause_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: dag = DAG(is_paused=False) response = dag_api.patch_dag(dag_id=dag_id, dag=dag, update_mask=["is_paused"]) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_dag_tasks(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dag_api.get_tasks(dag_id=dag_id) return [types.TextContent(type="text", text=str(response.to_dict()))] async def patch_dag( dag_id: str, is_paused: Optional[bool] = None, tags: Optional[List[str]] = None ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: update_request = {} update_mask = [] if is_paused is not None: update_request["is_paused"] = is_paused update_mask.append("is_paused") if tags is not None: update_request["tags"] = tags update_mask.append("tags") dag = DAG(**update_request) response = dag_api.patch_dag(dag_id=dag_id, dag=dag, update_mask=update_mask) return [types.TextContent(type="text", text=str(response.to_dict()))] async def patch_dags( dag_id_pattern: Optional[str] = None, is_paused: Optional[bool] = None, tags: Optional[List[str]] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: update_request = {} update_mask = [] if is_paused is not None: update_request["is_paused"] = is_paused update_mask.append("is_paused") if tags is not None: update_request["tags"] = tags update_mask.append("tags") dag = DAG(**update_request) kwargs = {} if dag_id_pattern is not None: kwargs["dag_id_pattern"] = dag_id_pattern response = dag_api.patch_dags(dag_id_pattern=dag_id_pattern, dag=dag, update_mask=update_mask, **kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def delete_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dag_api.delete_dag(dag_id=dag_id) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_task( dag_id: str, task_id: str ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dag_api.get_task(dag_id=dag_id, task_id=task_id) return [types.TextContent(type="text", text=str(response.to_dict()))] async def get_tasks( dag_id: str, order_by: Optional[str] = None ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: kwargs = {} if order_by is not None: kwargs["order_by"] = order_by response = dag_api.get_tasks(dag_id=dag_id, **kwargs) return [types.TextContent(type="text", text=str(response.to_dict()))] async def clear_task_instances( dag_id: str, task_ids: Optional[List[str]] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, include_subdags: Optional[bool] = None, include_parentdag: Optional[bool] = None, include_upstream: Optional[bool] = None, include_downstream: Optional[bool] = None, include_future: Optional[bool] = None, include_past: Optional[bool] = None, dry_run: Optional[bool] = None, reset_dag_runs: Optional[bool] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: clear_request = {} if task_ids is not None: clear_request["task_ids"] = task_ids if start_date is not None: clear_request["start_date"] = start_date if end_date is not None: clear_request["end_date"] = end_date if include_subdags is not None: clear_request["include_subdags"] = include_subdags if include_parentdag is not None: clear_request["include_parentdag"] = include_parentdag if include_upstream is not None: clear_request["include_upstream"] = include_upstream if include_downstream is not None: clear_request["include_downstream"] = include_downstream if include_future is not None: clear_request["include_future"] = include_future if include_past is not None: clear_request["include_past"] = include_past if dry_run is not None: clear_request["dry_run"] = dry_run if reset_dag_runs is not None: clear_request["reset_dag_runs"] = reset_dag_runs clear_task_instances = ClearTaskInstances(**clear_request) response = dag_api.post_clear_task_instances(dag_id=dag_id, clear_task_instances=clear_task_instances) return [types.TextContent(type="text", text=str(response.to_dict()))] async def set_task_instances_state( dag_id: str, state: str, task_ids: Optional[List[str]] = None, execution_date: Optional[str] = None, include_upstream: Optional[bool] = None, include_downstream: Optional[bool] = None, include_future: Optional[bool] = None, include_past: Optional[bool] = None, dry_run: Optional[bool] = None, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: state_request = {"state": state} if task_ids is not None: state_request["task_ids"] = task_ids if execution_date is not None: state_request["execution_date"] = execution_date if include_upstream is not None: state_request["include_upstream"] = include_upstream if include_downstream is not None: state_request["include_downstream"] = include_downstream if include_future is not None: state_request["include_future"] = include_future if include_past is not None: state_request["include_past"] = include_past if dry_run is not None: state_request["dry_run"] = dry_run update_task_instances_state = UpdateTaskInstancesState(**state_request) response = dag_api.post_set_task_instances_state( dag_id=dag_id, update_task_instances_state=update_task_instances_state, ) return [types.TextContent(type="text", text=str(response.to_dict()))] async def reparse_dag_file( file_token: str, ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: response = dag_api.reparse_dag_file(file_token=file_token) return [types.TextContent(type="text", text=str(response.to_dict()))] ``` -------------------------------------------------------------------------------- /test/test_main.py: -------------------------------------------------------------------------------- ```python """Tests for the main module using pytest framework.""" from unittest.mock import patch import pytest from click.testing import CliRunner from src.enums import APIType from src.main import APITYPE_TO_FUNCTIONS, Tool, main class TestMain: """Test cases for the main module.""" @pytest.fixture def runner(self): """Set up CLI test runner.""" return CliRunner() def test_apitype_to_functions_mapping(self): """Test that all API types are mapped to functions.""" # Verify all APIType enum values have corresponding functions for api_type in APIType: assert api_type in APITYPE_TO_FUNCTIONS assert APITYPE_TO_FUNCTIONS[api_type] is not None def test_apitype_to_functions_completeness(self): """Test that the function mapping is complete and contains only valid APITypes.""" # Verify mapping keys match APIType enum values expected_keys = set(APIType) actual_keys = set(APITYPE_TO_FUNCTIONS.keys()) assert expected_keys == actual_keys @patch("src.server.app") def test_main_default_options(self, mock_app, runner): """Test main function with default options.""" # Mock get_function to return valid functions mock_functions = [(lambda: None, "test_function", "Test description")] with patch.dict(APITYPE_TO_FUNCTIONS, {api: lambda: mock_functions for api in APIType}): result = runner.invoke(main, []) assert result.exit_code == 0 # Verify app.add_tool was called for each API type expected_calls = len(APIType) # One call per API type assert mock_app.add_tool.call_count == expected_calls # Verify app.run was called with stdio transport mock_app.run.assert_called_once_with(transport="stdio") @patch("src.server.app") def test_main_sse_transport(self, mock_app, runner): """Test main function with SSE transport.""" mock_functions = [(lambda: None, "test_function", "Test description")] with patch.dict(APITYPE_TO_FUNCTIONS, {api: lambda: mock_functions for api in APIType}): result = runner.invoke(main, ["--transport", "sse"]) assert result.exit_code == 0 mock_app.run.assert_called_once_with(transport="sse", port=8000, host="0.0.0.0") @patch("src.server.app") def test_main_specific_apis(self, mock_app, runner): """Test main function with specific APIs selected.""" mock_functions = [(lambda: None, "test_function", "Test description")] selected_apis = ["config", "connection"] with patch.dict(APITYPE_TO_FUNCTIONS, {api: lambda: mock_functions for api in APIType}): result = runner.invoke(main, ["--apis", "config", "--apis", "connection"]) assert result.exit_code == 0 # Should only add tools for selected APIs assert mock_app.add_tool.call_count == len(selected_apis) @patch("src.server.app") def test_main_not_implemented_error_handling(self, mock_app, runner): """Test main function handles NotImplementedError gracefully.""" def raise_not_implemented(): raise NotImplementedError("Not implemented") # Mock one API to raise NotImplementedError modified_mapping = APITYPE_TO_FUNCTIONS.copy() modified_mapping[APIType.CONFIG] = raise_not_implemented mock_functions = [(lambda: None, "test_function", "Test description")] # Other APIs should still work for api in APIType: if api != APIType.CONFIG: modified_mapping[api] = lambda: mock_functions with patch.dict(APITYPE_TO_FUNCTIONS, modified_mapping, clear=True): result = runner.invoke(main, []) assert result.exit_code == 0 # Should add tools for all APIs except the one that raised NotImplementedError expected_calls = len(APIType) - 1 assert mock_app.add_tool.call_count == expected_calls def test_cli_transport_choices(self, runner): """Test CLI transport option only accepts valid choices.""" result = runner.invoke(main, ["--transport", "invalid"]) assert result.exit_code != 0 assert "Invalid value for '--transport'" in result.output def test_cli_apis_choices(self, runner): """Test CLI apis option only accepts valid choices.""" result = runner.invoke(main, ["--apis", "invalid"]) assert result.exit_code != 0 assert "Invalid value for '--apis'" in result.output @patch("src.server.app") def test_function_registration_flow(self, mock_app, runner): """Test the complete function registration flow.""" def mock_function(): # .add_tools in FastMCP does not allow adding functions with *args # it limits to use Mock and MagicMock pass mock_functions = [(mock_function, "test_name", "test_description")] with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): result = runner.invoke(main, ["--apis", "config"]) assert result.exit_code == 0 mock_app.add_tool.assert_called_once_with( Tool.from_function(mock_function, name="test_name", description="test_description") ) @patch("src.server.app") def test_multiple_functions_per_api(self, mock_app, runner): """Test handling multiple functions per API.""" mock_functions = [ (lambda: "func1", "func1_name", "func1_desc"), (lambda: "func2", "func2_name", "func2_desc"), (lambda: "func3", "func3_name", "func3_desc"), ] with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): result = runner.invoke(main, ["--apis", "config"]) assert result.exit_code == 0 # Should register all functions assert mock_app.add_tool.call_count == 3 def test_help_option(self, runner): """Test CLI help option.""" result = runner.invoke(main, ["--help"]) assert result.exit_code == 0 assert "Transport type" in result.output assert "APIs to run" in result.output @pytest.mark.parametrize("transport", ["stdio", "sse", "http"]) @patch("src.server.app") def test_main_transport_options(self, mock_app, transport, runner): """Test main function with different transport options.""" mock_functions = [(lambda: None, "test_function", "Test description")] with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): result = runner.invoke(main, ["--transport", transport, "--apis", "config"]) assert result.exit_code == 0 if transport == "stdio": mock_app.run.assert_called_once_with(transport=transport) else: mock_app.run.assert_called_once_with(transport=transport, port=8000, host="0.0.0.0") @pytest.mark.parametrize("transport", ["sse", "http"]) @pytest.mark.parametrize("port", [None, "12345"]) @pytest.mark.parametrize("host", [None, "127.0.0.1"]) @patch("src.server.app") def test_port_and_host_options(self, mock_app, transport, port, host, runner): """Test that port and host are set for SSE and HTTP transports""" mock_functions = [(lambda: None, "test_function", "Test description")] with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): ext_params = [] if port: ext_params += ["--mcp-port", port] if host: ext_params += ["--mcp-host", host] runner.invoke(main, ["--transport", transport, "--apis", "config"] + ext_params) expected_params = {} expected_params["port"] = int(port) if port else 8000 expected_params["host"] = host if host else "0.0.0.0" mock_app.run.assert_called_once_with(transport=transport, **expected_params) @pytest.mark.parametrize("api_name", [api.value for api in APIType]) @patch("src.server.app") def test_individual_api_selection(self, mock_app, api_name, runner): """Test selecting individual APIs.""" mock_functions = [(lambda: None, "test_function", "Test description")] with patch.dict(APITYPE_TO_FUNCTIONS, {APIType(api_name): lambda: mock_functions}, clear=True): result = runner.invoke(main, ["--apis", api_name]) assert result.exit_code == 0 assert mock_app.add_tool.call_count == 1 def test_filter_functions_for_read_only(self): """Test that filter_functions_for_read_only correctly filters functions.""" from src.main import filter_functions_for_read_only # Mock function objects def mock_read_func(): pass def mock_write_func(): pass # Test functions with mixed read/write status functions = [ (mock_read_func, "get_something", "Get something", True), (mock_write_func, "create_something", "Create something", False), (mock_read_func, "list_something", "List something", True), (mock_write_func, "delete_something", "Delete something", False), ] filtered = filter_functions_for_read_only(functions) # Should only have the read-only functions assert len(filtered) == 2 assert filtered[0][1] == "get_something" assert filtered[1][1] == "list_something" # Verify all returned functions are read-only for _, _, _, is_read_only in filtered: assert is_read_only is True def test_connection_functions_have_correct_read_only_status(self): """Test that connection functions are correctly marked as read-only or write.""" from src.airflow.connection import get_all_functions functions = get_all_functions() function_names = {name: is_read_only for _, name, _, is_read_only in functions} # Verify read-only functions assert function_names["list_connections"] is True assert function_names["get_connection"] is True assert function_names["test_connection"] is True # Verify write functions assert function_names["create_connection"] is False assert function_names["update_connection"] is False assert function_names["delete_connection"] is False def test_dag_functions_have_correct_read_only_status(self): """Test that DAG functions are correctly marked as read-only or write.""" from src.airflow.dag import get_all_functions functions = get_all_functions() function_names = {name: is_read_only for _, name, _, is_read_only in functions} # Verify read-only functions assert function_names["fetch_dags"] is True assert function_names["get_dag"] is True assert function_names["get_dag_details"] is True assert function_names["get_dag_source"] is True assert function_names["get_dag_tasks"] is True assert function_names["get_task"] is True assert function_names["get_tasks"] is True # Verify write functions assert function_names["pause_dag"] is False assert function_names["unpause_dag"] is False assert function_names["patch_dag"] is False assert function_names["patch_dags"] is False assert function_names["delete_dag"] is False assert function_names["clear_task_instances"] is False assert function_names["set_task_instances_state"] is False assert function_names["reparse_dag_file"] is False @patch("src.server.app") def test_main_read_only_mode(self, mock_app, runner): """Test main function with read-only flag.""" # Create mock functions with mixed read/write status mock_functions = [ (lambda: None, "read_function", "Read function", True), (lambda: None, "write_function", "Write function", False), (lambda: None, "another_read_function", "Another read function", True), ] with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): result = runner.invoke(main, ["--read-only", "--apis", "config"]) assert result.exit_code == 0 # Should only register read-only functions (2 out of 3) assert mock_app.add_tool.call_count == 2 # Verify the correct functions were registered call_args_list = mock_app.add_tool.call_args_list registered_names = [call.args[0].name for call in call_args_list] assert "read_function" in registered_names assert "another_read_function" in registered_names assert "write_function" not in registered_names @patch("src.server.app") def test_main_read_only_mode_with_no_read_functions(self, mock_app, runner): """Test main function with read-only flag when API has no read-only functions.""" # Create mock functions with only write operations mock_functions = [ (lambda: None, "write_function1", "Write function 1", False), (lambda: None, "write_function2", "Write function 2", False), ] with patch.dict(APITYPE_TO_FUNCTIONS, {APIType.CONFIG: lambda: mock_functions}, clear=True): result = runner.invoke(main, ["--read-only", "--apis", "config"]) assert result.exit_code == 0 # Should not register any functions assert mock_app.add_tool.call_count == 0 def test_cli_read_only_flag_in_help(self, runner): """Test that read-only flag appears in help.""" result = runner.invoke(main, ["--help"]) assert result.exit_code == 0 assert "--read-only" in result.output assert "Only expose read-only tools" in result.output ``` -------------------------------------------------------------------------------- /test/airflow/test_dag.py: -------------------------------------------------------------------------------- ```python """Table-driven tests for the dag module using pytest framework.""" from unittest.mock import ANY, MagicMock, patch import mcp.types as types import pytest from src.airflow.dag import ( clear_task_instances, delete_dag, get_dag, get_dag_details, get_dag_source, get_dag_tasks, get_dag_url, get_dags, get_task, get_tasks, patch_dag, pause_dag, reparse_dag_file, set_task_instances_state, unpause_dag, ) class TestDagModule: """Table-driven test cases for the dag module.""" @pytest.fixture def mock_dag_api(self): """Create a mock DAG API instance.""" with patch("src.airflow.dag.dag_api") as mock_api: yield mock_api def test_get_dag_url(self): """Test DAG URL generation.""" test_cases = [ # (dag_id, expected_url) ("test_dag", "http://localhost:8080/dags/test_dag/grid"), ("my-complex_dag.v2", "http://localhost:8080/dags/my-complex_dag.v2/grid"), ("", "http://localhost:8080/dags//grid"), ] for dag_id, expected_url in test_cases: with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): result = get_dag_url(dag_id) assert result == expected_url @pytest.mark.parametrize( "test_case", [ # Test case structure: (input_params, mock_response_dict, expected_result_partial) { "name": "get_dags_no_params", "input": {}, "mock_response": {"dags": [{"dag_id": "test_dag", "description": "Test"}], "total_entries": 1}, "expected_call_kwargs": {}, "expected_ui_urls": True, }, { "name": "get_dags_with_limit_offset", "input": {"limit": 10, "offset": 5}, "mock_response": {"dags": [{"dag_id": "dag1"}, {"dag_id": "dag2"}], "total_entries": 2}, "expected_call_kwargs": {"limit": 10, "offset": 5}, "expected_ui_urls": True, }, { "name": "get_dags_with_filters", "input": {"tags": ["prod", "daily"], "only_active": True, "paused": False, "dag_id_pattern": "prod_*"}, "mock_response": {"dags": [{"dag_id": "prod_dag1"}], "total_entries": 1}, "expected_call_kwargs": { "tags": ["prod", "daily"], "only_active": True, "paused": False, "dag_id_pattern": "prod_*", }, "expected_ui_urls": True, }, ], ) async def test_get_dags_table_driven(self, test_case, mock_dag_api): """Table-driven test for get_dags function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.get_dags.return_value = mock_response # Execute function with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): result = await get_dags(**test_case["input"]) # Verify API call mock_dag_api.get_dags.assert_called_once_with(**test_case["expected_call_kwargs"]) # Verify result structure assert len(result) == 1 assert isinstance(result[0], types.TextContent) # Parse result and verify UI URLs were added if expected if test_case["expected_ui_urls"]: result_text = result[0].text assert "ui_url" in result_text @pytest.mark.parametrize( "test_case", [ { "name": "get_dag_basic", "input": {"dag_id": "test_dag"}, "mock_response": {"dag_id": "test_dag", "description": "Test DAG", "is_paused": False}, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "get_dag_complex_id", "input": {"dag_id": "complex-dag_name.v2"}, "mock_response": {"dag_id": "complex-dag_name.v2", "description": "Complex DAG", "is_paused": True}, "expected_call_kwargs": {"dag_id": "complex-dag_name.v2"}, }, ], ) async def test_get_dag_table_driven(self, test_case, mock_dag_api): """Table-driven test for get_dag function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.get_dag.return_value = mock_response # Execute function with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): result = await get_dag(**test_case["input"]) # Verify API call mock_dag_api.get_dag.assert_called_once_with(**test_case["expected_call_kwargs"]) # Verify result structure and UI URL addition assert len(result) == 1 assert isinstance(result[0], types.TextContent) assert "ui_url" in result[0].text @pytest.mark.parametrize( "test_case", [ { "name": "get_dag_details_no_fields", "input": {"dag_id": "test_dag"}, "mock_response": {"dag_id": "test_dag", "file_path": "/path/to/dag.py"}, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "get_dag_details_with_fields", "input": {"dag_id": "test_dag", "fields": ["dag_id", "description"]}, "mock_response": {"dag_id": "test_dag", "description": "Test"}, "expected_call_kwargs": {"dag_id": "test_dag", "fields": ["dag_id", "description"]}, }, ], ) async def test_get_dag_details_table_driven(self, test_case, mock_dag_api): """Table-driven test for get_dag_details function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.get_dag_details.return_value = mock_response # Execute function result = await get_dag_details(**test_case["input"]) # Verify API call and result mock_dag_api.get_dag_details.assert_called_once_with(**test_case["expected_call_kwargs"]) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "pause_dag", "function": pause_dag, "input": {"dag_id": "test_dag"}, "mock_response": {"dag_id": "test_dag", "is_paused": True}, "expected_call_kwargs": {"dag_id": "test_dag", "dag": ANY, "update_mask": ["is_paused"]}, "expected_dag_is_paused": True, }, { "name": "unpause_dag", "function": unpause_dag, "input": {"dag_id": "test_dag"}, "mock_response": {"dag_id": "test_dag", "is_paused": False}, "expected_call_kwargs": {"dag_id": "test_dag", "dag": ANY, "update_mask": ["is_paused"]}, "expected_dag_is_paused": False, }, ], ) async def test_pause_unpause_dag_table_driven(self, test_case, mock_dag_api): """Table-driven test for pause_dag and unpause_dag functions.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.patch_dag.return_value = mock_response # Execute function result = await test_case["function"](**test_case["input"]) # Verify API call and result mock_dag_api.patch_dag.assert_called_once_with(**test_case["expected_call_kwargs"]) # Verify the DAG object has correct is_paused value actual_call_args = mock_dag_api.patch_dag.call_args actual_dag = actual_call_args.kwargs["dag"] assert actual_dag["is_paused"] == test_case["expected_dag_is_paused"] assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "get_tasks_no_order", "input": {"dag_id": "test_dag"}, "mock_response": { "tasks": [ {"task_id": "task1", "operator": "BashOperator"}, {"task_id": "task2", "operator": "PythonOperator"}, ] }, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "get_tasks_with_order", "input": {"dag_id": "test_dag", "order_by": "task_id"}, "mock_response": { "tasks": [ {"task_id": "task1", "operator": "BashOperator"}, {"task_id": "task2", "operator": "PythonOperator"}, ] }, "expected_call_kwargs": {"dag_id": "test_dag", "order_by": "task_id"}, }, ], ) async def test_get_tasks_table_driven(self, test_case, mock_dag_api): """Table-driven test for get_tasks function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.get_tasks.return_value = mock_response # Execute function result = await get_tasks(**test_case["input"]) # Verify API call and result mock_dag_api.get_tasks.assert_called_once_with(**test_case["expected_call_kwargs"]) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "patch_dag_pause_only", "input": {"dag_id": "test_dag", "is_paused": True}, "mock_response": {"dag_id": "test_dag", "is_paused": True}, "expected_update_mask": ["is_paused"], }, { "name": "patch_dag_tags_only", "input": {"dag_id": "test_dag", "tags": ["prod", "daily"]}, "mock_response": {"dag_id": "test_dag", "tags": ["prod", "daily"]}, "expected_update_mask": ["tags"], }, { "name": "patch_dag_both_fields", "input": {"dag_id": "test_dag", "is_paused": False, "tags": ["dev"]}, "mock_response": {"dag_id": "test_dag", "is_paused": False, "tags": ["dev"]}, "expected_update_mask": ["is_paused", "tags"], }, ], ) async def test_patch_dag_table_driven(self, test_case, mock_dag_api): """Table-driven test for patch_dag function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.patch_dag.return_value = mock_response # Execute function with patch("src.airflow.dag.DAG") as mock_dag_class: mock_dag_instance = MagicMock() mock_dag_class.return_value = mock_dag_instance result = await patch_dag(**test_case["input"]) # Verify DAG instance creation and API call expected_update_request = {k: v for k, v in test_case["input"].items() if k != "dag_id"} mock_dag_class.assert_called_once_with(**expected_update_request) mock_dag_api.patch_dag.assert_called_once_with( dag_id=test_case["input"]["dag_id"], dag=mock_dag_instance, update_mask=test_case["expected_update_mask"], ) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "clear_task_instances_minimal", "input": {"dag_id": "test_dag"}, "mock_response": {"message": "Task instances cleared"}, "expected_clear_request": {}, }, { "name": "clear_task_instances_full", "input": { "dag_id": "test_dag", "task_ids": ["task1", "task2"], "start_date": "2023-01-01", "end_date": "2023-01-31", "include_subdags": True, "include_upstream": True, "dry_run": True, }, "mock_response": {"message": "Dry run completed"}, "expected_clear_request": { "task_ids": ["task1", "task2"], "start_date": "2023-01-01", "end_date": "2023-01-31", "include_subdags": True, "include_upstream": True, "dry_run": True, }, }, ], ) async def test_clear_task_instances_table_driven(self, test_case, mock_dag_api): """Table-driven test for clear_task_instances function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.post_clear_task_instances.return_value = mock_response # Execute function with patch("src.airflow.dag.ClearTaskInstances") as mock_clear_class: mock_clear_instance = MagicMock() mock_clear_class.return_value = mock_clear_instance result = await clear_task_instances(**test_case["input"]) # Verify ClearTaskInstances creation and API call mock_clear_class.assert_called_once_with(**test_case["expected_clear_request"]) mock_dag_api.post_clear_task_instances.assert_called_once_with( dag_id=test_case["input"]["dag_id"], clear_task_instances=mock_clear_instance ) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "set_task_state_minimal", "input": {"dag_id": "test_dag", "state": "success"}, "mock_response": {"message": "Task state updated"}, "expected_state_request": {"state": "success"}, }, { "name": "set_task_state_full", "input": { "dag_id": "test_dag", "state": "failed", "task_ids": ["task1"], "execution_date": "2023-01-01T00:00:00Z", "include_upstream": True, "include_downstream": False, "dry_run": True, }, "mock_response": {"message": "Dry run state update"}, "expected_state_request": { "state": "failed", "task_ids": ["task1"], "execution_date": "2023-01-01T00:00:00Z", "include_upstream": True, "include_downstream": False, "dry_run": True, }, }, ], ) async def test_set_task_instances_state_table_driven(self, test_case, mock_dag_api): """Table-driven test for set_task_instances_state function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.post_set_task_instances_state.return_value = mock_response # Execute function with patch("src.airflow.dag.UpdateTaskInstancesState") as mock_state_class: mock_state_instance = MagicMock() mock_state_class.return_value = mock_state_instance result = await set_task_instances_state(**test_case["input"]) # Verify UpdateTaskInstancesState creation and API call mock_state_class.assert_called_once_with(**test_case["expected_state_request"]) mock_dag_api.post_set_task_instances_state.assert_called_once_with( dag_id=test_case["input"]["dag_id"], update_task_instances_state=mock_state_instance ) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "simple_functions_get_dag_source", "function": get_dag_source, "api_method": "get_dag_source", "input": {"file_token": "test_token"}, "mock_response": {"content": "DAG source code"}, "expected_call_kwargs": {"file_token": "test_token"}, }, { "name": "simple_functions_get_dag_tasks", "function": get_dag_tasks, "api_method": "get_tasks", "input": {"dag_id": "test_dag"}, "mock_response": {"tasks": []}, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "simple_functions_get_task", "function": get_task, "api_method": "get_task", "input": {"dag_id": "test_dag", "task_id": "test_task"}, "mock_response": {"task_id": "test_task", "operator": "BashOperator"}, "expected_call_kwargs": {"dag_id": "test_dag", "task_id": "test_task"}, }, { "name": "simple_functions_delete_dag", "function": delete_dag, "api_method": "delete_dag", "input": {"dag_id": "test_dag"}, "mock_response": {"message": "DAG deleted"}, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "simple_functions_reparse_dag_file", "function": reparse_dag_file, "api_method": "reparse_dag_file", "input": {"file_token": "test_token"}, "mock_response": {"message": "DAG file reparsed"}, "expected_call_kwargs": {"file_token": "test_token"}, }, ], ) async def test_simple_functions_table_driven(self, test_case, mock_dag_api): """Table-driven test for simple functions that directly call API methods.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] getattr(mock_dag_api, test_case["api_method"]).return_value = mock_response # Execute function result = await test_case["function"](**test_case["input"]) # Verify API call and result getattr(mock_dag_api, test_case["api_method"]).assert_called_once_with(**test_case["expected_call_kwargs"]) assert len(result) == 1 assert isinstance(result[0], types.TextContent) assert str(test_case["mock_response"]) in result[0].text @pytest.mark.integration async def test_dag_functions_integration_flow(self, mock_dag_api): """Integration test showing typical DAG management workflow.""" # Test data for a complete workflow dag_id = "integration_test_dag" # Mock responses for each step mock_responses = { "get_dag": {"dag_id": dag_id, "is_paused": True}, "patch_dag": {"dag_id": dag_id, "is_paused": False}, "get_tasks": {"tasks": [{"task_id": "task1"}, {"task_id": "task2"}]}, "delete_dag": {"message": "DAG deleted successfully"}, } # Setup mock responses for method, response in mock_responses.items(): mock_response = MagicMock() mock_response.to_dict.return_value = response getattr(mock_dag_api, method).return_value = mock_response # Execute workflow steps with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): # 1. Get DAG info dag_info = await get_dag(dag_id) assert len(dag_info) == 1 # 2. Unpause DAG with patch("src.airflow.dag.DAG") as mock_dag_class: mock_dag_class.return_value = MagicMock() unpause_result = await patch_dag(dag_id, is_paused=False) assert len(unpause_result) == 1 # 3. Get tasks tasks_result = await get_tasks(dag_id) assert len(tasks_result) == 1 # 4. Delete DAG delete_result = await delete_dag(dag_id) assert len(delete_result) == 1 # Verify all API calls were made mock_dag_api.get_dag.assert_called_once_with(dag_id=dag_id) mock_dag_api.patch_dag.assert_called_once() mock_dag_api.get_tasks.assert_called_once_with(dag_id=dag_id) mock_dag_api.delete_dag.assert_called_once_with(dag_id=dag_id) ```