This is page 4 of 4. Use http://codebase.md/dbt-labs/dbt-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .changes
│ ├── header.tpl.md
│ ├── unreleased
│ │ ├── .gitkeep
│ │ └── Under the Hood-20251104-091321.yaml
│ ├── v0.1.3.md
│ ├── v0.10.0.md
│ ├── v0.10.1.md
│ ├── v0.10.2.md
│ ├── v0.10.3.md
│ ├── v0.2.0.md
│ ├── v0.2.1.md
│ ├── v0.2.10.md
│ ├── v0.2.11.md
│ ├── v0.2.12.md
│ ├── v0.2.13.md
│ ├── v0.2.14.md
│ ├── v0.2.15.md
│ ├── v0.2.16.md
│ ├── v0.2.17.md
│ ├── v0.2.18.md
│ ├── v0.2.19.md
│ ├── v0.2.2.md
│ ├── v0.2.20.md
│ ├── v0.2.3.md
│ ├── v0.2.4.md
│ ├── v0.2.5.md
│ ├── v0.2.6.md
│ ├── v0.2.7.md
│ ├── v0.2.8.md
│ ├── v0.2.9.md
│ ├── v0.3.0.md
│ ├── v0.4.0.md
│ ├── v0.4.1.md
│ ├── v0.4.2.md
│ ├── v0.5.0.md
│ ├── v0.6.0.md
│ ├── v0.6.1.md
│ ├── v0.6.2.md
│ ├── v0.7.0.md
│ ├── v0.8.0.md
│ ├── v0.8.1.md
│ ├── v0.8.2.md
│ ├── v0.8.3.md
│ ├── v0.8.4.md
│ ├── v0.9.0.md
│ ├── v0.9.1.md
│ ├── v1.0.0.md
│ └── v1.1.0.md
├── .changie.yaml
├── .env.example
├── .github
│ ├── actions
│ │ └── setup-python
│ │ └── action.yml
│ ├── CODEOWNERS
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.yml
│ │ └── feature_request.yml
│ ├── pull_request_template.md
│ └── workflows
│ ├── changelog-check.yml
│ ├── codeowners-check.yml
│ ├── create-release-pr.yml
│ ├── release.yml
│ └── run-checks-pr.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .task
│ └── checksum
│ └── d2
├── .tool-versions
├── .vscode
│ ├── launch.json
│ └── settings.json
├── CHANGELOG.md
├── CONTRIBUTING.md
├── docs
│ ├── d2.png
│ └── diagram.d2
├── evals
│ └── semantic_layer
│ └── test_eval_semantic_layer.py
├── examples
│ ├── .DS_Store
│ ├── aws_strands_agent
│ │ ├── __init__.py
│ │ ├── .DS_Store
│ │ ├── dbt_data_scientist
│ │ │ ├── __init__.py
│ │ │ ├── .env.example
│ │ │ ├── agent.py
│ │ │ ├── prompts.py
│ │ │ ├── quick_mcp_test.py
│ │ │ ├── test_all_tools.py
│ │ │ └── tools
│ │ │ ├── __init__.py
│ │ │ ├── dbt_compile.py
│ │ │ ├── dbt_mcp.py
│ │ │ └── dbt_model_analyzer.py
│ │ ├── LICENSE
│ │ ├── README.md
│ │ └── requirements.txt
│ ├── google_adk_agent
│ │ ├── __init__.py
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ └── README.md
│ ├── langgraph_agent
│ │ ├── __init__.py
│ │ ├── .python-version
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── openai_agent
│ │ ├── __init__.py
│ │ ├── .gitignore
│ │ ├── .python-version
│ │ ├── main_streamable.py
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── openai_responses
│ │ ├── __init__.py
│ │ ├── .gitignore
│ │ ├── .python-version
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── pydantic_ai_agent
│ │ ├── __init__.py
│ │ ├── .gitignore
│ │ ├── .python-version
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ └── README.md
│ └── remote_mcp
│ ├── .python-version
│ ├── main.py
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── LICENSE
├── pyproject.toml
├── README.md
├── src
│ ├── client
│ │ ├── __init__.py
│ │ ├── main.py
│ │ └── tools.py
│ ├── dbt_mcp
│ │ ├── __init__.py
│ │ ├── .gitignore
│ │ ├── config
│ │ │ ├── config_providers.py
│ │ │ ├── config.py
│ │ │ ├── dbt_project.py
│ │ │ ├── dbt_yaml.py
│ │ │ ├── headers.py
│ │ │ ├── settings.py
│ │ │ └── transport.py
│ │ ├── dbt_admin
│ │ │ ├── __init__.py
│ │ │ ├── client.py
│ │ │ ├── constants.py
│ │ │ ├── run_results_errors
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ └── parser.py
│ │ │ └── tools.py
│ │ ├── dbt_cli
│ │ │ ├── binary_type.py
│ │ │ └── tools.py
│ │ ├── dbt_codegen
│ │ │ ├── __init__.py
│ │ │ └── tools.py
│ │ ├── discovery
│ │ │ ├── client.py
│ │ │ └── tools.py
│ │ ├── errors
│ │ │ ├── __init__.py
│ │ │ ├── admin_api.py
│ │ │ ├── base.py
│ │ │ ├── cli.py
│ │ │ ├── common.py
│ │ │ ├── discovery.py
│ │ │ ├── semantic_layer.py
│ │ │ └── sql.py
│ │ ├── gql
│ │ │ └── errors.py
│ │ ├── lsp
│ │ │ ├── __init__.py
│ │ │ ├── lsp_binary_manager.py
│ │ │ ├── lsp_client.py
│ │ │ ├── lsp_connection.py
│ │ │ ├── providers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── local_lsp_client_provider.py
│ │ │ │ ├── local_lsp_connection_provider.py
│ │ │ │ ├── lsp_client_provider.py
│ │ │ │ └── lsp_connection_provider.py
│ │ │ └── tools.py
│ │ ├── main.py
│ │ ├── mcp
│ │ │ ├── create.py
│ │ │ └── server.py
│ │ ├── oauth
│ │ │ ├── client_id.py
│ │ │ ├── context_manager.py
│ │ │ ├── dbt_platform.py
│ │ │ ├── fastapi_app.py
│ │ │ ├── logging.py
│ │ │ ├── login.py
│ │ │ ├── refresh_strategy.py
│ │ │ ├── token_provider.py
│ │ │ └── token.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── admin_api
│ │ │ │ ├── cancel_job_run.md
│ │ │ │ ├── get_job_details.md
│ │ │ │ ├── get_job_run_artifact.md
│ │ │ │ ├── get_job_run_details.md
│ │ │ │ ├── get_job_run_error.md
│ │ │ │ ├── list_job_run_artifacts.md
│ │ │ │ ├── list_jobs_runs.md
│ │ │ │ ├── list_jobs.md
│ │ │ │ ├── retry_job_run.md
│ │ │ │ └── trigger_job_run.md
│ │ │ ├── dbt_cli
│ │ │ │ ├── args
│ │ │ │ │ ├── full_refresh.md
│ │ │ │ │ ├── limit.md
│ │ │ │ │ ├── resource_type.md
│ │ │ │ │ ├── selectors.md
│ │ │ │ │ ├── sql_query.md
│ │ │ │ │ └── vars.md
│ │ │ │ ├── build.md
│ │ │ │ ├── compile.md
│ │ │ │ ├── docs.md
│ │ │ │ ├── list.md
│ │ │ │ ├── parse.md
│ │ │ │ ├── run.md
│ │ │ │ ├── show.md
│ │ │ │ └── test.md
│ │ │ ├── dbt_codegen
│ │ │ │ ├── args
│ │ │ │ │ ├── case_sensitive_cols.md
│ │ │ │ │ ├── database_name.md
│ │ │ │ │ ├── generate_columns.md
│ │ │ │ │ ├── include_data_types.md
│ │ │ │ │ ├── include_descriptions.md
│ │ │ │ │ ├── leading_commas.md
│ │ │ │ │ ├── materialized.md
│ │ │ │ │ ├── model_name.md
│ │ │ │ │ ├── model_names.md
│ │ │ │ │ ├── schema_name.md
│ │ │ │ │ ├── source_name.md
│ │ │ │ │ ├── table_name.md
│ │ │ │ │ ├── table_names.md
│ │ │ │ │ ├── tables.md
│ │ │ │ │ └── upstream_descriptions.md
│ │ │ │ ├── generate_model_yaml.md
│ │ │ │ ├── generate_source.md
│ │ │ │ └── generate_staging_model.md
│ │ │ ├── discovery
│ │ │ │ ├── get_all_models.md
│ │ │ │ ├── get_all_sources.md
│ │ │ │ ├── get_exposure_details.md
│ │ │ │ ├── get_exposures.md
│ │ │ │ ├── get_mart_models.md
│ │ │ │ ├── get_model_children.md
│ │ │ │ ├── get_model_details.md
│ │ │ │ ├── get_model_health.md
│ │ │ │ └── get_model_parents.md
│ │ │ ├── lsp
│ │ │ │ ├── args
│ │ │ │ │ ├── column_name.md
│ │ │ │ │ └── model_id.md
│ │ │ │ └── get_column_lineage.md
│ │ │ ├── prompts.py
│ │ │ └── semantic_layer
│ │ │ ├── get_dimensions.md
│ │ │ ├── get_entities.md
│ │ │ ├── get_metrics_compiled_sql.md
│ │ │ ├── list_metrics.md
│ │ │ ├── list_saved_queries.md
│ │ │ └── query_metrics.md
│ │ ├── py.typed
│ │ ├── semantic_layer
│ │ │ ├── client.py
│ │ │ ├── gql
│ │ │ │ ├── gql_request.py
│ │ │ │ └── gql.py
│ │ │ ├── levenshtein.py
│ │ │ ├── tools.py
│ │ │ └── types.py
│ │ ├── sql
│ │ │ └── tools.py
│ │ ├── telemetry
│ │ │ └── logging.py
│ │ ├── tools
│ │ │ ├── annotations.py
│ │ │ ├── definitions.py
│ │ │ ├── policy.py
│ │ │ ├── register.py
│ │ │ ├── tool_names.py
│ │ │ └── toolsets.py
│ │ └── tracking
│ │ └── tracking.py
│ └── remote_mcp
│ ├── __init__.py
│ └── session.py
├── Taskfile.yml
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── env_vars.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── dbt_codegen
│ │ │ ├── __init__.py
│ │ │ └── test_dbt_codegen.py
│ │ ├── discovery
│ │ │ └── test_discovery.py
│ │ ├── initialization
│ │ │ ├── __init__.py
│ │ │ └── test_initialization.py
│ │ ├── lsp
│ │ │ └── test_lsp_connection.py
│ │ ├── remote_mcp
│ │ │ └── test_remote_mcp.py
│ │ ├── remote_tools
│ │ │ └── test_remote_tools.py
│ │ ├── semantic_layer
│ │ │ └── test_semantic_layer.py
│ │ └── tracking
│ │ └── test_tracking.py
│ ├── mocks
│ │ └── config.py
│ └── unit
│ ├── __init__.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── test_config.py
│ │ └── test_transport.py
│ ├── dbt_admin
│ │ ├── __init__.py
│ │ ├── test_client.py
│ │ ├── test_error_fetcher.py
│ │ └── test_tools.py
│ ├── dbt_cli
│ │ ├── __init__.py
│ │ ├── test_cli_integration.py
│ │ └── test_tools.py
│ ├── dbt_codegen
│ │ ├── __init__.py
│ │ └── test_tools.py
│ ├── discovery
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_exposures_fetcher.py
│ │ └── test_sources_fetcher.py
│ ├── lsp
│ │ ├── __init__.py
│ │ ├── test_local_lsp_client_provider.py
│ │ ├── test_local_lsp_connection_provider.py
│ │ ├── test_lsp_client.py
│ │ ├── test_lsp_connection.py
│ │ └── test_lsp_tools.py
│ ├── oauth
│ │ ├── test_credentials_provider.py
│ │ ├── test_fastapi_app_pagination.py
│ │ └── test_token.py
│ ├── semantic_layer
│ │ ├── __init__.py
│ │ └── test_saved_queries.py
│ ├── tools
│ │ ├── test_disable_tools.py
│ │ ├── test_tool_names.py
│ │ ├── test_tool_policies.py
│ │ └── test_toolsets.py
│ └── tracking
│ └── test_tracking.py
├── ui
│ ├── .gitignore
│ ├── assets
│ │ ├── dbt_logo BLK.svg
│ │ └── dbt_logo WHT.svg
│ ├── eslint.config.js
│ ├── index.html
│ ├── package.json
│ ├── pnpm-lock.yaml
│ ├── pnpm-workspace.yaml
│ ├── README.md
│ ├── src
│ │ ├── App.css
│ │ ├── App.tsx
│ │ ├── global.d.ts
│ │ ├── index.css
│ │ ├── main.tsx
│ │ └── vite-env.d.ts
│ ├── tsconfig.app.json
│ ├── tsconfig.json
│ ├── tsconfig.node.json
│ └── vite.config.ts
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/src/dbt_mcp/config/settings.py:
--------------------------------------------------------------------------------
```python
import logging
import socket
import time
import shutil
from enum import Enum
from pathlib import Path
from typing import Annotated
from filelock import FileLock
from pydantic import Field, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
from dbt_mcp.config.dbt_project import DbtProjectYaml
from dbt_mcp.config.dbt_yaml import try_read_yaml
from dbt_mcp.config.headers import (
TokenProvider,
)
from dbt_mcp.oauth.context_manager import DbtPlatformContextManager
from dbt_mcp.oauth.dbt_platform import DbtPlatformContext
from dbt_mcp.oauth.login import login
from dbt_mcp.oauth.token_provider import (
OAuthTokenProvider,
StaticTokenProvider,
)
from dbt_mcp.tools.tool_names import ToolName
logger = logging.getLogger(__name__)
OAUTH_REDIRECT_STARTING_PORT = 6785
DEFAULT_DBT_CLI_TIMEOUT = 60
class AuthenticationMethod(Enum):
OAUTH = "oauth"
ENV_VAR = "env_var"
class DbtMcpSettings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="",
case_sensitive=False,
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# dbt Platform settings
dbt_host: str | None = Field(None, alias="DBT_HOST")
dbt_mcp_host: str | None = Field(None, alias="DBT_MCP_HOST")
dbt_prod_env_id: int | None = Field(None, alias="DBT_PROD_ENV_ID")
dbt_env_id: int | None = Field(None, alias="DBT_ENV_ID") # legacy support
dbt_dev_env_id: int | None = Field(None, alias="DBT_DEV_ENV_ID")
dbt_user_id: int | None = Field(None, alias="DBT_USER_ID")
dbt_account_id: int | None = Field(None, alias="DBT_ACCOUNT_ID")
dbt_token: str | None = Field(None, alias="DBT_TOKEN")
multicell_account_prefix: str | None = Field(None, alias="MULTICELL_ACCOUNT_PREFIX")
host_prefix: str | None = Field(None, alias="DBT_HOST_PREFIX")
dbt_lsp_path: str | None = Field(None, alias="DBT_LSP_PATH")
# dbt CLI settings
dbt_project_dir: str | None = Field(None, alias="DBT_PROJECT_DIR")
dbt_path: str = Field("dbt", alias="DBT_PATH")
dbt_cli_timeout: int = Field(DEFAULT_DBT_CLI_TIMEOUT, alias="DBT_CLI_TIMEOUT")
dbt_warn_error_options: str | None = Field(None, alias="DBT_WARN_ERROR_OPTIONS")
dbt_profiles_dir: str | None = Field(None, alias="DBT_PROFILES_DIR")
# Disable tool settings
disable_dbt_cli: bool = Field(False, alias="DISABLE_DBT_CLI")
disable_dbt_codegen: bool = Field(True, alias="DISABLE_DBT_CODEGEN")
disable_semantic_layer: bool = Field(False, alias="DISABLE_SEMANTIC_LAYER")
disable_discovery: bool = Field(False, alias="DISABLE_DISCOVERY")
disable_remote: bool | None = Field(None, alias="DISABLE_REMOTE")
disable_admin_api: bool = Field(False, alias="DISABLE_ADMIN_API")
disable_sql: bool | None = Field(None, alias="DISABLE_SQL")
disable_tools: Annotated[list[ToolName] | None, NoDecode] = Field(
None, alias="DISABLE_TOOLS"
)
disable_lsp: bool | None = Field(None, alias="DISABLE_LSP")
# Tracking settings
do_not_track: str | None = Field(None, alias="DO_NOT_TRACK")
send_anonymous_usage_data: str | None = Field(
None, alias="DBT_SEND_ANONYMOUS_USAGE_STATS"
)
# Developer settings
file_logging: bool = Field(False, alias="DBT_MCP_SERVER_FILE_LOGGING")
def __repr__(self):
"""Custom repr to bring most important settings to front. Redact sensitive info."""
return (
# auto-disable settings
f"DbtMcpSettings(dbt_host={self.dbt_host}, "
f"dbt_path={self.dbt_path}, "
f"dbt_project_dir={self.dbt_project_dir}, "
# disable settings
f"disable_dbt_cli={self.disable_dbt_cli}, "
f"disable_dbt_codegen={self.disable_dbt_codegen}, "
f"disable_semantic_layer={self.disable_semantic_layer}, "
f"disable_discovery={self.disable_discovery}, "
f"disable_admin_api={self.disable_admin_api}, "
f"disable_sql={self.disable_sql}, "
f"disable_tools={self.disable_tools}, "
f"disable_lsp={self.disable_lsp}, "
# everything else
f"dbt_prod_env_id={self.dbt_prod_env_id}, "
f"dbt_dev_env_id={self.dbt_dev_env_id}, "
f"dbt_user_id={self.dbt_user_id}, "
f"dbt_account_id={self.dbt_account_id}, "
f"dbt_token={'***redacted***' if self.dbt_token else None}, "
f"send_anonymous_usage_data={self.send_anonymous_usage_data}, "
f"file_logging={self.file_logging})"
)
@property
def actual_host(self) -> str | None:
host = self.dbt_host or self.dbt_mcp_host
if host is None:
return None
return host.rstrip("/").removeprefix("https://").removeprefix("http://")
@property
def actual_prod_environment_id(self) -> int | None:
return self.dbt_prod_env_id or self.dbt_env_id
@property
def actual_disable_sql(self) -> bool:
if self.disable_sql is not None:
return self.disable_sql
if self.disable_remote is not None:
return self.disable_remote
return True
@property
def actual_host_prefix(self) -> str | None:
if self.host_prefix is not None:
return self.host_prefix
if self.multicell_account_prefix is not None:
return self.multicell_account_prefix
return None
@property
def dbt_project_yml(self) -> DbtProjectYaml | None:
if not self.dbt_project_dir:
return None
dbt_project_yml = try_read_yaml(Path(self.dbt_project_dir) / "dbt_project.yml")
if dbt_project_yml is None:
return None
return DbtProjectYaml.model_validate(dbt_project_yml)
@property
def usage_tracking_enabled(self) -> bool:
# dbt environment variables take precedence over dbt_project.yml
if (
self.send_anonymous_usage_data is not None
and (
self.send_anonymous_usage_data.lower() == "false"
or self.send_anonymous_usage_data == "0"
)
) or (
self.do_not_track is not None
and (self.do_not_track.lower() == "true" or self.do_not_track == "1")
):
return False
dbt_project_yml = self.dbt_project_yml
if (
dbt_project_yml
and dbt_project_yml.flags
and dbt_project_yml.flags.send_anonymous_usage_stats is not None
):
return dbt_project_yml.flags.send_anonymous_usage_stats
return True
@field_validator("dbt_host", "dbt_mcp_host", mode="after")
@classmethod
def validate_host(cls, v: str | None, info: ValidationInfo) -> str | None:
"""Intentionally error on misconfigured host-like env vars (DBT_HOST and DBT_MCP_HOST)."""
host = (
v.rstrip("/").removeprefix("https://").removeprefix("http://") if v else v
)
if host and (host.startswith("metadata") or host.startswith("semantic-layer")):
field_name = (
getattr(info, "field_name", "None") if info is not None else "None"
).upper()
raise ValueError(
f"{field_name} must not start with 'metadata' or 'semantic-layer': {v}"
)
return v
@field_validator("dbt_path", mode="after")
@classmethod
def validate_file_exists(cls, v: str | None, info: ValidationInfo) -> str | None:
"""Validate a path exists in the system.
This will only fail if the path is explicitly set to a non-existing path.
It will auto-disable upon model validation if it can't be found AND it's not $PATH.
"""
# Allow 'dbt' and 'dbtf' as special cases as they're expected to be on PATH
if v in ["dbt", "dbtf"]:
return v
if v:
p = Path(v)
if p.exists():
return v
field_name = (
getattr(info, "field_name", "None") if info is not None else "None"
).upper()
raise ValueError(f"{field_name} path does not exist: {v}")
return v
@field_validator("dbt_project_dir", "dbt_profiles_dir", mode="after")
@classmethod
def validate_dir_exists(cls, v: str | None, info: ValidationInfo) -> str | None:
"""Validate a directory path exists in the system."""
if v:
path = Path(v)
if not path.is_dir():
field_name = (
getattr(info, "field_name", "None") if info is not None else "None"
).upper()
raise ValueError(f"{field_name} directory does not exist: {v}")
return v
@field_validator("disable_tools", mode="before")
@classmethod
def parse_disable_tools(cls, env_var: str | None) -> list[ToolName]:
if not env_var:
return []
errors: list[str] = []
tool_names: list[ToolName] = []
for tool_name in env_var.split(","):
tool_name_stripped = tool_name.strip()
if tool_name_stripped == "":
continue
try:
tool_names.append(ToolName(tool_name_stripped))
except ValueError:
errors.append(
f"Invalid tool name in DISABLE_TOOLS: {tool_name_stripped}."
+ " Must be a valid tool name."
)
if errors:
raise ValueError("\n".join(errors))
return tool_names
@model_validator(mode="after")
def auto_disable(self) -> "DbtMcpSettings":
"""Auto-disable features based on required settings."""
# platform features
if (
not self.actual_host
): # host is the only truly required setting for platform features
# object.__setattr__ is used in case we want to set values on a frozen model
object.__setattr__(self, "disable_semantic_layer", True)
object.__setattr__(self, "disable_discovery", True)
object.__setattr__(self, "disable_admin_api", True)
object.__setattr__(self, "disable_sql", True)
logger.warning(
"Platform features have been automatically disabled due to missing DBT_HOST."
)
# CLI features
cli_errors = validate_dbt_cli_settings(self)
if cli_errors:
object.__setattr__(self, "disable_dbt_cli", True)
object.__setattr__(self, "disable_dbt_codegen", True)
logger.warning(
f"CLI features have been automatically disabled due to misconfigurations:\n {'\n '.join(cli_errors)}."
)
return self
def _find_available_port(*, start_port: int, max_attempts: int = 20) -> int:
"""
Return the first available port on 127.0.0.1 starting at start_port.
Raises RuntimeError if no port is found within the attempted range.
"""
for candidate_port in range(start_port, start_port + max_attempts):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind(("127.0.0.1", candidate_port))
except OSError:
continue
return candidate_port
raise RuntimeError(
"No available port found starting at "
f"{start_port} after {max_attempts} attempts."
)
def get_dbt_profiles_path(dbt_profiles_dir: str | None = None) -> Path:
# Respect DBT_PROFILES_DIR if set; otherwise default to ~/.dbt/mcp.yml
if dbt_profiles_dir:
return Path(dbt_profiles_dir).expanduser()
else:
return Path.home() / ".dbt"
async def get_dbt_platform_context(
*,
dbt_user_dir: Path,
dbt_platform_url: str,
dbt_platform_context_manager: DbtPlatformContextManager,
) -> DbtPlatformContext:
# Some MCP hosts (Claude Desktop) tend to run multiple MCP servers instances.
# We need to lock so that only one can run the oauth flow.
with FileLock(dbt_user_dir / "mcp.lock"):
dbt_ctx = dbt_platform_context_manager.read_context()
if (
dbt_ctx
and dbt_ctx.account_id
and dbt_ctx.host_prefix
and dbt_ctx.dev_environment
and dbt_ctx.prod_environment
and dbt_ctx.decoded_access_token
and dbt_ctx.decoded_access_token.access_token_response.expires_at
> time.time() + 120 # 2 minutes buffer
):
return dbt_ctx
# Find an available port for the local OAuth redirect server
selected_port = _find_available_port(start_port=OAUTH_REDIRECT_STARTING_PORT)
return await login(
dbt_platform_url=dbt_platform_url,
port=selected_port,
dbt_platform_context_manager=dbt_platform_context_manager,
)
def get_dbt_host(
settings: DbtMcpSettings, dbt_platform_context: DbtPlatformContext
) -> str:
actual_host = settings.actual_host
if not actual_host:
raise ValueError("DBT_HOST is a required environment variable")
host_prefix_with_period = f"{dbt_platform_context.host_prefix}."
if not actual_host.startswith(host_prefix_with_period):
raise ValueError(
f"The DBT_HOST environment variable is expected to start with the {dbt_platform_context.host_prefix} custom subdomain."
)
# We have to remove the custom subdomain prefix
# so that the metadata and semantic-layer URLs can be constructed correctly.
return actual_host.removeprefix(host_prefix_with_period)
def validate_settings(settings: DbtMcpSettings):
errors: list[str] = []
errors.extend(validate_dbt_platform_settings(settings))
errors.extend(validate_dbt_cli_settings(settings))
if errors:
raise ValueError("Errors found in configuration:\n\n" + "\n".join(errors))
def validate_dbt_platform_settings(
settings: DbtMcpSettings,
) -> list[str]:
errors: list[str] = []
if (
not settings.disable_semantic_layer
or not settings.disable_discovery
or not settings.actual_disable_sql
or not settings.disable_admin_api
):
if not settings.actual_host:
errors.append(
"DBT_HOST environment variable is required when semantic layer, discovery, SQL or admin API tools are enabled."
)
if not settings.actual_prod_environment_id:
errors.append(
"DBT_PROD_ENV_ID environment variable is required when semantic layer, discovery, SQL or admin API tools are enabled."
)
if not settings.dbt_token:
errors.append(
"DBT_TOKEN environment variable is required when semantic layer, discovery, SQL or admin API tools are enabled."
)
if settings.actual_host and (
settings.actual_host.startswith("metadata")
or settings.actual_host.startswith("semantic-layer")
):
errors.append(
"DBT_HOST must not start with 'metadata' or 'semantic-layer'."
)
if (
not settings.actual_disable_sql
and ToolName.TEXT_TO_SQL not in (settings.disable_tools or [])
and not settings.actual_prod_environment_id
):
errors.append(
"DBT_PROD_ENV_ID environment variable is required when text_to_sql is enabled."
)
if not settings.actual_disable_sql and ToolName.EXECUTE_SQL not in (
settings.disable_tools or []
):
if not settings.dbt_dev_env_id:
errors.append(
"DBT_DEV_ENV_ID environment variable is required when execute_sql is enabled."
)
if not settings.dbt_user_id:
errors.append(
"DBT_USER_ID environment variable is required when execute_sql is enabled."
)
return errors
def validate_dbt_cli_settings(settings: DbtMcpSettings) -> list[str]:
errors: list[str] = []
if not settings.disable_dbt_cli:
if not settings.dbt_project_dir:
errors.append(
"DBT_PROJECT_DIR environment variable is required when dbt CLI tools are enabled."
)
if not settings.dbt_path:
errors.append(
"DBT_PATH environment variable is required when dbt CLI tools are enabled."
)
else:
dbt_path = Path(settings.dbt_path)
if not (dbt_path.exists() or shutil.which(dbt_path)):
errors.append(
f"DBT_PATH executable can't be found: {settings.dbt_path}"
)
return errors
class CredentialsProvider:
def __init__(self, settings: DbtMcpSettings):
self.settings = settings
self.token_provider: TokenProvider | None = None
self.authentication_method: AuthenticationMethod | None = None
def _log_settings(self) -> None:
settings = self.settings.model_dump()
if settings.get("dbt_token") is not None:
settings["dbt_token"] = "***redacted***"
logger.info(f"Settings: {settings}")
async def get_credentials(self) -> tuple[DbtMcpSettings, TokenProvider]:
if self.token_provider is not None:
# If token provider is already set, just return the cached values
return self.settings, self.token_provider
# Load settings from environment variables using pydantic_settings
dbt_platform_errors = validate_dbt_platform_settings(self.settings)
if dbt_platform_errors:
dbt_user_dir = get_dbt_profiles_path(
dbt_profiles_dir=self.settings.dbt_profiles_dir
)
config_location = dbt_user_dir / "mcp.yml"
dbt_platform_url = f"https://{self.settings.actual_host}"
dbt_platform_context_manager = DbtPlatformContextManager(config_location)
dbt_platform_context = await get_dbt_platform_context(
dbt_platform_context_manager=dbt_platform_context_manager,
dbt_user_dir=dbt_user_dir,
dbt_platform_url=dbt_platform_url,
)
# Override settings with settings attained from login or mcp.yml
self.settings.dbt_user_id = dbt_platform_context.user_id
self.settings.dbt_dev_env_id = (
dbt_platform_context.dev_environment.id
if dbt_platform_context.dev_environment
else None
)
self.settings.dbt_prod_env_id = (
dbt_platform_context.prod_environment.id
if dbt_platform_context.prod_environment
else None
)
self.settings.dbt_account_id = dbt_platform_context.account_id
self.settings.host_prefix = dbt_platform_context.host_prefix
self.settings.dbt_host = get_dbt_host(self.settings, dbt_platform_context)
if not dbt_platform_context.decoded_access_token:
raise ValueError("No decoded access token found in OAuth context")
self.settings.dbt_token = dbt_platform_context.decoded_access_token.access_token_response.access_token
self.token_provider = OAuthTokenProvider(
access_token_response=dbt_platform_context.decoded_access_token.access_token_response,
dbt_platform_url=dbt_platform_url,
context_manager=dbt_platform_context_manager,
)
validate_settings(self.settings)
self.authentication_method = AuthenticationMethod.OAUTH
self._log_settings()
return self.settings, self.token_provider
self.token_provider = StaticTokenProvider(token=self.settings.dbt_token)
validate_settings(self.settings)
self.authentication_method = AuthenticationMethod.ENV_VAR
self._log_settings()
return self.settings, self.token_provider
```
--------------------------------------------------------------------------------
/tests/unit/config/test_config.py:
--------------------------------------------------------------------------------
```python
import os
from unittest.mock import patch
import pytest
from dbt_mcp.config.config import (
DbtMcpSettings,
load_config,
)
from dbt_mcp.config.settings import DEFAULT_DBT_CLI_TIMEOUT
from dbt_mcp.dbt_cli.binary_type import BinaryType
from dbt_mcp.tools.tool_names import ToolName
class TestDbtMcpSettings:
def setup_method(self):
# Clear environment variables that could interfere with default value tests
env_vars_to_clear = [
"DBT_HOST",
"DBT_MCP_HOST",
"DBT_PROD_ENV_ID",
"DBT_ENV_ID",
"DBT_DEV_ENV_ID",
"DBT_USER_ID",
"DBT_TOKEN",
"DBT_PROJECT_DIR",
"DBT_PATH",
"DBT_CLI_TIMEOUT",
"DISABLE_DBT_CLI",
"DISABLE_DBT_CODEGEN",
"DISABLE_SEMANTIC_LAYER",
"DISABLE_DISCOVERY",
"DISABLE_REMOTE",
"DISABLE_ADMIN_API",
"MULTICELL_ACCOUNT_PREFIX",
"DBT_WARN_ERROR_OPTIONS",
"DISABLE_TOOLS",
"DBT_ACCOUNT_ID",
]
for var in env_vars_to_clear:
os.environ.pop(var, None)
def test_default_values(self, env_setup):
# Test with clean environment and no .env file
clean_env = {
"HOME": os.environ.get("HOME", ""),
} # Keep HOME for potential path resolution
with env_setup(env_vars=clean_env):
settings = DbtMcpSettings(_env_file=None)
assert settings.dbt_path == "dbt"
assert settings.dbt_cli_timeout == DEFAULT_DBT_CLI_TIMEOUT
assert settings.disable_remote is None, "disable_remote"
assert settings.disable_dbt_cli is False, "disable_dbt_cli"
assert settings.disable_dbt_codegen is True, "disable_dbt_codegen"
assert settings.disable_admin_api is False, "disable_admin_api"
assert settings.disable_semantic_layer is False, "disable_semantic_layer"
assert settings.disable_discovery is False, "disable_discovery"
assert settings.disable_sql is None, "disable_sql"
assert settings.disable_tools == [], "disable_tools"
def test_usage_tracking_disabled_by_env_vars(self):
env_vars = {
"DO_NOT_TRACK": "true",
"DBT_SEND_ANONYMOUS_USAGE_STATS": "1",
}
with patch.dict(os.environ, env_vars, clear=True):
settings = DbtMcpSettings(_env_file=None)
assert settings.usage_tracking_enabled is False
def test_usage_tracking_respects_dbt_project_yaml(self, env_setup):
with env_setup() as (project_dir, helpers):
(project_dir / "dbt_project.yml").write_text(
"flags:\n send_anonymous_usage_stats: false\n"
)
settings = DbtMcpSettings(_env_file=None)
assert settings.usage_tracking_enabled is False
def test_usage_tracking_env_var_precedence_over_yaml(self, env_setup):
env_vars = {
"DBT_SEND_ANONYMOUS_USAGE_STATS": "false",
}
with env_setup(env_vars=env_vars) as (project_dir, helpers):
(project_dir / "dbt_project.yml").write_text(
"flags:\n send_anonymous_usage_stats: true\n"
)
settings = DbtMcpSettings(_env_file=None)
assert settings.usage_tracking_enabled is False
@pytest.mark.parametrize(
"do_not_track, send_anonymous_usage_stats",
[
("true", "1"),
("1", "true"),
("true", None),
("1", None),
(None, "false"),
(None, "0"),
],
)
def test_usage_tracking_conflicting_env_vars_bias_off(
self, do_not_track, send_anonymous_usage_stats
):
env_vars = {}
if do_not_track is not None:
env_vars["DO_NOT_TRACK"] = do_not_track
if send_anonymous_usage_stats is not None:
env_vars["DBT_SEND_ANONYMOUS_USAGE_STATS"] = send_anonymous_usage_stats
with patch.dict(os.environ, env_vars, clear=True):
settings = DbtMcpSettings(_env_file=None)
assert settings.usage_tracking_enabled is False
def test_env_var_parsing(self, env_setup):
env_vars = {
"DBT_HOST": "test.dbt.com",
"DBT_PROD_ENV_ID": "123",
"DBT_TOKEN": "test_token",
"DISABLE_DBT_CLI": "true",
"DISABLE_TOOLS": "build,compile,docs",
}
with env_setup(env_vars=env_vars) as (project_dir, helpers):
settings = DbtMcpSettings(_env_file=None)
assert settings.dbt_host == "test.dbt.com"
assert settings.dbt_prod_env_id == 123
assert settings.dbt_token == "test_token"
assert settings.dbt_project_dir == str(project_dir)
assert settings.disable_dbt_cli is True
assert settings.disable_tools == [
ToolName.BUILD,
ToolName.COMPILE,
ToolName.DOCS,
]
def test_disable_tools_parsing_edge_cases(self):
test_cases = [
("build,compile,docs", [ToolName.BUILD, ToolName.COMPILE, ToolName.DOCS]),
(
"build, compile , docs",
[ToolName.BUILD, ToolName.COMPILE, ToolName.DOCS],
),
("build,,docs", [ToolName.BUILD, ToolName.DOCS]),
("", []),
("run", [ToolName.RUN]),
]
for input_val, expected in test_cases:
with patch.dict(os.environ, {"DISABLE_TOOLS": input_val}):
settings = DbtMcpSettings(_env_file=None)
assert settings.disable_tools == expected
def test_actual_host_property(self):
with patch.dict(os.environ, {"DBT_HOST": "host1.com"}):
settings = DbtMcpSettings(_env_file=None)
assert settings.actual_host == "host1.com"
with patch.dict(os.environ, {"DBT_MCP_HOST": "host2.com"}):
settings = DbtMcpSettings(_env_file=None)
assert settings.actual_host == "host2.com"
with patch.dict(
os.environ, {"DBT_HOST": "host1.com", "DBT_MCP_HOST": "host2.com"}
):
settings = DbtMcpSettings(_env_file=None)
assert settings.actual_host == "host1.com" # DBT_HOST takes precedence
def test_actual_prod_environment_id_property(self):
with patch.dict(os.environ, {"DBT_PROD_ENV_ID": "123"}):
settings = DbtMcpSettings(_env_file=None)
assert settings.actual_prod_environment_id == 123
with patch.dict(os.environ, {"DBT_ENV_ID": "456"}):
settings = DbtMcpSettings(_env_file=None)
assert settings.actual_prod_environment_id == 456
with patch.dict(os.environ, {"DBT_PROD_ENV_ID": "123", "DBT_ENV_ID": "456"}):
settings = DbtMcpSettings(_env_file=None)
assert (
settings.actual_prod_environment_id == 123
) # DBT_PROD_ENV_ID takes precedence
def test_auto_disable_platform_features_logging(self):
with patch.dict(os.environ, {}, clear=True):
settings = DbtMcpSettings(_env_file=None)
# When DBT_HOST is missing, platform features should be disabled
assert settings.disable_admin_api is True
assert settings.disable_sql is True
assert settings.disable_semantic_layer is True
assert settings.disable_discovery is True
assert settings.disable_dbt_cli is True
assert settings.disable_dbt_codegen is True
class TestLoadConfig:
def setup_method(self):
# Clear any existing environment variables that might interfere
env_vars_to_clear = [
"DBT_HOST",
"DBT_MCP_HOST",
"DBT_PROD_ENV_ID",
"DBT_ENV_ID",
"DBT_DEV_ENV_ID",
"DBT_USER_ID",
"DBT_TOKEN",
"DBT_PROJECT_DIR",
"DBT_PATH",
"DBT_CLI_TIMEOUT",
"DISABLE_DBT_CLI",
"DISABLE_SEMANTIC_LAYER",
"DISABLE_DISCOVERY",
"DISABLE_REMOTE",
"DISABLE_ADMIN_API",
"MULTICELL_ACCOUNT_PREFIX",
"DBT_WARN_ERROR_OPTIONS",
"DISABLE_TOOLS",
"DBT_ACCOUNT_ID",
]
for var in env_vars_to_clear:
os.environ.pop(var, None)
def _load_config_with_env(self, env_vars):
"""Helper method to load config with test environment variables, avoiding .env file interference"""
with (
patch.dict(os.environ, env_vars),
patch("dbt_mcp.config.config.DbtMcpSettings") as mock_settings_class,
patch(
"dbt_mcp.config.config.detect_binary_type",
return_value=BinaryType.DBT_CORE,
),
):
# Create a real instance with test values, but without .env file loading
with patch.dict(os.environ, env_vars, clear=True):
settings_instance = DbtMcpSettings(_env_file=None)
mock_settings_class.return_value = settings_instance
return load_config()
def test_valid_config_all_services_enabled(self, env_setup):
env_vars = {
"DBT_HOST": "test.dbt.com",
"DBT_PROD_ENV_ID": "123",
"DBT_DEV_ENV_ID": "456",
"DBT_USER_ID": "789",
"DBT_ACCOUNT_ID": "123",
"DBT_TOKEN": "test_token",
"DISABLE_SEMANTIC_LAYER": "false",
"DISABLE_DISCOVERY": "false",
"DISABLE_REMOTE": "false",
"DISABLE_ADMIN_API": "false",
"DISABLE_DBT_CODEGEN": "false",
}
with env_setup(env_vars=env_vars) as (project_dir, helpers):
config = load_config()
assert config.sql_config_provider is not None, (
"sql_config_provider should be set"
)
assert config.dbt_cli_config is not None, "dbt_cli_config should be set"
assert config.discovery_config_provider is not None, (
"discovery_config_provider should be set"
)
assert config.semantic_layer_config_provider is not None, (
"semantic_layer_config_provider should be set"
)
assert config.admin_api_config_provider is not None, (
"admin_api_config_provider should be set"
)
assert config.credentials_provider is not None, (
"credentials_provider should be set"
)
assert config.dbt_codegen_config is not None, (
"dbt_codegen_config should be set"
)
def test_valid_config_all_services_disabled(self):
env_vars = {
"DBT_TOKEN": "test_token",
"DISABLE_DBT_CLI": "true",
"DISABLE_SEMANTIC_LAYER": "true",
"DISABLE_DISCOVERY": "true",
"DISABLE_REMOTE": "true",
"DISABLE_ADMIN_API": "true",
}
config = self._load_config_with_env(env_vars)
assert config.sql_config_provider is None
assert config.dbt_cli_config is None
assert config.discovery_config_provider is None
assert config.semantic_layer_config_provider is None
def test_invalid_environment_variable_types(self):
# Test invalid integer types
env_vars = {
"DBT_HOST": "test.dbt.com",
"DBT_PROD_ENV_ID": "not_an_integer",
"DBT_TOKEN": "test_token",
"DISABLE_DISCOVERY": "false",
}
with pytest.raises(ValueError):
self._load_config_with_env(env_vars)
def test_multicell_account_prefix_configurations(self):
env_vars = {
"DBT_HOST": "test.dbt.com",
"DBT_PROD_ENV_ID": "123",
"DBT_TOKEN": "test_token",
"MULTICELL_ACCOUNT_PREFIX": "prefix",
"DISABLE_DISCOVERY": "false",
"DISABLE_SEMANTIC_LAYER": "false",
"DISABLE_DBT_CLI": "true",
"DISABLE_REMOTE": "true",
}
config = self._load_config_with_env(env_vars)
assert config.discovery_config_provider is not None
assert config.semantic_layer_config_provider is not None
def test_localhost_semantic_layer_config(self):
env_vars = {
"DBT_HOST": "localhost:8080",
"DBT_PROD_ENV_ID": "123",
"DBT_TOKEN": "test_token",
"DISABLE_SEMANTIC_LAYER": "false",
"DISABLE_DBT_CLI": "true",
"DISABLE_DISCOVERY": "true",
"DISABLE_REMOTE": "true",
}
config = self._load_config_with_env(env_vars)
assert config.semantic_layer_config_provider is not None
def test_warn_error_options_default_setting(self):
env_vars = {
"DBT_TOKEN": "test_token",
"DISABLE_DBT_CLI": "true",
"DISABLE_SEMANTIC_LAYER": "true",
"DISABLE_DISCOVERY": "true",
"DISABLE_REMOTE": "true",
"DISABLE_ADMIN_API": "true",
}
# For this test, we need to call load_config directly to see environment side effects
with patch.dict(os.environ, env_vars, clear=True):
with patch("dbt_mcp.config.config.DbtMcpSettings") as mock_settings_class:
settings_instance = DbtMcpSettings(_env_file=None)
mock_settings_class.return_value = settings_instance
load_config()
assert (
os.environ["DBT_WARN_ERROR_OPTIONS"]
== '{"error": ["NoNodesForSelectionCriteria"]}'
)
def test_warn_error_options_not_overridden_if_set(self):
env_vars = {
"DBT_TOKEN": "test_token",
"DBT_WARN_ERROR_OPTIONS": "custom_options",
"DISABLE_DBT_CLI": "true",
"DISABLE_SEMANTIC_LAYER": "true",
"DISABLE_DISCOVERY": "true",
"DISABLE_REMOTE": "true",
"DISABLE_ADMIN_API": "true",
}
# For this test, we need to call load_config directly to see environment side effects
with patch.dict(os.environ, env_vars, clear=True):
with patch("dbt_mcp.config.config.DbtMcpSettings") as mock_settings_class:
settings_instance = DbtMcpSettings(_env_file=None)
mock_settings_class.return_value = settings_instance
load_config()
assert os.environ["DBT_WARN_ERROR_OPTIONS"] == "custom_options"
def test_local_user_id_loading_from_dbt_profile(self):
user_data = {"id": "local_user_123"}
env_vars = {
"DBT_TOKEN": "test_token",
"HOME": "/fake/home",
"DISABLE_DBT_CLI": "true",
"DISABLE_SEMANTIC_LAYER": "true",
"DISABLE_DISCOVERY": "true",
"DISABLE_REMOTE": "true",
"DISABLE_ADMIN_API": "true",
}
with (
patch.dict(os.environ, env_vars),
patch("dbt_mcp.tracking.tracking.try_read_yaml", return_value=user_data),
):
config = self._load_config_with_env(env_vars)
# local_user_id is now loaded by UsageTracker, not Config
assert config.credentials_provider is not None
def test_local_user_id_loading_failure_handling(self):
env_vars = {
"DBT_TOKEN": "test_token",
"HOME": "/fake/home",
"DISABLE_DBT_CLI": "true",
"DISABLE_SEMANTIC_LAYER": "true",
"DISABLE_DISCOVERY": "true",
"DISABLE_REMOTE": "true",
"DISABLE_ADMIN_API": "true",
}
with (
patch.dict(os.environ, env_vars),
patch("dbt_mcp.tracking.tracking.try_read_yaml", return_value=None),
):
config = self._load_config_with_env(env_vars)
# local_user_id is now loaded by UsageTracker, not Config
assert config.credentials_provider is not None
def test_remote_requirements(self):
# Test that remote_config is only created when remote tools are enabled
# and all required fields are present
env_vars = {
"DBT_HOST": "test.dbt.com",
"DBT_PROD_ENV_ID": "123",
"DBT_TOKEN": "test_token",
"DISABLE_REMOTE": "true",
"DISABLE_DBT_CLI": "true",
"DISABLE_SEMANTIC_LAYER": "true",
"DISABLE_DISCOVERY": "true",
"DISABLE_ADMIN_API": "true",
}
config = self._load_config_with_env(env_vars)
# Remote config should not be created when remote tools are disabled
assert config.sql_config_provider is None
# Test remote requirements (needs user_id and dev_env_id too)
env_vars.update(
{
"DBT_USER_ID": "789",
"DBT_DEV_ENV_ID": "456",
"DISABLE_REMOTE": "false",
}
)
config = self._load_config_with_env(env_vars)
assert config.sql_config_provider is not None
def test_disable_flags_combinations(self, env_setup):
base_env = {
"DBT_HOST": "test.dbt.com",
"DBT_PROD_ENV_ID": "123",
"DBT_TOKEN": "test_token",
}
test_cases = [
# Only CLI enabled
{
"DISABLE_DBT_CLI": "false",
"DISABLE_SEMANTIC_LAYER": "true",
"DISABLE_DISCOVERY": "true",
"DISABLE_REMOTE": "true",
},
# Only semantic layer enabled
{
"DISABLE_DBT_CLI": "true",
"DISABLE_SEMANTIC_LAYER": "false",
"DISABLE_DISCOVERY": "true",
"DISABLE_REMOTE": "true",
},
# Multiple services enabled
{
"DISABLE_DBT_CLI": "false",
"DISABLE_SEMANTIC_LAYER": "false",
"DISABLE_DISCOVERY": "false",
"DISABLE_REMOTE": "true",
},
]
for disable_flags in test_cases:
env_vars = {**base_env, **disable_flags}
with env_setup(env_vars=env_vars) as (project_dir, helpers):
config = load_config()
# Verify configs are created only when services are enabled
assert (config.dbt_cli_config is not None) == (
disable_flags["DISABLE_DBT_CLI"] == "false"
)
assert (config.semantic_layer_config_provider is not None) == (
disable_flags["DISABLE_SEMANTIC_LAYER"] == "false"
)
assert (config.discovery_config_provider is not None) == (
disable_flags["DISABLE_DISCOVERY"] == "false"
)
def test_legacy_env_id_support(self):
# Test that DBT_ENV_ID still works for backward compatibility
env_vars = {
"DBT_HOST": "test.dbt.com",
"DBT_ENV_ID": "123", # Using legacy variable
"DBT_TOKEN": "test_token",
"DISABLE_DISCOVERY": "false",
"DISABLE_DBT_CLI": "true",
"DISABLE_SEMANTIC_LAYER": "true",
"DISABLE_REMOTE": "true",
}
config = self._load_config_with_env(env_vars)
assert config.discovery_config_provider is not None
assert config.credentials_provider is not None
def test_case_insensitive_environment_variables(self):
# pydantic_settings should handle case insensitivity based on config
env_vars = {
"dbt_host": "test.dbt.com", # lowercase
"DBT_PROD_ENV_ID": "123", # uppercase
"dbt_token": "test_token", # lowercase
"DISABLE_DISCOVERY": "false",
"DISABLE_DBT_CLI": "true",
"DISABLE_SEMANTIC_LAYER": "true",
"DISABLE_REMOTE": "true",
}
config = self._load_config_with_env(env_vars)
assert config.discovery_config_provider is not None
assert config.credentials_provider is not None
```
--------------------------------------------------------------------------------
/tests/integration/lsp/test_lsp_connection.py:
--------------------------------------------------------------------------------
```python
"""Integration-style tests for LSP connection using real instances instead of mocks.
These tests use real sockets, asyncio primitives, and actual data flow
to provide more realistic test coverage compared to heavily mocked unit tests.
"""
import asyncio
import json
import socket
import pytest
from dbt_mcp.lsp.lsp_connection import (
SocketLSPConnection,
LspEventName,
JsonRpcMessage,
)
class TestRealSocketOperations:
"""Tests using real sockets to verify actual network communication."""
def test_setup_socket_real(self, tmp_path):
"""Test socket setup with real socket binding."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Use real socket
conn.setup_socket()
try:
# Verify real socket was created and bound
assert conn._socket is not None
assert isinstance(conn._socket, socket.socket)
assert conn.port > 0 # OS assigned a port
assert conn.host == "127.0.0.1"
# Verify socket is actually listening
sockname = conn._socket.getsockname()
assert sockname[0] == "127.0.0.1"
assert sockname[1] == conn.port
finally:
# Cleanup
if conn._socket:
conn._socket.close()
def test_socket_reuse_address(self, tmp_path):
"""Test that SO_REUSEADDR is set on real socket."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.setup_socket()
try:
# Verify SO_REUSEADDR is set (value varies by platform, just check it's non-zero)
reuse = conn._socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
assert reuse != 0
finally:
if conn._socket:
conn._socket.close()
@pytest.mark.asyncio
async def test_socket_accept_with_real_client(self, tmp_path):
"""Test socket accept with real client connection."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test", connection_timeout=2.0)
conn.setup_socket()
try:
# Create real client socket that connects to the server
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
async def client_connect():
await asyncio.sleep(0.1) # Let server start listening
await asyncio.get_running_loop().run_in_executor(
None, client_socket.connect, (conn.host, conn.port)
)
async def server_accept():
conn._socket.settimeout(conn.connection_timeout)
connection, addr = await asyncio.get_running_loop().run_in_executor(
None, conn._socket.accept
)
return connection, addr
# Run both concurrently
client_task = asyncio.create_task(client_connect())
server_result = await server_accept()
await client_task
connection, client_addr = server_result
assert connection is not None
assert client_addr[0] in ("127.0.0.1", "::1") # IPv4 or IPv6 localhost
# Cleanup
connection.close()
client_socket.close()
finally:
if conn._socket:
conn._socket.close()
class TestRealAsyncioQueues:
"""Tests using real asyncio queues to verify message queueing."""
def test_send_message_with_real_queue(self, tmp_path):
"""Test message sending with real asyncio queue."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# The _outgoing_queue is already a real asyncio.Queue
assert isinstance(conn._outgoing_queue, asyncio.Queue)
assert conn._outgoing_queue.empty()
# Send a message
message = JsonRpcMessage(id=1, method="test", params={"key": "value"})
conn._send_message(message)
# Verify message was actually queued
assert not conn._outgoing_queue.empty()
data = conn._outgoing_queue.get_nowait()
# Verify LSP protocol format
assert isinstance(data, bytes)
assert b"Content-Length:" in data
assert b"\r\n\r\n" in data
assert b'"jsonrpc"' in data
assert b'"2.0"' in data
assert b'"test"' in data
def test_multiple_messages_queue_order(self, tmp_path):
"""Test that multiple messages maintain FIFO order in real queue."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Send multiple messages
msg1 = JsonRpcMessage(id=1, method="first")
msg2 = JsonRpcMessage(id=2, method="second")
msg3 = JsonRpcMessage(id=3, method="third")
conn._send_message(msg1)
conn._send_message(msg2)
conn._send_message(msg3)
# Verify queue size
assert conn._outgoing_queue.qsize() == 3
# Verify FIFO order
data1 = conn._outgoing_queue.get_nowait()
data2 = conn._outgoing_queue.get_nowait()
data3 = conn._outgoing_queue.get_nowait()
assert b'"first"' in data1
assert b'"second"' in data2
assert b'"third"' in data3
# Queue should be empty
assert conn._outgoing_queue.empty()
class TestRealAsyncioFutures:
"""Tests using real asyncio futures to verify async behavior."""
@pytest.mark.asyncio
async def test_handle_response_with_real_future(self, tmp_path):
"""Test handling response with real asyncio future."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create real future in current event loop
future = asyncio.get_running_loop().create_future()
conn.state.pending_requests[42] = future
# Handle response in the same loop
message = JsonRpcMessage(id=42, result={"success": True, "data": "test"})
conn._handle_incoming_message(message)
# Wait for future to be resolved (should be immediate via call_soon_threadsafe)
result = await asyncio.wait_for(future, timeout=1.0)
assert result == {"success": True, "data": "test"}
assert 42 not in conn.state.pending_requests
@pytest.mark.asyncio
async def test_handle_error_with_real_future(self, tmp_path):
"""Test handling error response with real asyncio future."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create real future
future = asyncio.get_running_loop().create_future()
conn.state.pending_requests[42] = future
# Handle error response
message = JsonRpcMessage(
id=42, error={"code": -32601, "message": "Method not found"}
)
conn._handle_incoming_message(message)
# Future should be rejected with exception
with pytest.raises(RuntimeError, match="LSP error"):
await asyncio.wait_for(future, timeout=1.0)
assert 42 not in conn.state.pending_requests
@pytest.mark.asyncio
async def test_notification_futures_real(self, tmp_path):
"""Test waiting for notifications with real futures."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Register to wait for a notification
future = conn.wait_for_notification(LspEventName.compileComplete)
# Verify it's a real future
assert isinstance(future, asyncio.Future)
assert not future.done()
# Simulate receiving the notification
message = JsonRpcMessage(
method="dbt/lspCompileComplete", params={"success": True, "errors": []}
)
conn._handle_incoming_message(message)
# Wait for notification
result = await asyncio.wait_for(future, timeout=1.0)
assert result == {"success": True, "errors": []}
assert conn.state.compiled is True
class TestRealSocketCommunication:
"""Tests using real socket pairs to verify end-to-end communication."""
@pytest.mark.asyncio
async def test_socket_pair_communication(self, tmp_path):
"""Test bidirectional communication using socketpair."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create a real socket pair (connected sockets)
server_socket, client_socket = socket.socketpair()
conn._connection = server_socket
try:
# Send a message through the connection
message = JsonRpcMessage(id=1, method="test", params={"foo": "bar"})
conn._send_message(message)
# Get the data from the queue
data = conn._outgoing_queue.get_nowait()
# Actually send it through the socket
await asyncio.get_running_loop().run_in_executor(
None, server_socket.sendall, data
)
# Read it back on the client side
received_data = await asyncio.get_running_loop().run_in_executor(
None, client_socket.recv, 4096
)
# Verify we got the complete LSP message
assert b"Content-Length:" in received_data
assert b"\r\n\r\n" in received_data
assert b'"test"' in received_data
assert b'"foo"' in received_data
assert b'"bar"' in received_data
finally:
server_socket.close()
client_socket.close()
@pytest.mark.asyncio
async def test_message_roundtrip_real(self, tmp_path):
"""Test complete message send and parse roundtrip."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create socket pair
server_socket, client_socket = socket.socketpair()
conn._connection = server_socket
try:
# Original message
original_message = JsonRpcMessage(
id=123,
method="textDocument/completion",
params={
"textDocument": {"uri": "file:///test.sql"},
"position": {"line": 10, "character": 5},
},
)
# Send through connection
conn._send_message(original_message)
data = conn._outgoing_queue.get_nowait()
await asyncio.get_running_loop().run_in_executor(
None, server_socket.sendall, data
)
# Receive on client side
received_data = await asyncio.get_running_loop().run_in_executor(
None, client_socket.recv, 4096
)
# Parse it back
parsed_message, remaining = conn._parse_message(received_data)
# Verify roundtrip integrity
assert parsed_message is not None
assert parsed_message.id == original_message.id
assert parsed_message.method == original_message.method
assert parsed_message.params == original_message.params
assert remaining == b""
finally:
server_socket.close()
client_socket.close()
@pytest.mark.asyncio
async def test_multiple_messages_streaming(self, tmp_path):
"""Test streaming multiple messages through real socket."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create socket pair
server_socket, client_socket = socket.socketpair()
conn._connection = server_socket
try:
# Set non-blocking for client to avoid hangs
client_socket.setblocking(False)
# Send multiple messages
messages = [
JsonRpcMessage(id=1, method="initialize"),
JsonRpcMessage(method="initialized", params={}),
JsonRpcMessage(id=2, method="textDocument/didOpen"),
]
for msg in messages:
conn._send_message(msg)
data = conn._outgoing_queue.get_nowait()
await asyncio.get_running_loop().run_in_executor(
None, server_socket.sendall, data
)
# Receive all data on client side with timeout
received_data = b""
client_socket.setblocking(True)
client_socket.settimeout(1.0)
try:
while True:
chunk = await asyncio.get_running_loop().run_in_executor(
None, client_socket.recv, 4096
)
if not chunk:
break
received_data += chunk
# Try to parse - if we have all 3 messages, we're done
temp_buffer = received_data
temp_count = 0
while True:
msg, temp_buffer = conn._parse_message(temp_buffer)
if msg is None:
break
temp_count += 1
if temp_count >= 3:
break
except TimeoutError:
pass # Expected when all data is received
# Parse all messages
buffer = received_data
parsed_messages = []
while buffer:
msg, buffer = conn._parse_message(buffer)
if msg is None:
break
parsed_messages.append(msg)
# Verify all messages were received and parsed correctly
assert len(parsed_messages) == 3
assert parsed_messages[0].id == 1
assert parsed_messages[0].method == "initialize"
assert parsed_messages[1].method == "initialized"
assert parsed_messages[2].id == 2
finally:
server_socket.close()
client_socket.close()
class TestRealMessageParsing:
"""Tests parsing with real byte streams."""
def test_parse_real_lsp_message(self, tmp_path):
"""Test parsing a real LSP protocol message."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create a real LSP message exactly as it would be sent
content = json.dumps(
{
"jsonrpc": "2.0",
"id": 1,
"result": {
"capabilities": {
"textDocumentSync": 2,
"completionProvider": {"triggerCharacters": ["."]},
}
},
}
)
content_bytes = content.encode("utf-8")
header = f"Content-Length: {len(content_bytes)}\r\n\r\n"
full_message = header.encode("utf-8") + content_bytes
# Parse it
message, remaining = conn._parse_message(full_message)
assert message is not None
assert message.id == 1
assert "capabilities" in message.result
assert message.result["capabilities"]["textDocumentSync"] == 2
assert remaining == b""
def test_parse_chunked_message_real(self, tmp_path):
"""Test parsing message that arrives in multiple chunks."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create a message
content = json.dumps({"jsonrpc": "2.0", "id": 1, "method": "test"})
content_bytes = content.encode("utf-8")
header = f"Content-Length: {len(content_bytes)}\r\n\r\n"
full_message = header.encode("utf-8") + content_bytes
# Split into chunks (simulate network chunking)
chunk1 = full_message[:20]
chunk2 = full_message[20:40]
chunk3 = full_message[40:]
# Parse first chunk - should be incomplete
msg1, buffer = conn._parse_message(chunk1)
assert msg1 is None
assert buffer == chunk1
# Add second chunk - still incomplete
buffer += chunk2
msg2, buffer = conn._parse_message(buffer)
assert msg2 is None
# Add final chunk - should complete
buffer += chunk3
msg3, buffer = conn._parse_message(buffer)
assert msg3 is not None
assert msg3.id == 1
assert msg3.method == "test"
assert buffer == b""
class TestRealConcurrentOperations:
"""Tests with real concurrent async operations."""
@pytest.mark.asyncio
async def test_concurrent_request_futures(self, tmp_path):
"""Test handling multiple concurrent requests with real futures."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create multiple real futures for concurrent requests
futures = {}
for i in range(10):
future = asyncio.get_running_loop().create_future()
futures[i] = future
conn.state.pending_requests[i] = future
# Simulate responses arriving concurrently
async def respond(request_id: int, delay: float):
await asyncio.sleep(delay)
message = JsonRpcMessage(id=request_id, result={"request_id": request_id})
conn._handle_incoming_message(message)
# Start all responses with random delays
response_tasks = [asyncio.create_task(respond(i, i * 0.01)) for i in range(10)]
# Wait for all futures to resolve
results = await asyncio.gather(*[futures[i] for i in range(10)])
# Verify all completed correctly
assert len(results) == 10
for i, result in enumerate(results):
assert result["request_id"] == i
# All requests should be removed
assert len(conn.state.pending_requests) == 0
# Cleanup
await asyncio.gather(*response_tasks)
@pytest.mark.asyncio
async def test_concurrent_notifications_real(self, tmp_path):
"""Test multiple futures waiting for the same notification."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create multiple waiters for the same event
future1 = conn.wait_for_notification(LspEventName.compileComplete)
future2 = conn.wait_for_notification(LspEventName.compileComplete)
future3 = conn.wait_for_notification(LspEventName.compileComplete)
# All should be real futures
assert all(isinstance(f, asyncio.Future) for f in [future1, future2, future3])
# Send the notification
message = JsonRpcMessage(
method="dbt/lspCompileComplete", params={"status": "success"}
)
conn._handle_incoming_message(message)
# All futures should resolve
results = await asyncio.wait_for(
asyncio.gather(future1, future2, future3), timeout=1.0
)
assert all(r == {"status": "success"} for r in results)
class TestRealStateManagement:
"""Tests using real state objects."""
def test_real_state_initialization(self, tmp_path):
"""Test that connection uses real LspConnectionState."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Verify state is a real instance
from dbt_mcp.lsp.lsp_connection import LspConnectionState
assert isinstance(conn.state, LspConnectionState)
assert conn.state.initialized is False
assert conn.state.compiled is False
assert isinstance(conn.state.pending_requests, dict)
assert isinstance(conn.state.pending_notifications, dict)
def test_real_request_id_generation(self, tmp_path):
"""Test real request ID counter."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Get sequential IDs
ids = [conn.state.get_next_request_id() for _ in range(100)]
# Verify they're sequential (starting point may vary if other tests ran)
# Just verify they are sequential and unique
first_id = ids[0]
assert ids[-1] == first_id + 99
assert ids == list(range(first_id, first_id + 100))
assert len(set(ids)) == 100 # All unique
def test_real_state_updates(self, tmp_path):
"""Test that state updates work with real instances."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Update state
conn.state.initialized = True
conn.state.capabilities = {"test": True}
conn.state.compiled = True
# Verify updates persist
assert conn.state.initialized is True
assert conn.state.capabilities == {"test": True}
assert conn.state.compiled is True
```
--------------------------------------------------------------------------------
/ui/src/App.css:
--------------------------------------------------------------------------------
```css
/* Reset and base styles */
* {
box-sizing: border-box;
}
body {
width: 100%;
margin: 0;
padding: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',
'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue',
sans-serif;
line-height: 1.6;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
background-color: #fff;
color: #1c1a19;
}
p {
margin-bottom: 1rem;
}
@media (prefers-color-scheme: dark) {
body {
background-color: #1c1a19;
color: #f6f6f6;
}
p {
margin-bottom: 1rem;
}
}
/* Logo */
.logo-container {
position: fixed;
top: 1rem;
left: 1rem;
z-index: 1000;
}
.logo {
height: 2rem;
width: auto;
transition: opacity 0.2s ease-in-out;
}
.logo-light {
display: block;
}
.logo-dark {
display: none;
}
/* Main layout */
.app-container {
min-height: 100vh;
display: flex;
justify-content: center;
align-items: flex-start;
padding: 2rem 1rem;
}
.app-content {
width: 100%;
max-width: 600px;
display: flex;
flex-direction: column;
gap: 2rem;
}
/* Header */
.app-header {
text-align: center;
margin-bottom: 1rem;
}
.app-header h1 {
margin: 0 0 0.5rem 0;
font-size: 2.5rem;
font-weight: 700;
letter-spacing: -0.025em;
}
.app-header p {
margin: 0;
font-size: 1.125rem;
opacity: 0.7;
}
/* Sections */
section {
background: #fff;
border-radius: 12px;
border: 1px solid #ebe9e9;
overflow: visible;
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06);
}
/* Specific overflow handling for sections with dropdowns */
.project-selection-section {
overflow: visible;
}
.section-header {
padding: 1.5rem 1.5rem 0 1.5rem;
border-bottom: 1px solid #ebe9e9;
margin-bottom: 1.5rem;
}
.section-header h2 {
margin: 0 0 0.5rem 0;
font-size: 1.5rem;
font-weight: 600;
}
.section-header h3 {
margin: 0 0 0.5rem 0;
font-size: 1.25rem;
font-weight: 600;
}
.section-header p {
margin: 0 0 1.5rem 0;
opacity: 0.7;
font-size: 0.875rem;
}
/* Form content */
.form-content {
padding: 0 1.5rem 1.5rem 1.5rem;
}
.form-group {
margin-bottom: 1rem;
}
.form-label {
display: block;
margin-bottom: 0.5rem;
font-weight: 500;
font-size: 0.875rem;
}
.form-select {
width: 100%;
padding: 0.875rem 3rem 0.875rem 1rem;
border: 1.5px solid #ebe9e9;
border-radius: 12px;
font-size: 1rem;
font-weight: 500;
background-color: #fff;
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%231c1a19' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
background-position: right 0.75rem center;
background-repeat: no-repeat;
background-size: 1.25rem 1.25rem;
cursor: pointer;
transition: all 0.2s ease-in-out;
appearance: none;
-webkit-appearance: none;
-moz-appearance: none;
}
.form-select:focus {
outline: none;
border-color: #3b82f6;
background-color: white;
box-shadow:
0 0 0 3px rgba(59, 130, 246, 0.12),
0 4px 6px -1px rgba(0, 0, 0, 0.1),
0 2px 4px -1px rgba(0, 0, 0, 0.06);
transform: translateY(-1px);
}
.form-select:hover:not(:focus) {
border-color: #9ca3af;
background-color: white;
box-shadow:
0 2px 4px -1px rgba(0, 0, 0, 0.1),
0 1px 2px -1px rgba(0, 0, 0, 0.06);
}
.form-select:disabled {
background-color: #f3f4f6;
border-color: #e5e7eb;
cursor: not-allowed;
opacity: 0.7;
}
/* Custom dropdown */
.custom-dropdown {
position: relative;
width: 100%;
z-index: 999999;
/* Ensure proper stacking context */
isolation: isolate;
}
.dropdown-trigger {
width: 100%;
padding: 0.875rem 3rem 0.875rem 1rem;
border: 1.5px solid #ebe9e9;
border-radius: 12px;
font-size: 1rem;
font-weight: 500;
background-color: #fff;
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%231c1a19' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
background-position: right 0.75rem center;
background-repeat: no-repeat;
background-size: 1.25rem 1.25rem;
cursor: pointer;
transition: all 0.2s ease-in-out;
text-align: left;
color: #1c1a19;
}
.dropdown-trigger:focus {
outline: none;
border-color: #3b82f6;
background-color: white;
box-shadow:
0 0 0 3px rgba(59, 130, 246, 0.12),
0 4px 6px -1px rgba(0, 0, 0, 0.1),
0 2px 4px -1px rgba(0, 0, 0, 0.06);
transform: translateY(-1px);
}
.dropdown-trigger:hover:not(:focus) {
border-color: #9ca3af;
background-color: white;
box-shadow:
0 2px 4px -1px rgba(0, 0, 0, 0.1),
0 1px 2px -1px rgba(0, 0, 0, 0.06);
}
.dropdown-trigger.open {
border-color: #3b82f6;
background-color: white;
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%233b82f6' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M14 12l-4-4-4 4'/%3e%3c/svg%3e");
box-shadow:
0 0 0 3px rgba(59, 130, 246, 0.12),
0 4px 6px -1px rgba(0, 0, 0, 0.1),
0 2px 4px -1px rgba(0, 0, 0, 0.06);
border-bottom-left-radius: 4px;
border-bottom-right-radius: 4px;
}
.dropdown-trigger.placeholder {
color: #9ca3af;
font-weight: 400;
}
.dropdown-options {
position: absolute;
top: 100%;
left: 0;
right: 0;
background: white;
border: 1.5px solid;
border-top: none;
border-bottom-left-radius: 12px;
border-bottom-right-radius: 12px;
box-shadow:
0 0 0 3px rgba(59, 130, 246, 0.12),
0 10px 15px -3px rgba(0, 0, 0, 0.1),
0 4px 6px -2px rgba(0, 0, 0, 0.05);
z-index: 999999;
animation: dropdownSlideIn 0.15s ease-out;
/* Ensure proper rendering and isolation */
isolation: isolate;
contain: layout style;
/* Add scrolling for long lists */
max-height: 300px;
overflow-y: auto;
}
/* Removed dropdown-options-fixed - using simple absolute positioning */
@keyframes dropdownSlideIn {
0% {
opacity: 0;
transform: translateY(-8px);
}
100% {
opacity: 1;
transform: translateY(0);
}
}
.dropdown-option {
padding: 0.875rem 1rem;
cursor: pointer;
transition: all 0.15s ease-in-out;
border: none;
background: none;
width: 100%;
text-align: left;
font-size: 1rem;
color: #374151;
display: flex;
flex-direction: column;
gap: 0.125rem;
}
.dropdown-option:hover {
background-color: #f8fafc;
color: #1f2937;
}
.dropdown-option:focus {
outline: none;
background-color: #eff6ff;
color: #1e40af;
}
.dropdown-option:active {
background-color: #dbeafe;
}
.dropdown-option.selected {
background-color: #f3f4f6;
color: #374151;
}
.dropdown-option.selected:hover {
background-color: #e5e7eb;
}
.option-primary {
font-weight: 500;
line-height: 1.4;
}
.option-secondary {
font-size: 0.875rem;
opacity: 0.7;
font-weight: 400;
}
.dropdown-option.selected .option-secondary {
opacity: 0.9;
}
/* Dropdown scrollbar styling */
.dropdown-options::-webkit-scrollbar {
width: 8px;
}
.dropdown-options::-webkit-scrollbar-track {
background: #f8fafc;
border-radius: 4px;
}
.dropdown-options::-webkit-scrollbar-thumb {
background: #cbd5e1;
border-radius: 4px;
border: 1px solid #f8fafc;
}
.dropdown-options::-webkit-scrollbar-thumb:hover {
background: #94a3b8;
}
/* Loading state */
.loading-state {
display: flex;
align-items: center;
gap: 0.75rem;
padding: 1rem;
background-color: #fff;
border: 1px solid #ebe9e9;
border-radius: 8px;
margin: 1rem 1.5rem;
}
.spinner {
width: 20px;
height: 20px;
border: 2px solid #ebe9e9;
border-top: 2px solid #1c1a19;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
/* Error state */
.error-state {
padding: 1rem;
background-color: #fef2f2;
border: 1px solid #fecaca;
border-radius: 8px;
margin: 1rem 1.5rem;
}
.error-state strong {
display: block;
margin-bottom: 0.25rem;
font-weight: 600;
}
.error-state p {
margin: 0;
font-size: 0.875rem;
opacity: 0.8;
}
/* OAuth Error Section */
.error-section {
background: #fff;
border: 1px solid #fecaca;
}
.error-details {
padding: 0 1.5rem 1.5rem 1.5rem;
display: flex;
flex-direction: column;
gap: 1rem;
}
.error-item {
display: flex;
flex-direction: column;
gap: 0.5rem;
}
.error-item strong {
font-weight: 500;
font-size: 0.875rem;
color: #991b1b;
}
.error-code {
display: inline-block;
padding: 0.5rem 0.75rem;
background-color: #fef2f2;
border: 1px solid #fecaca;
border-radius: 6px;
font-family: 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas, 'Courier New', monospace;
font-size: 0.875rem;
color: #991b1b;
font-weight: 500;
}
.error-description {
margin: 0;
padding: 0.75rem;
background-color: #fef2f2;
border: 1px solid #fecaca;
border-radius: 6px;
color: #991b1b;
font-size: 0.875rem;
line-height: 1.5;
}
.error-actions {
margin-top: 0.5rem;
padding: 1rem;
background-color: #fffbeb;
border: 1px solid #fde68a;
border-radius: 6px;
}
.error-actions p {
margin: 0;
color: #92400e;
font-size: 0.875rem;
line-height: 1.5;
}
/* Context details */
.context-details {
padding: 0 1.5rem 1.5rem 1.5rem;
display: flex;
flex-direction: column;
gap: 1rem;
}
.context-item {
display: flex;
flex-direction: column;
gap: 0.25rem;
}
.context-item strong {
font-weight: 500;
font-size: 0.875rem;
opacity: 0.7;
}
.environment-details {
display: flex;
align-items: center;
gap: 0.5rem;
}
.env-name {
font-weight: 500;
}
.env-type {
font-size: 0.875rem;
opacity: 0.6;
}
/* Actions section */
.actions-section {
padding: 1.5rem;
text-align: center;
background-color: #f9fafb;
}
/* Button container */
.button-container {
display: flex;
justify-content: center;
align-items: center;
}
/* Button */
.primary-button {
display: inline-flex;
align-items: center;
padding: 0.75rem 1.5rem;
background-color: #1c1a19;
color: #fff;
border: 1px solid #1c1a19;
border-radius: 8px;
font-size: 1rem;
font-weight: 500;
cursor: pointer;
transition:
background-color 0.15s ease-in-out,
transform 0.15s ease-in-out,
opacity 0.15s ease-in-out;
}
.primary-button:hover {
background-color: #2d2a28;
border-color: #2d2a28;
transform: translateY(-1px);
}
.primary-button:focus {
outline: none;
box-shadow: 0 0 0 3px rgba(28, 26, 25, 0.2);
border-color: #2d2a28;
}
.primary-button:active {
transform: translateY(0);
background-color: #3d3a38;
}
.primary-button:disabled {
background-color: #d1d5db;
border-color: #d1d5db;
color: #6b7280;
cursor: not-allowed;
transform: none;
box-shadow: none;
opacity: 0.65;
}
.primary-button:disabled:hover,
.primary-button:disabled:focus,
.primary-button:disabled:active {
background-color: #d1d5db;
border-color: #d1d5db;
color: #6b7280;
transform: none;
box-shadow: none;
}
/* Completion section */
.completion-section {
padding: 0;
}
.completion-card {
padding: 2rem 1.5rem;
text-align: center;
}
.completion-card h2 {
margin: 0 0 1rem 0;
font-size: 1.75rem;
font-weight: 600;
}
.completion-card p {
margin: 0;
font-size: 1rem;
line-height: 1.6;
}
/* Response section */
.response-section {
padding: 1.5rem;
}
.response-text {
background-color: #fff;
border: 1px solid #ebe9e9;
border-radius: 8px;
padding: 1rem;
font-family: 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas, 'Courier New', monospace;
font-size: 0.875rem;
line-height: 1.5;
white-space: pre-wrap;
word-break: break-word;
overflow-x: auto;
}
/* Responsive design */
@media (max-width: 768px) {
.logo-container {
top: 0.5rem;
left: 0.5rem;
}
.logo {
height: 1.5rem;
}
.app-container {
padding: 1rem 0.5rem;
}
.app-content {
max-width: 100%;
}
.app-header h1 {
font-size: 2rem;
}
.section-header {
padding: 1rem 1rem 0 1rem;
margin-bottom: 1rem;
}
.form-content,
.context-details,
.actions-section,
.response-section {
padding-left: 1rem;
padding-right: 1rem;
}
.loading-state,
.error-state {
margin-left: 1rem;
margin-right: 1rem;
}
}
@media (max-width: 480px) {
.logo {
height: 1.25rem;
}
.app-container {
padding: 0.5rem 0.25rem;
}
.app-header h1 {
font-size: 1.75rem;
}
.primary-button {
width: 100%;
}
}
/* Light mode styles */
@media (prefers-color-scheme: light) {
body {
background-color: #fff;
color: #1c1a19;
}
/* Sections */
section {
background: #fff;
border-color: #ebe9e9;
}
.section-header {
border-bottom-color: #ebe9e9;
}
.section-header h2,
.section-header h3 {
color: #1c1a19;
}
.section-header p {
color: #1c1a19;
}
/* Form elements */
.form-label {
color: #1c1a19;
}
.form-select {
background-color: #fff;
border-color: #ebe9e9;
color: #1c1a19;
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%231c1a19' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
}
.form-select:focus {
background-color: #fff;
border-color: #ebe9e9;
}
.form-select:hover:not(:focus) {
background-color: #fff;
border-color: #ebe9e9;
}
.form-select:disabled {
background-color: #f9f9f9;
border-color: #ebe9e9;
}
/* Custom dropdown */
.dropdown-trigger {
background-color: #fff;
border-color: #ebe9e9;
color: #1c1a19;
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%231c1a19' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
}
.dropdown-trigger:focus {
background-color: #fff;
border-color: #ebe9e9;
}
.dropdown-trigger:hover:not(:focus) {
background-color: #fff;
border-color: #ebe9e9;
}
.dropdown-trigger.open {
background-color: #fff;
border-color: #ebe9e9;
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%231c1a19' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M14 12l-4-4-4 4'/%3e%3c/svg%3e");
}
.dropdown-options {
background: #fff;
border-color: #ebe9e9;
}
.dropdown-option {
color: #1c1a19;
}
.dropdown-option:hover {
background-color: #f9f9f9;
color: #1c1a19;
}
.dropdown-option:focus {
background-color: #f9f9f9;
color: #1c1a19;
}
.dropdown-option:active {
background-color: #f3f3f3;
}
.dropdown-option.selected {
background-color: #f9f9f9;
color: #1c1a19;
}
.dropdown-option.selected:hover {
background-color: #f3f3f3;
}
/* Loading state */
.loading-state {
background-color: #fff;
border: 1px solid #ebe9e9;
color: #1c1a19;
}
.spinner {
border-color: #ebe9e9;
border-top-color: #1c1a19;
}
/* Error state */
.error-state {
background-color: #fef2f2;
border-color: #fecaca;
color: #991b1b;
}
.error-state strong {
color: #991b1b;
}
.error-state p {
color: #991b1b;
}
/* Context details */
.context-item strong {
color: #1c1a19;
}
.env-name {
color: #1c1a19;
}
.env-type {
color: #1c1a19;
}
/* Actions section */
.actions-section {
background-color: #fff;
}
/* Response section */
.response-text {
background-color: #fff;
border-color: #ebe9e9;
color: #1c1a19;
}
/* App header */
.app-header h1 {
color: #1c1a19;
}
.app-header p {
color: #1c1a19;
}
/* Button light mode */
.primary-button {
background-color: #1c1a19;
color: #fff;
border-color: #1c1a19;
}
.primary-button:hover {
background-color: #2d2a28;
border-color: #2d2a28;
}
.primary-button:focus {
box-shadow: 0 0 0 3px rgba(28, 26, 25, 0.2);
border-color: #2d2a28;
}
.primary-button:active {
background-color: #3d3a38;
}
.primary-button:disabled {
background-color: #d6d3d1;
border-color: #e7e5e4;
color: #78716c;
opacity: 0.7;
}
.primary-button:disabled:hover,
.primary-button:disabled:focus,
.primary-button:disabled:active {
background-color: #d6d3d1;
border-color: #e7e5e4;
color: #78716c;
}
}
/* Dark mode styles */
@media (prefers-color-scheme: dark) {
/* Logo theme switching */
.logo-light {
display: none;
}
.logo-dark {
display: block;
}
/* Sections */
section {
background: #1c1a19;
border-color: #4e4a49;
}
.section-header {
border-bottom-color: #4e4a49;
}
.section-header h2,
.section-header h3 {
color: #f6f6f6;
}
.section-header p {
color: #f6f6f6;
}
/* Form elements */
.form-label {
color: #f6f6f6;
}
.form-select {
background-color: #1c1a19;
border-color: #4e4a49;
color: #f6f6f6;
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%23f6f6f6' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
}
.form-select:focus {
background-color: #1c1a19;
border-color: #4e4a49;
}
.form-select:hover:not(:focus) {
background-color: #1c1a19;
border-color: #4e4a49;
}
.form-select:disabled {
background-color: #374151;
border-color: #4e4a49;
}
/* Custom dropdown */
.dropdown-trigger {
background-color: #1c1a19;
border-color: #4e4a49;
color: #f6f6f6;
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%23f6f6f6' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
}
.dropdown-trigger:focus {
background-color: #1c1a19;
border-color: #4e4a49;
}
.dropdown-trigger:hover:not(:focus) {
background-color: #1c1a19;
border-color: #4e4a49;
}
.dropdown-trigger.open {
background-color: #1c1a19;
border-color: #4e4a49;
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%23f6f6f6' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M14 12l-4-4-4 4'/%3e%3c/svg%3e");
}
.dropdown-trigger.placeholder {
color: #9ca3af;
}
.dropdown-options {
background: #1c1a19;
border-color: #4e4a49;
}
.dropdown-option {
color: #f6f6f6;
}
.dropdown-option:hover {
background-color: #374151;
color: #f6f6f6;
}
.dropdown-option:focus {
background-color: #374151;
color: #f6f6f6;
}
.dropdown-option:active {
background-color: #4b5563;
}
.dropdown-option.selected {
background-color: #4e4a49;
color: #f6f6f6;
}
.dropdown-option.selected:hover {
background-color: #6b7280;
}
/* Dropdown scrollbar styling for dark mode */
.dropdown-options::-webkit-scrollbar-track {
background: #374151;
}
.dropdown-options::-webkit-scrollbar-thumb {
background: #6b7280;
border: 1px solid #374151;
}
.dropdown-options::-webkit-scrollbar-thumb:hover {
background: #9ca3af;
}
/* Loading state */
.loading-state {
background-color: #1c1a19;
border: 1px solid #4e4a49;
color: #f6f6f6;
}
.spinner {
border-color: #4e4a49;
border-top-color: #f6f6f6;
}
/* Error state */
.error-state {
background-color: #7f1d1d;
border-color: #4e4a49;
color: #f6f6f6;
}
.error-state strong {
color: #f6f6f6;
}
.error-state p {
color: #f6f6f6;
}
/* OAuth Error Section Dark Mode */
.error-section {
background: #1c1a19;
border-color: #991b1b;
}
.error-item strong {
color: #fca5a5;
}
.error-code {
background-color: #450a0a;
border-color: #7f1d1d;
color: #fca5a5;
}
.error-description {
background-color: #450a0a;
border-color: #7f1d1d;
color: #fca5a5;
}
.error-actions {
background-color: #422006;
border-color: #92400e;
}
.error-actions p {
color: #fde68a;
}
/* Context details */
.context-item strong {
color: #f6f6f6;
}
.env-name {
color: #f6f6f6;
}
.env-type {
color: #f6f6f6;
}
/* Actions section */
.actions-section {
background-color: #374151;
}
/* Response section */
.response-text {
background-color: #374151;
border-color: #4e4a49;
color: #f6f6f6;
}
/* App header */
.app-header h1 {
color: #f6f6f6;
}
.app-header p {
color: #f6f6f6;
}
/* Button dark mode */
.primary-button {
background-color: #fdfdfd;
color: #374151;
border-color: #4e4a49;
}
.primary-button:hover {
background-color: #f3f4f6;
border-color: #6b7280;
}
.primary-button:focus {
box-shadow: 0 0 0 3px rgba(246, 246, 246, 0.1);
border-color: #9ca3af;
}
.primary-button:active {
background-color: #e5e7eb;
}
.primary-button:disabled {
background-color: #2f2f30;
border-color: #3f3f40;
color: #8b949e;
opacity: 0.55;
}
.primary-button:disabled:hover,
.primary-button:disabled:focus,
.primary-button:disabled:active {
background-color: #2f2f30;
border-color: #3f3f40;
color: #8b949e;
}
}
```
--------------------------------------------------------------------------------
/tests/unit/discovery/test_exposures_fetcher.py:
--------------------------------------------------------------------------------
```python
from unittest.mock import patch
import pytest
from dbt_mcp.discovery.client import ExposuresFetcher
@pytest.fixture
def exposures_fetcher(mock_api_client):
return ExposuresFetcher(api_client=mock_api_client)
async def test_fetch_exposures_single_page(exposures_fetcher, mock_api_client):
mock_response = {
"data": {
"environment": {
"definition": {
"exposures": {
"pageInfo": {"hasNextPage": False, "endCursor": None},
"edges": [
{
"node": {
"name": "test_exposure",
"uniqueId": "exposure.test.test_exposure",
"exposureType": "application",
"maturity": "high",
"ownerEmail": "[email protected]",
"ownerName": "Test Owner",
"url": "https://example.com",
"meta": {},
"freshnessStatus": "Unknown",
"description": "Test exposure",
"label": None,
"parents": [
{"uniqueId": "model.test.parent_model"}
],
}
}
],
}
}
}
}
}
mock_api_client.execute_query.return_value = mock_response
with patch("dbt_mcp.discovery.client.raise_gql_error"):
result = await exposures_fetcher.fetch_exposures()
assert len(result) == 1
assert result[0]["name"] == "test_exposure"
assert result[0]["uniqueId"] == "exposure.test.test_exposure"
assert result[0]["exposureType"] == "application"
assert result[0]["maturity"] == "high"
assert result[0]["ownerEmail"] == "[email protected]"
assert result[0]["ownerName"] == "Test Owner"
assert result[0]["url"] == "https://example.com"
assert result[0]["meta"] == {}
assert result[0]["freshnessStatus"] == "Unknown"
assert result[0]["description"] == "Test exposure"
assert result[0]["parents"] == [{"uniqueId": "model.test.parent_model"}]
mock_api_client.execute_query.assert_called_once()
args, kwargs = mock_api_client.execute_query.call_args
assert args[1]["environmentId"] == 123
assert args[1]["first"] == 100
async def test_fetch_exposures_multiple_pages(exposures_fetcher, mock_api_client):
page1_response = {
"data": {
"environment": {
"definition": {
"exposures": {
"pageInfo": {"hasNextPage": True, "endCursor": "cursor123"},
"edges": [
{
"node": {
"name": "exposure1",
"uniqueId": "exposure.test.exposure1",
"exposureType": "application",
"maturity": "high",
"ownerEmail": "[email protected]",
"ownerName": "Test Owner 1",
"url": "https://example1.com",
"meta": {},
"freshnessStatus": "Unknown",
"description": "Test exposure 1",
"label": None,
"parents": [],
}
}
],
}
}
}
}
}
page2_response = {
"data": {
"environment": {
"definition": {
"exposures": {
"pageInfo": {"hasNextPage": False, "endCursor": "cursor456"},
"edges": [
{
"node": {
"name": "exposure2",
"uniqueId": "exposure.test.exposure2",
"exposureType": "dashboard",
"maturity": "medium",
"ownerEmail": "[email protected]",
"ownerName": "Test Owner 2",
"url": "https://example2.com",
"meta": {"key": "value"},
"freshnessStatus": "Fresh",
"description": "Test exposure 2",
"label": "Label 2",
"parents": [
{"uniqueId": "model.test.parent_model2"}
],
}
}
],
}
}
}
}
}
mock_api_client.execute_query.side_effect = [page1_response, page2_response]
with patch("dbt_mcp.discovery.client.raise_gql_error"):
result = await exposures_fetcher.fetch_exposures()
assert len(result) == 2
assert result[0]["name"] == "exposure1"
assert result[1]["name"] == "exposure2"
assert result[1]["meta"] == {"key": "value"}
assert result[1]["label"] == "Label 2"
assert mock_api_client.execute_query.call_count == 2
# Check first call (no cursor)
first_call = mock_api_client.execute_query.call_args_list[0]
assert first_call[0][1]["environmentId"] == 123
assert first_call[0][1]["first"] == 100
assert "after" not in first_call[0][1]
# Check second call (with cursor)
second_call = mock_api_client.execute_query.call_args_list[1]
assert second_call[0][1]["environmentId"] == 123
assert second_call[0][1]["first"] == 100
assert second_call[0][1]["after"] == "cursor123"
async def test_fetch_exposures_empty_response(exposures_fetcher, mock_api_client):
mock_response = {
"data": {
"environment": {
"definition": {
"exposures": {
"pageInfo": {"hasNextPage": False, "endCursor": None},
"edges": [],
}
}
}
}
}
mock_api_client.execute_query.return_value = mock_response
with patch("dbt_mcp.discovery.client.raise_gql_error"):
result = await exposures_fetcher.fetch_exposures()
assert len(result) == 0
assert isinstance(result, list)
async def test_fetch_exposures_handles_malformed_edges(
exposures_fetcher, mock_api_client
):
mock_response = {
"data": {
"environment": {
"definition": {
"exposures": {
"pageInfo": {"hasNextPage": False, "endCursor": None},
"edges": [
{
"node": {
"name": "valid_exposure",
"uniqueId": "exposure.test.valid_exposure",
"exposureType": "application",
"maturity": "high",
"ownerEmail": "[email protected]",
"ownerName": "Test Owner",
"url": "https://example.com",
"meta": {},
"freshnessStatus": "Unknown",
"description": "Valid exposure",
"label": None,
"parents": [],
}
},
{"invalid": "edge"}, # Missing "node" key
{"node": "not_a_dict"}, # Node is not a dict
{
"node": {
"name": "another_valid_exposure",
"uniqueId": "exposure.test.another_valid_exposure",
"exposureType": "dashboard",
"maturity": "low",
"ownerEmail": "[email protected]",
"ownerName": "Test Owner 2",
"url": "https://example2.com",
"meta": {},
"freshnessStatus": "Stale",
"description": "Another valid exposure",
"label": None,
"parents": [],
}
},
],
}
}
}
}
}
mock_api_client.execute_query.return_value = mock_response
with patch("dbt_mcp.discovery.client.raise_gql_error"):
result = await exposures_fetcher.fetch_exposures()
# Should only get the valid exposures (malformed edges should be filtered out)
assert len(result) == 2
assert result[0]["name"] == "valid_exposure"
assert result[1]["name"] == "another_valid_exposure"
async def test_fetch_exposure_details_by_unique_ids_single(
exposures_fetcher, mock_api_client
):
mock_response = {
"data": {
"environment": {
"definition": {
"exposures": {
"edges": [
{
"node": {
"name": "customer_dashboard",
"uniqueId": "exposure.analytics.customer_dashboard",
"exposureType": "dashboard",
"maturity": "high",
"ownerEmail": "[email protected]",
"ownerName": "Analytics Team",
"url": "https://dashboard.example.com/customers",
"meta": {"team": "analytics", "priority": "high"},
"freshnessStatus": "Fresh",
"description": "Customer analytics dashboard",
"label": "Customer Dashboard",
"parents": [
{"uniqueId": "model.analytics.customers"},
{
"uniqueId": "model.analytics.customer_metrics"
},
],
}
}
]
}
}
}
}
}
mock_api_client.execute_query.return_value = mock_response
with patch("dbt_mcp.discovery.client.raise_gql_error"):
result = await exposures_fetcher.fetch_exposure_details(
unique_ids=["exposure.analytics.customer_dashboard"]
)
assert isinstance(result, list)
assert len(result) == 1
exposure = result[0]
assert exposure["name"] == "customer_dashboard"
assert exposure["uniqueId"] == "exposure.analytics.customer_dashboard"
assert exposure["exposureType"] == "dashboard"
assert exposure["maturity"] == "high"
assert exposure["ownerEmail"] == "[email protected]"
assert exposure["ownerName"] == "Analytics Team"
assert exposure["url"] == "https://dashboard.example.com/customers"
assert exposure["meta"] == {"team": "analytics", "priority": "high"}
assert exposure["freshnessStatus"] == "Fresh"
assert exposure["description"] == "Customer analytics dashboard"
assert exposure["label"] == "Customer Dashboard"
assert len(exposure["parents"]) == 2
assert exposure["parents"][0]["uniqueId"] == "model.analytics.customers"
assert exposure["parents"][1]["uniqueId"] == "model.analytics.customer_metrics"
mock_api_client.execute_query.assert_called_once()
args, kwargs = mock_api_client.execute_query.call_args
assert args[1]["environmentId"] == 123
assert args[1]["first"] == 1
assert args[1]["filter"] == {"uniqueIds": ["exposure.analytics.customer_dashboard"]}
async def test_fetch_exposure_details_by_unique_ids_multiple(
exposures_fetcher, mock_api_client
):
mock_response = {
"data": {
"environment": {
"definition": {
"exposures": {
"edges": [
{
"node": {
"name": "customer_dashboard",
"uniqueId": "exposure.analytics.customer_dashboard",
"exposureType": "dashboard",
"maturity": "high",
"ownerEmail": "[email protected]",
"ownerName": "Analytics Team",
"url": "https://dashboard.example.com/customers",
"meta": {"team": "analytics", "priority": "high"},
"freshnessStatus": "Fresh",
"description": "Customer analytics dashboard",
"label": "Customer Dashboard",
"parents": [],
}
},
{
"node": {
"name": "sales_report",
"uniqueId": "exposure.sales.sales_report",
"exposureType": "analysis",
"maturity": "medium",
"ownerEmail": "[email protected]",
"ownerName": "Sales Team",
"url": None,
"meta": {},
"freshnessStatus": "Stale",
"description": "Monthly sales analysis report",
"label": None,
"parents": [{"uniqueId": "model.sales.sales_data"}],
}
},
]
}
}
}
}
}
mock_api_client.execute_query.return_value = mock_response
with patch("dbt_mcp.discovery.client.raise_gql_error"):
result = await exposures_fetcher.fetch_exposure_details(
unique_ids=[
"exposure.analytics.customer_dashboard",
"exposure.sales.sales_report",
]
)
assert isinstance(result, list)
assert len(result) == 2
# Check first exposure
exposure1 = result[0]
assert exposure1["name"] == "customer_dashboard"
assert exposure1["uniqueId"] == "exposure.analytics.customer_dashboard"
assert exposure1["exposureType"] == "dashboard"
# Check second exposure
exposure2 = result[1]
assert exposure2["name"] == "sales_report"
assert exposure2["uniqueId"] == "exposure.sales.sales_report"
assert exposure2["exposureType"] == "analysis"
mock_api_client.execute_query.assert_called_once()
args, kwargs = mock_api_client.execute_query.call_args
assert args[1]["environmentId"] == 123
assert args[1]["first"] == 2
assert args[1]["filter"] == {
"uniqueIds": [
"exposure.analytics.customer_dashboard",
"exposure.sales.sales_report",
]
}
async def test_fetch_exposure_details_by_name(exposures_fetcher, mock_api_client):
# Mock the response for fetch_exposures (which gets called when filtering by name)
mock_exposures_response = {
"data": {
"environment": {
"definition": {
"exposures": {
"pageInfo": {"hasNextPage": False, "endCursor": None},
"edges": [
{
"node": {
"name": "sales_report",
"uniqueId": "exposure.sales.sales_report",
"exposureType": "analysis",
"maturity": "medium",
"ownerEmail": "[email protected]",
"ownerName": "Sales Team",
"url": None,
"meta": {},
"freshnessStatus": "Stale",
"description": "Monthly sales analysis report",
"label": None,
"parents": [{"uniqueId": "model.sales.sales_data"}],
}
},
{
"node": {
"name": "other_exposure",
"uniqueId": "exposure.other.other_exposure",
"exposureType": "dashboard",
"maturity": "high",
"ownerEmail": "[email protected]",
"ownerName": "Other Team",
"url": None,
"meta": {},
"freshnessStatus": "Fresh",
"description": "Other exposure",
"label": None,
"parents": [],
}
},
],
}
}
}
}
}
mock_api_client.execute_query.return_value = mock_exposures_response
with patch("dbt_mcp.discovery.client.raise_gql_error"):
result = await exposures_fetcher.fetch_exposure_details(
exposure_name="sales_report"
)
assert isinstance(result, list)
assert len(result) == 1
exposure = result[0]
assert exposure["name"] == "sales_report"
assert exposure["uniqueId"] == "exposure.sales.sales_report"
assert exposure["exposureType"] == "analysis"
assert exposure["maturity"] == "medium"
assert exposure["url"] is None
assert exposure["meta"] == {}
assert exposure["freshnessStatus"] == "Stale"
assert exposure["label"] is None
# Should have called the GET_EXPOSURES query (not GET_EXPOSURE_DETAILS)
mock_api_client.execute_query.assert_called_once()
args, kwargs = mock_api_client.execute_query.call_args
assert args[1]["environmentId"] == 123
assert args[1]["first"] == 100 # PAGE_SIZE for fetch_exposures
async def test_fetch_exposure_details_not_found(exposures_fetcher, mock_api_client):
mock_response = {
"data": {"environment": {"definition": {"exposures": {"edges": []}}}}
}
mock_api_client.execute_query.return_value = mock_response
with patch("dbt_mcp.discovery.client.raise_gql_error"):
result = await exposures_fetcher.fetch_exposure_details(
unique_ids=["exposure.nonexistent.exposure"]
)
assert result == []
async def test_get_exposure_filters_unique_ids(exposures_fetcher):
filters = exposures_fetcher._get_exposure_filters(
unique_ids=["exposure.test.test_exposure"]
)
assert filters == {"uniqueIds": ["exposure.test.test_exposure"]}
async def test_get_exposure_filters_multiple_unique_ids(exposures_fetcher):
filters = exposures_fetcher._get_exposure_filters(
unique_ids=["exposure.test.test1", "exposure.test.test2"]
)
assert filters == {"uniqueIds": ["exposure.test.test1", "exposure.test.test2"]}
async def test_get_exposure_filters_name_raises_error(exposures_fetcher):
from dbt_mcp.errors import InvalidParameterError
with pytest.raises(
InvalidParameterError, match="ExposureFilter only supports uniqueIds"
):
exposures_fetcher._get_exposure_filters(exposure_name="test_exposure")
async def test_get_exposure_filters_no_params(exposures_fetcher):
from dbt_mcp.errors import InvalidParameterError
with pytest.raises(
InvalidParameterError,
match="unique_ids must be provided for exposure filtering",
):
exposures_fetcher._get_exposure_filters()
async def test_fetch_exposure_details_by_name_not_found(
exposures_fetcher, mock_api_client
):
# Mock empty response for fetch_exposures
mock_response = {
"data": {
"environment": {
"definition": {
"exposures": {
"pageInfo": {"hasNextPage": False, "endCursor": None},
"edges": [],
}
}
}
}
}
mock_api_client.execute_query.return_value = mock_response
with patch("dbt_mcp.discovery.client.raise_gql_error"):
result = await exposures_fetcher.fetch_exposure_details(
exposure_name="nonexistent_exposure"
)
assert result == []
```
--------------------------------------------------------------------------------
/src/dbt_mcp/lsp/lsp_connection.py:
--------------------------------------------------------------------------------
```python
"""LSP Connection Manager for dbt Fusion LSP.
This module manages the lifecycle of LSP processes and handles JSON-RPC
communication according to the Language Server Protocol specification.
"""
import asyncio
import itertools
import json
import logging
import socket
import subprocess
from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import uuid
from dataclasses import asdict
from dbt_mcp.lsp.providers.lsp_connection_provider import (
LSPConnectionProviderProtocol,
LspEventName,
)
logger = logging.getLogger(__name__)
def event_name_from_string(string: str) -> LspEventName | None:
"""Create an LSP event name from a string."""
try:
return LspEventName(string)
except ValueError:
return None
@dataclass
class JsonRpcMessage:
"""Represents a JSON-RPC 2.0 message."""
jsonrpc: str = "2.0"
id: int | str | None = None
method: str | None = None
params: dict[str, Any] | list[Any] | None = None
result: Any = None
error: dict[str, Any] | None = None
def to_dict(self, none_values: bool = False) -> dict[str, Any]:
"""Convert the message to a dictionary."""
def dict_factory(x: list[tuple[str, Any]]) -> dict[str, Any]:
return dict(x) if none_values else {k: v for k, v in x if v is not None}
return asdict(self, dict_factory=dict_factory)
@dataclass
class LspConnectionState:
"""Tracks the state of an LSP connection."""
initialized: bool = False
shutting_down: bool = False
capabilities: dict[str, Any] = field(default_factory=dict)
pending_requests: dict[int | str, asyncio.Future] = field(default_factory=dict)
pending_notifications: dict[LspEventName, list[asyncio.Future]] = field(
default_factory=dict
)
compiled: bool = False
# start at 20 to avoid collisions between ids of requests we are waiting for and the lsp server requests from us
request_id_counter: Iterator[int] = field(
default_factory=lambda: itertools.count(20)
)
def get_next_request_id(self) -> int:
return next(self.request_id_counter)
class SocketLSPConnection(LSPConnectionProviderProtocol):
"""LSP process lifecycle and communication via socket.
This class handles:
- Starting and stopping LSP server processes
- Socket-based JSON-RPC communication
- Request/response correlation
- Error handling and cleanup
"""
def __init__(
self,
binary_path: str,
cwd: str,
args: Sequence[str] | None = None,
connection_timeout: float = 10,
default_request_timeout: float = 60,
):
"""Initialize the LSP connection manager.
Args:
binary_path: Path to the LSP server binary
cwd: Working directory for the LSP process
args: Optional command-line arguments for the LSP server
connection_timeout: Timeout in seconds for establishing the initial socket
connection (default: 10). Used during server startup.
default_request_timeout: Default timeout in seconds for LSP request operations
(default: 60). Used when no timeout is specified for
individual requests.
"""
self.binary_path = Path(binary_path)
self.args = list(args) if args else []
self.cwd = cwd
self.host = "127.0.0.1"
self.port = 0
self.process: asyncio.subprocess.Process | None = None
self.state = LspConnectionState()
# Socket components
self._socket: socket.socket | None = None
self._connection: socket.socket | None = None
# Asyncio components for I/O
self._reader_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None
self._stdout_reader_task: asyncio.Task | None = None
self._stderr_reader_task: asyncio.Task | None = None
self._stop_event = asyncio.Event()
self._outgoing_queue: asyncio.Queue[bytes] = asyncio.Queue()
# Timeouts
self.connection_timeout = connection_timeout
self.default_request_timeout = default_request_timeout
logger.debug(f"LSP Connection initialized with binary: {self.binary_path}")
def compiled(self) -> bool:
return self.state.compiled
def initialized(self) -> bool:
return self.state.initialized
async def start(self) -> None:
"""Start the LSP server process and socket communication tasks."""
if self.process is not None:
logger.warning("LSP process is already running")
return
try:
self.setup_socket()
await self.launch_lsp_process()
# Wait for connection with timeout (run socket.accept in executor)
if self._socket:
self._socket.settimeout(self.connection_timeout)
try:
(
self._connection,
client_addr,
) = await asyncio.get_running_loop().run_in_executor(
None, self._socket.accept
)
if self._connection:
self._connection.settimeout(
None
) # Set to blocking for read/write
logger.debug(f"LSP server connected from {client_addr}")
except TimeoutError:
raise RuntimeError("Timeout waiting for LSP server to connect")
# Start I/O tasks
self._stop_event.clear()
self._reader_task = asyncio.get_running_loop().create_task(
self._read_loop()
)
self._writer_task = asyncio.get_running_loop().create_task(
self._write_loop()
)
except Exception as e:
logger.error(f"Failed to start LSP server: {e}")
await self.stop()
raise
def setup_socket(self) -> None:
"""Set up the socket for LSP server communication.
Creates a TCP socket, binds it to the configured host and port,
and starts listening for incoming connections. If port is 0,
the OS will auto-assign an available port.
"""
# Create socket and bind
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._socket.bind((self.host, self.port))
self._socket.listen(1)
# Get the actual port if auto-assigned
_, actual_port = self._socket.getsockname()
self.port = actual_port
logger.debug(f"Socket listening on {self.host}:{self.port}")
async def launch_lsp_process(self) -> None:
"""Launch the LSP server process.
Starts the LSP server as a subprocess with socket communication enabled.
The process is started with stdout and stderr capture for monitoring.
The server will connect back to the socket set up by setup_socket().
"""
# Prepare command with socket info
cmd = [
str(self.binary_path),
"--socket",
f"{self.port}",
"--project-dir",
self.cwd,
*self.args,
]
logger.debug(f"Starting LSP server: {' '.join(cmd)}")
self.process = await asyncio.create_subprocess_exec(*cmd)
logger.info(f"LSP server started with PID: {self.process.pid}")
async def stop(self) -> None:
"""Stop the LSP server process and cleanup resources."""
logger.info("Stopping LSP server...")
# Signal tasks to stop
self._stop_event.set()
# Cancel I/O tasks
if self._reader_task and not self._reader_task.done():
self._reader_task.cancel()
try:
await self._reader_task
except asyncio.CancelledError:
pass
if self._writer_task and not self._writer_task.done():
self._writer_task.cancel()
try:
await self._writer_task
except asyncio.CancelledError:
pass
# Cancel stdout/stderr reader tasks
if self._stdout_reader_task and not self._stdout_reader_task.done():
self._stdout_reader_task.cancel()
try:
await self._stdout_reader_task
except asyncio.CancelledError:
pass
if self._stderr_reader_task and not self._stderr_reader_task.done():
self._stderr_reader_task.cancel()
try:
await self._stderr_reader_task
except asyncio.CancelledError:
pass
# Send shutdown request if initialized
if self.process and not self.state.shutting_down:
self.state.shutting_down = True
try:
self._send_shutdown_request()
await asyncio.sleep(0.5) # Give server time to process shutdown
except Exception as e:
logger.warning(f"Error sending shutdown request: {e}")
# Close socket connection
if self._connection:
try:
self._connection.close()
except Exception as e:
logger.error(f"Error closing socket connection: {e}")
finally:
self._connection = None
# Close listening socket
if self._socket:
try:
self._socket.close()
except Exception as e:
logger.error(f"Error closing socket: {e}")
finally:
self._socket = None
# Terminate the process
if self.process:
try:
self.process.terminate()
try:
await self.process.wait()
except subprocess.TimeoutExpired:
logger.warning("LSP process didn't terminate, killing...")
self.process.kill()
await self.process.wait()
except Exception as e:
logger.error(f"Error terminating LSP process: {e}")
finally:
self.process = None
# Clear state
self.state = LspConnectionState()
logger.info("LSP server stopped")
async def initialize(self, timeout: float | None = None) -> None:
"""Initialize the LSP connection.
Sends the initialize request to the LSP server and waits for the response.
The server capabilities are stored in the connection state.
Args:
root_uri: The root URI of the workspace (optional)
timeout: Timeout in seconds for the initialize request (default: 10)
"""
if self.state.initialized:
raise RuntimeError("LSP server is already initialized")
params = {
"processId": None,
"rootUri": None,
"clientInfo": {
"name": "dbt-mcp",
"version": "1.0.0",
},
"capabilities": {},
"initializationOptions": {
"project-dir": "file:///",
"command-prefix": str(uuid.uuid4()),
},
}
# Send initialize request
result = await self.send_request(
"initialize", params, timeout=timeout or self.default_request_timeout
)
# Store capabilities
self.state.capabilities = result.get("capabilities", {})
self.state.initialized = True
# Send initialized notification
self.send_notification("initialized", {})
logger.info("LSP server initialized successfully")
async def _read_loop(self) -> None:
"""Background task that reads messages from the LSP server via socket."""
if not self._connection:
logger.warning("LSP server socket is not available")
return
buffer = b""
while not self._stop_event.is_set():
try:
# Read data from socket (run in executor to avoid blocking)
self._connection.settimeout(0.1) # Short timeout to check stop event
try:
chunk = await asyncio.get_running_loop().run_in_executor(
None, self._connection.recv, 4096
)
except TimeoutError:
continue
if not chunk:
logger.warning("LSP server socket closed")
break
buffer += chunk
# Try to parse messages from buffer
while True:
message, remaining = self._parse_message(buffer)
if message is None:
break
buffer = remaining
# Process the message
self._handle_incoming_message(message)
except asyncio.CancelledError:
# Task was cancelled, exit cleanly
break
except Exception as e:
if not self._stop_event.is_set():
logger.error(f"Error in reader task: {e}")
break
async def _write_loop(self) -> None:
"""Background task that writes messages to the LSP server via socket."""
if not self._connection:
return
while not self._stop_event.is_set():
try:
# Get message from queue (with timeout to check stop event)
try:
data = await asyncio.wait_for(
self._outgoing_queue.get(), timeout=0.1
)
except TimeoutError:
continue
# Write to socket (run in executor to avoid blocking)
await asyncio.get_running_loop().run_in_executor(
None, self._connection.sendall, data
)
except asyncio.CancelledError:
# Task was cancelled, exit cleanly
break
except Exception as e:
if not self._stop_event.is_set():
logger.error(f"Error in writer task: {e}")
break
def _parse_message(self, buffer: bytes) -> tuple[JsonRpcMessage | None, bytes]:
"""Parse a JSON-RPC message from the buffer.
LSP uses HTTP-like headers followed by JSON content:
Content-Length: <length>\\r\\n
\\r\\n
<json-content>
"""
# Look for Content-Length header
header_end = buffer.find(b"\r\n\r\n")
if header_end == -1:
return None, buffer
# Parse headers
headers = buffer[:header_end].decode("utf-8")
content_length = None
for line in headers.split("\r\n"):
if line.startswith("Content-Length:"):
try:
content_length = int(line.split(":")[1].strip())
except (IndexError, ValueError):
logger.error(f"Invalid Content-Length header: {line}")
return None, buffer[header_end + 4 :]
if content_length is None:
logger.error("Missing Content-Length header")
return None, buffer[header_end + 4 :]
# Check if we have the full message
content_start = header_end + 4
content_end = content_start + content_length
if len(buffer) < content_end:
return None, buffer
# Parse JSON content
try:
content = buffer[content_start:content_end].decode("utf-8")
data = json.loads(content)
message = JsonRpcMessage(**data)
return message, buffer[content_end:]
except (json.JSONDecodeError, UnicodeDecodeError) as e:
logger.error(f"Failed to parse message: {e}")
return None, buffer[content_end:]
def _handle_incoming_message(self, message: JsonRpcMessage) -> None:
"""Handle an incoming message from the LSP server."""
# Handle responses to requests
if message.id is not None:
# Thread-safe: pop with default avoids race condition between check and pop
future = self.state.pending_requests.pop(message.id, None)
if future is not None:
logger.debug(f"Received response for request {message.to_dict()}")
# Use call_soon_threadsafe to safely resolve futures across event loop contexts
# This prevents "Task got Future attached to a different loop" errors when
# the future was created in one loop but is being resolved from another loop
# Get the loop from the future itself to ensure we schedule on the correct loop
future_loop = future.get_loop()
if message.error:
future_loop.call_soon_threadsafe(
future.set_exception,
RuntimeError(f"LSP error: {message.error}"),
)
else:
future_loop.call_soon_threadsafe(future.set_result, message.result)
return
else:
# it's an unknown request, we respond with an empty result
logger.debug(f"LSP request {message.to_dict()}")
self._send_message(
JsonRpcMessage(id=message.id, result=None), none_values=True
)
if message.method is None:
return
# it's a known event type we want to explicitly handle
if lsp_event_name := event_name_from_string(message.method):
# Check if this is an event we're waiting for
# Thread-safe: pop with default avoids race condition
futures = self.state.pending_notifications.pop(lsp_event_name, None)
if futures is not None:
logger.debug(f"Received event {lsp_event_name} - {message.to_dict()}")
# Use call_soon_threadsafe for notification futures as well
for future in futures:
future_loop = future.get_loop()
future_loop.call_soon_threadsafe(future.set_result, message.params)
match lsp_event_name:
case LspEventName.compileComplete:
logger.info("Recorded compile complete event")
self.state.compiled = True
case _:
logger.debug(f"LSP event {message.method}")
pass
else:
# it's an unknown notification, log it and move on
logger.debug(f"LSP event {message.method}")
async def send_request(
self,
method: str,
params: dict[str, Any] | list[Any] | None = None,
timeout: float | None = None,
) -> dict[str, Any]:
"""Send a request to the LSP server.
Args:
method: The JSON-RPC method name
params: Optional parameters for the method
timeout: Timeout in seconds for this request. If not specified, uses
default_request_timeout from the connection configuration.
Returns:
A dictionary containing the response result or error information
"""
if not self.process:
raise RuntimeError("LSP server is not running")
# Create request message
request_id = self.state.get_next_request_id()
message = JsonRpcMessage(
id=request_id,
method=method,
params=params,
)
# Create future for response using the current running loop
# This prevents "Task got Future attached to a different loop" errors
# when send_request is called from a different loop context than where
# the connection was initialized
future = asyncio.get_running_loop().create_future()
self.state.pending_requests[request_id] = future
# Send the message
self._send_message(message)
try:
return await asyncio.wait_for(
future, timeout=timeout or self.default_request_timeout
)
except Exception as e:
return {"error": str(e)}
def send_notification(
self,
method: str,
params: dict[str, Any] | list[Any] | None = None,
) -> None:
"""Send a notification to the LSP server.
Args:
method: The JSON-RPC method name
params: Optional parameters for the method
"""
if not self.process:
raise RuntimeError("LSP server is not running")
# Create notification message (no ID)
message = JsonRpcMessage(
method=method,
params=params,
)
# Send the message
self._send_message(message)
def wait_for_notification(
self, event_name: LspEventName
) -> asyncio.Future[dict[str, Any]]:
"""Wait for a notification from the LSP server.
Args:
event_name: The LSP event name to wait for
Returns:
A Future that will be resolved with the notification params when received
"""
future = asyncio.get_running_loop().create_future()
self.state.pending_notifications.setdefault(event_name, []).append(future)
return future
def _send_message(self, message: JsonRpcMessage, none_values: bool = False) -> None:
"""Send a message to the LSP server."""
# Serialize message
content = json.dumps(message.to_dict(none_values=none_values))
content_bytes = content.encode("utf-8")
# Create LSP message with headers
header = f"Content-Length: {len(content_bytes)}\r\n\r\n"
header_bytes = header.encode("utf-8")
data = header_bytes + content_bytes
logger.debug(f"Sending message: {content}")
# Queue for sending (put_nowait is safe from sync context)
self._outgoing_queue.put_nowait(data)
def _send_shutdown_request(self) -> None:
"""Send shutdown request to the LSP server."""
try:
# Send shutdown request
message = JsonRpcMessage(
id=self.state.get_next_request_id(),
method="shutdown",
)
self._send_message(message)
# Send exit notification
exit_message = JsonRpcMessage(
method="exit",
)
self._send_message(exit_message)
except Exception as e:
logger.error(f"Error sending shutdown: {e}")
def is_running(self) -> bool:
"""Check if the LSP server is running."""
return self.process is not None and self.process.returncode is None
```
--------------------------------------------------------------------------------
/src/dbt_mcp/discovery/client.py:
--------------------------------------------------------------------------------
```python
import textwrap
from typing import Literal, TypedDict
import requests
from dbt_mcp.config.config_providers import ConfigProvider, DiscoveryConfig
from dbt_mcp.errors import GraphQLError, InvalidParameterError
from dbt_mcp.gql.errors import raise_gql_error
PAGE_SIZE = 100
MAX_NODE_QUERY_LIMIT = 1000
class GraphQLQueries:
GET_MODELS = textwrap.dedent("""
query GetModels(
$environmentId: BigInt!,
$modelsFilter: ModelAppliedFilter,
$after: String,
$first: Int,
$sort: AppliedModelSort
) {
environment(id: $environmentId) {
applied {
models(filter: $modelsFilter, after: $after, first: $first, sort: $sort) {
pageInfo {
endCursor
}
edges {
node {
name
uniqueId
description
}
}
}
}
}
}
""")
GET_MODEL_HEALTH = textwrap.dedent("""
query GetModelDetails(
$environmentId: BigInt!,
$modelsFilter: ModelAppliedFilter
$first: Int,
) {
environment(id: $environmentId) {
applied {
models(filter: $modelsFilter, first: $first) {
edges {
node {
name
uniqueId
executionInfo {
lastRunGeneratedAt
lastRunStatus
executeCompletedAt
executeStartedAt
}
tests {
name
description
columnName
testType
executionInfo {
lastRunGeneratedAt
lastRunStatus
executeCompletedAt
executeStartedAt
}
}
ancestors(types: [Model, Source, Seed, Snapshot]) {
... on ModelAppliedStateNestedNode {
name
uniqueId
resourceType
materializedType
modelexecutionInfo: executionInfo {
lastRunStatus
executeCompletedAt
}
}
... on SnapshotAppliedStateNestedNode {
name
uniqueId
resourceType
snapshotExecutionInfo: executionInfo {
lastRunStatus
executeCompletedAt
}
}
... on SeedAppliedStateNestedNode {
name
uniqueId
resourceType
seedExecutionInfo: executionInfo {
lastRunStatus
executeCompletedAt
}
}
... on SourceAppliedStateNestedNode {
sourceName
name
resourceType
freshness {
maxLoadedAt
maxLoadedAtTimeAgoInS
freshnessStatus
}
}
}
}
}
}
}
}
}
""")
GET_MODEL_DETAILS = textwrap.dedent("""
query GetModelDetails(
$environmentId: BigInt!,
$modelsFilter: ModelAppliedFilter
$first: Int,
) {
environment(id: $environmentId) {
applied {
models(filter: $modelsFilter, first: $first) {
edges {
node {
name
uniqueId
compiledCode
description
database
schema
alias
catalog {
columns {
description
name
type
}
}
}
}
}
}
}
}
""")
COMMON_FIELDS_PARENTS_CHILDREN = textwrap.dedent("""
{
... on ExposureAppliedStateNestedNode {
resourceType
name
description
}
... on ExternalModelNode {
resourceType
description
name
}
... on MacroDefinitionNestedNode {
resourceType
name
description
}
... on MetricDefinitionNestedNode {
resourceType
name
description
}
... on ModelAppliedStateNestedNode {
resourceType
name
description
}
... on SavedQueryDefinitionNestedNode {
resourceType
name
description
}
... on SeedAppliedStateNestedNode {
resourceType
name
description
}
... on SemanticModelDefinitionNestedNode {
resourceType
name
description
}
... on SnapshotAppliedStateNestedNode {
resourceType
name
description
}
... on SourceAppliedStateNestedNode {
resourceType
sourceName
uniqueId
name
description
}
... on TestAppliedStateNestedNode {
resourceType
name
description
}
""")
GET_MODEL_PARENTS = (
textwrap.dedent("""
query GetModelParents(
$environmentId: BigInt!,
$modelsFilter: ModelAppliedFilter
$first: Int,
) {
environment(id: $environmentId) {
applied {
models(filter: $modelsFilter, first: $first) {
pageInfo {
endCursor
}
edges {
node {
parents
""")
+ COMMON_FIELDS_PARENTS_CHILDREN
+ textwrap.dedent("""
}
}
}
}
}
}
}
""")
)
GET_MODEL_CHILDREN = (
textwrap.dedent("""
query GetModelChildren(
$environmentId: BigInt!,
$modelsFilter: ModelAppliedFilter
$first: Int,
) {
environment(id: $environmentId) {
applied {
models(filter: $modelsFilter, first: $first) {
pageInfo {
endCursor
}
edges {
node {
children
""")
+ COMMON_FIELDS_PARENTS_CHILDREN
+ textwrap.dedent("""
}
}
}
}
}
}
}
""")
)
GET_SOURCES = textwrap.dedent("""
query GetSources(
$environmentId: BigInt!,
$sourcesFilter: SourceAppliedFilter,
$after: String,
$first: Int
) {
environment(id: $environmentId) {
applied {
sources(filter: $sourcesFilter, after: $after, first: $first) {
pageInfo {
hasNextPage
endCursor
}
edges {
node {
name
uniqueId
identifier
description
sourceName
resourceType
database
schema
freshness {
maxLoadedAt
maxLoadedAtTimeAgoInS
freshnessStatus
}
}
}
}
}
}
}
""")
GET_EXPOSURES = textwrap.dedent("""
query Exposures($environmentId: BigInt!, $first: Int, $after: String) {
environment(id: $environmentId) {
definition {
exposures(first: $first, after: $after) {
totalCount
pageInfo {
hasNextPage
endCursor
}
edges {
node {
name
uniqueId
url
description
}
}
}
}
}
}
""")
GET_EXPOSURE_DETAILS = textwrap.dedent("""
query ExposureDetails($environmentId: BigInt!, $filter: ExposureFilter, $first: Int) {
environment(id: $environmentId) {
definition {
exposures(first: $first, filter: $filter) {
edges {
node {
name
maturity
label
ownerEmail
ownerName
uniqueId
url
meta
freshnessStatus
exposureType
description
parents {
uniqueId
}
}
}
}
}
}
}
""")
class MetadataAPIClient:
def __init__(self, config_provider: ConfigProvider[DiscoveryConfig]):
self.config_provider = config_provider
async def execute_query(self, query: str, variables: dict) -> dict:
config = await self.config_provider.get_config()
url = config.url
headers = config.headers_provider.get_headers()
response = requests.post(
url=url,
json={"query": query, "variables": variables},
headers=headers,
)
return response.json()
class ModelFilter(TypedDict, total=False):
modelingLayer: Literal["marts"] | None
class SourceFilter(TypedDict, total=False):
sourceNames: list[str]
uniqueIds: list[str] | None
class ModelsFetcher:
def __init__(self, api_client: MetadataAPIClient):
self.api_client = api_client
async def get_environment_id(self) -> int:
config = await self.api_client.config_provider.get_config()
return config.environment_id
def _parse_response_to_json(self, result: dict) -> list[dict]:
raise_gql_error(result)
edges = result["data"]["environment"]["applied"]["models"]["edges"]
parsed_edges: list[dict] = []
if not edges:
return parsed_edges
if result.get("errors"):
raise GraphQLError(f"GraphQL query failed: {result['errors']}")
for edge in edges:
if not isinstance(edge, dict) or "node" not in edge:
continue
node = edge["node"]
if not isinstance(node, dict):
continue
parsed_edges.append(node)
return parsed_edges
def _get_model_filters(
self, model_name: str | None = None, unique_id: str | None = None
) -> dict[str, list[str] | str]:
if unique_id:
return {"uniqueIds": [unique_id]}
elif model_name:
return {"identifier": model_name}
else:
raise InvalidParameterError(
"Either model_name or unique_id must be provided"
)
async def fetch_models(self, model_filter: ModelFilter | None = None) -> list[dict]:
has_next_page = True
after_cursor: str = ""
all_edges: list[dict] = []
while has_next_page and len(all_edges) < MAX_NODE_QUERY_LIMIT:
variables = {
"environmentId": await self.get_environment_id(),
"after": after_cursor,
"first": PAGE_SIZE,
"modelsFilter": model_filter or {},
"sort": {"field": "queryUsageCount", "direction": "desc"},
}
result = await self.api_client.execute_query(
GraphQLQueries.GET_MODELS, variables
)
all_edges.extend(self._parse_response_to_json(result))
previous_after_cursor = after_cursor
after_cursor = result["data"]["environment"]["applied"]["models"][
"pageInfo"
]["endCursor"]
if previous_after_cursor == after_cursor:
has_next_page = False
return all_edges
async def fetch_model_details(
self, model_name: str | None = None, unique_id: str | None = None
) -> dict:
model_filters = self._get_model_filters(model_name, unique_id)
variables = {
"environmentId": await self.get_environment_id(),
"modelsFilter": model_filters,
"first": 1,
}
result = await self.api_client.execute_query(
GraphQLQueries.GET_MODEL_DETAILS, variables
)
raise_gql_error(result)
edges = result["data"]["environment"]["applied"]["models"]["edges"]
if not edges:
return {}
return edges[0]["node"]
async def fetch_model_parents(
self, model_name: str | None = None, unique_id: str | None = None
) -> list[dict]:
model_filters = self._get_model_filters(model_name, unique_id)
variables = {
"environmentId": await self.get_environment_id(),
"modelsFilter": model_filters,
"first": 1,
}
result = await self.api_client.execute_query(
GraphQLQueries.GET_MODEL_PARENTS, variables
)
raise_gql_error(result)
edges = result["data"]["environment"]["applied"]["models"]["edges"]
if not edges:
return []
return edges[0]["node"]["parents"]
async def fetch_model_children(
self, model_name: str | None = None, unique_id: str | None = None
) -> list[dict]:
model_filters = self._get_model_filters(model_name, unique_id)
variables = {
"environmentId": await self.get_environment_id(),
"modelsFilter": model_filters,
"first": 1,
}
result = await self.api_client.execute_query(
GraphQLQueries.GET_MODEL_CHILDREN, variables
)
raise_gql_error(result)
edges = result["data"]["environment"]["applied"]["models"]["edges"]
if not edges:
return []
return edges[0]["node"]["children"]
async def fetch_model_health(
self, model_name: str | None = None, unique_id: str | None = None
) -> list[dict]:
model_filters = self._get_model_filters(model_name, unique_id)
variables = {
"environmentId": await self.get_environment_id(),
"modelsFilter": model_filters,
"first": 1,
}
result = await self.api_client.execute_query(
GraphQLQueries.GET_MODEL_HEALTH, variables
)
raise_gql_error(result)
edges = result["data"]["environment"]["applied"]["models"]["edges"]
if not edges:
return []
return edges[0]["node"]
class ExposuresFetcher:
def __init__(self, api_client: MetadataAPIClient):
self.api_client = api_client
async def get_environment_id(self) -> int:
config = await self.api_client.config_provider.get_config()
return config.environment_id
def _parse_response_to_json(self, result: dict) -> list[dict]:
raise_gql_error(result)
edges = result["data"]["environment"]["definition"]["exposures"]["edges"]
parsed_edges: list[dict] = []
if not edges:
return parsed_edges
if result.get("errors"):
raise GraphQLError(f"GraphQL query failed: {result['errors']}")
for edge in edges:
if not isinstance(edge, dict) or "node" not in edge:
continue
node = edge["node"]
if not isinstance(node, dict):
continue
parsed_edges.append(node)
return parsed_edges
async def fetch_exposures(self) -> list[dict]:
has_next_page = True
after_cursor: str | None = None
all_edges: list[dict] = []
while has_next_page:
variables: dict[str, int | str] = {
"environmentId": await self.get_environment_id(),
"first": PAGE_SIZE,
}
if after_cursor:
variables["after"] = after_cursor
result = await self.api_client.execute_query(
GraphQLQueries.GET_EXPOSURES, variables
)
new_edges = self._parse_response_to_json(result)
all_edges.extend(new_edges)
page_info = result["data"]["environment"]["definition"]["exposures"][
"pageInfo"
]
has_next_page = page_info.get("hasNextPage", False)
after_cursor = page_info.get("endCursor")
return all_edges
def _get_exposure_filters(
self, exposure_name: str | None = None, unique_ids: list[str] | None = None
) -> dict[str, list[str]]:
if unique_ids:
return {"uniqueIds": unique_ids}
elif exposure_name:
raise InvalidParameterError(
"ExposureFilter only supports uniqueIds. Please use unique_ids parameter instead of exposure_name."
)
else:
raise InvalidParameterError(
"unique_ids must be provided for exposure filtering"
)
async def fetch_exposure_details(
self, exposure_name: str | None = None, unique_ids: list[str] | None = None
) -> list[dict]:
if exposure_name and not unique_ids:
# Since ExposureFilter doesn't support filtering by name,
# we need to fetch all exposures and find the one with matching name
all_exposures = await self.fetch_exposures()
for exposure in all_exposures:
if exposure.get("name") == exposure_name:
return [exposure]
return []
elif unique_ids:
exposure_filters = self._get_exposure_filters(unique_ids=unique_ids)
variables = {
"environmentId": await self.get_environment_id(),
"filter": exposure_filters,
"first": len(unique_ids), # Request as many as we're filtering for
}
result = await self.api_client.execute_query(
GraphQLQueries.GET_EXPOSURE_DETAILS, variables
)
raise_gql_error(result)
edges = result["data"]["environment"]["definition"]["exposures"]["edges"]
if not edges:
return []
return [edge["node"] for edge in edges]
else:
raise InvalidParameterError(
"Either exposure_name or unique_ids must be provided"
)
class SourcesFetcher:
def __init__(self, api_client: MetadataAPIClient):
self.api_client = api_client
async def get_environment_id(self) -> int:
config = await self.api_client.config_provider.get_config()
return config.environment_id
def _parse_response_to_json(self, result: dict) -> list[dict]:
raise_gql_error(result)
edges = result["data"]["environment"]["applied"]["sources"]["edges"]
parsed_edges: list[dict] = []
if not edges:
return parsed_edges
if result.get("errors"):
raise GraphQLError(f"GraphQL query failed: {result['errors']}")
for edge in edges:
if not isinstance(edge, dict) or "node" not in edge:
continue
node = edge["node"]
if not isinstance(node, dict):
continue
parsed_edges.append(node)
return parsed_edges
async def fetch_sources(
self,
source_names: list[str] | None = None,
unique_ids: list[str] | None = None,
) -> list[dict]:
source_filter: SourceFilter = {}
if source_names is not None:
source_filter["sourceNames"] = source_names
if unique_ids is not None:
source_filter["uniqueIds"] = unique_ids
has_next_page = True
after_cursor: str = ""
all_edges: list[dict] = []
while has_next_page and len(all_edges) < MAX_NODE_QUERY_LIMIT:
variables = {
"environmentId": await self.get_environment_id(),
"after": after_cursor,
"first": PAGE_SIZE,
"sourcesFilter": source_filter,
}
result = await self.api_client.execute_query(
GraphQLQueries.GET_SOURCES, variables
)
all_edges.extend(self._parse_response_to_json(result))
page_info = result["data"]["environment"]["applied"]["sources"]["pageInfo"]
has_next_page = page_info.get("hasNextPage", False)
after_cursor = page_info.get("endCursor")
return all_edges
```
--------------------------------------------------------------------------------
/tests/unit/lsp/test_lsp_connection.py:
--------------------------------------------------------------------------------
```python
"""Unit tests for the LSP connection module."""
import asyncio
import socket
import subprocess
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dbt_mcp.lsp.lsp_connection import (
SocketLSPConnection,
LspConnectionState,
LspEventName,
JsonRpcMessage,
event_name_from_string,
)
class TestJsonRpcMessage:
"""Test JsonRpcMessage dataclass."""
def test_to_dict_with_request(self):
"""Test converting a request message to dictionary."""
msg = JsonRpcMessage(id=1, method="initialize", params={"processId": None})
result = msg.to_dict()
assert result == {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {"processId": None},
}
def test_to_dict_with_response(self):
"""Test converting a response message to dictionary."""
msg = JsonRpcMessage(id=1, result={"capabilities": {}})
result = msg.to_dict()
assert result == {"jsonrpc": "2.0", "id": 1, "result": {"capabilities": {}}}
def test_to_dict_with_error(self):
"""Test converting an error message to dictionary."""
msg = JsonRpcMessage(
id=1, error={"code": -32601, "message": "Method not found"}
)
result = msg.to_dict()
assert result == {
"jsonrpc": "2.0",
"id": 1,
"error": {"code": -32601, "message": "Method not found"},
}
def test_to_dict_notification(self):
"""Test converting a notification message to dictionary."""
msg = JsonRpcMessage(
method="window/logMessage", params={"type": 3, "message": "Server started"}
)
result = msg.to_dict()
assert result == {
"jsonrpc": "2.0",
"method": "window/logMessage",
"params": {"type": 3, "message": "Server started"},
}
def test_from_dict(self):
"""Test creating message from dictionary."""
data = {
"jsonrpc": "2.0",
"id": 42,
"method": "textDocument/completion",
"params": {"textDocument": {"uri": "file:///test.sql"}},
}
msg = JsonRpcMessage(**data)
assert msg.jsonrpc == "2.0"
assert msg.id == 42
assert msg.method == "textDocument/completion"
assert msg.params == {"textDocument": {"uri": "file:///test.sql"}}
class TestLspEventName:
"""Test LspEventName enum and helpers."""
def test_event_name_from_string_valid(self):
"""Test converting valid string to event name."""
assert (
event_name_from_string("dbt/lspCompileComplete")
== LspEventName.compileComplete
)
assert event_name_from_string("window/logMessage") == LspEventName.logMessage
assert event_name_from_string("$/progress") == LspEventName.progress
def test_event_name_from_string_invalid(self):
"""Test converting invalid string returns None."""
assert event_name_from_string("invalid/event") is None
assert event_name_from_string("") is None
class TestLspConnectionState:
"""Test LspConnectionState dataclass."""
def test_initial_state(self):
"""Test initial state values."""
state = LspConnectionState()
assert state.initialized is False
assert state.shutting_down is False
assert state.capabilities is not None
assert len(state.capabilities) == 0
assert state.pending_requests == {}
assert state.pending_notifications == {}
assert state.compiled is False
def test_get_next_request_id(self):
"""Test request ID generation."""
state = LspConnectionState()
# Should start at 20 to avoid collisions
id1 = state.get_next_request_id()
id2 = state.get_next_request_id()
id3 = state.get_next_request_id()
assert id1 == 20
assert id2 == 21
assert id3 == 22
class TestLSPConnectionInitialization:
"""Test LSP connection initialization and validation."""
def test_init_valid_binary(self, tmp_path):
"""Test initialization with valid binary path."""
# Create a dummy binary file
binary_path = tmp_path / "lsp-server"
binary_path.touch()
conn = SocketLSPConnection(
binary_path=str(binary_path),
cwd="/test/dir",
args=["--arg1", "--arg2"],
connection_timeout=15,
default_request_timeout=60,
)
assert conn.binary_path == binary_path
assert conn.cwd == "/test/dir"
assert conn.args == ["--arg1", "--arg2"]
assert conn.host == "127.0.0.1"
assert conn.port == 0
assert conn.connection_timeout == 15
assert conn.default_request_timeout == 60
assert conn.process is None
assert isinstance(conn.state, LspConnectionState)
class TestSocketSetup:
"""Test socket setup and lifecycle."""
def test_setup_socket_success(self, tmp_path):
"""Test successful socket setup."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
with patch("socket.socket") as mock_socket_class:
mock_socket = MagicMock()
mock_socket.getsockname.return_value = ("127.0.0.1", 54321)
mock_socket_class.return_value = mock_socket
conn.setup_socket()
# Verify socket setup
mock_socket_class.assert_called_once_with(
socket.AF_INET, socket.SOCK_STREAM
)
mock_socket.setsockopt.assert_called_once_with(
socket.SOL_SOCKET, socket.SO_REUSEADDR, 1
)
mock_socket.bind.assert_called_once_with(("127.0.0.1", 0))
mock_socket.listen.assert_called_once_with(1)
assert conn.port == 54321
assert conn._socket == mock_socket
class TestProcessLaunching:
"""Test LSP process launching and termination."""
@pytest.mark.asyncio
async def test_launch_lsp_process_success(self, tmp_path):
"""Test successful process launch."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test/dir")
conn.port = 12345
with patch("asyncio.create_subprocess_exec") as mock_create_subprocess:
mock_process = MagicMock()
mock_process.pid = 9999
mock_create_subprocess.return_value = mock_process
await conn.launch_lsp_process()
# Verify process was started with correct arguments
mock_create_subprocess.assert_called_once_with(
str(binary_path), "--socket", "12345", "--project-dir", "/test/dir"
)
assert conn.process == mock_process
class TestStartStop:
"""Test start/stop lifecycle."""
@pytest.mark.asyncio
async def test_start_success(self, tmp_path):
"""Test successful server start."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Mock socket setup
mock_socket = MagicMock()
mock_connection = MagicMock()
mock_socket.getsockname.return_value = ("127.0.0.1", 54321)
# Mock process
mock_process = MagicMock()
mock_process.pid = 9999
with (
patch("socket.socket", return_value=mock_socket),
patch("asyncio.create_subprocess_exec", return_value=mock_process),
patch.object(conn, "_read_loop", new_callable=AsyncMock),
patch.object(conn, "_write_loop", new_callable=AsyncMock),
):
# Mock socket accept
async def mock_accept_wrapper():
return mock_connection, ("127.0.0.1", 12345)
with patch("asyncio.get_running_loop") as mock_loop:
mock_loop.return_value.run_in_executor.return_value = (
mock_accept_wrapper()
)
mock_loop.return_value.create_task.side_effect = (
lambda coro: asyncio.create_task(coro)
)
await conn.start()
assert conn.process == mock_process
assert conn._connection == mock_connection
assert conn._reader_task is not None
assert conn._writer_task is not None
@pytest.mark.asyncio
async def test_start_already_running(self, tmp_path):
"""Test starting when already running."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.process = MagicMock() # Simulate already running
with (
patch("socket.socket"),
patch("asyncio.create_subprocess_exec") as mock_create_subprocess,
):
await conn.start()
# Should not create a new process
mock_create_subprocess.assert_not_called()
@pytest.mark.asyncio
async def test_start_timeout(self, tmp_path):
"""Test start timeout when server doesn't connect."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test", connection_timeout=0.1)
mock_socket = MagicMock()
mock_socket.getsockname.return_value = ("127.0.0.1", 54321)
mock_process = MagicMock()
with (
patch("socket.socket", return_value=mock_socket),
patch("asyncio.create_subprocess_exec", return_value=mock_process),
):
# Simulate timeout in socket.accept
mock_socket.accept.side_effect = TimeoutError
with patch("asyncio.get_running_loop") as mock_loop:
mock_loop.return_value.run_in_executor.side_effect = TimeoutError
with pytest.raises(
RuntimeError, match="Timeout waiting for LSP server to connect"
):
await conn.start()
@pytest.mark.asyncio
async def test_stop_complete_cleanup(self, tmp_path):
"""Test complete cleanup on stop."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Setup mocks for running state
conn.process = MagicMock()
conn.process.terminate = MagicMock()
conn.process.wait = AsyncMock()
conn.process.kill = MagicMock()
conn._socket = MagicMock()
conn._connection = MagicMock()
# Create mock tasks with proper async behavior
async def mock_task():
pass
conn._reader_task = asyncio.create_task(mock_task())
conn._writer_task = asyncio.create_task(mock_task())
# Let tasks complete
await asyncio.sleep(0.01)
# Store references before they are set to None
mock_connection = conn._connection
mock_socket = conn._socket
mock_process = conn.process
with patch.object(conn, "_send_shutdown_request") as mock_shutdown:
await conn.stop()
# Verify cleanup methods were called
mock_shutdown.assert_called_once()
mock_connection.close.assert_called_once()
mock_socket.close.assert_called_once()
mock_process.terminate.assert_called_once()
# Verify everything was set to None
assert conn.process is None
assert conn._socket is None
assert conn._connection is None
@pytest.mark.asyncio
async def test_stop_force_kill(self, tmp_path):
"""Test force kill when process doesn't terminate."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Setup mock process that doesn't terminate
mock_process = MagicMock()
mock_process.terminate = MagicMock()
mock_process.wait = AsyncMock(
side_effect=[subprocess.TimeoutExpired("cmd", 1), None]
)
mock_process.kill = MagicMock()
conn.process = mock_process
await conn.stop()
# Verify force kill was called
mock_process.terminate.assert_called_once()
mock_process.kill.assert_called_once()
class TestInitializeMethod:
"""Test LSP initialization handshake."""
@pytest.mark.asyncio
async def test_initialize_success(self, tmp_path):
"""Test successful initialization."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.process = MagicMock() # Simulate running
# Mock send_request to return capabilities
mock_result = {
"capabilities": {
"textDocumentSync": 2,
"completionProvider": {"triggerCharacters": [".", ":"]},
}
}
with (
patch.object(
conn, "send_request", new_callable=AsyncMock
) as mock_send_request,
patch.object(conn, "send_notification") as mock_send_notification,
):
mock_send_request.return_value = mock_result
await conn.initialize(timeout=5)
# Verify initialize request was sent
mock_send_request.assert_called_once()
call_args = mock_send_request.call_args
assert call_args[0][0] == "initialize"
assert call_args[1]["timeout"] == 5
params = call_args[0][1]
assert params["rootUri"] is None # currently not using cwd
assert params["clientInfo"]["name"] == "dbt-mcp"
# Verify initialized notification was sent
mock_send_notification.assert_called_once_with("initialized", {})
# Verify state was updated
assert conn.state.initialized is True
assert conn.state.capabilities == mock_result["capabilities"]
@pytest.mark.asyncio
async def test_initialize_already_initialized(self, tmp_path):
"""Test initialization when already initialized."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.process = MagicMock()
conn.state.initialized = True
with pytest.raises(RuntimeError, match="LSP server is already initialized"):
await conn.initialize()
class TestMessageParsing:
"""Test JSON-RPC message parsing."""
def test_parse_message_complete(self, tmp_path):
"""Test parsing a complete message."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create a valid LSP message
content = '{"jsonrpc":"2.0","id":1,"result":{"test":true}}'
header = f"Content-Length: {len(content)}\r\n\r\n"
buffer = (header + content).encode("utf-8")
message, remaining = conn._parse_message(buffer)
assert message is not None
assert message.id == 1
assert message.result == {"test": True}
assert remaining == b""
def test_parse_message_incomplete_header(self, tmp_path):
"""Test parsing with incomplete header."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
buffer = b"Content-Length: 50\r\n" # Missing \r\n\r\n
message, remaining = conn._parse_message(buffer)
assert message is None
assert remaining == buffer
def test_parse_message_incomplete_content(self, tmp_path):
"""Test parsing with incomplete content."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
content = '{"jsonrpc":"2.0","id":1,"result":{"test":true}}'
header = f"Content-Length: {len(content)}\r\n\r\n"
# Only include part of the content
buffer = (header + content[:10]).encode("utf-8")
message, remaining = conn._parse_message(buffer)
assert message is None
assert remaining == buffer
def test_parse_message_invalid_json(self, tmp_path):
"""Test parsing with invalid JSON content."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
content = '{"invalid json'
header = f"Content-Length: {len(content)}\r\n\r\n"
buffer = (header + content).encode("utf-8")
message, remaining = conn._parse_message(buffer)
assert message is None
assert remaining == b"" # Invalid message is discarded
def test_parse_message_missing_content_length(self, tmp_path):
"""Test parsing with missing Content-Length header."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
buffer = b'Some-Header: value\r\n\r\n{"test":true}'
message, remaining = conn._parse_message(buffer)
assert message is None
assert remaining == b'{"test":true}' # Header consumed, content remains
def test_parse_message_multiple_messages(self, tmp_path):
"""Test parsing multiple messages from buffer."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create two messages
content1 = '{"jsonrpc":"2.0","id":1,"result":true}'
content2 = '{"jsonrpc":"2.0","id":2,"result":false}'
header1 = f"Content-Length: {len(content1)}\r\n\r\n"
header2 = f"Content-Length: {len(content2)}\r\n\r\n"
buffer = (header1 + content1 + header2 + content2).encode("utf-8")
# Parse first message
message1, remaining1 = conn._parse_message(buffer)
assert message1 is not None
assert message1.id == 1
assert message1.result is True
# Parse second message
message2, remaining2 = conn._parse_message(remaining1)
assert message2 is not None
assert message2.id == 2
assert message2.result is False
assert remaining2 == b""
class TestMessageHandling:
"""Test incoming message handling."""
def test_handle_response_message(self, tmp_path):
"""Test handling response to a request."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create a pending request
future = asyncio.Future()
conn.state.pending_requests[42] = future
# Handle response message
message = JsonRpcMessage(id=42, result={"success": True})
with patch.object(future, "get_loop") as mock_get_loop:
mock_loop = MagicMock()
mock_get_loop.return_value = mock_loop
conn._handle_incoming_message(message)
# Verify future was resolved
mock_loop.call_soon_threadsafe.assert_called_once()
args = mock_loop.call_soon_threadsafe.call_args[0]
assert args[0] == future.set_result
assert args[1] == {"success": True}
# Verify request was removed from pending
assert 42 not in conn.state.pending_requests
def test_handle_error_response(self, tmp_path):
"""Test handling error response."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create a pending request
future = asyncio.Future()
conn.state.pending_requests[42] = future
# Handle error response
message = JsonRpcMessage(
id=42, error={"code": -32601, "message": "Method not found"}
)
with patch.object(future, "get_loop") as mock_get_loop:
mock_loop = MagicMock()
mock_get_loop.return_value = mock_loop
conn._handle_incoming_message(message)
# Verify future was rejected
mock_loop.call_soon_threadsafe.assert_called_once()
args = mock_loop.call_soon_threadsafe.call_args[0]
assert args[0] == future.set_exception
def test_handle_unknown_response(self, tmp_path):
"""Test handling response for unknown request ID."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Handle response with unknown ID
message = JsonRpcMessage(id=999, result={"test": True})
with patch.object(conn, "_send_message") as mock_send:
conn._handle_incoming_message(message)
# Should send empty response back
mock_send.assert_called_once()
sent_msg = mock_send.call_args[0][0]
assert isinstance(sent_msg, JsonRpcMessage)
assert sent_msg.id == 999
assert sent_msg.result is None
def test_handle_notification(self, tmp_path):
"""Test handling notification messages."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create futures waiting for compile complete event
future1 = asyncio.Future()
future2 = asyncio.Future()
conn.state.pending_notifications[LspEventName.compileComplete] = [
future1,
future2,
]
# Handle compile complete notification
message = JsonRpcMessage(
method="dbt/lspCompileComplete", params={"success": True}
)
with (
patch.object(future1, "get_loop") as mock_get_loop1,
patch.object(future2, "get_loop") as mock_get_loop2,
):
mock_loop1 = MagicMock()
mock_loop2 = MagicMock()
mock_get_loop1.return_value = mock_loop1
mock_get_loop2.return_value = mock_loop2
conn._handle_incoming_message(message)
# Verify futures were resolved
mock_loop1.call_soon_threadsafe.assert_called_once_with(
future1.set_result, {"success": True}
)
mock_loop2.call_soon_threadsafe.assert_called_once_with(
future2.set_result, {"success": True}
)
# Verify compile state was set
assert conn.state.compiled is True
def test_handle_unknown_notification(self, tmp_path):
"""Test handling unknown notification."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Handle unknown notification
message = JsonRpcMessage(method="unknown/notification", params={"data": "test"})
# Should not raise, just log
conn._handle_incoming_message(message)
class TestSendRequest:
"""Test sending requests to LSP server."""
@pytest.mark.asyncio
async def test_send_request_success(self, tmp_path):
"""Test successful request sending."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.process = MagicMock() # Simulate running
with (
patch.object(conn, "_send_message") as mock_send,
patch("asyncio.wait_for", new_callable=AsyncMock) as mock_wait_for,
):
mock_wait_for.return_value = {"result": "success"}
result = await conn.send_request(
"testMethod", {"param": "value"}, timeout=5
)
# Verify message was sent
mock_send.assert_called_once()
sent_msg = mock_send.call_args[0][0]
assert isinstance(sent_msg, JsonRpcMessage)
assert sent_msg.method == "testMethod"
assert sent_msg.params == {"param": "value"}
assert sent_msg.id is not None
# Verify result
assert result == {"result": "success"}
@pytest.mark.asyncio
async def test_send_request_not_running(self, tmp_path):
"""Test sending request when server not running."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# process is None - not running
with pytest.raises(RuntimeError, match="LSP server is not running"):
await conn.send_request("testMethod")
@pytest.mark.asyncio
async def test_send_request_timeout(self, tmp_path):
"""Test request timeout."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test", default_request_timeout=1)
conn.process = MagicMock()
with patch.object(conn, "_send_message"):
# Create a future that never resolves
future = asyncio.Future()
conn.state.pending_requests[20] = future
# Use real wait_for to test timeout
result = await conn.send_request("testMethod", timeout=0.01)
assert "error" in result
class TestSendNotification:
"""Test sending notifications to LSP server."""
def test_send_notification_success(self, tmp_path):
"""Test successful notification sending."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.process = MagicMock()
with patch.object(conn, "_send_message") as mock_send:
conn.send_notification(
"window/showMessage", {"type": 3, "message": "Hello"}
)
# Verify message was sent
mock_send.assert_called_once()
sent_msg = mock_send.call_args[0][0]
assert isinstance(sent_msg, JsonRpcMessage)
assert sent_msg.method == "window/showMessage"
assert sent_msg.params == {"type": 3, "message": "Hello"}
assert sent_msg.id is None # Notifications have no ID
def test_send_notification_not_running(self, tmp_path):
"""Test sending notification when server not running."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# process is None - not running
with pytest.raises(RuntimeError, match="LSP server is not running"):
conn.send_notification("testMethod")
class TestWaitForNotification:
"""Test waiting for notifications."""
def test_wait_for_notification(self, tmp_path):
"""Test registering to wait for a notification."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
with patch("asyncio.get_running_loop") as mock_get_loop:
mock_loop = MagicMock()
mock_future = MagicMock()
mock_loop.create_future.return_value = mock_future
mock_get_loop.return_value = mock_loop
result = conn.wait_for_notification(LspEventName.compileComplete)
# Verify future was created and registered
assert result == mock_future
assert LspEventName.compileComplete in conn.state.pending_notifications
assert (
mock_future
in conn.state.pending_notifications[LspEventName.compileComplete]
)
class TestSendMessage:
"""Test low-level message sending."""
def test_send_message_with_jsonrpc_message(self, tmp_path):
"""Test sending JsonRpcMessage."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn._outgoing_queue = MagicMock()
message = JsonRpcMessage(id=1, method="test", params={"key": "value"})
conn._send_message(message)
# Verify message was queued
conn._outgoing_queue.put_nowait.assert_called_once()
data = conn._outgoing_queue.put_nowait.call_args[0][0]
# Parse the data to verify format
assert b"Content-Length:" in data
assert b"\r\n\r\n" in data
# JSON might have spaces after colons, check for both variants
assert b'"jsonrpc"' in data and b'"2.0"' in data
assert b'"method"' in data and b'"test"' in data
class TestShutdown:
"""Test shutdown sequence."""
def test_send_shutdown_request(self, tmp_path):
"""Test sending shutdown and exit messages."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
with patch.object(conn, "_send_message") as mock_send:
conn._send_shutdown_request()
# Verify two messages were sent
assert mock_send.call_count == 2
# First should be shutdown request
shutdown_msg = mock_send.call_args_list[0][0][0]
assert isinstance(shutdown_msg, JsonRpcMessage)
assert shutdown_msg.method == "shutdown"
assert shutdown_msg.id is not None
# Second should be exit notification
exit_msg = mock_send.call_args_list[1][0][0]
assert isinstance(exit_msg, JsonRpcMessage)
assert exit_msg.method == "exit"
assert exit_msg.id is None
class TestIsRunning:
"""Test is_running method."""
def test_is_running_true(self, tmp_path):
"""Test when process is running."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.process = MagicMock()
conn.process.returncode = None
assert conn.is_running() is True
def test_is_running_false_no_process(self, tmp_path):
"""Test when no process."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
assert conn.is_running() is False
def test_is_running_false_process_exited(self, tmp_path):
"""Test when process has exited."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.process = MagicMock()
conn.process.returncode = 0
assert conn.is_running() is False
class TestReadWriteLoops:
"""Test async I/O loops."""
@pytest.mark.asyncio
async def test_read_loop_processes_messages(self, tmp_path):
"""Test read loop processes incoming messages."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Setup mock connection
mock_connection = MagicMock()
conn._connection = mock_connection
# Create test data
content = '{"jsonrpc":"2.0","id":1,"result":true}'
header = f"Content-Length: {len(content)}\r\n\r\n"
test_data = (header + content).encode("utf-8")
# Mock recv to return data once then empty
recv_calls = [test_data, b""]
async def mock_recv_wrapper(size):
if recv_calls:
return recv_calls.pop(0)
return b""
with (
patch("asyncio.get_running_loop") as mock_get_loop,
patch.object(conn, "_handle_incoming_message") as mock_handle,
):
mock_loop = MagicMock()
mock_get_loop.return_value = mock_loop
mock_loop.run_in_executor.side_effect = (
lambda _, func, *args: mock_recv_wrapper(*args)
)
# Run read loop (will exit when recv returns empty)
await conn._read_loop()
# Verify message was handled
mock_handle.assert_called_once()
handled_msg = mock_handle.call_args[0][0]
assert handled_msg.id == 1
assert handled_msg.result is True
@pytest.mark.asyncio
async def test_write_loop_sends_messages(self, tmp_path):
"""Test write loop sends queued messages."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Setup mock connection
mock_connection = MagicMock()
conn._connection = mock_connection
# Queue test data
test_data = b"test message data"
conn._outgoing_queue.put_nowait(test_data)
# Set stop event after first iteration
async def stop_after_one():
await asyncio.sleep(0.01)
conn._stop_event.set()
with patch("asyncio.get_running_loop") as mock_get_loop:
mock_loop = MagicMock()
mock_get_loop.return_value = mock_loop
mock_loop.run_in_executor.return_value = asyncio.sleep(0)
# Run both coroutines
await asyncio.gather(
conn._write_loop(), stop_after_one(), return_exceptions=True
)
# Verify data was sent
mock_loop.run_in_executor.assert_called()
call_args = mock_loop.run_in_executor.call_args_list[-1]
assert call_args[0][1] == mock_connection.sendall
assert call_args[0][2] == test_data
class TestEdgeCases:
"""Test edge cases and error conditions."""
@pytest.mark.asyncio
async def test_concurrent_requests(self, tmp_path):
"""Test handling concurrent requests."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.process = MagicMock()
# Track sent messages
sent_messages = []
def track_message(msg):
sent_messages.append(msg)
with patch.object(conn, "_send_message", side_effect=track_message):
# Create futures for multiple requests
future1 = asyncio.create_task(
conn.send_request("method1", JsonRpcMessage(id=1))
)
future2 = asyncio.create_task(
conn.send_request("method2", JsonRpcMessage(id=2))
)
future3 = asyncio.create_task(
conn.send_request("method3", JsonRpcMessage(id=3))
)
# Let tasks start
await asyncio.sleep(0.01)
# Verify all messages were sent with unique IDs
assert len(sent_messages) == 3
ids = [msg.id for msg in sent_messages]
assert len(set(ids)) == 3 # All IDs are unique
# Simulate responses
for msg in sent_messages:
if msg.id in conn.state.pending_requests:
future = conn.state.pending_requests[msg.id]
future.set_result({"response": msg.id})
# Wait for all requests
results = await asyncio.gather(future1, future2, future3)
# Verify each got correct response
assert all("response" in r for r in results)
@pytest.mark.asyncio
async def test_stop_with_pending_requests(self, tmp_path):
"""Test stopping with pending requests."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
conn.process = MagicMock()
conn.process.terminate = MagicMock()
conn.process.wait = AsyncMock()
# Add pending requests
future1 = asyncio.Future()
future2 = asyncio.Future()
conn.state.pending_requests[1] = future1
conn.state.pending_requests[2] = future2
await conn.stop()
# Verify state was cleared
assert len(conn.state.pending_requests) == 0
def test_message_with_unicode(self, tmp_path):
"""Test handling messages with unicode content."""
binary_path = tmp_path / "lsp"
binary_path.touch()
conn = SocketLSPConnection(str(binary_path), "/test")
# Create message with unicode
content = '{"jsonrpc":"2.0","method":"test","params":{"text":"Hello 世界 🚀"}}'
header = f"Content-Length: {len(content.encode('utf-8'))}\r\n\r\n"
buffer = header.encode("utf-8") + content.encode("utf-8")
message, remaining = conn._parse_message(buffer)
assert message is not None
assert message.method == "test"
assert message.params["text"] == "Hello 世界 🚀"
assert remaining == b""
```