This is page 2 of 4. Use http://codebase.md/dbt-labs/dbt-mcp?page={x} to view the full context.
# Directory Structure
```
├── .changes
│ ├── header.tpl.md
│ ├── unreleased
│ │ ├── .gitkeep
│ │ ├── Bug Fix-20251028-143835.yaml
│ │ ├── Enhancement or New Feature-20251014-175047.yaml
│ │ └── Under the Hood-20251030-151902.yaml
│ ├── v0.1.3.md
│ ├── v0.10.0.md
│ ├── v0.10.1.md
│ ├── v0.10.2.md
│ ├── v0.10.3.md
│ ├── v0.2.0.md
│ ├── v0.2.1.md
│ ├── v0.2.10.md
│ ├── v0.2.11.md
│ ├── v0.2.12.md
│ ├── v0.2.13.md
│ ├── v0.2.14.md
│ ├── v0.2.15.md
│ ├── v0.2.16.md
│ ├── v0.2.17.md
│ ├── v0.2.18.md
│ ├── v0.2.19.md
│ ├── v0.2.2.md
│ ├── v0.2.20.md
│ ├── v0.2.3.md
│ ├── v0.2.4.md
│ ├── v0.2.5.md
│ ├── v0.2.6.md
│ ├── v0.2.7.md
│ ├── v0.2.8.md
│ ├── v0.2.9.md
│ ├── v0.3.0.md
│ ├── v0.4.0.md
│ ├── v0.4.1.md
│ ├── v0.4.2.md
│ ├── v0.5.0.md
│ ├── v0.6.0.md
│ ├── v0.6.1.md
│ ├── v0.6.2.md
│ ├── v0.7.0.md
│ ├── v0.8.0.md
│ ├── v0.8.1.md
│ ├── v0.8.2.md
│ ├── v0.8.3.md
│ ├── v0.8.4.md
│ ├── v0.9.0.md
│ ├── v0.9.1.md
│ └── v1.0.0.md
├── .changie.yaml
├── .env.example
├── .github
│ ├── actions
│ │ └── setup-python
│ │ └── action.yml
│ ├── CODEOWNERS
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.yml
│ │ └── feature_request.yml
│ ├── pull_request_template.md
│ └── workflows
│ ├── changelog-check.yml
│ ├── codeowners-check.yml
│ ├── create-release-pr.yml
│ ├── release.yml
│ └── run-checks-pr.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .task
│ └── checksum
│ └── d2
├── .tool-versions
├── .vscode
│ ├── launch.json
│ └── settings.json
├── CHANGELOG.md
├── CONTRIBUTING.md
├── docs
│ ├── d2.png
│ └── diagram.d2
├── evals
│ └── semantic_layer
│ └── test_eval_semantic_layer.py
├── examples
│ ├── .DS_Store
│ ├── aws_strands_agent
│ │ ├── __init__.py
│ │ ├── .DS_Store
│ │ ├── dbt_data_scientist
│ │ │ ├── __init__.py
│ │ │ ├── .env.example
│ │ │ ├── agent.py
│ │ │ ├── prompts.py
│ │ │ ├── quick_mcp_test.py
│ │ │ ├── test_all_tools.py
│ │ │ └── tools
│ │ │ ├── __init__.py
│ │ │ ├── dbt_compile.py
│ │ │ ├── dbt_mcp.py
│ │ │ └── dbt_model_analyzer.py
│ │ ├── LICENSE
│ │ ├── README.md
│ │ └── requirements.txt
│ ├── google_adk_agent
│ │ ├── __init__.py
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ └── README.md
│ ├── langgraph_agent
│ │ ├── __init__.py
│ │ ├── .python-version
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── openai_agent
│ │ ├── __init__.py
│ │ ├── .gitignore
│ │ ├── .python-version
│ │ ├── main_streamable.py
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── openai_responses
│ │ ├── __init__.py
│ │ ├── .gitignore
│ │ ├── .python-version
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── pydantic_ai_agent
│ │ ├── __init__.py
│ │ ├── .gitignore
│ │ ├── .python-version
│ │ ├── main.py
│ │ ├── pyproject.toml
│ │ └── README.md
│ └── remote_mcp
│ ├── .python-version
│ ├── main.py
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── LICENSE
├── pyproject.toml
├── README.md
├── src
│ ├── client
│ │ ├── __init__.py
│ │ ├── main.py
│ │ └── tools.py
│ ├── dbt_mcp
│ │ ├── __init__.py
│ │ ├── .gitignore
│ │ ├── config
│ │ │ ├── config_providers.py
│ │ │ ├── config.py
│ │ │ ├── dbt_project.py
│ │ │ ├── dbt_yaml.py
│ │ │ ├── headers.py
│ │ │ ├── settings.py
│ │ │ └── transport.py
│ │ ├── dbt_admin
│ │ │ ├── __init__.py
│ │ │ ├── client.py
│ │ │ ├── constants.py
│ │ │ ├── run_results_errors
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ └── parser.py
│ │ │ └── tools.py
│ │ ├── dbt_cli
│ │ │ ├── binary_type.py
│ │ │ └── tools.py
│ │ ├── dbt_codegen
│ │ │ ├── __init__.py
│ │ │ └── tools.py
│ │ ├── discovery
│ │ │ ├── client.py
│ │ │ └── tools.py
│ │ ├── errors
│ │ │ ├── __init__.py
│ │ │ ├── admin_api.py
│ │ │ ├── base.py
│ │ │ ├── cli.py
│ │ │ ├── common.py
│ │ │ ├── discovery.py
│ │ │ ├── semantic_layer.py
│ │ │ └── sql.py
│ │ ├── gql
│ │ │ └── errors.py
│ │ ├── lsp
│ │ │ ├── __init__.py
│ │ │ ├── lsp_binary_manager.py
│ │ │ ├── lsp_client.py
│ │ │ ├── lsp_connection.py
│ │ │ └── tools.py
│ │ ├── main.py
│ │ ├── mcp
│ │ │ ├── create.py
│ │ │ └── server.py
│ │ ├── oauth
│ │ │ ├── client_id.py
│ │ │ ├── context_manager.py
│ │ │ ├── dbt_platform.py
│ │ │ ├── fastapi_app.py
│ │ │ ├── logging.py
│ │ │ ├── login.py
│ │ │ ├── refresh_strategy.py
│ │ │ ├── token_provider.py
│ │ │ └── token.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── admin_api
│ │ │ │ ├── cancel_job_run.md
│ │ │ │ ├── get_job_details.md
│ │ │ │ ├── get_job_run_artifact.md
│ │ │ │ ├── get_job_run_details.md
│ │ │ │ ├── get_job_run_error.md
│ │ │ │ ├── list_job_run_artifacts.md
│ │ │ │ ├── list_jobs_runs.md
│ │ │ │ ├── list_jobs.md
│ │ │ │ ├── retry_job_run.md
│ │ │ │ └── trigger_job_run.md
│ │ │ ├── dbt_cli
│ │ │ │ ├── args
│ │ │ │ │ ├── full_refresh.md
│ │ │ │ │ ├── limit.md
│ │ │ │ │ ├── resource_type.md
│ │ │ │ │ ├── selectors.md
│ │ │ │ │ ├── sql_query.md
│ │ │ │ │ └── vars.md
│ │ │ │ ├── build.md
│ │ │ │ ├── compile.md
│ │ │ │ ├── docs.md
│ │ │ │ ├── list.md
│ │ │ │ ├── parse.md
│ │ │ │ ├── run.md
│ │ │ │ ├── show.md
│ │ │ │ └── test.md
│ │ │ ├── dbt_codegen
│ │ │ │ ├── args
│ │ │ │ │ ├── case_sensitive_cols.md
│ │ │ │ │ ├── database_name.md
│ │ │ │ │ ├── generate_columns.md
│ │ │ │ │ ├── include_data_types.md
│ │ │ │ │ ├── include_descriptions.md
│ │ │ │ │ ├── leading_commas.md
│ │ │ │ │ ├── materialized.md
│ │ │ │ │ ├── model_name.md
│ │ │ │ │ ├── model_names.md
│ │ │ │ │ ├── schema_name.md
│ │ │ │ │ ├── source_name.md
│ │ │ │ │ ├── table_name.md
│ │ │ │ │ ├── table_names.md
│ │ │ │ │ ├── tables.md
│ │ │ │ │ └── upstream_descriptions.md
│ │ │ │ ├── generate_model_yaml.md
│ │ │ │ ├── generate_source.md
│ │ │ │ └── generate_staging_model.md
│ │ │ ├── discovery
│ │ │ │ ├── get_all_models.md
│ │ │ │ ├── get_all_sources.md
│ │ │ │ ├── get_exposure_details.md
│ │ │ │ ├── get_exposures.md
│ │ │ │ ├── get_mart_models.md
│ │ │ │ ├── get_model_children.md
│ │ │ │ ├── get_model_details.md
│ │ │ │ ├── get_model_health.md
│ │ │ │ └── get_model_parents.md
│ │ │ ├── lsp
│ │ │ │ ├── args
│ │ │ │ │ ├── column_name.md
│ │ │ │ │ └── model_id.md
│ │ │ │ └── get_column_lineage.md
│ │ │ ├── prompts.py
│ │ │ └── semantic_layer
│ │ │ ├── get_dimensions.md
│ │ │ ├── get_entities.md
│ │ │ ├── get_metrics_compiled_sql.md
│ │ │ ├── list_metrics.md
│ │ │ └── query_metrics.md
│ │ ├── py.typed
│ │ ├── semantic_layer
│ │ │ ├── client.py
│ │ │ ├── gql
│ │ │ │ ├── gql_request.py
│ │ │ │ └── gql.py
│ │ │ ├── levenshtein.py
│ │ │ ├── tools.py
│ │ │ └── types.py
│ │ ├── sql
│ │ │ └── tools.py
│ │ ├── telemetry
│ │ │ └── logging.py
│ │ ├── tools
│ │ │ ├── annotations.py
│ │ │ ├── definitions.py
│ │ │ ├── policy.py
│ │ │ ├── register.py
│ │ │ ├── tool_names.py
│ │ │ └── toolsets.py
│ │ └── tracking
│ │ └── tracking.py
│ └── remote_mcp
│ ├── __init__.py
│ └── session.py
├── Taskfile.yml
├── tests
│ ├── __init__.py
│ ├── env_vars.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── dbt_codegen
│ │ │ ├── __init__.py
│ │ │ └── test_dbt_codegen.py
│ │ ├── discovery
│ │ │ └── test_discovery.py
│ │ ├── initialization
│ │ │ ├── __init__.py
│ │ │ └── test_initialization.py
│ │ ├── lsp
│ │ │ └── test_lsp_connection.py
│ │ ├── remote_mcp
│ │ │ └── test_remote_mcp.py
│ │ ├── remote_tools
│ │ │ └── test_remote_tools.py
│ │ ├── semantic_layer
│ │ │ └── test_semantic_layer.py
│ │ └── tracking
│ │ └── test_tracking.py
│ ├── mocks
│ │ └── config.py
│ └── unit
│ ├── __init__.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── test_config.py
│ │ └── test_transport.py
│ ├── dbt_admin
│ │ ├── __init__.py
│ │ ├── test_client.py
│ │ ├── test_error_fetcher.py
│ │ └── test_tools.py
│ ├── dbt_cli
│ │ ├── __init__.py
│ │ ├── test_cli_integration.py
│ │ └── test_tools.py
│ ├── dbt_codegen
│ │ ├── __init__.py
│ │ └── test_tools.py
│ ├── discovery
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_exposures_fetcher.py
│ │ └── test_sources_fetcher.py
│ ├── lsp
│ │ ├── __init__.py
│ │ ├── test_lsp_client.py
│ │ ├── test_lsp_connection.py
│ │ └── test_lsp_tools.py
│ ├── oauth
│ │ ├── test_credentials_provider.py
│ │ ├── test_fastapi_app_pagination.py
│ │ └── test_token.py
│ ├── tools
│ │ ├── test_disable_tools.py
│ │ ├── test_tool_names.py
│ │ ├── test_tool_policies.py
│ │ └── test_toolsets.py
│ └── tracking
│ └── test_tracking.py
├── ui
│ ├── .gitignore
│ ├── assets
│ │ ├── dbt_logo BLK.svg
│ │ └── dbt_logo WHT.svg
│ ├── eslint.config.js
│ ├── index.html
│ ├── package.json
│ ├── pnpm-lock.yaml
│ ├── pnpm-workspace.yaml
│ ├── README.md
│ ├── src
│ │ ├── App.css
│ │ ├── App.tsx
│ │ ├── global.d.ts
│ │ ├── index.css
│ │ ├── main.tsx
│ │ └── vite-env.d.ts
│ ├── tsconfig.app.json
│ ├── tsconfig.json
│ ├── tsconfig.node.json
│ └── vite.config.ts
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/src/client/main.py:
--------------------------------------------------------------------------------
```python
import asyncio
import json
from time import time
from openai import OpenAI
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.response_output_message import ResponseOutputMessage
from client.tools import get_tools
from dbt_mcp.config.config import load_config
from dbt_mcp.mcp.server import create_dbt_mcp
LLM_MODEL = "gpt-4o-mini"
TOOL_RESPONSE_TRUNCATION = 100 # set to None for no truncation
llm_client = OpenAI()
config = load_config()
messages = []
async def main():
dbt_mcp = await create_dbt_mcp(config)
user_role = "user"
available_tools = await get_tools(dbt_mcp)
tools_str = "\n".join(
[
f"- {t['name']}({', '.join(t['parameters']['properties'].keys())})"
for t in available_tools
]
)
print(f"Available tools:\n{tools_str}")
while True:
user_input = input(f"{user_role} > ")
messages.append({"role": user_role, "content": user_input})
response_output = None
tool_call_error = None
while (
response_output is None
or response_output.type == "function_call"
or tool_call_error is not None
):
tool_call_error = None
response = llm_client.responses.create(
model=LLM_MODEL,
input=messages,
tools=available_tools,
parallel_tool_calls=False,
)
response_output = response.output[0]
if isinstance(response_output, ResponseOutputMessage):
print(f"{response_output.role} > {response_output.content[0].text}")
messages.append(response_output)
if response_output.type != "function_call":
continue
print(
f"Calling tool: {response_output.name} with arguments: {response_output.arguments}"
)
start_time = time()
try:
tool_response = await dbt_mcp.call_tool(
response_output.name,
json.loads(response_output.arguments),
)
except Exception as e:
tool_call_error = e
print(f"Error calling tool: {e}")
messages.append(
FunctionCallOutput(
type="function_call_output",
call_id=response_output.call_id,
output=str(e),
)
)
continue
tool_response_str = str(tool_response)
print(
f"Tool responded in {time() - start_time} seconds: "
+ (
f"{tool_response_str[:TOOL_RESPONSE_TRUNCATION]} [TRUNCATED]..."
if TOOL_RESPONSE_TRUNCATION
and len(tool_response_str) > TOOL_RESPONSE_TRUNCATION
else tool_response_str
)
)
messages.append(
FunctionCallOutput(
type="function_call_output",
call_id=response_output.call_id,
output=str(tool_response),
)
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nExiting.")
```
--------------------------------------------------------------------------------
/examples/openai_agent/main_streamable.py:
--------------------------------------------------------------------------------
```python
# mypy: ignore-errors
import asyncio
import os
from agents import Agent, Runner, trace
from agents.mcp import create_static_tool_filter
from agents.mcp.server import MCPServerStreamableHttp
from agents.stream_events import RawResponsesStreamEvent, RunItemStreamEvent
from openai.types.responses import ResponseCompletedEvent, ResponseOutputMessage
def print_tool_call(tool_name, params, color="yellow", show_params=True):
# Define color codes for different colors
# we could use a library like colorama but this avoids adding a dependency
color_codes = {
"grey": "\033[37m",
"yellow": "\033[93m",
}
color_code_reset = "\033[0m"
color_code = color_codes.get(color, color_codes["yellow"])
msg = f"Calling the tool {tool_name}"
if show_params:
msg += f" with params {params}"
print(f"{color_code}# {msg}{color_code_reset}")
def handle_event_printing(event, show_tools_calls=True):
if type(event) is RunItemStreamEvent and show_tools_calls:
if event.name == "tool_called":
print_tool_call(
event.item.raw_item.name,
event.item.raw_item.arguments,
color="grey",
show_params=True,
)
if type(event) is RawResponsesStreamEvent:
if type(event.data) is ResponseCompletedEvent:
for output in event.data.response.output:
if type(output) is ResponseOutputMessage:
print(output.content[0].text)
async def main(inspect_events_tools_calls=False):
prod_environment_id = os.environ.get("DBT_PROD_ENV_ID", os.getenv("DBT_ENV_ID"))
token = os.environ.get("DBT_TOKEN")
host = os.environ.get("DBT_HOST", "cloud.getdbt.com")
async with MCPServerStreamableHttp(
name="dbt",
params={
"url": f"https://{host}/api/ai/v1/mcp/",
"headers": {
"Authorization": f"token {token}",
"x-dbt-prod-environment-id": prod_environment_id,
},
},
client_session_timeout_seconds=20,
cache_tools_list=True,
tool_filter=create_static_tool_filter(
allowed_tool_names=[
"list_metrics",
"get_dimensions",
"get_entities",
"query_metrics",
"get_metrics_compiled_sql",
],
),
) as server:
agent = Agent(
name="Assistant",
instructions="Use the tools to answer the user's questions. Do not invent data or sample data.",
mcp_servers=[server],
model="gpt-5",
)
with trace(workflow_name="Conversation"):
conversation = []
result = None
while True:
if result:
conversation = result.to_input_list()
conversation.append({"role": "user", "content": input("User > ")})
if inspect_events_tools_calls:
async for event in Runner.run_streamed(
agent, conversation
).stream_events():
handle_event_printing(event, show_tools_calls=True)
else:
result = await Runner.run(agent, conversation)
print(result.final_output)
if __name__ == "__main__":
try:
asyncio.run(main(inspect_events_tools_calls=True))
except KeyboardInterrupt:
print("\nExiting.")
```
--------------------------------------------------------------------------------
/ui/assets/dbt_logo BLK.svg:
--------------------------------------------------------------------------------
```
<svg width="490" height="190" viewBox="0 0 490 190" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_33_23)">
<path d="M455.26 148.48C444.41 148.48 436.47 139.69 436.47 128.84V82.57H423.04V61.9H437.55L437.33 40.92L459.04 33.41V61.91H477.82V82.58H459.04V127.43H475.67V148.49H455.25L455.26 148.48Z" fill="black"/>
<path d="M381.33 149.48C369.96 149.48 361.35 145.8 355.14 135.81L353.94 135.98C354.11 140.29 354.28 144.03 354.28 148.51H331.37V33.53L354.63 25.67V53.97C354.63 62.24 354.63 66.72 353.77 74.47L355.15 74.81C361.53 64.47 371 59.82 382.03 59.82C405.46 59.82 420.1 79.8 420.1 105.3C420.1 130.8 404.94 149.46 381.34 149.46L381.33 149.48ZM375.82 128.63C388.74 128.63 397.01 118.92 397.01 104.97C397.01 91.02 388.74 80.68 375.82 80.68C362.9 80.68 354.12 90.84 354.12 105.14C354.12 119.44 362.56 128.63 375.82 128.63Z" fill="black"/>
<path d="M262.51 149.48C239.6 149.48 224.44 130.13 224.44 104.97C224.44 79.81 239.43 59.83 263.37 59.83C274.22 59.83 283.36 64.14 289.56 74.3L290.76 73.96C290.07 66.72 290.07 61.73 290.07 53.98V33.54L313.33 25.68V148.52H290.42C290.42 144.21 290.42 140.29 290.76 135.98L289.56 135.64C283.18 145.8 273.88 149.48 262.51 149.48ZM268.88 128.63C282.15 128.63 290.59 119.27 290.59 104.97C290.59 90.67 282.15 80.68 268.88 80.68C255.79 80.85 247.69 91.19 247.69 105.14C247.69 119.09 255.79 128.63 268.88 128.63Z" fill="black"/>
<path d="M485.216 148.816C482.608 148.816 480.8 146.976 480.8 144.368C480.8 141.776 482.624 139.92 485.216 139.92C487.808 139.92 489.632 141.776 489.632 144.368C489.632 146.976 487.824 148.816 485.216 148.816ZM485.216 148.032C487.328 148.032 488.752 146.48 488.752 144.368C488.752 142.256 487.328 140.72 485.216 140.72C483.104 140.72 481.68 142.256 481.68 144.368C481.68 146.48 483.104 148.032 485.216 148.032ZM483.744 146.56V142.112H485.52C486.448 142.112 486.928 142.64 486.928 143.424C486.928 144.144 486.448 144.656 485.696 144.656H485.664L487.2 146.544V146.56H486.144L484.592 144.608H484.576V146.56H483.744ZM484.576 144.064H485.408C485.84 144.064 486.08 143.808 486.08 143.44C486.08 143.088 485.824 142.848 485.408 142.848H484.576V144.064Z" fill="black"/>
<path d="M158.184 2.16438C166.564 -2.6797 175.59 1.19557 182.359 7.97729C189.45 15.082 192.351 22.8325 187.839 31.5518C186.227 34.7812 167.209 67.721 161.407 77.0863C158.184 82.2533 156.572 88.7121 156.572 94.8479C156.572 100.984 158.184 107.443 161.407 112.933C167.209 121.975 186.227 155.238 187.839 158.467C192.351 167.509 189.128 174.291 182.681 181.396C175.267 188.823 167.854 192.698 158.828 187.854C155.605 185.917 65.3511 133.924 65.3511 133.924C66.9627 144.581 72.7648 154.269 80.1785 160.082C79.2115 160.405 34.5761 186.232 31.5058 187.854C23.0444 192.326 15.3286 189.336 8.62001 183.01C1.04465 175.867 -2.66173 167.509 2.1733 158.79C3.78498 155.56 22.8028 122.298 28.2825 113.255C31.5058 107.765 33.4398 101.63 33.4398 95.1709C33.4398 88.7121 31.5058 82.5762 28.2825 77.4092C22.8028 67.721 3.78498 34.1354 2.1733 31.2289C-2.66173 22.5096 1.22016 13.1436 7.97534 7.00847C15.6327 0.0538926 22.8028 -2.03382 31.5058 2.16438C34.0845 3.1332 124.016 56.4182 124.016 56.4182C123.049 46.0841 117.892 36.7189 109.511 30.2601C110.156 29.9372 154.96 3.45614 158.184 2.16438ZM98.2293 110.995L111.123 98.0773C112.734 96.4626 112.734 93.8791 111.123 91.9415L98.2293 79.0239C96.2953 77.0863 93.7166 77.0863 91.7826 79.0239L78.8892 91.9415C77.2775 93.5562 77.2775 96.4626 78.8892 98.0773L91.7826 110.995C93.3942 112.61 96.2953 112.61 98.2293 110.995Z" fill="#FE6703"/>
</g>
<defs>
<clipPath id="clip0_33_23">
<rect width="490" height="190" fill="white"/>
</clipPath>
</defs>
</svg>
```
--------------------------------------------------------------------------------
/ui/assets/dbt_logo WHT.svg:
--------------------------------------------------------------------------------
```
<svg width="490" height="190" viewBox="0 0 490 190" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_33_49)">
<path d="M455.26 148.48C444.41 148.48 436.47 139.69 436.47 128.84V82.57H423.04V61.9H437.55L437.33 40.92L459.04 33.41V61.91H477.82V82.58H459.04V127.43H475.67V148.49H455.25L455.26 148.48Z" fill="white"/>
<path d="M381.33 149.48C369.96 149.48 361.35 145.8 355.14 135.81L353.94 135.98C354.11 140.29 354.28 144.03 354.28 148.51H331.37V33.53L354.63 25.67V53.97C354.63 62.24 354.63 66.72 353.77 74.47L355.15 74.81C361.53 64.47 371 59.82 382.03 59.82C405.46 59.82 420.1 79.8 420.1 105.3C420.1 130.8 404.94 149.46 381.34 149.46L381.33 149.48ZM375.82 128.63C388.74 128.63 397.01 118.92 397.01 104.97C397.01 91.02 388.74 80.68 375.82 80.68C362.9 80.68 354.12 90.84 354.12 105.14C354.12 119.44 362.56 128.63 375.82 128.63Z" fill="white"/>
<path d="M262.51 149.48C239.6 149.48 224.44 130.13 224.44 104.97C224.44 79.81 239.43 59.83 263.37 59.83C274.22 59.83 283.36 64.14 289.56 74.3L290.76 73.96C290.07 66.72 290.07 61.73 290.07 53.98V33.54L313.33 25.68V148.52H290.42C290.42 144.21 290.42 140.29 290.76 135.98L289.56 135.64C283.18 145.8 273.88 149.48 262.51 149.48ZM268.88 128.63C282.15 128.63 290.59 119.27 290.59 104.97C290.59 90.67 282.15 80.68 268.88 80.68C255.79 80.85 247.69 91.19 247.69 105.14C247.69 119.09 255.79 128.63 268.88 128.63Z" fill="white"/>
<path d="M485.216 148.816C482.608 148.816 480.8 146.976 480.8 144.368C480.8 141.776 482.624 139.92 485.216 139.92C487.808 139.92 489.632 141.776 489.632 144.368C489.632 146.976 487.824 148.816 485.216 148.816ZM485.216 148.032C487.328 148.032 488.752 146.48 488.752 144.368C488.752 142.256 487.328 140.72 485.216 140.72C483.104 140.72 481.68 142.256 481.68 144.368C481.68 146.48 483.104 148.032 485.216 148.032ZM483.744 146.56V142.112H485.52C486.448 142.112 486.928 142.64 486.928 143.424C486.928 144.144 486.448 144.656 485.696 144.656H485.664L487.2 146.544V146.56H486.144L484.592 144.608H484.576V146.56H483.744ZM484.576 144.064H485.408C485.84 144.064 486.08 143.808 486.08 143.44C486.08 143.088 485.824 142.848 485.408 142.848H484.576V144.064Z" fill="white"/>
<path d="M158.184 2.16438C166.564 -2.6797 175.59 1.19557 182.359 7.97729C189.45 15.082 192.351 22.8325 187.839 31.5518C186.227 34.7812 167.209 67.721 161.407 77.0863C158.184 82.2533 156.572 88.7121 156.572 94.8479C156.572 100.984 158.184 107.443 161.407 112.933C167.209 121.975 186.227 155.238 187.839 158.467C192.351 167.509 189.128 174.291 182.681 181.396C175.267 188.823 167.854 192.698 158.828 187.854C155.605 185.917 65.3511 133.924 65.3511 133.924C66.9627 144.581 72.7648 154.269 80.1785 160.082C79.2115 160.405 34.5761 186.232 31.5058 187.854C23.0444 192.326 15.3286 189.336 8.62001 183.01C1.04465 175.867 -2.66173 167.509 2.1733 158.79C3.78498 155.56 22.8028 122.298 28.2825 113.255C31.5058 107.765 33.4398 101.63 33.4398 95.1709C33.4398 88.7121 31.5058 82.5762 28.2825 77.4092C22.8028 67.721 3.78498 34.1354 2.1733 31.2289C-2.66173 22.5096 1.22016 13.1436 7.97534 7.00847C15.6327 0.0538926 22.8028 -2.03382 31.5058 2.16438C34.0845 3.1332 124.016 56.4182 124.016 56.4182C123.049 46.0841 117.892 36.7189 109.511 30.2601C110.156 29.9372 154.96 3.45614 158.184 2.16438ZM98.2293 110.995L111.123 98.0773C112.734 96.4626 112.734 93.8791 111.123 91.9415L98.2293 79.0239C96.2953 77.0863 93.7166 77.0863 91.7826 79.0239L78.8892 91.9415C77.2775 93.5562 77.2775 96.4626 78.8892 98.0773L91.7826 110.995C93.3942 112.61 96.2953 112.61 98.2293 110.995Z" fill="#FE6703"/>
</g>
<defs>
<clipPath id="clip0_33_49">
<rect width="490" height="190" fill="white"/>
</clipPath>
</defs>
</svg>
```
--------------------------------------------------------------------------------
/tests/unit/oauth/test_fastapi_app_pagination.py:
--------------------------------------------------------------------------------
```python
from unittest.mock import Mock, patch
import pytest
from dbt_mcp.oauth.dbt_platform import DbtPlatformAccount
from dbt_mcp.oauth.fastapi_app import (
_get_all_environments_for_project,
_get_all_projects_for_account,
)
@pytest.fixture
def base_headers():
return {"Accept": "application/json", "Authorization": "Bearer token"}
@pytest.fixture
def account():
return DbtPlatformAccount(
id=1,
name="Account 1",
locked=False,
state=1,
static_subdomain=None,
vanity_subdomain=None,
)
@patch("dbt_mcp.oauth.fastapi_app.requests.get")
def test_get_all_projects_for_account_paginates(mock_get: Mock, base_headers, account):
# Two pages: first full page (limit=2), second partial page (1 item) -> stop
first_page_resp = Mock()
first_page_resp.json.return_value = {
"data": [
{"id": 101, "name": "Proj A", "account_id": account.id},
{"id": 102, "name": "Proj B", "account_id": account.id},
]
}
first_page_resp.raise_for_status.return_value = None
second_page_resp = Mock()
second_page_resp.json.return_value = {
"data": [
{"id": 103, "name": "Proj C", "account_id": account.id},
]
}
second_page_resp.raise_for_status.return_value = None
mock_get.side_effect = [first_page_resp, second_page_resp]
result = _get_all_projects_for_account(
dbt_platform_url="https://cloud.getdbt.com",
account=account,
headers=base_headers,
page_size=2,
)
# Should aggregate 3 projects and include account_name field
assert len(result) == 3
assert {p.id for p in result} == {101, 102, 103}
assert all(p.account_name == account.name for p in result)
# Verify correct pagination URLs called
expected_urls = [
"https://cloud.getdbt.com/api/v3/accounts/1/projects/?state=1&offset=0&limit=2",
"https://cloud.getdbt.com/api/v3/accounts/1/projects/?state=1&offset=2&limit=2",
]
actual_urls = [
call.kwargs["url"] if "url" in call.kwargs else call.args[0]
for call in mock_get.call_args_list
]
assert actual_urls == expected_urls
@patch("dbt_mcp.oauth.fastapi_app.requests.get")
def test_get_all_environments_for_project_paginates(mock_get: Mock, base_headers):
# Two pages: first full page (limit=2), second partial (1 item)
first_page_resp = Mock()
first_page_resp.json.return_value = {
"data": [
{"id": 201, "name": "Dev", "deployment_type": "development"},
{"id": 202, "name": "Prod", "deployment_type": "production"},
]
}
first_page_resp.raise_for_status.return_value = None
second_page_resp = Mock()
second_page_resp.json.return_value = {
"data": [
{"id": 203, "name": "Staging", "deployment_type": "development"},
]
}
second_page_resp.raise_for_status.return_value = None
mock_get.side_effect = [first_page_resp, second_page_resp]
result = _get_all_environments_for_project(
dbt_platform_url="https://cloud.getdbt.com",
account_id=1,
project_id=9,
headers=base_headers,
page_size=2,
)
assert len(result) == 3
assert {e.id for e in result} == {201, 202, 203}
expected_urls = [
"https://cloud.getdbt.com/api/v3/accounts/1/projects/9/environments/?state=1&offset=0&limit=2",
"https://cloud.getdbt.com/api/v3/accounts/1/projects/9/environments/?state=1&offset=2&limit=2",
]
actual_urls = [
call.kwargs["url"] if "url" in call.kwargs else call.args[0]
for call in mock_get.call_args_list
]
assert actual_urls == expected_urls
```
--------------------------------------------------------------------------------
/tests/unit/dbt_cli/test_cli_integration.py:
--------------------------------------------------------------------------------
```python
import unittest
from unittest.mock import MagicMock, patch
from tests.mocks.config import mock_config
class TestDbtCliIntegration(unittest.TestCase):
@patch("subprocess.Popen")
def test_dbt_command_execution(self, mock_popen):
"""
Tests the full execution path for dbt commands, ensuring they are properly
executed with the right arguments.
"""
# Import here to prevent circular import issues during patching
from dbt_mcp.dbt_cli.tools import register_dbt_cli_tools
# Mock setup
mock_process = MagicMock()
mock_process.communicate.return_value = ("command output", None)
mock_popen.return_value = mock_process
# Create a mock FastMCP and Config
mock_fastmcp = MagicMock()
# Patch the tool decorator to capture functions
tools = {}
def mock_tool_decorator(**kwargs):
def decorator(func):
tools[func.__name__] = func
return func
return decorator
mock_fastmcp.tool = mock_tool_decorator
# Register the tools
register_dbt_cli_tools(mock_fastmcp, mock_config.dbt_cli_config)
# Test cases for different command types
test_cases = [
# Command name, args, expected command list
("build", [], ["/path/to/dbt", "--no-use-colors", "build", "--quiet"]),
(
"compile",
[],
["/path/to/dbt", "--no-use-colors", "compile", "--quiet"],
),
(
"docs",
[],
["/path/to/dbt", "--no-use-colors", "docs", "--quiet", "generate"],
),
(
"ls",
[],
["/path/to/dbt", "--no-use-colors", "list", "--quiet"],
),
("parse", [], ["/path/to/dbt", "--no-use-colors", "parse", "--quiet"]),
("run", [], ["/path/to/dbt", "--no-use-colors", "run", "--quiet"]),
("test", [], ["/path/to/dbt", "--no-use-colors", "test", "--quiet"]),
(
"show",
["SELECT * FROM model"],
[
"/path/to/dbt",
"--no-use-colors",
"show",
"--inline",
"SELECT * FROM model",
"--favor-state",
"--output",
"json",
],
),
(
"show",
["SELECT * FROM model", 10],
[
"/path/to/dbt",
"--no-use-colors",
"show",
"--inline",
"SELECT * FROM model",
"--favor-state",
"--limit",
"10",
"--output",
"json",
],
),
]
# Run each test case
for command_name, args, expected_args in test_cases:
mock_popen.reset_mock()
# Call the function
result = tools[command_name](*args)
# Verify the command was called correctly
mock_popen.assert_called_once()
actual_args = mock_popen.call_args.kwargs.get("args")
num_params = 4
self.assertEqual(actual_args[:num_params], expected_args[:num_params])
# Verify correct working directory
self.assertEqual(mock_popen.call_args.kwargs.get("cwd"), "/test/project")
# Verify the output is returned correctly
self.assertEqual(result, "command output")
if __name__ == "__main__":
unittest.main()
```
--------------------------------------------------------------------------------
/tests/mocks/config.py:
--------------------------------------------------------------------------------
```python
from dbt_mcp.config.config import (
Config,
DbtCliConfig,
DbtCodegenConfig,
LspConfig,
)
from dbt_mcp.config.config_providers import (
AdminApiConfig,
DefaultAdminApiConfigProvider,
DefaultDiscoveryConfigProvider,
DefaultSemanticLayerConfigProvider,
DefaultSqlConfigProvider,
DiscoveryConfig,
SemanticLayerConfig,
SqlConfig,
)
from dbt_mcp.config.headers import (
AdminApiHeadersProvider,
DiscoveryHeadersProvider,
SemanticLayerHeadersProvider,
SqlHeadersProvider,
)
from dbt_mcp.config.settings import CredentialsProvider, DbtMcpSettings
from dbt_mcp.dbt_cli.binary_type import BinaryType
from dbt_mcp.oauth.token_provider import StaticTokenProvider
mock_settings = DbtMcpSettings.model_construct()
mock_sql_config = SqlConfig(
url="http://localhost:8000",
prod_environment_id=1,
dev_environment_id=1,
user_id=1,
headers_provider=SqlHeadersProvider(
token_provider=StaticTokenProvider(token="token")
),
)
mock_dbt_cli_config = DbtCliConfig(
project_dir="/test/project",
dbt_path="/path/to/dbt",
dbt_cli_timeout=10,
binary_type=BinaryType.DBT_CORE,
)
mock_dbt_codegen_config = DbtCodegenConfig(
project_dir="/test/project",
dbt_path="/path/to/dbt",
dbt_cli_timeout=10,
binary_type=BinaryType.DBT_CORE,
)
mock_lsp_config = LspConfig(
project_dir="/test/project",
lsp_path="/path/to/lsp",
)
mock_discovery_config = DiscoveryConfig(
url="http://localhost:8000",
headers_provider=DiscoveryHeadersProvider(
token_provider=StaticTokenProvider(token="token")
),
environment_id=1,
)
mock_semantic_layer_config = SemanticLayerConfig(
host="localhost",
token="token",
url="http://localhost:8000",
headers_provider=SemanticLayerHeadersProvider(
token_provider=StaticTokenProvider(token="token")
),
prod_environment_id=1,
)
mock_admin_api_config = AdminApiConfig(
url="http://localhost:8000",
headers_provider=AdminApiHeadersProvider(
token_provider=StaticTokenProvider(token="token")
),
account_id=12345,
)
# Create mock config providers
class MockSqlConfigProvider(DefaultSqlConfigProvider):
def __init__(self):
pass # Skip the base class __init__
async def get_config(self):
return mock_sql_config
class MockDiscoveryConfigProvider(DefaultDiscoveryConfigProvider):
def __init__(self):
pass # Skip the base class __init__
async def get_config(self):
return mock_discovery_config
class MockSemanticLayerConfigProvider(DefaultSemanticLayerConfigProvider):
def __init__(self):
pass # Skip the base class __init__
async def get_config(self):
return mock_semantic_layer_config
class MockAdminApiConfigProvider(DefaultAdminApiConfigProvider):
def __init__(self):
pass # Skip the base class __init__
async def get_config(self):
return mock_admin_api_config
class MockCredentialsProvider(CredentialsProvider):
def __init__(self, settings: DbtMcpSettings | None = None):
super().__init__(settings or mock_settings)
self.token_provider = StaticTokenProvider(token=self.settings.dbt_token)
async def get_credentials(self):
return self.settings, self.token_provider
mock_config = Config(
sql_config_provider=MockSqlConfigProvider(),
dbt_cli_config=mock_dbt_cli_config,
dbt_codegen_config=mock_dbt_codegen_config,
discovery_config_provider=MockDiscoveryConfigProvider(),
semantic_layer_config_provider=MockSemanticLayerConfigProvider(),
admin_api_config_provider=MockAdminApiConfigProvider(),
lsp_config=mock_lsp_config,
disable_tools=[],
credentials_provider=MockCredentialsProvider(),
)
```
--------------------------------------------------------------------------------
/src/dbt_mcp/oauth/token_provider.py:
--------------------------------------------------------------------------------
```python
import asyncio
import logging
from typing import Protocol
from authlib.integrations.requests_client import OAuth2Session
from dbt_mcp.oauth.client_id import OAUTH_CLIENT_ID
from dbt_mcp.oauth.context_manager import DbtPlatformContextManager
from dbt_mcp.oauth.dbt_platform import dbt_platform_context_from_token_response
from dbt_mcp.oauth.refresh_strategy import DefaultRefreshStrategy, RefreshStrategy
from dbt_mcp.oauth.token import AccessTokenResponse
logger = logging.getLogger(__name__)
class TokenProvider(Protocol):
def get_token(self) -> str: ...
class OAuthTokenProvider:
"""
Token provider for OAuth access token with periodic refresh.
"""
def __init__(
self,
access_token_response: AccessTokenResponse,
dbt_platform_url: str,
context_manager: DbtPlatformContextManager,
refresh_strategy: RefreshStrategy | None = None,
):
self.access_token_response = access_token_response
self.context_manager = context_manager
self.dbt_platform_url = dbt_platform_url
self.refresh_strategy = refresh_strategy or DefaultRefreshStrategy()
self.token_url = f"{self.dbt_platform_url}/oauth/token"
self.oauth_client = OAuth2Session(
client_id=OAUTH_CLIENT_ID,
token_endpoint=self.token_url,
)
self.refresh_started = False
def _get_access_token_response(self) -> AccessTokenResponse:
dbt_platform_context = self.context_manager.read_context()
if not dbt_platform_context or not dbt_platform_context.decoded_access_token:
raise ValueError("No decoded access token found in context")
return dbt_platform_context.decoded_access_token.access_token_response
def get_token(self) -> str:
if not self.refresh_started:
self.start_background_refresh()
self.refresh_started = True
return self.access_token_response.access_token
def start_background_refresh(self) -> asyncio.Task[None]:
logger.info("Starting oauth token background refresh")
return asyncio.create_task(
self._background_refresh_worker(), name="oauth-token-refresh"
)
async def _refresh_token(self) -> None:
logger.info("Refreshing OAuth access token using authlib")
token_response = self.oauth_client.refresh_token(
url=self.token_url,
refresh_token=self.access_token_response.refresh_token,
)
dbt_platform_context = dbt_platform_context_from_token_response(
token_response, self.dbt_platform_url
)
self.context_manager.update_context(dbt_platform_context)
if not dbt_platform_context.decoded_access_token:
raise ValueError("No decoded access token found in context")
self.access_token_response = (
dbt_platform_context.decoded_access_token.access_token_response
)
logger.info("OAuth access token refreshed and context updated successfully")
async def _background_refresh_worker(self) -> None:
"""Background worker that periodically refreshes tokens before expiry."""
logger.info("Background token refresh worker started")
while True:
try:
await self.refresh_strategy.wait_until_refresh_needed(
self.access_token_response.expires_at
)
await self._refresh_token()
except Exception as e:
logger.error(f"Error in background refresh worker: {e}")
await self.refresh_strategy.wait_after_error()
class StaticTokenProvider:
"""
Token provider for tokens that aren't refreshed (e.g. service tokens and PATs)
"""
def __init__(self, token: str | None = None):
self.token = token
def get_token(self) -> str:
if not self.token:
raise ValueError("No token provided")
return self.token
```
--------------------------------------------------------------------------------
/tests/unit/oauth/test_credentials_provider.py:
--------------------------------------------------------------------------------
```python
from unittest.mock import MagicMock, patch
import pytest
from dbt_mcp.config.settings import (
AuthenticationMethod,
CredentialsProvider,
DbtMcpSettings,
)
class TestCredentialsProviderAuthenticationMethod:
"""Test the authentication_method field on CredentialsProvider"""
@pytest.mark.asyncio
async def test_authentication_method_oauth(self):
"""Test that authentication_method is set to OAUTH when using OAuth flow"""
mock_settings = DbtMcpSettings.model_construct(
dbt_host="cloud.getdbt.com",
dbt_prod_env_id=123,
dbt_account_id=456,
dbt_token=None, # No token means OAuth
)
credentials_provider = CredentialsProvider(mock_settings)
# Mock OAuth flow - create a properly structured context
mock_dbt_context = MagicMock()
mock_dbt_context.account_id = 456
mock_dbt_context.host_prefix = ""
mock_dbt_context.user_id = 789
mock_dbt_context.dev_environment.id = 111
mock_dbt_context.prod_environment.id = 123
mock_decoded_token = MagicMock()
mock_decoded_token.access_token_response.access_token = "mock_token"
mock_dbt_context.decoded_access_token = mock_decoded_token
with (
patch(
"dbt_mcp.config.settings.get_dbt_platform_context",
return_value=mock_dbt_context,
),
patch(
"dbt_mcp.config.settings.get_dbt_host", return_value="cloud.getdbt.com"
),
patch("dbt_mcp.config.settings.OAuthTokenProvider") as mock_token_provider,
patch("dbt_mcp.config.settings.validate_settings"),
):
mock_token_provider.return_value = MagicMock()
settings, token_provider = await credentials_provider.get_credentials()
assert (
credentials_provider.authentication_method == AuthenticationMethod.OAUTH
)
assert token_provider is not None
@pytest.mark.asyncio
async def test_authentication_method_env_var(self):
"""Test that authentication_method is set to ENV_VAR when using token from env"""
mock_settings = DbtMcpSettings.model_construct(
dbt_host="test.dbt.com",
dbt_prod_env_id=123,
dbt_token="test_token", # Token provided
)
credentials_provider = CredentialsProvider(mock_settings)
with patch("dbt_mcp.config.settings.validate_settings"):
settings, token_provider = await credentials_provider.get_credentials()
assert (
credentials_provider.authentication_method
== AuthenticationMethod.ENV_VAR
)
assert token_provider is not None
@pytest.mark.asyncio
async def test_authentication_method_initially_none(self):
"""Test that authentication_method starts as None before get_credentials is called"""
mock_settings = DbtMcpSettings.model_construct(
dbt_token="test_token",
)
credentials_provider = CredentialsProvider(mock_settings)
assert credentials_provider.authentication_method is None
@pytest.mark.asyncio
async def test_authentication_method_persists_after_get_credentials(self):
"""Test that authentication_method persists after get_credentials is called"""
mock_settings = DbtMcpSettings.model_construct(
dbt_host="test.dbt.com",
dbt_prod_env_id=123,
dbt_token="test_token",
)
credentials_provider = CredentialsProvider(mock_settings)
with patch("dbt_mcp.config.settings.validate_settings"):
# First call
await credentials_provider.get_credentials()
assert (
credentials_provider.authentication_method
== AuthenticationMethod.ENV_VAR
)
# Second call - should still be set
await credentials_provider.get_credentials()
assert (
credentials_provider.authentication_method
== AuthenticationMethod.ENV_VAR
)
```
--------------------------------------------------------------------------------
/src/dbt_mcp/config/config.py:
--------------------------------------------------------------------------------
```python
import os
from dataclasses import dataclass
from dbt_mcp.config.config_providers import (
DefaultAdminApiConfigProvider,
DefaultDiscoveryConfigProvider,
DefaultSemanticLayerConfigProvider,
DefaultSqlConfigProvider,
)
from dbt_mcp.config.settings import (
CredentialsProvider,
DbtMcpSettings,
)
from dbt_mcp.dbt_cli.binary_type import BinaryType, detect_binary_type
from dbt_mcp.telemetry.logging import configure_logging
from dbt_mcp.tools.tool_names import ToolName
PACKAGE_NAME = "dbt-mcp"
@dataclass
class DbtCliConfig:
project_dir: str
dbt_path: str
dbt_cli_timeout: int
binary_type: BinaryType
@dataclass
class DbtCodegenConfig:
project_dir: str
dbt_path: str
dbt_cli_timeout: int
binary_type: BinaryType
@dataclass
class LspConfig:
project_dir: str
lsp_path: str | None
@dataclass
class Config:
disable_tools: list[ToolName]
sql_config_provider: DefaultSqlConfigProvider | None
dbt_cli_config: DbtCliConfig | None
dbt_codegen_config: DbtCodegenConfig | None
discovery_config_provider: DefaultDiscoveryConfigProvider | None
semantic_layer_config_provider: DefaultSemanticLayerConfigProvider | None
admin_api_config_provider: DefaultAdminApiConfigProvider | None
credentials_provider: CredentialsProvider
lsp_config: LspConfig | None
def load_config() -> Config:
settings = DbtMcpSettings() # type: ignore
configure_logging(settings.file_logging)
credentials_provider = CredentialsProvider(settings)
# Set default warn error options if not provided
if settings.dbt_warn_error_options is None:
warn_error_options = '{"error": ["NoNodesForSelectionCriteria"]}'
os.environ["DBT_WARN_ERROR_OPTIONS"] = warn_error_options
# Build configurations
sql_config_provider = None
if not settings.actual_disable_sql:
sql_config_provider = DefaultSqlConfigProvider(
credentials_provider=credentials_provider,
)
admin_api_config_provider = None
if not settings.disable_admin_api:
admin_api_config_provider = DefaultAdminApiConfigProvider(
credentials_provider=credentials_provider,
)
dbt_cli_config = None
if not settings.disable_dbt_cli and settings.dbt_project_dir and settings.dbt_path:
binary_type = detect_binary_type(settings.dbt_path)
dbt_cli_config = DbtCliConfig(
project_dir=settings.dbt_project_dir,
dbt_path=settings.dbt_path,
dbt_cli_timeout=settings.dbt_cli_timeout,
binary_type=binary_type,
)
dbt_codegen_config = None
if (
not settings.disable_dbt_codegen
and settings.dbt_project_dir
and settings.dbt_path
):
binary_type = detect_binary_type(settings.dbt_path)
dbt_codegen_config = DbtCodegenConfig(
project_dir=settings.dbt_project_dir,
dbt_path=settings.dbt_path,
dbt_cli_timeout=settings.dbt_cli_timeout,
binary_type=binary_type,
)
discovery_config_provider = None
if not settings.disable_discovery:
discovery_config_provider = DefaultDiscoveryConfigProvider(
credentials_provider=credentials_provider,
)
semantic_layer_config_provider = None
if not settings.disable_semantic_layer:
semantic_layer_config_provider = DefaultSemanticLayerConfigProvider(
credentials_provider=credentials_provider,
)
lsp_config = None
if not settings.disable_lsp and settings.dbt_project_dir:
lsp_config = LspConfig(
project_dir=settings.dbt_project_dir,
lsp_path=settings.dbt_lsp_path,
)
return Config(
disable_tools=settings.disable_tools or [],
sql_config_provider=sql_config_provider,
dbt_cli_config=dbt_cli_config,
dbt_codegen_config=dbt_codegen_config,
discovery_config_provider=discovery_config_provider,
semantic_layer_config_provider=semantic_layer_config_provider,
admin_api_config_provider=admin_api_config_provider,
credentials_provider=credentials_provider,
lsp_config=lsp_config,
)
```
--------------------------------------------------------------------------------
/tests/integration/semantic_layer/test_semantic_layer.py:
--------------------------------------------------------------------------------
```python
import pytest
from dbtsl.api.shared.query_params import GroupByParam, GroupByType
import pyarrow as pa
from dbt_mcp.config.config import load_config
from dbt_mcp.semantic_layer.client import (
DefaultSemanticLayerClientProvider,
SemanticLayerFetcher,
)
from dbt_mcp.semantic_layer.types import OrderByParam
config = load_config()
@pytest.fixture
def semantic_layer_fetcher() -> SemanticLayerFetcher:
assert config.semantic_layer_config_provider is not None
return SemanticLayerFetcher(
config_provider=config.semantic_layer_config_provider,
client_provider=DefaultSemanticLayerClientProvider(
config_provider=config.semantic_layer_config_provider,
),
)
async def test_semantic_layer_list_metrics(
semantic_layer_fetcher: SemanticLayerFetcher,
):
metrics = await semantic_layer_fetcher.list_metrics()
assert len(metrics) > 0
async def test_semantic_layer_list_dimensions(
semantic_layer_fetcher: SemanticLayerFetcher,
):
metrics = await semantic_layer_fetcher.list_metrics()
dimensions = await semantic_layer_fetcher.get_dimensions(metrics=[metrics[0].name])
assert len(dimensions) > 0
async def test_semantic_layer_query_metrics(
semantic_layer_fetcher: SemanticLayerFetcher,
):
result = await semantic_layer_fetcher.query_metrics(
metrics=["revenue"],
group_by=[
GroupByParam(
name="metric_time",
type=GroupByType.TIME_DIMENSION,
grain=None,
)
],
)
assert result is not None
async def test_semantic_layer_query_metrics_invalid_query(
semantic_layer_fetcher: SemanticLayerFetcher,
):
result = await semantic_layer_fetcher.query_metrics(
metrics=["food_revenue"],
group_by=[
GroupByParam(
name="order_id__location__location_name",
type=GroupByType.DIMENSION,
grain=None,
),
GroupByParam(
name="metric_time",
type=GroupByType.TIME_DIMENSION,
grain="MONTH",
),
],
order_by=[
OrderByParam(
name="metric_time",
descending=True,
),
OrderByParam(
name="food_revenue",
descending=True,
),
],
limit=5,
)
assert result is not None
async def test_semantic_layer_query_metrics_with_group_by_grain(
semantic_layer_fetcher: SemanticLayerFetcher,
):
result = await semantic_layer_fetcher.query_metrics(
metrics=["revenue"],
group_by=[
GroupByParam(
name="metric_time",
type=GroupByType.TIME_DIMENSION,
grain="day",
)
],
)
assert result is not None
async def test_semantic_layer_query_metrics_with_order_by(
semantic_layer_fetcher: SemanticLayerFetcher,
):
result = await semantic_layer_fetcher.query_metrics(
metrics=["revenue"],
group_by=[
GroupByParam(
name="metric_time",
type=GroupByType.TIME_DIMENSION,
grain=None,
)
],
order_by=[OrderByParam(name="metric_time", descending=True)],
)
assert result is not None
async def test_semantic_layer_query_metrics_with_misspellings(
semantic_layer_fetcher: SemanticLayerFetcher,
):
result = await semantic_layer_fetcher.query_metrics(["revehue"])
assert result.result is not None
assert "revenue" in result.result
async def test_semantic_layer_get_entities(
semantic_layer_fetcher: SemanticLayerFetcher,
):
entities = await semantic_layer_fetcher.get_entities(
metrics=["count_dbt_copilot_requests"]
)
assert len(entities) > 0
async def test_semantic_layer_query_metrics_with_csv_formatter(
semantic_layer_fetcher: SemanticLayerFetcher,
):
def csv_formatter(table: pa.Table) -> str:
return table.to_pandas().to_csv(index=False)
result = await semantic_layer_fetcher.query_metrics(
metrics=["revenue"],
group_by=[
GroupByParam(
name="metric_time",
type=GroupByType.TIME_DIMENSION,
grain=None,
)
],
result_formatter=csv_formatter,
)
assert result.result is not None
assert "revenue" in result.result.casefold()
# CSV format should have comma separators
assert "," in result.result
```
--------------------------------------------------------------------------------
/examples/aws_strands_agent/dbt_data_scientist/tools/dbt_mcp.py:
--------------------------------------------------------------------------------
```python
"""dbt MCP Tool - Remote dbt MCP server connection for AWS Bedrock Agent Core."""
import os
from strands import Agent, tool
from mcp import ClientSession
from dotenv import load_dotenv
from mcp.client.streamable_http import streamablehttp_client
from strands.tools.mcp.mcp_client import MCPClient
# Load environment variables
load_dotenv()
DBT_MCP_URL = os.environ.get("DBT_MCP_URL")
DBT_USER_ID = os.environ.get("DBT_USER_ID")
DBT_PROD_ENV_ID = os.environ.get("DBT_PROD_ENV_ID")
DBT_DEV_ENV_ID = os.environ.get("DBT_DEV_ENV_ID")
DBT_ACCOUNT_ID = os.environ.get("DBT_ACCOUNT_ID")
DBT_TOKEN = os.environ.get("DBT_TOKEN")
DBT_MCP_AGENT_SYSTEM_PROMPT = """
You are a dbt MCP server expert, a specialized assistant for dbt MCP server analysis and troubleshooting. Your capabilities include:
When asked to 'view features available on the dbt MCP server', or 'ask about a specific tool or function', inspect the dbt MCP server and return a result based on the available tools and functions.
"""
# Create MCP client once at module level
def create_dbt_mcp_client():
"""Create the dbt MCP client with proper configuration."""
load_dotenv()
if not DBT_MCP_URL:
raise ValueError("DBT_MCP_URL environment variable is required")
return MCPClient(lambda: streamablehttp_client(
url=DBT_MCP_URL,
headers={
"x-dbt-user-id": DBT_USER_ID,
"x-dbt-prod-environment-id": DBT_PROD_ENV_ID,
"x-dbt-dev-environment-id": DBT_DEV_ENV_ID,
"x-dbt-account-id": DBT_ACCOUNT_ID,
"Authorization": f"token {DBT_TOKEN}",
},
))
# Global MCP client instance
dbt_mcp_client = create_dbt_mcp_client()
@tool
def dbt_mcp_tool(query: str) -> str:
"""
Connects to remote dbt MCP server and executes queries.
Args:
query: The user's question about dbt MCP server functionality
Returns:
String response with dbt MCP server results
"""
try:
print(f"Connecting to dbt MCP server for query: {query}")
with dbt_mcp_client:
# Get available tools from MCP server
tools = dbt_mcp_client.list_tools_sync()
if not tools:
return "No tools available on the dbt MCP server."
# If user asks to list tools, return them
if "list" in query.lower() and ("tool" in query.lower() or "feature" in query.lower()):
tool_list = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
return f"Available dbt MCP tools:\n{tool_list}"
# For other queries, try to find and execute the most relevant tool
# This is a simplified approach - in practice you'd want more sophisticated routing
if tools:
# Try to call the first available tool as an example
first_tool = tools[0]
try:
result = dbt_mcp_client.call_tool_sync(first_tool.name, {})
return f"Executed {first_tool.name}: {result}"
except Exception as e:
return f"Error executing {first_tool.name}: {str(e)}"
return f"Found {len(tools)} tools on dbt MCP server. Use 'list tools' to see them."
except Exception as e:
return f"Error connecting to dbt MCP server: {str(e)}"
def test_connection():
"""Test function to verify MCP connectivity."""
print("🧪 Testing dbt MCP connection...")
try:
with dbt_mcp_client:
tools = dbt_mcp_client.list_tools_sync()
print(f"✅ Successfully connected to dbt MCP server!")
print(f"📋 Found {len(tools)} available tools:")
for i, tool in enumerate(tools, 1):
print(f" {i}. {tool.name}: {tool.description}")
return True
except Exception as e:
print(f"❌ Connection failed: {e}")
return False
# Test the connection when this module is run directly
if __name__ == "__main__":
print("🔌 dbt MCP Server Connection Test")
print("=" * 40)
# Check environment variables
load_dotenv()
required_vars = ["DBT_MCP_URL", "DBT_TOKEN", "DBT_USER_ID", "DBT_PROD_ENV_ID"]
missing_vars = [var for var in required_vars if not os.environ.get(var)]
if missing_vars:
print(f"❌ Missing required environment variables: {', '.join(missing_vars)}")
print("Please set these in your .env file or environment.")
sys.exit(1)
# Test connection
success = test_connection()
if success:
print("\n🎉 MCP connection test passed!")
print("You can now run the agent: python dbt_data_scientist/agent.py")
else:
print("\n💥 MCP connection test failed!")
print("Please check your configuration and try again.")
sys.exit(0 if success else 1)
```
--------------------------------------------------------------------------------
/src/dbt_mcp/semantic_layer/tools.py:
--------------------------------------------------------------------------------
```python
import logging
from collections.abc import Sequence
from dbtsl.api.shared.query_params import GroupByParam
from mcp.server.fastmcp import FastMCP
from dbt_mcp.config.config_providers import (
ConfigProvider,
SemanticLayerConfig,
)
from dbt_mcp.prompts.prompts import get_prompt
from dbt_mcp.semantic_layer.client import (
SemanticLayerClientProvider,
SemanticLayerFetcher,
)
from dbt_mcp.semantic_layer.types import (
DimensionToolResponse,
EntityToolResponse,
GetMetricsCompiledSqlSuccess,
MetricToolResponse,
OrderByParam,
QueryMetricsSuccess,
)
from dbt_mcp.tools.annotations import create_tool_annotations
from dbt_mcp.tools.definitions import ToolDefinition
from dbt_mcp.tools.register import register_tools
from dbt_mcp.tools.tool_names import ToolName
logger = logging.getLogger(__name__)
def create_sl_tool_definitions(
config_provider: ConfigProvider[SemanticLayerConfig],
client_provider: SemanticLayerClientProvider,
) -> list[ToolDefinition]:
semantic_layer_fetcher = SemanticLayerFetcher(
config_provider=config_provider,
client_provider=client_provider,
)
async def list_metrics(search: str | None = None) -> list[MetricToolResponse]:
return await semantic_layer_fetcher.list_metrics(search=search)
async def get_dimensions(
metrics: list[str], search: str | None = None
) -> list[DimensionToolResponse]:
return await semantic_layer_fetcher.get_dimensions(
metrics=metrics, search=search
)
async def get_entities(
metrics: list[str], search: str | None = None
) -> list[EntityToolResponse]:
return await semantic_layer_fetcher.get_entities(metrics=metrics, search=search)
async def query_metrics(
metrics: list[str],
group_by: list[GroupByParam] | None = None,
order_by: list[OrderByParam] | None = None,
where: str | None = None,
limit: int | None = None,
) -> str:
result = await semantic_layer_fetcher.query_metrics(
metrics=metrics,
group_by=group_by,
order_by=order_by,
where=where,
limit=limit,
)
if isinstance(result, QueryMetricsSuccess):
return result.result
else:
return result.error
async def get_metrics_compiled_sql(
metrics: list[str],
group_by: list[GroupByParam] | None = None,
order_by: list[OrderByParam] | None = None,
where: str | None = None,
limit: int | None = None,
) -> str:
result = await semantic_layer_fetcher.get_metrics_compiled_sql(
metrics=metrics,
group_by=group_by,
order_by=order_by,
where=where,
limit=limit,
)
if isinstance(result, GetMetricsCompiledSqlSuccess):
return result.sql
else:
return result.error
return [
ToolDefinition(
description=get_prompt("semantic_layer/list_metrics"),
fn=list_metrics,
annotations=create_tool_annotations(
title="List Metrics",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("semantic_layer/get_dimensions"),
fn=get_dimensions,
annotations=create_tool_annotations(
title="Get Dimensions",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("semantic_layer/get_entities"),
fn=get_entities,
annotations=create_tool_annotations(
title="Get Entities",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("semantic_layer/query_metrics"),
fn=query_metrics,
annotations=create_tool_annotations(
title="Query Metrics",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("semantic_layer/get_metrics_compiled_sql"),
fn=get_metrics_compiled_sql,
annotations=create_tool_annotations(
title="Compile SQL",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
]
def register_sl_tools(
dbt_mcp: FastMCP,
config_provider: ConfigProvider[SemanticLayerConfig],
client_provider: SemanticLayerClientProvider,
exclude_tools: Sequence[ToolName] = [],
) -> None:
register_tools(
dbt_mcp,
create_sl_tool_definitions(config_provider, client_provider),
exclude_tools,
)
```
--------------------------------------------------------------------------------
/src/dbt_mcp/lsp/lsp_client.py:
--------------------------------------------------------------------------------
```python
"""LSP Client for dbt Fusion.
This module provides a high-level client interface for interacting with the
dbt Fusion LSP server, wrapping low-level JSON-RPC communication with
typed methods for dbt-specific operations.
"""
import asyncio
import logging
from typing import Any
from dbt_mcp.lsp.lsp_connection import LSPConnection, LspEventName
logger = logging.getLogger(__name__)
# Default timeout for LSP operations (in seconds)
DEFAULT_LSP_TIMEOUT = 30
class LSPClient:
"""High-level client for dbt Fusion LSP operations.
This class provides typed methods for dbt-specific LSP operations
such as column lineage, model references, and more.
"""
def __init__(self, lsp_connection: LSPConnection, timeout: float | None = None):
"""Initialize the dbt LSP client.
Args:
lsp_connection: The LSP connection to use
timeout: Default timeout for LSP operations in seconds. If not specified,
uses DEFAULT_LSP_TIMEOUT (30 seconds).
"""
self.lsp_connection = lsp_connection
self.timeout = timeout if timeout is not None else DEFAULT_LSP_TIMEOUT
async def compile(self, timeout: float | None = None) -> dict[str, Any]:
"""Compile the dbt project.
Returns the compilation log as dictionary.
"""
# Register for the notification BEFORE sending the command to avoid race conditions
compile_complete_future = self.lsp_connection.wait_for_notification(
LspEventName.compileComplete
)
async with asyncio.timeout(timeout or self.timeout):
await self.lsp_connection.send_request(
"workspace/executeCommand",
{"command": "dbt.compileLsp", "arguments": []},
)
# wait for complation to complete
result = await compile_complete_future
if "error" in result and result["error"] is not None:
return {"error": result["error"]}
if "log" in result and result["log"] is not None:
return {"log": result["log"]}
return result
async def get_column_lineage(
self,
model_id: str,
column_name: str,
timeout: float | None = None,
) -> dict[str, Any]:
"""Get column lineage information for a specific model column.
Args:
model_id: The dbt model identifier
column_name: The column name to trace lineage for
Returns:
Dictionary containing lineage information with 'nodes' key
"""
if not self.lsp_connection.state.compiled:
await self.compile()
logger.info(f"Requesting column lineage for {model_id}.{column_name}")
selector = f"+column:{model_id}.{column_name.upper()}+"
async with asyncio.timeout(timeout or self.timeout):
result = await self.lsp_connection.send_request(
"workspace/executeCommand",
{"command": "dbt.listNodes", "arguments": [selector]},
)
if not result:
return {"error": "No result from LSP"}
if "error" in result and result["error"] is not None:
return {"error": result["error"]}
if "nodes" in result and result["nodes"] is not None:
return {"nodes": result["nodes"]}
return result
async def get_model_lineage(self, model_selector: str) -> dict[str, Any]:
nodes = []
response = await self._list_nodes(model_selector)
if not response:
return {"error": "No result from LSP"}
if "error" in response and response["error"] is not None:
return {"error": response["error"]}
if "nodes" in response and response["nodes"] is not None:
for node in response["nodes"]:
nodes.append(
{
"depends_on": node["depends_on"],
"name": node["name"],
"unique_id": node["unique_id"],
"path": node["path"],
}
)
return {"nodes": nodes}
async def _list_nodes(
self, model_selector: str, timeout: float | None = None
) -> dict[str, Any]:
"""List nodes in the dbt project."""
if not self.lsp_connection.state.compiled:
await self.compile()
logger.info("Listing nodes", extra={"model_selector": model_selector})
async with asyncio.timeout(timeout or self.timeout):
result = await self.lsp_connection.send_request(
"workspace/executeCommand",
{"command": "dbt.listNodes", "arguments": [model_selector]},
)
if not result:
return {"error": "No result from LSP"}
if "error" in result and result["error"] is not None:
return {"error": result["error"]}
if "nodes" in result and result["nodes"] is not None:
return {"nodes": result["nodes"]}
return result
```
--------------------------------------------------------------------------------
/src/dbt_mcp/config/config_providers.py:
--------------------------------------------------------------------------------
```python
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dbt_mcp.config.headers import (
AdminApiHeadersProvider,
DiscoveryHeadersProvider,
HeadersProvider,
SemanticLayerHeadersProvider,
SqlHeadersProvider,
)
from dbt_mcp.config.settings import CredentialsProvider
@dataclass
class SemanticLayerConfig:
url: str
host: str
prod_environment_id: int
token: str
headers_provider: HeadersProvider
@dataclass
class DiscoveryConfig:
url: str
headers_provider: HeadersProvider
environment_id: int
@dataclass
class AdminApiConfig:
url: str
headers_provider: HeadersProvider
account_id: int
prod_environment_id: int | None = None
@dataclass
class SqlConfig:
user_id: int
dev_environment_id: int
prod_environment_id: int
url: str
headers_provider: HeadersProvider
class ConfigProvider[ConfigType](ABC):
@abstractmethod
async def get_config(self) -> ConfigType: ...
class DefaultSemanticLayerConfigProvider(ConfigProvider[SemanticLayerConfig]):
def __init__(self, credentials_provider: CredentialsProvider):
self.credentials_provider = credentials_provider
async def get_config(self) -> SemanticLayerConfig:
settings, token_provider = await self.credentials_provider.get_credentials()
assert (
settings.actual_host
and settings.actual_prod_environment_id
and settings.dbt_token
)
is_local = settings.actual_host and settings.actual_host.startswith("localhost")
if is_local:
host = settings.actual_host
elif settings.actual_host_prefix:
host = (
f"{settings.actual_host_prefix}.semantic-layer.{settings.actual_host}"
)
else:
host = f"semantic-layer.{settings.actual_host}"
assert host is not None
return SemanticLayerConfig(
url=f"http://{host}" if is_local else f"https://{host}" + "/api/graphql",
host=host,
prod_environment_id=settings.actual_prod_environment_id,
token=settings.dbt_token,
headers_provider=SemanticLayerHeadersProvider(
token_provider=token_provider
),
)
class DefaultDiscoveryConfigProvider(ConfigProvider[DiscoveryConfig]):
def __init__(self, credentials_provider: CredentialsProvider):
self.credentials_provider = credentials_provider
async def get_config(self) -> DiscoveryConfig:
settings, token_provider = await self.credentials_provider.get_credentials()
assert (
settings.actual_host
and settings.actual_prod_environment_id
and settings.dbt_token
)
if settings.actual_host_prefix:
url = f"https://{settings.actual_host_prefix}.metadata.{settings.actual_host}/graphql"
else:
url = f"https://metadata.{settings.actual_host}/graphql"
return DiscoveryConfig(
url=url,
headers_provider=DiscoveryHeadersProvider(token_provider=token_provider),
environment_id=settings.actual_prod_environment_id,
)
class DefaultAdminApiConfigProvider(ConfigProvider[AdminApiConfig]):
def __init__(self, credentials_provider: CredentialsProvider):
self.credentials_provider = credentials_provider
async def get_config(self) -> AdminApiConfig:
settings, token_provider = await self.credentials_provider.get_credentials()
assert settings.dbt_token and settings.actual_host and settings.dbt_account_id
if settings.actual_host_prefix:
url = f"https://{settings.actual_host_prefix}.{settings.actual_host}"
else:
url = f"https://{settings.actual_host}"
return AdminApiConfig(
url=url,
headers_provider=AdminApiHeadersProvider(token_provider=token_provider),
account_id=settings.dbt_account_id,
prod_environment_id=settings.actual_prod_environment_id,
)
class DefaultSqlConfigProvider(ConfigProvider[SqlConfig]):
def __init__(self, credentials_provider: CredentialsProvider):
self.credentials_provider = credentials_provider
async def get_config(self) -> SqlConfig:
settings, token_provider = await self.credentials_provider.get_credentials()
assert (
settings.dbt_user_id
and settings.dbt_token
and settings.dbt_dev_env_id
and settings.actual_prod_environment_id
and settings.actual_host
)
is_local = settings.actual_host and settings.actual_host.startswith("localhost")
path = "/v1/mcp/" if is_local else "/api/ai/v1/mcp/"
scheme = "http://" if is_local else "https://"
host_prefix = (
f"{settings.actual_host_prefix}." if settings.actual_host_prefix else ""
)
url = f"{scheme}{host_prefix}{settings.actual_host}{path}"
return SqlConfig(
user_id=settings.dbt_user_id,
dev_environment_id=settings.dbt_dev_env_id,
prod_environment_id=settings.actual_prod_environment_id,
url=url,
headers_provider=SqlHeadersProvider(token_provider=token_provider),
)
```
--------------------------------------------------------------------------------
/src/dbt_mcp/sql/tools.py:
--------------------------------------------------------------------------------
```python
import logging
from collections.abc import Sequence
from contextlib import AsyncExitStack
from typing import (
Annotated,
Any,
)
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import ClientSession
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.tools.base import Tool as InternalTool
from mcp.server.fastmcp.utilities.func_metadata import (
ArgModelBase,
FuncMetadata,
_get_typed_annotation,
)
from mcp.shared.message import SessionMessage
from mcp.types import (
ContentBlock,
Tool,
)
from pydantic import Field, WithJsonSchema, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from dbt_mcp.config.config_providers import ConfigProvider, SqlConfig
from dbt_mcp.errors import RemoteToolError
from dbt_mcp.tools.tool_names import ToolName
from dbt_mcp.tools.toolsets import Toolset, toolsets
logger = logging.getLogger(__name__)
# Based on this: https://github.com/modelcontextprotocol/python-sdk/blob/9ae4df85fbab97bf476ddd160b766ca4c208cd13/src/mcp/server/fastmcp/utilities/func_metadata.py#L105
def get_remote_tool_fn_metadata(tool: Tool) -> FuncMetadata:
dynamic_pydantic_model_params: dict[str, Any] = {}
for key in tool.inputSchema["properties"]:
# Remote tools shouldn't have type annotations or default values
# for their arguments. So, we set them to defaults.
field_info = FieldInfo.from_annotated_attribute(
annotation=_get_typed_annotation(
annotation=Annotated[
Any,
Field(),
WithJsonSchema({"title": key, "type": "string"}),
],
globalns={},
),
default=PydanticUndefined,
)
dynamic_pydantic_model_params[key] = (field_info.annotation, None)
return FuncMetadata(
arg_model=create_model(
f"{tool.name}Arguments",
**dynamic_pydantic_model_params,
__base__=ArgModelBase,
)
)
async def _get_sql_tools(session: ClientSession) -> list[Tool]:
try:
sql_tool_names = {t.value for t in toolsets[Toolset.SQL]}
return [
t for t in (await session.list_tools()).tools if t.name in sql_tool_names
]
except Exception as e:
logger.error(f"Error getting SQL tools: {e}")
return []
class SqlToolsManager:
_stack = AsyncExitStack()
async def get_remote_mcp_session(
self, url: str, headers: dict[str, str]
) -> ClientSession:
streamablehttp_client_context: tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback,
] = await self._stack.enter_async_context(
streamablehttp_client(
url=url,
headers=headers,
)
)
read_stream, write_stream, _ = streamablehttp_client_context
return await self._stack.enter_async_context(
ClientSession(read_stream, write_stream)
)
@classmethod
async def close(cls) -> None:
await cls._stack.aclose()
async def register_sql_tools(
dbt_mcp: FastMCP,
config_provider: ConfigProvider[SqlConfig],
exclude_tools: Sequence[ToolName] = [],
) -> None:
"""
Register SQL MCP tools.
SQL tools are hosted remotely, so their definitions aren't found in this repo.
"""
config = await config_provider.get_config()
headers = {
"x-dbt-prod-environment-id": str(config.prod_environment_id),
"x-dbt-dev-environment-id": str(config.dev_environment_id),
"x-dbt-user-id": str(config.user_id),
} | config.headers_provider.get_headers()
sql_tools_manager = SqlToolsManager()
session = await sql_tools_manager.get_remote_mcp_session(config.url, headers)
await session.initialize()
sql_tools = await _get_sql_tools(session)
logger.info(f"Loaded sql tools: {', '.join([tool.name for tool in sql_tools])}")
for tool in sql_tools:
if tool.name.lower() in [tool.value.lower() for tool in exclude_tools]:
continue
# Create a new function using a factory to avoid closure issues
def create_tool_function(tool_name: str):
async def tool_function(*args, **kwargs) -> Sequence[ContentBlock]:
tool_call_result = await session.call_tool(
tool_name,
kwargs,
)
if tool_call_result.isError:
raise RemoteToolError(
f"Tool {tool_name} reported an error: "
+ f"{tool_call_result.content}"
)
return tool_call_result.content
return tool_function
dbt_mcp._tool_manager._tools[tool.name] = InternalTool(
fn=create_tool_function(tool.name),
title=tool.title,
name=tool.name,
annotations=tool.annotations,
description=tool.description or "",
parameters=tool.inputSchema,
fn_metadata=get_remote_tool_fn_metadata(tool),
is_async=True,
context_kwarg=None,
)
```
--------------------------------------------------------------------------------
/tests/unit/dbt_admin/test_error_fetcher.py:
--------------------------------------------------------------------------------
```python
import json
from unittest.mock import AsyncMock, Mock
import pytest
from dbt_mcp.config.config_providers import AdminApiConfig
from dbt_mcp.dbt_admin.run_results_errors.parser import ErrorFetcher
class MockHeadersProvider:
"""Mock headers provider for testing."""
def get_headers(self) -> dict[str, str]:
return {"Authorization": "Bearer test_token"}
@pytest.fixture
def admin_config():
"""Admin API config for testing."""
return AdminApiConfig(
account_id=12345,
headers_provider=MockHeadersProvider(),
url="https://cloud.getdbt.com",
)
@pytest.fixture
def mock_client():
"""Base mock client - behavior configured per test."""
return Mock()
@pytest.mark.parametrize(
"run_details,artifact_responses,expected_step_count,expected_error_messages",
[
# Cancelled run
(
{
"id": 300,
"status": 30,
"is_cancelled": True,
"finished_at": "2024-01-01T09:00:00Z",
"run_steps": [],
},
[],
1,
["Job run was cancelled"],
),
# Source freshness fails (doesn't stop job) + model error downstream
(
{
"id": 400,
"status": 20,
"is_cancelled": False,
"finished_at": "2024-01-01T10:00:00Z",
"run_steps": [
{
"index": 1,
"name": "Source freshness",
"status": 20,
"finished_at": "2024-01-01T09:30:00Z",
},
{
"index": 2,
"name": "Invoke dbt with `dbt build`",
"status": 20,
"finished_at": "2024-01-01T10:00:00Z",
},
],
},
[
None, # Source freshness artifact not available
{
"results": [
{
"unique_id": "model.test_model",
"status": "error",
"message": "Model compilation failed",
"relation_name": "analytics.test_model",
}
],
"args": {"target": "prod"},
},
],
2,
["Source freshness error - returning logs", "Model compilation failed"],
),
],
)
async def test_error_scenarios(
mock_client,
admin_config,
run_details,
artifact_responses,
expected_step_count,
expected_error_messages,
):
"""Test various error scenarios with parametrized data."""
# Map step_index to run_results_content
step_index_to_run_results = {}
for i, failed_step in enumerate(run_details.get("run_steps", [])):
if i < len(artifact_responses):
step_index = failed_step["index"]
step_index_to_run_results[step_index] = artifact_responses[i]
async def mock_get_artifact(account_id, run_id, artifact_path, step=None):
run_results_content = step_index_to_run_results.get(step)
if run_results_content is None:
raise Exception("Artifact not available")
return json.dumps(run_results_content)
mock_client.get_job_run_artifact = AsyncMock(side_effect=mock_get_artifact)
error_fetcher = ErrorFetcher(
run_id=run_details["id"],
run_details=run_details,
client=mock_client,
admin_api_config=admin_config,
)
result = await error_fetcher.analyze_run_errors()
assert len(result["failed_steps"]) == expected_step_count
for i, expected_msg in enumerate(expected_error_messages):
assert expected_msg in result["failed_steps"][i]["errors"][0]["message"]
async def test_schema_validation_failure(mock_client, admin_config):
"""Test handling of run_results.json schema changes - should fallback to logs."""
run_details = {
"id": 400,
"status": 20,
"is_cancelled": False,
"finished_at": "2024-01-01T11:00:00Z",
"run_steps": [
{
"index": 1,
"name": "Invoke dbt with `dbt build`",
"status": 20,
"finished_at": "2024-01-01T11:00:00Z",
"logs": "Model compilation failed due to missing table",
}
],
}
# Return valid JSON but with missing required fields (schema mismatch)
# Expected schema: {"results": [...], "args": {...}, "metadata": {...}}
mock_client.get_job_run_artifact = AsyncMock(
return_value='{"metadata": {"some": "value"}, "invalid_field": true}'
)
error_fetcher = ErrorFetcher(
run_id=400,
run_details=run_details,
client=mock_client,
admin_api_config=admin_config,
)
result = await error_fetcher.analyze_run_errors()
# Should fallback to logs when schema validation fails
assert len(result["failed_steps"]) == 1
step = result["failed_steps"][0]
assert step["step_name"] == "Invoke dbt with `dbt build`"
assert "run_results.json not available" in step["errors"][0]["message"]
assert "Model compilation failed" in step["errors"][0]["truncated_logs"]
```
--------------------------------------------------------------------------------
/src/dbt_mcp/mcp/server.py:
--------------------------------------------------------------------------------
```python
import logging
import time
import uuid
from collections.abc import AsyncIterator, Callable, Sequence
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
from typing import Any
from dbtlabs_vortex.producer import shutdown
from mcp.server.fastmcp import FastMCP
from mcp.server.lowlevel.server import LifespanResultT
from mcp.types import (
ContentBlock,
TextContent,
)
from dbt_mcp.config.config import Config
from dbt_mcp.dbt_admin.tools import register_admin_api_tools
from dbt_mcp.dbt_cli.tools import register_dbt_cli_tools
from dbt_mcp.dbt_codegen.tools import register_dbt_codegen_tools
from dbt_mcp.discovery.tools import register_discovery_tools
from dbt_mcp.semantic_layer.client import DefaultSemanticLayerClientProvider
from dbt_mcp.semantic_layer.tools import register_sl_tools
from dbt_mcp.sql.tools import SqlToolsManager, register_sql_tools
from dbt_mcp.tracking.tracking import DefaultUsageTracker, ToolCalledEvent, UsageTracker
from dbt_mcp.lsp.tools import cleanup_lsp_connection, register_lsp_tools
logger = logging.getLogger(__name__)
class DbtMCP(FastMCP):
def __init__(
self,
config: Config,
usage_tracker: UsageTracker,
lifespan: Callable[["DbtMCP"], AbstractAsyncContextManager[LifespanResultT]],
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs, lifespan=lifespan)
self.usage_tracker = usage_tracker
self.config = config
async def call_tool(
self, name: str, arguments: dict[str, Any]
) -> Sequence[ContentBlock] | dict[str, Any]:
logger.info(f"Calling tool: {name} with arguments: {arguments}")
result = None
start_time = int(time.time() * 1000)
try:
result = await super().call_tool(
name,
arguments,
)
except Exception as e:
end_time = int(time.time() * 1000)
logger.error(
f"Error calling tool: {name} with arguments: {arguments} "
+ f"in {end_time - start_time}ms: {e}"
)
await self.usage_tracker.emit_tool_called_event(
tool_called_event=ToolCalledEvent(
tool_name=name,
arguments=arguments,
start_time_ms=start_time,
end_time_ms=end_time,
error_message=str(e),
),
)
return [
TextContent(
type="text",
text=str(e),
)
]
end_time = int(time.time() * 1000)
logger.info(f"Tool {name} called successfully in {end_time - start_time}ms")
await self.usage_tracker.emit_tool_called_event(
tool_called_event=ToolCalledEvent(
tool_name=name,
arguments=arguments,
start_time_ms=start_time,
end_time_ms=end_time,
error_message=None,
),
)
return result
@asynccontextmanager
async def app_lifespan(server: DbtMCP) -> AsyncIterator[None]:
logger.info("Starting MCP server")
try:
yield
except Exception as e:
logger.error(f"Error in MCP server: {e}")
raise e
finally:
logger.info("Shutting down MCP server")
try:
await SqlToolsManager.close()
except Exception:
logger.exception("Error closing SQL tools manager")
try:
await cleanup_lsp_connection()
except Exception:
logger.exception("Error cleaning up LSP connection")
try:
shutdown()
except Exception:
logger.exception("Error shutting down MCP server")
async def create_dbt_mcp(config: Config) -> DbtMCP:
dbt_mcp = DbtMCP(
config=config,
usage_tracker=DefaultUsageTracker(
credentials_provider=config.credentials_provider,
session_id=uuid.uuid4(),
),
name="dbt",
lifespan=app_lifespan,
)
if config.semantic_layer_config_provider:
logger.info("Registering semantic layer tools")
register_sl_tools(
dbt_mcp,
config_provider=config.semantic_layer_config_provider,
client_provider=DefaultSemanticLayerClientProvider(
config_provider=config.semantic_layer_config_provider,
),
exclude_tools=config.disable_tools,
)
if config.discovery_config_provider:
logger.info("Registering discovery tools")
register_discovery_tools(
dbt_mcp, config.discovery_config_provider, config.disable_tools
)
if config.dbt_cli_config:
logger.info("Registering dbt cli tools")
register_dbt_cli_tools(dbt_mcp, config.dbt_cli_config, config.disable_tools)
if config.dbt_codegen_config:
logger.info("Registering dbt codegen tools")
register_dbt_codegen_tools(
dbt_mcp, config.dbt_codegen_config, config.disable_tools
)
if config.admin_api_config_provider:
logger.info("Registering dbt admin API tools")
register_admin_api_tools(
dbt_mcp, config.admin_api_config_provider, config.disable_tools
)
if config.sql_config_provider:
logger.info("Registering SQL tools")
await register_sql_tools(
dbt_mcp, config.sql_config_provider, config.disable_tools
)
if config.lsp_config:
logger.info("Registering LSP tools")
await register_lsp_tools(dbt_mcp, config.lsp_config, config.disable_tools)
return dbt_mcp
```
--------------------------------------------------------------------------------
/src/dbt_mcp/lsp/tools.py:
--------------------------------------------------------------------------------
```python
import functools
import inspect
import logging
from collections.abc import Callable, Sequence
from typing import Any
from mcp.server.fastmcp import FastMCP
from pydantic import Field
from dbt_mcp.config.config import LspConfig
from dbt_mcp.lsp.lsp_binary_manager import dbt_lsp_binary_info
from dbt_mcp.lsp.lsp_client import LSPClient
from dbt_mcp.lsp.lsp_connection import LSPConnection
from dbt_mcp.prompts.prompts import get_prompt
from dbt_mcp.tools.annotations import create_tool_annotations
from dbt_mcp.tools.definitions import ToolDefinition
from dbt_mcp.tools.register import register_tools
from dbt_mcp.tools.tool_names import ToolName
logger = logging.getLogger(__name__)
# Module-level LSP connection to manage lifecycle
_lsp_connection: LSPConnection | None = None
async def register_lsp_tools(
server: FastMCP,
config: LspConfig,
exclude_tools: Sequence[ToolName] | None = None,
) -> None:
register_tools(
server,
await list_lsp_tools(config),
exclude_tools or [],
)
async def list_lsp_tools(config: LspConfig) -> list[ToolDefinition]:
"""Register dbt Fusion tools with the MCP server.
Args:
config: LSP configuration containing LSP settings
Returns:
List of tool definitions for LSP tools
"""
global _lsp_connection
# Only initialize if not already initialized
if _lsp_connection is None:
lsp_binary_path = dbt_lsp_binary_info(config.lsp_path)
if not lsp_binary_path:
logger.warning("No LSP binary path found")
return []
logger.info(
f"Using LSP binary in {lsp_binary_path.path} with version {lsp_binary_path.version}"
)
_lsp_connection = LSPConnection(
binary_path=lsp_binary_path.path,
args=[],
cwd=config.project_dir,
)
def call_with_lsp_client(func: Callable) -> Callable:
"""Call a function with the LSP connection manager."""
@functools.wraps(func)
async def wrapper(*args, **kwargs) -> Any:
global _lsp_connection
if _lsp_connection is None:
return "LSP connection not initialized"
if not _lsp_connection.state.initialized:
try:
await _lsp_connection.start()
await _lsp_connection.initialize()
logger.info("LSP connection started and initialized successfully")
except Exception as e:
logger.error(f"Error starting LSP connection: {e}")
# Clean up failed connection
_lsp_connection = None
return "Error: Failed to establish LSP connection"
lsp_client = LSPClient(_lsp_connection)
return await func(lsp_client, *args, **kwargs)
# remove the lsp_client argument from the signature
wrapper.__signature__ = inspect.signature(func).replace( # type: ignore
parameters=[
param
for param in inspect.signature(func).parameters.values()
if param.name != "lsp_client"
]
)
return wrapper
return [
ToolDefinition(
fn=call_with_lsp_client(get_column_lineage),
description=get_prompt("lsp/get_column_lineage"),
annotations=create_tool_annotations(
title="get_column_lineage",
read_only_hint=False,
destructive_hint=False,
idempotent_hint=True,
),
),
]
async def get_column_lineage(
lsp_client: LSPClient,
model_id: str = Field(description=get_prompt("lsp/args/model_id")),
column_name: str = Field(description=get_prompt("lsp/args/column_name")),
) -> dict[str, Any]:
"""Get column lineage for a specific model column.
Args:
lsp_client: The LSP client instance
model_id: The dbt model identifier
column_name: The column name to trace lineage for
Returns:
Dictionary with either:
- 'nodes' key containing lineage information on success
- 'error' key containing error message on failure
"""
try:
response = await lsp_client.get_column_lineage(
model_id=model_id,
column_name=column_name,
)
# Check for LSP-level errors
if "error" in response:
logger.error(f"LSP error getting column lineage: {response['error']}")
return {"error": f"LSP error: {response['error']}"}
# Validate response has expected data
if "nodes" not in response or not response["nodes"]:
logger.warning(f"No column lineage found for {model_id}.{column_name}")
return {
"error": f"No column lineage found for model {model_id} and column {column_name}"
}
return {"nodes": response["nodes"]}
except TimeoutError:
error_msg = f"Timeout waiting for column lineage (model: {model_id}, column: {column_name})"
logger.error(error_msg)
return {"error": error_msg}
except Exception as e:
error_msg = (
f"Failed to get column lineage for {model_id}.{column_name}: {str(e)}"
)
logger.error(error_msg)
return {"error": error_msg}
async def cleanup_lsp_connection() -> None:
"""Clean up the LSP connection when shutting down."""
global _lsp_connection
if _lsp_connection:
try:
logger.info("Cleaning up LSP connection")
await _lsp_connection.stop()
except Exception as e:
logger.error(f"Error cleaning up LSP connection: {e}")
finally:
_lsp_connection = None
```
--------------------------------------------------------------------------------
/src/dbt_mcp/tools/policy.py:
--------------------------------------------------------------------------------
```python
from enum import Enum
from pydantic.dataclasses import dataclass
from dbt_mcp.tools.tool_names import ToolName
class ToolBehavior(Enum):
"""Behavior of the tool."""
# The tool can return row-level data.
RESULT_SET = "result_set"
# The tool only returns metadata.
METADATA = "metadata"
@dataclass
class ToolPolicy:
"""Policy for a tool."""
name: str
behavior: ToolBehavior
# Defining tool policies is important for our internal usage of dbt-mcp.
# Our policies dictate that we do not send row-level data to LLMs.
tool_policies = {
# CLI tools
ToolName.SHOW.value: ToolPolicy(
name=ToolName.SHOW.value, behavior=ToolBehavior.RESULT_SET
),
ToolName.LIST.value: ToolPolicy(
name=ToolName.LIST.value, behavior=ToolBehavior.METADATA
),
ToolName.DOCS.value: ToolPolicy(
name=ToolName.DOCS.value, behavior=ToolBehavior.METADATA
),
# Compile tool can have result_set behavior because of macros like print_table
ToolName.COMPILE.value: ToolPolicy(
name=ToolName.COMPILE.value, behavior=ToolBehavior.RESULT_SET
),
ToolName.TEST.value: ToolPolicy(
name=ToolName.TEST.value, behavior=ToolBehavior.METADATA
),
# Run tool can have result_set behavior because of macros like print_table
ToolName.RUN.value: ToolPolicy(
name=ToolName.RUN.value, behavior=ToolBehavior.RESULT_SET
),
# Build tool can have result_set behavior because of macros like print_table
ToolName.BUILD.value: ToolPolicy(
name=ToolName.BUILD.value, behavior=ToolBehavior.RESULT_SET
),
ToolName.PARSE.value: ToolPolicy(
name=ToolName.PARSE.value, behavior=ToolBehavior.METADATA
),
# Semantic Layer tools
ToolName.LIST_METRICS.value: ToolPolicy(
name=ToolName.LIST_METRICS.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_DIMENSIONS.value: ToolPolicy(
name=ToolName.GET_DIMENSIONS.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_ENTITIES.value: ToolPolicy(
name=ToolName.GET_ENTITIES.value, behavior=ToolBehavior.METADATA
),
ToolName.QUERY_METRICS.value: ToolPolicy(
name=ToolName.QUERY_METRICS.value, behavior=ToolBehavior.RESULT_SET
),
ToolName.GET_METRICS_COMPILED_SQL.value: ToolPolicy(
name=ToolName.GET_METRICS_COMPILED_SQL.value, behavior=ToolBehavior.METADATA
),
# Discovery tools
ToolName.GET_MODEL_PARENTS.value: ToolPolicy(
name=ToolName.GET_MODEL_PARENTS.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_MODEL_CHILDREN.value: ToolPolicy(
name=ToolName.GET_MODEL_CHILDREN.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_MODEL_DETAILS.value: ToolPolicy(
name=ToolName.GET_MODEL_DETAILS.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_MODEL_HEALTH.value: ToolPolicy(
name=ToolName.GET_MODEL_HEALTH.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_MART_MODELS.value: ToolPolicy(
name=ToolName.GET_MART_MODELS.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_ALL_MODELS.value: ToolPolicy(
name=ToolName.GET_ALL_MODELS.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_ALL_SOURCES.value: ToolPolicy(
name=ToolName.GET_ALL_SOURCES.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_EXPOSURES.value: ToolPolicy(
name=ToolName.GET_EXPOSURES.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_EXPOSURE_DETAILS.value: ToolPolicy(
name=ToolName.GET_EXPOSURE_DETAILS.value, behavior=ToolBehavior.METADATA
),
# SQL tools
ToolName.TEXT_TO_SQL.value: ToolPolicy(
name=ToolName.TEXT_TO_SQL.value, behavior=ToolBehavior.METADATA
),
ToolName.EXECUTE_SQL.value: ToolPolicy(
name=ToolName.EXECUTE_SQL.value, behavior=ToolBehavior.RESULT_SET
),
# Admin API tools
ToolName.LIST_JOBS.value: ToolPolicy(
name=ToolName.LIST_JOBS.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_JOB_DETAILS.value: ToolPolicy(
name=ToolName.GET_JOB_DETAILS.value, behavior=ToolBehavior.METADATA
),
ToolName.TRIGGER_JOB_RUN.value: ToolPolicy(
name=ToolName.TRIGGER_JOB_RUN.value, behavior=ToolBehavior.METADATA
),
ToolName.LIST_JOBS_RUNS.value: ToolPolicy(
name=ToolName.LIST_JOBS_RUNS.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_JOB_RUN_DETAILS.value: ToolPolicy(
name=ToolName.GET_JOB_RUN_DETAILS.value, behavior=ToolBehavior.METADATA
),
ToolName.CANCEL_JOB_RUN.value: ToolPolicy(
name=ToolName.CANCEL_JOB_RUN.value, behavior=ToolBehavior.METADATA
),
ToolName.RETRY_JOB_RUN.value: ToolPolicy(
name=ToolName.RETRY_JOB_RUN.value, behavior=ToolBehavior.METADATA
),
ToolName.LIST_JOB_RUN_ARTIFACTS.value: ToolPolicy(
name=ToolName.LIST_JOB_RUN_ARTIFACTS.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_JOB_RUN_ARTIFACT.value: ToolPolicy(
name=ToolName.GET_JOB_RUN_ARTIFACT.value, behavior=ToolBehavior.METADATA
),
ToolName.GET_JOB_RUN_ERROR.value: ToolPolicy(
name=ToolName.GET_JOB_RUN_ERROR.value, behavior=ToolBehavior.METADATA
),
# dbt-codegen tools
ToolName.GENERATE_SOURCE.value: ToolPolicy(
name=ToolName.GENERATE_SOURCE.value, behavior=ToolBehavior.METADATA
),
ToolName.GENERATE_MODEL_YAML.value: ToolPolicy(
name=ToolName.GENERATE_MODEL_YAML.value, behavior=ToolBehavior.METADATA
),
ToolName.GENERATE_STAGING_MODEL.value: ToolPolicy(
name=ToolName.GENERATE_STAGING_MODEL.value, behavior=ToolBehavior.METADATA
),
# LSP tools
ToolName.GET_COLUMN_LINEAGE.value: ToolPolicy(
name=ToolName.GET_COLUMN_LINEAGE.value, behavior=ToolBehavior.METADATA
),
}
```
--------------------------------------------------------------------------------
/src/dbt_mcp/discovery/tools.py:
--------------------------------------------------------------------------------
```python
import logging
from collections.abc import Sequence
from mcp.server.fastmcp import FastMCP
from dbt_mcp.config.config_providers import (
ConfigProvider,
DiscoveryConfig,
)
from dbt_mcp.discovery.client import (
ExposuresFetcher,
MetadataAPIClient,
ModelsFetcher,
SourcesFetcher,
)
from dbt_mcp.prompts.prompts import get_prompt
from dbt_mcp.tools.annotations import create_tool_annotations
from dbt_mcp.tools.definitions import ToolDefinition
from dbt_mcp.tools.register import register_tools
from dbt_mcp.tools.tool_names import ToolName
logger = logging.getLogger(__name__)
def create_discovery_tool_definitions(
config_provider: ConfigProvider[DiscoveryConfig],
) -> list[ToolDefinition]:
api_client = MetadataAPIClient(config_provider=config_provider)
models_fetcher = ModelsFetcher(api_client=api_client)
exposures_fetcher = ExposuresFetcher(api_client=api_client)
sources_fetcher = SourcesFetcher(api_client=api_client)
async def get_mart_models() -> list[dict]:
mart_models = await models_fetcher.fetch_models(
model_filter={"modelingLayer": "marts"}
)
return [m for m in mart_models if m["name"] != "metricflow_time_spine"]
async def get_all_models() -> list[dict]:
return await models_fetcher.fetch_models()
async def get_model_details(
model_name: str | None = None, unique_id: str | None = None
) -> dict:
return await models_fetcher.fetch_model_details(model_name, unique_id)
async def get_model_parents(
model_name: str | None = None, unique_id: str | None = None
) -> list[dict]:
return await models_fetcher.fetch_model_parents(model_name, unique_id)
async def get_model_children(
model_name: str | None = None, unique_id: str | None = None
) -> list[dict]:
return await models_fetcher.fetch_model_children(model_name, unique_id)
async def get_model_health(
model_name: str | None = None, unique_id: str | None = None
) -> list[dict]:
return await models_fetcher.fetch_model_health(model_name, unique_id)
async def get_exposures() -> list[dict]:
return await exposures_fetcher.fetch_exposures()
async def get_exposure_details(
exposure_name: str | None = None, unique_ids: list[str] | None = None
) -> list[dict]:
return await exposures_fetcher.fetch_exposure_details(exposure_name, unique_ids)
async def get_all_sources(
source_names: list[str] | None = None,
unique_ids: list[str] | None = None,
) -> list[dict]:
return await sources_fetcher.fetch_sources(source_names, unique_ids)
return [
ToolDefinition(
description=get_prompt("discovery/get_mart_models"),
fn=get_mart_models,
annotations=create_tool_annotations(
title="Get Mart Models",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("discovery/get_all_models"),
fn=get_all_models,
annotations=create_tool_annotations(
title="Get All Models",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("discovery/get_model_details"),
fn=get_model_details,
annotations=create_tool_annotations(
title="Get Model Details",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("discovery/get_model_parents"),
fn=get_model_parents,
annotations=create_tool_annotations(
title="Get Model Parents",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("discovery/get_model_children"),
fn=get_model_children,
annotations=create_tool_annotations(
title="Get Model Children",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("discovery/get_model_health"),
fn=get_model_health,
annotations=create_tool_annotations(
title="Get Model Health",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("discovery/get_exposures"),
fn=get_exposures,
annotations=create_tool_annotations(
title="Get Exposures",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("discovery/get_exposure_details"),
fn=get_exposure_details,
annotations=create_tool_annotations(
title="Get Exposure Details",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
description=get_prompt("discovery/get_all_sources"),
fn=get_all_sources,
annotations=create_tool_annotations(
title="Get All Sources",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
]
def register_discovery_tools(
dbt_mcp: FastMCP,
config_provider: ConfigProvider[DiscoveryConfig],
exclude_tools: Sequence[ToolName] = [],
) -> None:
register_tools(
dbt_mcp,
create_discovery_tool_definitions(config_provider),
exclude_tools,
)
```
--------------------------------------------------------------------------------
/tests/unit/lsp/test_lsp_tools.py:
--------------------------------------------------------------------------------
```python
from pathlib import Path
from unittest.mock import AsyncMock, patch
import pytest
from dbt_mcp.config.config import LspConfig
from dbt_mcp.lsp.lsp_binary_manager import LspBinaryInfo
from dbt_mcp.lsp.tools import (
cleanup_lsp_connection,
get_column_lineage,
register_lsp_tools,
)
from dbt_mcp.lsp.lsp_client import LSPClient
from dbt_mcp.mcp.server import FastMCP
from dbt_mcp.tools.tool_names import ToolName
@pytest.fixture
def test_mcp_server() -> FastMCP:
"""Create a mock FastMCP server."""
server = FastMCP(
name="test",
)
return server
@pytest.fixture
def lsp_config(tmp_path: Path) -> LspConfig:
"""Create a test LSP configuration."""
return LspConfig(
lsp_path="/usr/local/bin/dbt-lsp",
project_dir=str(tmp_path),
)
@pytest.mark.asyncio
async def test_register_lsp_tools_no_binary(
test_mcp_server: FastMCP, lsp_config: LspConfig
) -> None:
"""Test that registration fails gracefully when no LSP binary is found."""
with patch("dbt_mcp.lsp.tools.dbt_lsp_binary_info", return_value=None):
await register_lsp_tools(test_mcp_server, lsp_config)
assert not await test_mcp_server.list_tools()
@pytest.mark.asyncio
async def test_register_lsp_tools_success(
test_mcp_server: FastMCP, lsp_config: LspConfig
) -> None:
"""Test successful registration of LSP tools."""
lsp_connection_mock = AsyncMock()
lsp_connection_mock.start = AsyncMock()
lsp_connection_mock.initialize = AsyncMock()
with (
patch(
"dbt_mcp.lsp.tools.dbt_lsp_binary_info",
return_value=LspBinaryInfo(path="/path/to/lsp", version="1.0.0"),
),
patch("dbt_mcp.lsp.tools.LSPConnection", return_value=lsp_connection_mock),
):
await register_lsp_tools(test_mcp_server, lsp_config)
# Verify correct tools were registered
tool_names = [tool.name for tool in await test_mcp_server.list_tools()]
assert ToolName.GET_COLUMN_LINEAGE.value in tool_names
@pytest.mark.asyncio
async def test_cleanup_lsp_connection() -> None:
"""Test that cleanup_lsp_connection properly stops the LSP connection."""
mock_connection = AsyncMock()
mock_connection.stop = AsyncMock()
with patch("dbt_mcp.lsp.tools._lsp_connection", mock_connection):
await cleanup_lsp_connection()
mock_connection.stop.assert_called_once()
@pytest.mark.asyncio
async def test_cleanup_lsp_connection_no_connection() -> None:
"""Test that cleanup_lsp_connection handles no connection gracefully."""
with patch("dbt_mcp.lsp.tools._lsp_connection", None):
# Should not raise any exceptions
await cleanup_lsp_connection()
@pytest.mark.asyncio
async def test_cleanup_lsp_connection_error() -> None:
"""Test that cleanup_lsp_connection handles errors gracefully."""
mock_connection = AsyncMock()
mock_connection.stop = AsyncMock(side_effect=Exception("Stop failed"))
with patch("dbt_mcp.lsp.tools._lsp_connection", mock_connection):
# Should not raise the exception, but log it
await cleanup_lsp_connection()
mock_connection.stop.assert_called_once()
@pytest.mark.asyncio
async def test_register_lsp_tools_idempotent(
test_mcp_server: FastMCP, lsp_config: LspConfig
) -> None:
"""Test that registering LSP tools twice doesn't create duplicate connections."""
import dbt_mcp.lsp.tools as tools_module
lsp_connection_mock = AsyncMock()
lsp_connection_mock.start = AsyncMock()
lsp_connection_mock.initialize = AsyncMock()
# Reset the module-level connection
tools_module._lsp_connection = None
try:
with (
patch(
"dbt_mcp.lsp.tools.dbt_lsp_binary_info",
return_value=LspBinaryInfo(path="/path/to/lsp", version="1.0.0"),
),
patch(
"dbt_mcp.lsp.tools.LSPConnection", return_value=lsp_connection_mock
) as connection_constructor,
):
# Register twice
await register_lsp_tools(test_mcp_server, lsp_config)
await register_lsp_tools(test_mcp_server, lsp_config)
# Connection should only be created once (idempotent)
assert connection_constructor.call_count == 1
finally:
# Clean up module state
tools_module._lsp_connection = None
@pytest.mark.asyncio
async def test_get_column_lineage_success() -> None:
"""Test successful column lineage retrieval."""
mock_lsp_client = AsyncMock(spec=LSPClient)
mock_lsp_client.get_column_lineage = AsyncMock(
return_value={"nodes": [{"id": "model.project.table", "column": "id"}]}
)
result = await get_column_lineage(mock_lsp_client, "model.project.table", "id")
assert "nodes" in result
assert len(result["nodes"]) == 1
assert result["nodes"][0]["id"] == "model.project.table"
mock_lsp_client.get_column_lineage.assert_called_once_with(
model_id="model.project.table", column_name="id"
)
@pytest.mark.asyncio
async def test_get_column_lineage_lsp_error() -> None:
"""Test column lineage with LSP error response."""
mock_lsp_client = AsyncMock(spec=LSPClient)
mock_lsp_client.get_column_lineage = AsyncMock(
return_value={"error": "Model not found"}
)
result = await get_column_lineage(mock_lsp_client, "invalid_model", "column")
assert "error" in result
assert "LSP error: Model not found" in result["error"]
@pytest.mark.asyncio
async def test_get_column_lineage_no_results() -> None:
"""Test column lineage when no lineage is found."""
mock_lsp_client = AsyncMock(spec=LSPClient)
mock_lsp_client.get_column_lineage = AsyncMock(return_value={"nodes": []})
result = await get_column_lineage(mock_lsp_client, "model.project.table", "column")
assert "error" in result
assert "No column lineage found" in result["error"]
@pytest.mark.asyncio
async def test_get_column_lineage_timeout() -> None:
"""Test column lineage with timeout error."""
mock_lsp_client = AsyncMock(spec=LSPClient)
mock_lsp_client.get_column_lineage = AsyncMock(side_effect=TimeoutError())
result = await get_column_lineage(mock_lsp_client, "model.project.table", "column")
assert "error" in result
assert "Timeout waiting for column lineage" in result["error"]
@pytest.mark.asyncio
async def test_get_column_lineage_generic_exception() -> None:
"""Test column lineage with generic exception."""
mock_lsp_client = AsyncMock(spec=LSPClient)
mock_lsp_client.get_column_lineage = AsyncMock(
side_effect=Exception("Connection lost")
)
result = await get_column_lineage(mock_lsp_client, "model.project.table", "column")
assert "error" in result
assert "Failed to get column lineage" in result["error"]
assert "Connection lost" in result["error"]
```
--------------------------------------------------------------------------------
/examples/aws_strands_agent/dbt_data_scientist/tools/dbt_model_analyzer.py:
--------------------------------------------------------------------------------
```python
"""dbt Model Analyzer Tool - Data model analysis and recommendations."""
import os
import json
import subprocess
from typing import Dict, Any, List, Optional
from strands import Agent, tool
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
DBT_MODEL_ANALYZER_SYSTEM_PROMPT = """
You are a dbt data modeling expert and analyst. Your capabilities include:
1. **Model Structure Analysis**: Analyze dbt model structure, dependencies, and relationships
2. **Data Quality Assessment**: Evaluate data quality patterns, test coverage, and validation rules
3. **Performance Optimization**: Identify performance bottlenecks and optimization opportunities
4. **Best Practices Review**: Check adherence to dbt best practices and naming conventions
5. **Dependency Analysis**: Map model dependencies and identify circular dependencies or issues
6. **Documentation Review**: Assess model documentation completeness and quality
When analyzing models, provide:
- Clear summary of findings
- Specific recommendations for improvement
- Priority levels for each recommendation
- Code examples where applicable
- Impact assessment for suggested changes
Focus on actionable insights that help improve data modeling practices and model quality.
"""
@tool
def dbt_model_analyzer_agent(query: str) -> str:
"""
Analyzes dbt models and provides recommendations for data modeling improvements.
This tool can:
- Analyze model structure and dependencies
- Assess data quality patterns and test coverage
- Review adherence to dbt best practices
- Provide optimization recommendations
- Generate model documentation suggestions
Args:
query: The user's question about data modeling analysis or specific model to analyze
Returns:
String response with analysis results and recommendations
"""
try:
# Load environment variables
load_dotenv()
# Get dbt project location
dbt_project_location = os.getenv("DBT_PROJECT_LOCATION", os.getcwd())
dbt_executable = os.getenv("DBT_EXECUTABLE", "dbt")
print(f"Analyzing dbt models in: {dbt_project_location}")
# Parse the query to determine analysis type
query_lower = query.lower()
# Determine what type of analysis to perform
analysis_type = "comprehensive"
if "dependency" in query_lower:
analysis_type = "dependencies"
elif "quality" in query_lower or "test" in query_lower:
analysis_type = "data_quality"
elif "performance" in query_lower or "optimize" in query_lower:
analysis_type = "performance"
elif "documentation" in query_lower or "docs" in query_lower:
analysis_type = "documentation"
elif "best practice" in query_lower or "convention" in query_lower:
analysis_type = "best_practices"
# Gather dbt project information
project_info = gather_dbt_project_info(dbt_project_location, dbt_executable)
# Format the analysis query
formatted_query = f"""
User wants to analyze their dbt data modeling approach. Analysis type: {analysis_type}
Project information:
- Project location: {dbt_project_location}
- Models count: {project_info.get('models_count', 'Unknown')}
- Tests count: {project_info.get('tests_count', 'Unknown')}
- Sources count: {project_info.get('sources_count', 'Unknown')}
User's specific question: {query}
Please provide a comprehensive analysis focusing on {analysis_type} and give actionable recommendations.
"""
# Create the model analyzer agent
model_analyzer_agent = Agent(
system_prompt=DBT_MODEL_ANALYZER_SYSTEM_PROMPT,
tools=[],
)
# Get analysis from the agent
agent_response = model_analyzer_agent(formatted_query)
text_response = str(agent_response)
if len(text_response) > 0:
return text_response
return "I apologize, but I couldn't process your dbt model analysis request. Please try rephrasing or providing more specific details about what you'd like to analyze."
except Exception as e:
return f"Error processing your dbt model analysis query: {str(e)}"
def gather_dbt_project_info(project_location: str, dbt_executable: str) -> Dict[str, Any]:
"""
Gather basic information about the dbt project.
Args:
project_location: Path to dbt project
dbt_executable: Path to dbt executable
Returns:
Dictionary with project information
"""
info = {
"models_count": 0,
"tests_count": 0,
"sources_count": 0,
"project_name": "Unknown"
}
try:
# Try to get project name from dbt_project.yml
dbt_project_file = os.path.join(project_location, "dbt_project.yml")
if os.path.exists(dbt_project_file):
with open(dbt_project_file, 'r') as f:
content = f.read()
if 'name:' in content:
# Simple extraction of project name
for line in content.split('\n'):
if line.strip().startswith('name:'):
info["project_name"] = line.split(':')[1].strip().strip('"\'')
break
# Try to count models, tests, and sources by running dbt commands
try:
# List models
result = subprocess.run(
[dbt_executable, "list", "--resource-type", "model"],
cwd=project_location,
text=True,
capture_output=True,
timeout=30
)
if result.returncode == 0:
info["models_count"] = len([line for line in result.stdout.split('\n') if line.strip()])
except:
pass
try:
# List tests
result = subprocess.run(
[dbt_executable, "list", "--resource-type", "test"],
cwd=project_location,
text=True,
capture_output=True,
timeout=30
)
if result.returncode == 0:
info["tests_count"] = len([line for line in result.stdout.split('\n') if line.strip()])
except:
pass
try:
# List sources
result = subprocess.run(
[dbt_executable, "list", "--resource-type", "source"],
cwd=project_location,
text=True,
capture_output=True,
timeout=30
)
if result.returncode == 0:
info["sources_count"] = len([line for line in result.stdout.split('\n') if line.strip()])
except:
pass
except Exception as e:
print(f"Error gathering project info: {e}")
return info
```
--------------------------------------------------------------------------------
/src/dbt_mcp/dbt_codegen/tools.py:
--------------------------------------------------------------------------------
```python
import json
import os
import subprocess
from collections.abc import Sequence
from typing import Any
from mcp.server.fastmcp import FastMCP
from pydantic import Field
from dbt_mcp.config.config import DbtCodegenConfig
from dbt_mcp.dbt_cli.binary_type import get_color_disable_flag
from dbt_mcp.prompts.prompts import get_prompt
from dbt_mcp.tools.definitions import ToolDefinition
from dbt_mcp.tools.register import register_tools
from dbt_mcp.tools.tool_names import ToolName
from dbt_mcp.tools.annotations import create_tool_annotations
def create_dbt_codegen_tool_definitions(
config: DbtCodegenConfig,
) -> list[ToolDefinition]:
def _run_codegen_operation(
macro_name: str,
args: dict[str, Any] | None = None,
) -> str:
"""Execute a dbt-codegen macro using dbt run-operation."""
try:
# Build the dbt run-operation command
command = ["run-operation", macro_name]
# Add arguments if provided
if args:
# Convert args to JSON string for dbt
args_json = json.dumps(args)
command.extend(["--args", args_json])
full_command = command.copy()
# Add --quiet flag to reduce output verbosity
main_command = full_command[0]
command_args = full_command[1:] if len(full_command) > 1 else []
full_command = [main_command, "--quiet", *command_args]
# We change the path only if this is an absolute path, otherwise we can have
# problems with relative paths applied multiple times as DBT_PROJECT_DIR
# is applied to dbt Core and Fusion as well (but not the dbt Cloud CLI)
cwd_path = config.project_dir if os.path.isabs(config.project_dir) else None
# Add appropriate color disable flag based on binary type
color_flag = get_color_disable_flag(config.binary_type)
args_list = [config.dbt_path, color_flag, *full_command]
process = subprocess.Popen(
args=args_list,
cwd=cwd_path,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.DEVNULL,
text=True,
)
output, _ = process.communicate(timeout=config.dbt_cli_timeout)
# Return the output directly or handle errors
if process.returncode != 0:
if "dbt found" in output and "resource" in output:
return f"Error: dbt-codegen package may not be installed. Run 'dbt deps' to install it.\n{output}"
return f"Error running dbt-codegen macro: {output}"
return output or "OK"
except subprocess.TimeoutExpired:
return f"Timeout: dbt-codegen operation took longer than {config.dbt_cli_timeout} seconds."
except Exception as e:
return str(e)
def generate_source(
schema_name: str = Field(
description=get_prompt("dbt_codegen/args/schema_name")
),
database_name: str | None = Field(
default=None, description=get_prompt("dbt_codegen/args/database_name")
),
table_names: list[str] | None = Field(
default=None, description=get_prompt("dbt_codegen/args/table_names")
),
generate_columns: bool = Field(
default=False, description=get_prompt("dbt_codegen/args/generate_columns")
),
include_descriptions: bool = Field(
default=False,
description=get_prompt("dbt_codegen/args/include_descriptions"),
),
) -> str:
args: dict[str, Any] = {"schema_name": schema_name}
if database_name:
args["database_name"] = database_name
if table_names:
args["table_names"] = table_names
args["generate_columns"] = generate_columns
args["include_descriptions"] = include_descriptions
return _run_codegen_operation("generate_source", args)
def generate_model_yaml(
model_names: list[str] = Field(
description=get_prompt("dbt_codegen/args/model_names")
),
upstream_descriptions: bool = Field(
default=False,
description=get_prompt("dbt_codegen/args/upstream_descriptions"),
),
include_data_types: bool = Field(
default=True, description=get_prompt("dbt_codegen/args/include_data_types")
),
) -> str:
args: dict[str, Any] = {
"model_names": model_names,
"upstream_descriptions": upstream_descriptions,
"include_data_types": include_data_types,
}
return _run_codegen_operation("generate_model_yaml", args)
def generate_staging_model(
source_name: str = Field(
description=get_prompt("dbt_codegen/args/source_name")
),
table_name: str = Field(description=get_prompt("dbt_codegen/args/table_name")),
leading_commas: bool = Field(
default=False, description=get_prompt("dbt_codegen/args/leading_commas")
),
case_sensitive_cols: bool = Field(
default=False,
description=get_prompt("dbt_codegen/args/case_sensitive_cols"),
),
materialized: str | None = Field(
default=None, description=get_prompt("dbt_codegen/args/materialized")
),
) -> str:
args: dict[str, Any] = {
"source_name": source_name,
"table_name": table_name,
"leading_commas": leading_commas,
"case_sensitive_cols": case_sensitive_cols,
}
if materialized:
args["materialized"] = materialized
return _run_codegen_operation("generate_base_model", args)
return [
ToolDefinition(
fn=generate_source,
description=get_prompt("dbt_codegen/generate_source"),
annotations=create_tool_annotations(
title="dbt-codegen generate_source",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
fn=generate_model_yaml,
description=get_prompt("dbt_codegen/generate_model_yaml"),
annotations=create_tool_annotations(
title="dbt-codegen generate_model_yaml",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
fn=generate_staging_model,
description=get_prompt("dbt_codegen/generate_staging_model"),
annotations=create_tool_annotations(
title="dbt-codegen generate_staging_model",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
]
def register_dbt_codegen_tools(
dbt_mcp: FastMCP,
config: DbtCodegenConfig,
exclude_tools: Sequence[ToolName] = [],
) -> None:
register_tools(
dbt_mcp,
create_dbt_codegen_tool_definitions(config),
exclude_tools,
)
```
--------------------------------------------------------------------------------
/tests/integration/dbt_codegen/test_dbt_codegen.py:
--------------------------------------------------------------------------------
```python
import os
import pytest
from dbt_mcp.config.config import DbtCodegenConfig, load_config
from dbt_mcp.dbt_cli.binary_type import BinaryType
from dbt_mcp.dbt_codegen.tools import create_dbt_codegen_tool_definitions
@pytest.fixture
def dbt_codegen_config():
"""Fixture for dbt-codegen configuration."""
# Try to load from full config first
try:
config = load_config()
if config.dbt_codegen_config:
return config.dbt_codegen_config
except Exception:
pass
# Fall back to environment variables
project_dir = os.getenv("DBT_PROJECT_DIR")
dbt_path = os.getenv("DBT_PATH", "dbt")
dbt_cli_timeout = os.getenv("DBT_CLI_TIMEOUT", "30")
if not project_dir:
pytest.skip(
"DBT_PROJECT_DIR environment variable is required for integration tests"
)
return DbtCodegenConfig(
project_dir=project_dir,
dbt_path=dbt_path,
dbt_cli_timeout=int(dbt_cli_timeout),
binary_type=BinaryType.DBT_CORE,
)
@pytest.fixture
def generate_source_tool(dbt_codegen_config):
"""Fixture for generate_source tool."""
tools = create_dbt_codegen_tool_definitions(dbt_codegen_config)
for tool in tools:
if tool.fn.__name__ == "generate_source":
return tool.fn
raise ValueError("generate_source tool not found")
@pytest.fixture
def generate_model_yaml_tool(dbt_codegen_config):
"""Fixture for generate_model_yaml tool."""
tools = create_dbt_codegen_tool_definitions(dbt_codegen_config)
for tool in tools:
if tool.fn.__name__ == "generate_model_yaml":
return tool.fn
raise ValueError("generate_model_yaml tool not found")
@pytest.fixture
def generate_staging_model_tool(dbt_codegen_config):
"""Fixture for generate_staging_model tool."""
tools = create_dbt_codegen_tool_definitions(dbt_codegen_config)
for tool in tools:
if tool.fn.__name__ == "generate_staging_model":
return tool.fn
raise ValueError("generate_staging_model tool not found")
def test_generate_source_basic(generate_source_tool):
"""Test basic source generation with minimal parameters."""
# This will fail if dbt-codegen is not installed
result = generate_source_tool(
schema_name="public", generate_columns=False, include_descriptions=False
)
# Check for error conditions
if "Error:" in result:
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
else:
pytest.fail(f"Unexpected error: {result}")
# Basic validation - should return YAML-like content
assert result is not None
assert len(result) > 0
def test_generate_source_with_columns(generate_source_tool):
"""Test source generation with column definitions."""
result = generate_source_tool(
schema_name="public", generate_columns=True, include_descriptions=True
)
if "Error:" in result:
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
else:
pytest.fail(f"Unexpected error: {result}")
assert result is not None
def test_generate_source_with_specific_tables(generate_source_tool):
"""Test source generation for specific tables."""
result = generate_source_tool(
schema_name="public",
table_names=["users", "orders"],
generate_columns=True,
include_descriptions=False,
)
if "Error:" in result:
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
assert result is not None
def test_generate_model_yaml(generate_model_yaml_tool):
"""Test model YAML generation."""
# This assumes there's at least one model in the project
result = generate_model_yaml_tool(
model_names=["stg_customers"],
upstream_descriptions=False,
include_data_types=True,
)
if "Error:" in result:
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
elif "Model" in result and "not found" in result:
pytest.skip("Test model not found in project")
assert result is not None
def test_generate_model_yaml_with_upstream(generate_model_yaml_tool):
"""Test model YAML generation with upstream descriptions."""
result = generate_model_yaml_tool(
model_names=["stg_customers"],
upstream_descriptions=True,
include_data_types=True,
)
if "Error:" in result:
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
elif "Model" in result and "not found" in result:
pytest.skip("Test model not found in project")
assert result is not None
def test_generate_staging_model(generate_staging_model_tool):
"""Test staging model SQL generation."""
# This assumes a source is defined
result = generate_staging_model_tool(
source_name="raw", # Common source name
table_name="customers",
leading_commas=False,
materialized="view",
)
if "Error:" in result:
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
elif "Source" in result and "not found" in result:
pytest.skip("Test source not found in project")
# Should generate SQL with SELECT statement
assert result is not None
def test_generate_staging_model_with_case_sensitive(generate_staging_model_tool):
"""Test staging model generation with case-sensitive columns."""
result = generate_staging_model_tool(
source_name="raw",
table_name="customers",
case_sensitive_cols=True,
leading_commas=True,
)
if "Error:" in result:
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
elif "Source" in result and "not found" in result:
pytest.skip("Test source not found in project")
assert result is not None
def test_error_handling_invalid_schema(generate_source_tool):
"""Test handling of invalid schema names."""
# Use a schema that definitely doesn't exist
result = generate_source_tool(
schema_name="definitely_nonexistent_schema_12345",
generate_columns=False,
include_descriptions=False,
)
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
# Should return an error but not crash
assert result is not None
def test_error_handling_invalid_model(generate_model_yaml_tool):
"""Test handling of non-existent model names."""
result = generate_model_yaml_tool(
model_names=["definitely_nonexistent_model_12345"]
)
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
# Should handle gracefully
assert result is not None
def test_error_handling_invalid_source(generate_staging_model_tool):
"""Test handling of invalid source references."""
result = generate_staging_model_tool(
source_name="nonexistent_source", table_name="nonexistent_table"
)
if "dbt-codegen package may not be installed" in result:
pytest.skip("dbt-codegen package not installed")
# Should return an error message
assert result is not None
```
--------------------------------------------------------------------------------
/src/dbt_mcp/lsp/lsp_binary_manager.py:
--------------------------------------------------------------------------------
```python
"""Binary detection and management for the dbt Language Server Protocol (LSP).
This module provides utilities to locate and validate the dbt LSP binary across
different operating systems and code editors (VS Code, Cursor, Windsurf). It handles
platform-specific paths and binary naming conventions.
"""
from enum import StrEnum
import os
from pathlib import Path
import platform
import subprocess
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class LspBinaryInfo:
"""Information about a detected dbt LSP binary.
Attributes:
path: Full filesystem path to the LSP binary executable.
version: Version string of the LSP binary.
"""
path: str
version: str
def get_platform_specific_binary_names(tag: str) -> str:
"""Generate platform-specific binary filename for the dbt LSP.
Creates a standardized binary filename based on the current platform's
operating system and architecture. This follows the naming convention
used by dbt LSP releases.
Args:
tag: Version tag or identifier for the LSP binary.
Returns:
Platform-specific binary filename including extension.
Format: fs-lsp-{tag}-{arch}-{platform}{extension}
Raises:
ValueError: If the current platform or architecture is not supported.
Examples:
>>> get_platform_specific_binary_names("v1.0.0")
'fs-lsp-v1.0.0-x86_64-apple-darwin.tar.gz' # on macOS Intel
"""
system = platform.system().lower()
machine = platform.machine().lower()
if system == "windows":
platform_name = "pc-windows-msvc"
extension = ".zip"
elif system == "darwin":
platform_name = "apple-darwin"
extension = ".tar.gz"
elif system == "linux":
platform_name = "unknown-linux-gnu"
extension = ".tar.gz"
else:
raise ValueError(f"Unsupported platform: {system}")
if machine in ("x86_64", "amd64"):
arch_name = "x86_64"
elif machine in ("arm64", "aarch64"):
arch_name = "aarch64"
else:
raise ValueError(f"Unsupported architecture: {machine}")
return f"fs-lsp-{tag}-{arch_name}-{platform_name}{extension}"
class CodeEditor(StrEnum):
"""Supported code editors that can install the dbt LSP.
These editors use similar global storage patterns for VSCode extensions
and can install the dbt Labs extension with the LSP binary.
"""
CODE = "code" # Visual Studio Code
CURSOR = "cursor" # Cursor editor
WINDSURF = "windsurf" # Windsurf editor
def get_storage_path(editor: CodeEditor) -> Path:
"""Get the storage path for dbt LSP binary based on editor and OS.
Determines the platform-specific path where code editors store the dbt LSP
binary. Follows standard conventions for each operating system and editor.
Platform-specific paths:
- Windows: %APPDATA%\\{editor}\\User\\globalStorage\\dbtlabsinc.dbt\\bin\\dbt-lsp
- macOS: ~/Library/Application Support/{editor}/User/globalStorage/dbtlabsinc.dbt/bin/dbt-lsp
- Linux: ~/.config/{editor}/User/globalStorage/dbtlabsinc.dbt/bin/dbt-lsp
Args:
editor: The code editor to get the storage path for.
Returns:
Path object pointing to the expected location of the dbt-lsp binary.
Raises:
ValueError: If the operating system is not supported (Windows, macOS, or Linux).
Note:
This function returns the expected path regardless of whether the binary
actually exists at that location. Use Path.exists() to verify.
"""
system = platform.system()
home = Path.home()
if system == "Windows":
appdata = os.environ.get("APPDATA", home / "AppData" / "Roaming")
base = Path(appdata) / editor.value
elif system == "Darwin": # macOS
base = home / "Library" / "Application Support" / editor.value
elif system == "Linux":
config_home = os.environ.get("XDG_CONFIG_HOME", home / ".config")
base = Path(config_home) / editor.value
else:
raise ValueError(f"Unsupported OS: {system}")
return Path(base, "User", "globalStorage", "dbtlabsinc.dbt", "bin", "dbt-lsp")
def dbt_lsp_binary_info(lsp_path: str | None = None) -> LspBinaryInfo | None:
"""Get dbt LSP binary information from a custom path or auto-detect it.
Attempts to locate and validate the dbt LSP binary. If a custom path is provided,
it will be validated first. If the custom path is invalid or not provided, the
function will attempt to auto-detect the binary in standard editor locations.
Args:
lsp_path: Optional custom path to the dbt LSP binary. If provided, this
path will be validated and used if it exists. If None or invalid,
auto-detection will be attempted.
Returns:
LspBinaryInfo object containing the path and version of the found binary,
or None if no valid binary could be found.
Note:
If a custom path is provided but invalid, a warning will be logged before
falling back to auto-detection.
"""
if lsp_path:
logger.debug(f"Using custom LSP binary path: {lsp_path}")
if Path(lsp_path).exists() and Path(lsp_path).is_file():
version = get_lsp_binary_version(lsp_path)
return LspBinaryInfo(path=lsp_path, version=version)
logger.warning(
f"Provided LSP binary path {lsp_path} does not exist or is not a file, falling back to detecting LSP binary"
)
return detect_lsp_binary()
def detect_lsp_binary() -> LspBinaryInfo | None:
"""Auto-detect dbt LSP binary in standard code editor locations.
Searches through all supported code editors (VS Code, Cursor, Windsurf) to find
an installed dbt LSP binary. Returns the first valid binary found.
Returns:
LspBinaryInfo object containing the path and version of the first found binary,
or None if no binary is found in any of the standard locations.
Note:
The detection checks editors in the order defined by the CodeEditor enum.
Debug logging is used to track the search process.
"""
for editor in CodeEditor:
path = get_storage_path(editor)
logger.debug(f"Checking for LSP binary in {path}")
if path.exists() and path.is_file():
version = get_lsp_binary_version(path.as_posix())
logger.debug(f"Found LSP binary in {path} with version {version}")
return LspBinaryInfo(path=path.as_posix(), version=version)
return None
def get_lsp_binary_version(path: str) -> str:
"""Extract the version string from a dbt LSP binary.
Retrieves the version of the dbt LSP binary using one of two methods:
1. For standard 'dbt-lsp' binaries, reads from the adjacent .version file
2. For other binaries, executes the binary with --version flag
Args:
path: Full filesystem path to the dbt LSP binary.
Returns:
Version string of the binary (whitespace stripped).
Raises:
FileNotFoundError: If the .version file doesn't exist (for dbt-lsp binaries).
subprocess.SubprocessError: If the binary execution fails (for non-dbt-lsp binaries).
Note:
The .version file is expected to be in the same directory as the dbt-lsp
binary and should be named '.version'.
"""
if path.endswith("dbt-lsp"):
return Path(path[:-7], ".version").read_text().strip()
else:
return subprocess.run(
[path, "--version"], capture_output=True, text=True
).stdout.strip()
```
--------------------------------------------------------------------------------
/evals/semantic_layer/test_eval_semantic_layer.py:
--------------------------------------------------------------------------------
```python
import json
from typing import Any
import pytest
from dbtsl.api.shared.query_params import GroupByParam
from openai import OpenAI
from openai.types.responses import (
FunctionToolParam,
ResponseFunctionToolCall,
ResponseInputParam,
ResponseOutputItem,
)
from openai.types.responses.response_input_param import FunctionCallOutput
from client.tools import get_tools
from dbt_mcp.config.config import load_config
from dbt_mcp.mcp.server import create_dbt_mcp
from dbt_mcp.semantic_layer.client import (
DefaultSemanticLayerClientProvider,
SemanticLayerFetcher,
)
from dbt_mcp.semantic_layer.types import OrderByParam, QueryMetricsSuccess
LLM_MODEL = "gpt-4o-mini"
llm_client = OpenAI()
config = load_config()
async def expect_metadata_tool_call(
messages: list,
tools: list[FunctionToolParam],
expected_tool: str,
expected_arguments: str | None = None,
) -> ResponseOutputItem:
response = llm_client.responses.create(
model=LLM_MODEL,
input=messages,
tools=tools,
parallel_tool_calls=False,
)
assert len(response.output) == 1
tool_call = response.output[0]
assert isinstance(tool_call, ResponseFunctionToolCall)
function_name = tool_call.name
function_arguments = tool_call.arguments
assert tool_call.type == "function_call"
assert function_name == expected_tool
assert expected_arguments is None or function_arguments == expected_arguments
tool_response = await (await create_dbt_mcp(config)).call_tool(
function_name,
json.loads(function_arguments),
)
messages.append(tool_call)
messages.append(
FunctionCallOutput(
type="function_call_output",
call_id=tool_call.call_id,
output=str(tool_response),
)
)
return tool_call
def deep_equal(dict1: Any, dict2: Any) -> bool:
if isinstance(dict1, dict) and isinstance(dict2, dict):
return dict1.keys() == dict2.keys() and all(
deep_equal(dict1[k], dict2[k]) for k in dict1
)
elif isinstance(dict1, list) and isinstance(dict2, list):
return len(dict1) == len(dict2) and all(
deep_equal(x, y) for x, y in zip(dict1, dict2, strict=False)
)
else:
return dict1 == dict2
async def expect_query_metrics_tool_call(
messages: list,
tools: list[FunctionToolParam],
expected_metrics: list[str],
expected_group_by: list[dict] | None = None,
expected_order_by: list[dict] | None = None,
expected_where: str | None = None,
expected_limit: int | None = None,
):
response = llm_client.responses.create(
model=LLM_MODEL,
input=messages,
tools=tools,
parallel_tool_calls=False,
)
assert len(response.output) == 1
tool_call = response.output[0]
assert isinstance(tool_call, ResponseFunctionToolCall)
assert tool_call.name == "query_metrics"
args_dict = json.loads(tool_call.arguments)
assert set(args_dict["metrics"]) == set(expected_metrics)
if expected_group_by is not None:
assert deep_equal(args_dict["group_by"], expected_group_by)
else:
assert args_dict.get("group_by", []) == []
if expected_order_by is not None:
assert deep_equal(args_dict["order_by"], expected_order_by)
else:
assert args_dict.get("order_by", []) == []
if expected_where is not None:
assert args_dict["where"] == expected_where
else:
assert args_dict.get("where", None) is None
if expected_limit is not None:
assert args_dict["limit"] == expected_limit
else:
assert args_dict.get("limit", None) is None
assert config.semantic_layer_config_provider is not None
semantic_layer_fetcher = SemanticLayerFetcher(
config_provider=config.semantic_layer_config_provider,
client_provider=DefaultSemanticLayerClientProvider(
config_provider=config.semantic_layer_config_provider,
),
)
tool_response = await semantic_layer_fetcher.query_metrics(
metrics=args_dict["metrics"],
group_by=[
GroupByParam(name=g["name"], type=g["type"], grain=g.get("grain"))
for g in args_dict.get("group_by", [])
],
order_by=[
OrderByParam(name=o["name"], descending=o["descending"])
for o in args_dict.get("order_by", [])
],
where=args_dict.get("where"),
limit=args_dict.get("limit"),
)
assert isinstance(tool_response, QueryMetricsSuccess)
def initial_messages(content: str) -> ResponseInputParam:
return [
{
"role": "user",
"content": content,
}
]
@pytest.mark.parametrize(
"content, expected_tool",
[
(
"What metrics are available? Use the list_metrics tool",
"list_metrics",
),
(
"What dimensions are available for the order metric? Use the get_dimensions tool",
"get_dimensions",
),
(
"What entities are available for the order metric? Use the get_entities tool",
"get_entities",
),
],
)
async def test_explicit_tool_request(content: str, expected_tool: str):
dbt_mcp = await create_dbt_mcp(config)
response = llm_client.responses.create(
model=LLM_MODEL,
input=initial_messages(content),
tools=await get_tools(dbt_mcp),
parallel_tool_calls=False,
)
assert len(response.output) == 1
assert response.output[0].type == "function_call"
assert response.output[0].name == expected_tool
async def test_semantic_layer_fulfillment_query():
tools = await get_tools()
messages = initial_messages(
"How many orders did we fulfill this month last year?",
)
await expect_metadata_tool_call(
messages,
tools,
"list_metrics",
"{}",
)
await expect_metadata_tool_call(
messages,
tools,
"get_dimensions",
'{"metrics":["orders"]}',
)
await expect_query_metrics_tool_call(
messages,
tools,
expected_metrics=["orders"],
)
async def test_semantic_layer_food_revenue_per_month():
tools = await get_tools()
messages = initial_messages(
"What is our food revenue per location per month?",
)
await expect_metadata_tool_call(
messages,
tools,
"list_metrics",
"{}",
)
await expect_metadata_tool_call(
messages,
tools,
"get_dimensions",
'{"metrics":["food_revenue"]}',
)
await expect_metadata_tool_call(
messages,
tools,
"get_entities",
'{"metrics":["food_revenue"]}',
)
await expect_query_metrics_tool_call(
messages=messages,
tools=tools,
expected_metrics=["food_revenue"],
expected_group_by=[
{
"name": "order_id__location__location_name",
"type": "entity",
},
{
"name": "metric_time",
"type": "time_dimension",
"grain": "MONTH",
},
],
expected_order_by=[
{
"name": "metric_time",
"descending": True,
},
],
expected_limit=5,
)
async def test_semantic_layer_what_percentage_of_orders_were_large():
tools = await get_tools()
messages = initial_messages(
"What percentage of orders were large this year?",
)
await expect_metadata_tool_call(
messages,
tools,
"list_metrics",
"{}",
)
await expect_query_metrics_tool_call(
messages=messages,
tools=tools,
expected_metrics=["orders", "large_orders"],
expected_where="metric_time >= '2024-01-01' and metric_time < '2025-01-01'",
)
```
--------------------------------------------------------------------------------
/tests/unit/oauth/test_token.py:
--------------------------------------------------------------------------------
```python
"""
Tests for OAuth token models.
"""
import pytest
from pydantic import ValidationError
from dbt_mcp.oauth.token import AccessTokenResponse, DecodedAccessToken
class TestAccessTokenResponse:
"""Test the AccessTokenResponse model."""
def test_valid_token_response(self):
"""Test creating a valid access token response."""
token_data = {
"access_token": "test_access_token",
"refresh_token": "test_refresh_token",
"expires_in": 3600,
"scope": "read write",
"token_type": "Bearer",
"expires_at": 1609459200, # 2021-01-01 00:00:00 UTC
}
token_response = AccessTokenResponse(**token_data)
assert token_response.access_token == "test_access_token"
assert token_response.refresh_token == "test_refresh_token"
assert token_response.expires_in == 3600
assert token_response.scope == "read write"
assert token_response.token_type == "Bearer"
assert token_response.expires_at == 1609459200
def test_missing_required_field(self):
"""Test that missing required fields raise validation errors."""
incomplete_data = {
"access_token": "test_access_token",
"refresh_token": "test_refresh_token",
# Missing expires_in, scope, token_type, expires_at
}
with pytest.raises(ValidationError) as exc_info:
AccessTokenResponse(**incomplete_data)
error = exc_info.value
assert len(error.errors()) >= 4 # At least 4 missing fields
missing_fields = {
err["loc"][0] for err in error.errors() if err["type"] == "missing"
}
expected_missing = {"expires_in", "scope", "token_type", "expires_at"}
assert expected_missing.issubset(missing_fields)
def test_invalid_data_types(self):
"""Test that invalid data types raise validation errors."""
invalid_data = {
"access_token": "test_access_token",
"refresh_token": "test_refresh_token",
"expires_in": "not_an_int", # Should be int
"scope": "read write",
"token_type": "Bearer",
"expires_at": "not_an_int", # Should be int
}
with pytest.raises(ValidationError) as exc_info:
AccessTokenResponse(**invalid_data)
error = exc_info.value
# Should have validation errors for expires_in and expires_at
assert len(error.errors()) >= 2
invalid_fields = {err["loc"][0] for err in error.errors()}
assert "expires_in" in invalid_fields
assert "expires_at" in invalid_fields
def test_model_dict_conversion(self):
"""Test converting model to dict and back."""
token_data = {
"access_token": "test_access_token",
"refresh_token": "test_refresh_token",
"expires_in": 3600,
"scope": "read write",
"token_type": "Bearer",
"expires_at": 1609459200,
}
token_response = AccessTokenResponse(**token_data)
token_dict = token_response.model_dump()
# Should be able to recreate from dict
recreated_token = AccessTokenResponse(**token_dict)
assert recreated_token == token_response
def test_model_json_serialization(self):
"""Test JSON serialization and deserialization."""
token_data = {
"access_token": "test_access_token",
"refresh_token": "test_refresh_token",
"expires_in": 3600,
"scope": "read write",
"token_type": "Bearer",
"expires_at": 1609459200,
}
token_response = AccessTokenResponse(**token_data)
json_str = token_response.model_dump_json()
# Should be valid JSON that can be parsed back
import json
parsed_data = json.loads(json_str)
recreated_token = AccessTokenResponse(**parsed_data)
assert recreated_token == token_response
class TestDecodedAccessToken:
"""Test the DecodedAccessToken model."""
def test_valid_decoded_token(self):
"""Test creating a valid decoded access token."""
access_token_response = AccessTokenResponse(
access_token="test_access_token",
refresh_token="test_refresh_token",
expires_in=3600,
scope="read write",
token_type="Bearer",
expires_at=1609459200,
)
decoded_claims = {
"sub": "user123",
"iss": "https://auth.example.com",
"exp": 1609459200,
"iat": 1609455600,
"scope": "read write",
}
decoded_token = DecodedAccessToken(
access_token_response=access_token_response, decoded_claims=decoded_claims
)
assert decoded_token.access_token_response == access_token_response
assert decoded_token.decoded_claims == decoded_claims
assert decoded_token.decoded_claims["sub"] == "user123"
assert decoded_token.decoded_claims["scope"] == "read write"
def test_empty_decoded_claims(self):
"""Test that empty decoded claims are allowed."""
access_token_response = AccessTokenResponse(
access_token="test_access_token",
refresh_token="test_refresh_token",
expires_in=3600,
scope="read write",
token_type="Bearer",
expires_at=1609459200,
)
decoded_token = DecodedAccessToken(
access_token_response=access_token_response, decoded_claims={}
)
assert decoded_token.access_token_response == access_token_response
assert decoded_token.decoded_claims == {}
def test_missing_access_token_response(self):
"""Test that missing access_token_response raises validation error."""
decoded_claims = {"sub": "user123"}
with pytest.raises(ValidationError) as exc_info:
DecodedAccessToken(decoded_claims=decoded_claims)
error = exc_info.value
assert len(error.errors()) == 1
assert error.errors()[0]["loc"] == ("access_token_response",)
assert error.errors()[0]["type"] == "missing"
def test_invalid_access_token_response_type(self):
"""Test that invalid access_token_response type raises validation error."""
with pytest.raises(ValidationError) as exc_info:
DecodedAccessToken(
access_token_response="not_a_token_response", # Should be AccessTokenResponse
decoded_claims={"sub": "user123"},
)
error = exc_info.value
assert len(error.errors()) >= 1
# Should have validation error for access_token_response field
assert any(err["loc"] == ("access_token_response",) for err in error.errors())
def test_complex_decoded_claims(self):
"""Test with complex nested decoded claims."""
access_token_response = AccessTokenResponse(
access_token="test_access_token",
refresh_token="test_refresh_token",
expires_in=3600,
scope="read write",
token_type="Bearer",
expires_at=1609459200,
)
complex_claims = {
"sub": "user123",
"roles": ["admin", "user"],
"permissions": {"read": ["resource1", "resource2"], "write": ["resource1"]},
"metadata": {
"created_at": "2021-01-01T00:00:00Z",
"last_login": "2021-01-01T12:00:00Z",
},
}
decoded_token = DecodedAccessToken(
access_token_response=access_token_response, decoded_claims=complex_claims
)
assert decoded_token.decoded_claims["roles"] == ["admin", "user"]
assert decoded_token.decoded_claims["permissions"]["read"] == [
"resource1",
"resource2",
]
assert (
decoded_token.decoded_claims["metadata"]["created_at"]
== "2021-01-01T00:00:00Z"
)
```
--------------------------------------------------------------------------------
/src/dbt_mcp/tracking/tracking.py:
--------------------------------------------------------------------------------
```python
import json
import logging
import uuid
from collections.abc import Mapping
from contextlib import suppress
from dataclasses import dataclass
from importlib.metadata import version
from pathlib import Path
from typing import Any, Protocol, assert_never
import yaml
from dbtlabs.proto.public.v1.common.vortex_telemetry_contexts_pb2 import (
VortexTelemetryDbtCloudContext,
)
from dbtlabs.proto.public.v1.events.mcp_pb2 import ToolCalled
from dbtlabs_vortex.producer import log_proto
from dbt_mcp.config.config import PACKAGE_NAME
from dbt_mcp.config.dbt_yaml import try_read_yaml
from dbt_mcp.config.settings import (
CredentialsProvider,
DbtMcpSettings,
get_dbt_profiles_path,
)
from dbt_mcp.tools.toolsets import Toolset, proxied_tools
logger = logging.getLogger(__name__)
@dataclass
class ToolCalledEvent:
tool_name: str
arguments: dict[str, Any]
error_message: str | None
start_time_ms: int
end_time_ms: int
class UsageTracker(Protocol):
async def emit_tool_called_event(
self, tool_called_event: ToolCalledEvent
) -> None: ...
class DefaultUsageTracker:
def __init__(
self,
credentials_provider: CredentialsProvider,
session_id: uuid.UUID,
):
self.credentials_provider = credentials_provider
self.session_id = session_id
self.dbt_mcp_version = version(PACKAGE_NAME)
self._settings_cache: DbtMcpSettings | None = None
self._local_user_id: str | None = None
def _get_disabled_toolsets(self, settings: DbtMcpSettings) -> list[Toolset]:
disabled_toolsets: list[Toolset] = []
# Looping over the Toolset enum to ensure that type validation
# accounts for additions to the Toolset enum with `assert_never`
for toolset in Toolset:
match toolset:
case Toolset.SQL:
if settings.disable_sql:
disabled_toolsets.append(toolset)
case Toolset.SEMANTIC_LAYER:
if settings.disable_semantic_layer:
disabled_toolsets.append(toolset)
case Toolset.DISCOVERY:
if settings.disable_discovery:
disabled_toolsets.append(toolset)
case Toolset.DBT_CLI:
if settings.disable_dbt_cli:
disabled_toolsets.append(toolset)
case Toolset.ADMIN_API:
if settings.disable_admin_api:
disabled_toolsets.append(toolset)
case Toolset.DBT_CODEGEN:
if settings.disable_dbt_codegen:
disabled_toolsets.append(toolset)
case Toolset.DBT_LSP:
if settings.disable_lsp:
disabled_toolsets.append(toolset)
case _:
assert_never(toolset)
return disabled_toolsets
def _get_local_user_id(self, settings: DbtMcpSettings) -> str:
if self._local_user_id is None:
# Load local user ID from dbt profile
user_dir = get_dbt_profiles_path(settings.dbt_profiles_dir)
user_yaml_path = user_dir / ".user.yml"
user_yaml = try_read_yaml(user_yaml_path)
if user_yaml:
try:
self._local_user_id = str(user_yaml.get("id"))
except Exception:
# dbt Fusion may have a different format for
# the .user.yml file which is handled here
self._local_user_id = str(user_yaml)
else:
self._local_user_id = str(uuid.uuid4())
with suppress(Exception):
Path(user_yaml_path).write_text(
yaml.dump({"id": self._local_user_id})
)
return self._local_user_id
async def _get_settings(self) -> DbtMcpSettings:
# Caching in memory instead of read from disk every time
if self._settings_cache is None:
settings, _ = await self.credentials_provider.get_credentials()
self._settings_cache = settings
return self._settings_cache
async def emit_tool_called_event(
self,
tool_called_event: ToolCalledEvent,
) -> None:
settings = await self._get_settings()
if not settings.usage_tracking_enabled:
return
# Proxied tools are tracked on our backend, so we don't want
# to double count them here.
if tool_called_event.tool_name in [tool.value for tool in proxied_tools]:
return
try:
arguments_mapping: Mapping[str, str] = {
k: json.dumps(v) for k, v in tool_called_event.arguments.items()
}
event_id = str(uuid.uuid4())
dbt_cloud_account_id = (
str(settings.dbt_account_id) if settings.dbt_account_id else ""
)
dbt_cloud_environment_id_prod = (
str(settings.dbt_prod_env_id) if settings.dbt_prod_env_id else ""
)
dbt_cloud_environment_id_dev = (
str(settings.dbt_dev_env_id) if settings.dbt_dev_env_id else ""
)
dbt_cloud_user_id = (
str(settings.dbt_user_id) if settings.dbt_user_id else ""
)
authentication_method = (
self.credentials_provider.authentication_method.value
if self.credentials_provider.authentication_method
else ""
)
log_proto(
ToolCalled(
event_id=event_id,
start_time_ms=tool_called_event.start_time_ms,
end_time_ms=tool_called_event.end_time_ms,
tool_name=tool_called_event.tool_name,
arguments=arguments_mapping,
error_message=tool_called_event.error_message or "",
dbt_cloud_environment_id_dev=dbt_cloud_environment_id_dev,
dbt_cloud_environment_id_prod=dbt_cloud_environment_id_prod,
dbt_cloud_user_id=dbt_cloud_user_id,
local_user_id=self._get_local_user_id(settings) or "",
host=settings.actual_host or "",
multicell_account_prefix=settings.actual_host_prefix or "",
# Some of the fields of VortexTelemetryDbtCloudContext are
# duplicates of the top-level ToolCalled fields because we didn't
# know about VortexTelemetryDbtCloudContext or it didn't exist when
# we created the original event.
ctx=VortexTelemetryDbtCloudContext(
event_id=event_id,
feature="dbt-mcp",
snowplow_domain_session_id="",
snowplow_domain_user_id="",
session_id=str(self.session_id),
referrer_url="",
dbt_cloud_account_id=dbt_cloud_account_id,
dbt_cloud_account_identifier="",
dbt_cloud_project_id="",
dbt_cloud_environment_id="",
dbt_cloud_user_id=dbt_cloud_user_id,
),
dbt_mcp_version=self.dbt_mcp_version,
authentication_method=authentication_method,
trace_id="", # Only used for internal agents
disabled_toolsets=[
tool.value
for tool in self._get_disabled_toolsets(settings) or []
],
disabled_tools=[
tool.value for tool in settings.disable_tools or []
],
user_agent="", # Only used for remote MCP
)
)
except Exception as e:
logger.error(f"Error emitting tool called event: {e}")
```
--------------------------------------------------------------------------------
/src/dbt_mcp/prompts/semantic_layer/query_metrics.md:
--------------------------------------------------------------------------------
```markdown
<instructions>
Queries the dbt Semantic Layer to answer business questions from the data warehouse.
This tool allows ordering and grouping by dimensions and entities.
To use this tool, you must first know about specific metrics, dimensions and
entities to provide. You can call the list_metrics, get_dimensions,
and get_entities tools to get information about which metrics, dimensions,
and entities to use.
When using the `order_by` parameter, you must ensure that the dimension or
entity also appears in the `group_by` parameter. When fulfilling a lookback
query, prefer using order_by and limit instead of using the where parameter.
A lookback query requires that the `order_by` parameter includes a descending
order for a time dimension.
The `where` parameter should be database agnostic SQL syntax, however dimensions
and entity are referenced differently. For categorical dimensions,
use `{{ Dimension('<name>') }}` and for time dimensions add the grain
like `{{ TimeDimension('<name>', '<grain>') }}`. For entities,
use `{{ Entity('<name>') }}`. When referencing dates in the `where`
parameter, only use the format `yyyy-mm-dd`.
Don't call this tool if the user's question cannot be answered with the provided
metrics, dimensions, and entities. Instead, clarify what metrics, dimensions,
and entities are available and suggest a new question that can be answered
and is approximately the same as the user's question.
For queries that may return large amounts of data, it's recommended
to use a two-step approach:
1. First make a query with a small limit to verify the results are what you expect
2. Then make a follow-up query without a limit (or with a larger limit) to get the full dataset
IMPORTANT:
Do the below if the GET_MODEL_HEALTH tool is enabled.
When responding to user requests to pull metrics data, check the health of the dbt models that
are the parents of the dbt semantic models. Use the instructions from the "ASSESSING MODEL HEALTH"
section of the get_model_health() prompts to do this.
</instructions>
<examples>
<example>
Question: "What were our total sales last month?"
Thinking step-by-step:
- I know "total_sales" is the metric I need
- I know "metric_time" is a valid dimension for this metric and supports MONTH grain
- I need to group by metric_time to get monthly data
- Since this is time-based data, I should order by metric_time. I am also grouping by metric_time, so this is valid.
- The user is asking for a lookback query, so I should set descending to true so the most recent month is at the top of the results.
- The user is asking for just the last month, so I should limit to 1 month of data
Parameters:
metrics=["total_sales"]
group_by=[{"name": "metric_time", "grain": "MONTH", "type": "time_dimension"}]
order_by=[{"name": "metric_time", "descending": true}]
limit=1
</example>
<example>
Question: "Show me our top customers by revenue in the last quarter"
Thinking step-by-step:
- First, I need to find the revenue metric
- Using list_metrics(), I find "revenue" is available
- I need to check what dimensions are available for revenue
- Using get_dimensions(["revenue"]), I see "customer_name" and "metric_time" supports QUARTER grain
- I need to check what entities are available
- Using get_entities(["revenue"]), I confirm "customer" is an entity
- I need quarterly time grain
- Since this is time-based data, I should order by metric_time. I am grouping by metric_time, so this is valid. This is a lookback query, so I should set descending to true.
- I should also order by revenue to see top customers. I am grouping by revenue, so this is valid. The user is asking for the highest revenue customers, so I should set descending to true.
- I should limit to top 5 results to verify the query works
Parameters:
metrics=["revenue"]
group_by=[{"name": "customer_name", "type": "dimension"}, {"name": "metric_time", "grain": "QUARTER", "type": "time_dimension"}]
order_by=[{"name": "metric_time", "descending": true}, {"name": "revenue", "descending": true}]
limit=5
Follow-up Query (after verifying results):
metrics=["revenue"]
group_by=[{"name": "customer_name", "type": "dimension"}, {"name": "metric_time", "grain": "QUARTER", "type": "time_dimension"}]
order_by=[{"name": "metric_time", "descending": true}, {"name": "revenue", "descending": true}]
limit=null
</example>
<example>
Question: "What's our average order value by product category for orders over $100?"
Thinking step-by-step:
- I know "average_order_value" is the metric I need
- I know "product_category" is a valid dimension
- I need to filter for orders over $100
- No time dimension needed
- I should first limit results to verify the query works
Parameters (initial query):
metrics=["average_order_value"]
group_by=[{"name": "product_category", "type": "dimension", "grain": null}]
where="{{ Dimension('order_value') }} > 100"
limit=10
Follow-up Query (after verifying results):
metrics=["average_order_value"]
group_by=[{"name": "product_category", "type": "dimension", "grain": null}]
where="{{ Dimension('order_value') }} > 100"
limit=null
</example>
<example>
Question: "How many new users did we get each week last year?"
Thinking step-by-step:
- First, I need to find the new users metric
- Using list_metrics(), I find "new_users" is available
- I need to check what dimensions are available
- Using get_dimensions(["new_users"]), I see "metric_time" supports WEEK grain
- I need to check what entities are available
- Using get_entities(["new_users"]), I confirm "user" is an entity
- I need weekly time grain
- I need to group by metric_time
- Since this is time-based data, I should order by metric_time to show progression
- I need to filter for the previous year's data using proper time dimension syntax
- Should first get a few weeks to verify the query works
Parameters (initial query):
metrics=["new_users"]
group_by=[{"name": "metric_time", "grain": "WEEK", "type": "time_dimension"}]
order_by=[{"name": "metric_time", "descending": false}]
where="{{ TimeDimension('metric_time', 'WEEK') }} >= '2023-01-01' AND {{ TimeDimension('metric_time', 'WEEK') }} < '2024-01-01'"
limit=4
Follow-up Query (after verifying results):
metrics=["new_users"]
group_by=[{"name": "metric_time", "grain": "WEEK", "type": "time_dimension"}]
order_by=[{"name": "metric_time", "descending": false}]
where="{{ TimeDimension('metric_time', 'WEEK') }} >= '2023-01-01' AND {{ TimeDimension('metric_time', 'WEEK') }} < '2024-01-01'"
limit=null
</example>
<example>
Question: "What's our customer satisfaction score by region?"
Thinking step-by-step:
- First, I need to check if we have a customer satisfaction metric
- Using list_metrics(), I find we don't have a direct "customer_satisfaction" metric
- I should check what metrics we do have that might be related
- I see we have "net_promoter_score" and "customer_retention_rate"
- I should inform the user that we don't have a direct customer satisfaction metric
- I can suggest using NPS as a proxy for customer satisfaction
Response to user:
"I don't have a direct customer satisfaction metric, but I can show you Net Promoter Score (NPS) by region, which is often used as a proxy for customer satisfaction. Would you like to see that instead?"
If user agrees, then:
Parameters:
metrics=["net_promoter_score"]
group_by=[{"name": "region", "type": "dimension", "grain": null}]
order_by=[{"name": "net_promoter_score", "descending": true}]
limit=10
</example>
</examples>
<parameters>
metrics: List of metric names (strings) to query for.
group_by: Optional list of objects with name (string), type ("dimension" or "time_dimension"), and grain (string or null for time dimensions only).
order_by: Optional list of objects with name (string) and descending (boolean, default false).
where: Optional SQL WHERE clause (string) to filter results.
limit: Optional limit (integer) for number of results.
</parameters>
```
--------------------------------------------------------------------------------
/tests/unit/dbt_cli/test_tools.py:
--------------------------------------------------------------------------------
```python
import subprocess
import pytest
from pytest import MonkeyPatch
from dbt_mcp.dbt_cli.tools import register_dbt_cli_tools
from tests.mocks.config import mock_dbt_cli_config
@pytest.fixture
def mock_process():
class MockProcess:
def communicate(self, timeout=None):
return "command output", None
return MockProcess()
@pytest.fixture
def mock_fastmcp():
class MockFastMCP:
def __init__(self):
self.tools = {}
def tool(self, **kwargs):
def decorator(func):
self.tools[func.__name__] = func
return func
return decorator
fastmcp = MockFastMCP()
return fastmcp, fastmcp.tools
@pytest.mark.parametrize(
"sql_query,limit_param,expected_args",
[
# SQL with explicit LIMIT - should set --limit=-1
(
"SELECT * FROM my_model LIMIT 10",
None,
[
"--no-use-colors",
"show",
"--inline",
"SELECT * FROM my_model LIMIT 10",
"--favor-state",
"--limit",
"-1",
"--output",
"json",
],
),
# SQL with lowercase limit - should set --limit=-1
(
"select * from my_model limit 5",
None,
[
"--no-use-colors",
"show",
"--inline",
"select * from my_model limit 5",
"--favor-state",
"--limit",
"-1",
"--output",
"json",
],
),
# No SQL LIMIT but with limit parameter - should use provided limit
(
"SELECT * FROM my_model",
10,
[
"--no-use-colors",
"show",
"--inline",
"SELECT * FROM my_model",
"--favor-state",
"--limit",
"10",
"--output",
"json",
],
),
# No limits at all - should not include --limit flag
(
"SELECT * FROM my_model",
None,
[
"--no-use-colors",
"show",
"--inline",
"SELECT * FROM my_model",
"--favor-state",
"--output",
"json",
],
),
],
)
def test_show_command_limit_logic(
monkeypatch: MonkeyPatch,
mock_process,
mock_fastmcp,
sql_query,
limit_param,
expected_args,
):
# Mock Popen
mock_calls = []
def mock_popen(args, **kwargs):
mock_calls.append(args)
return mock_process
monkeypatch.setattr("subprocess.Popen", mock_popen)
# Register tools and get show tool
fastmcp, tools = mock_fastmcp
register_dbt_cli_tools(fastmcp, mock_dbt_cli_config)
show_tool = tools["show"]
# Call show tool with test parameters
show_tool(sql_query=sql_query, limit=limit_param)
# Verify the command was called with expected arguments
assert mock_calls
args_list = mock_calls[0][1:] # Skip the dbt path
assert args_list == expected_args
def test_run_command_adds_quiet_flag_to_verbose_commands(
monkeypatch: MonkeyPatch, mock_process, mock_fastmcp
):
# Mock Popen
mock_calls = []
def mock_popen(args, **kwargs):
mock_calls.append(args)
return mock_process
monkeypatch.setattr("subprocess.Popen", mock_popen)
# Setup
mock_fastmcp_obj, tools = mock_fastmcp
register_dbt_cli_tools(mock_fastmcp_obj, mock_dbt_cli_config)
run_tool = tools["run"]
# Execute
run_tool()
# Verify
assert mock_calls
args_list = mock_calls[0]
assert "--quiet" in args_list
def test_run_command_correctly_formatted(
monkeypatch: MonkeyPatch, mock_process, mock_fastmcp
):
# Mock Popen
mock_calls = []
def mock_popen(args, **kwargs):
mock_calls.append(args)
return mock_process
monkeypatch.setattr("subprocess.Popen", mock_popen)
fastmcp, tools = mock_fastmcp
# Register the tools
register_dbt_cli_tools(fastmcp, mock_dbt_cli_config)
run_tool = tools["run"]
# Run the command with a selector
run_tool(selector="my_model")
# Verify the command is correctly formatted
assert mock_calls
args_list = mock_calls[0]
assert args_list == [
"/path/to/dbt",
"--no-use-colors",
"run",
"--quiet",
"--select",
"my_model",
]
def test_show_command_correctly_formatted(
monkeypatch: MonkeyPatch, mock_process, mock_fastmcp
):
# Mock Popen
mock_calls = []
def mock_popen(args, **kwargs):
mock_calls.append(args)
return mock_process
monkeypatch.setattr("subprocess.Popen", mock_popen)
# Setup
mock_fastmcp_obj, tools = mock_fastmcp
register_dbt_cli_tools(mock_fastmcp_obj, mock_dbt_cli_config)
show_tool = tools["show"]
# Execute
show_tool(sql_query="SELECT * FROM my_model")
# Verify
assert mock_calls
args_list = mock_calls[0]
assert args_list[0].endswith("dbt")
assert args_list[1] == "--no-use-colors"
assert args_list[2] == "show"
assert args_list[3] == "--inline"
assert args_list[4] == "SELECT * FROM my_model"
assert args_list[5] == "--favor-state"
def test_list_command_timeout_handling(monkeypatch: MonkeyPatch, mock_fastmcp):
# Mock Popen
class MockProcessWithTimeout:
def communicate(self, timeout=None):
raise subprocess.TimeoutExpired(cmd=["dbt", "list"], timeout=10)
def mock_popen(*args, **kwargs):
return MockProcessWithTimeout()
monkeypatch.setattr("subprocess.Popen", mock_popen)
# Setup
mock_fastmcp_obj, tools = mock_fastmcp
register_dbt_cli_tools(mock_fastmcp_obj, mock_dbt_cli_config)
list_tool = tools["ls"]
# Test timeout case
result = list_tool(resource_type=["model", "snapshot"])
assert "Timeout: dbt command took too long to complete" in result
assert "Try using a specific selector to narrow down the results" in result
# Test with selector - should still timeout
result = list_tool(selector="my_model", resource_type=["model"])
assert "Timeout: dbt command took too long to complete" in result
assert "Try using a specific selector to narrow down the results" in result
@pytest.mark.parametrize("command_name", ["run", "build"])
def test_full_refresh_flag_added_to_command(
monkeypatch: MonkeyPatch, mock_process, mock_fastmcp, command_name
):
mock_calls = []
def mock_popen(args, **kwargs):
mock_calls.append(args)
return mock_process
monkeypatch.setattr("subprocess.Popen", mock_popen)
fastmcp, tools = mock_fastmcp
register_dbt_cli_tools(fastmcp, mock_dbt_cli_config)
tool = tools[command_name]
tool(is_full_refresh=True)
assert mock_calls
args_list = mock_calls[0]
assert "--full-refresh" in args_list
@pytest.mark.parametrize("command_name", ["build", "run", "test"])
def test_vars_flag_added_to_command(
monkeypatch: MonkeyPatch, mock_process, mock_fastmcp, command_name
):
mock_calls = []
def mock_popen(args, **kwargs):
mock_calls.append(args)
return mock_process
monkeypatch.setattr("subprocess.Popen", mock_popen)
fastmcp, tools = mock_fastmcp
register_dbt_cli_tools(fastmcp, mock_dbt_cli_config)
tool = tools[command_name]
tool(vars="environment: production")
assert mock_calls
args_list = mock_calls[0]
assert "--vars" in args_list
assert "environment: production" in args_list
def test_vars_not_added_when_none(monkeypatch: MonkeyPatch, mock_process, mock_fastmcp):
mock_calls = []
def mock_popen(args, **kwargs):
mock_calls.append(args)
return mock_process
monkeypatch.setattr("subprocess.Popen", mock_popen)
fastmcp, tools = mock_fastmcp
register_dbt_cli_tools(fastmcp, mock_dbt_cli_config)
build_tool = tools["build"]
build_tool() # Non-explicit
assert mock_calls
args_list = mock_calls[0]
assert "--vars" not in args_list
```
--------------------------------------------------------------------------------
/src/dbt_mcp/dbt_cli/tools.py:
--------------------------------------------------------------------------------
```python
import os
import subprocess
from collections.abc import Iterable, Sequence
from mcp.server.fastmcp import FastMCP
from pydantic import Field
from dbt_mcp.config.config import DbtCliConfig
from dbt_mcp.dbt_cli.binary_type import get_color_disable_flag
from dbt_mcp.prompts.prompts import get_prompt
from dbt_mcp.tools.definitions import ToolDefinition
from dbt_mcp.tools.register import register_tools
from dbt_mcp.tools.tool_names import ToolName
from dbt_mcp.tools.annotations import create_tool_annotations
def create_dbt_cli_tool_definitions(config: DbtCliConfig) -> list[ToolDefinition]:
def _run_dbt_command(
command: list[str],
selector: str | None = None,
resource_type: list[str] | None = None,
is_selectable: bool = False,
is_full_refresh: bool | None = False,
vars: str | None = None,
) -> str:
try:
# Commands that should always be quiet to reduce output verbosity
verbose_commands = [
"build",
"compile",
"docs",
"parse",
"run",
"test",
"list",
]
if is_full_refresh is True:
command.append("--full-refresh")
if vars and isinstance(vars, str):
command.extend(["--vars", vars])
if selector:
selector_params = str(selector).split(" ")
command.extend(["--select"] + selector_params)
if isinstance(resource_type, Iterable):
command.extend(["--resource-type"] + resource_type)
full_command = command.copy()
# Add --quiet flag to specific commands to reduce context window usage
if len(full_command) > 0 and full_command[0] in verbose_commands:
main_command = full_command[0]
command_args = full_command[1:] if len(full_command) > 1 else []
full_command = [main_command, "--quiet", *command_args]
# We change the path only if this is an absolute path, otherwise we can have
# problems with relative paths applied multiple times as DBT_PROJECT_DIR
# is applied to dbt Core and Fusion as well (but not the dbt Cloud CLI)
cwd_path = config.project_dir if os.path.isabs(config.project_dir) else None
# Add appropriate color disable flag based on binary type
color_flag = get_color_disable_flag(config.binary_type)
args = [config.dbt_path, color_flag, *full_command]
process = subprocess.Popen(
args=args,
cwd=cwd_path,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.DEVNULL,
text=True,
)
output, _ = process.communicate(timeout=config.dbt_cli_timeout)
return output or "OK"
except subprocess.TimeoutExpired:
return "Timeout: dbt command took too long to complete." + (
" Try using a specific selector to narrow down the results."
if is_selectable
else ""
)
def build(
selector: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/selectors")
),
is_full_refresh: bool | None = Field(
default=None, description=get_prompt("dbt_cli/args/full_refresh")
),
vars: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/vars")
),
) -> str:
return _run_dbt_command(
["build"],
selector,
is_selectable=True,
is_full_refresh=is_full_refresh,
vars=vars,
)
def compile() -> str:
return _run_dbt_command(["compile"])
def docs() -> str:
return _run_dbt_command(["docs", "generate"])
def ls(
selector: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/selectors")
),
resource_type: list[str] | None = Field(
default=None,
description=get_prompt("dbt_cli/args/resource_type"),
),
) -> str:
return _run_dbt_command(
["list"],
selector,
resource_type=resource_type,
is_selectable=True,
)
def parse() -> str:
return _run_dbt_command(["parse"])
def run(
selector: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/selectors")
),
is_full_refresh: bool | None = Field(
default=None, description=get_prompt("dbt_cli/args/full_refresh")
),
vars: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/vars")
),
) -> str:
return _run_dbt_command(
["run"],
selector,
is_selectable=True,
is_full_refresh=is_full_refresh,
vars=vars,
)
def test(
selector: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/selectors")
),
vars: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/vars")
),
) -> str:
return _run_dbt_command(["test"], selector, is_selectable=True, vars=vars)
def show(
sql_query: str = Field(description=get_prompt("dbt_cli/args/sql_query")),
limit: int = Field(default=5, description=get_prompt("dbt_cli/args/limit")),
) -> str:
args = ["show", "--inline", sql_query, "--favor-state"]
# This is quite crude, but it should be okay for now
# until we have a dbt Fusion integration.
cli_limit = None
if "limit" in sql_query.lower():
# When --limit=-1, dbt won't apply a separate limit.
cli_limit = -1
elif limit:
# This can be problematic if the LLM provides
# a SQL limit and a `limit` argument. However, preferencing the limit
# in the SQL query leads to a better experience when the LLM
# makes that mistake.
cli_limit = limit
if cli_limit is not None:
args.extend(["--limit", str(cli_limit)])
args.extend(["--output", "json"])
return _run_dbt_command(args)
return [
ToolDefinition(
fn=build,
description=get_prompt("dbt_cli/build"),
annotations=create_tool_annotations(
title="dbt build",
read_only_hint=False,
destructive_hint=True,
idempotent_hint=False,
),
),
ToolDefinition(
fn=compile,
description=get_prompt("dbt_cli/compile"),
annotations=create_tool_annotations(
title="dbt compile",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
fn=docs,
description=get_prompt("dbt_cli/docs"),
annotations=create_tool_annotations(
title="dbt docs",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
name="list",
fn=ls,
description=get_prompt("dbt_cli/list"),
annotations=create_tool_annotations(
title="dbt list",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
fn=parse,
description=get_prompt("dbt_cli/parse"),
annotations=create_tool_annotations(
title="dbt parse",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
fn=run,
description=get_prompt("dbt_cli/run"),
annotations=create_tool_annotations(
title="dbt run",
read_only_hint=False,
destructive_hint=True,
idempotent_hint=False,
),
),
ToolDefinition(
fn=test,
description=get_prompt("dbt_cli/test"),
annotations=create_tool_annotations(
title="dbt test",
read_only_hint=False,
destructive_hint=True,
idempotent_hint=False,
),
),
ToolDefinition(
fn=show,
description=get_prompt("dbt_cli/show"),
annotations=create_tool_annotations(
title="dbt show",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
]
def register_dbt_cli_tools(
dbt_mcp: FastMCP,
config: DbtCliConfig,
exclude_tools: Sequence[ToolName] = [],
) -> None:
register_tools(
dbt_mcp,
create_dbt_cli_tool_definitions(config),
exclude_tools,
)
```
--------------------------------------------------------------------------------
/src/dbt_mcp/dbt_admin/run_results_errors/parser.py:
--------------------------------------------------------------------------------
```python
import logging
from typing import Any
from pydantic import ValidationError
from dbt_mcp.config.config_providers import AdminApiConfig
from dbt_mcp.dbt_admin.client import DbtAdminAPIClient
from dbt_mcp.dbt_admin.constants import (
SOURCE_FRESHNESS_STEP_NAME,
STATUS_MAP,
JobRunStatus,
RunResultsStatus,
)
from dbt_mcp.dbt_admin.run_results_errors.config import (
ErrorResultSchema,
ErrorStepSchema,
RunDetailsSchema,
RunResultsArtifactSchema,
RunResultSchema,
RunStepSchema,
)
from dbt_mcp.errors import ArtifactRetrievalError
logger = logging.getLogger(__name__)
class ErrorFetcher:
"""Parses dbt Cloud job run data to extract focused error information."""
def __init__(
self,
run_id: int,
run_details: dict[str, Any],
client: DbtAdminAPIClient,
admin_api_config: AdminApiConfig,
):
"""
Initialize parser with run data.
Args:
run_id: dbt Cloud job run ID
run_details: Raw run details from get_job_run_details()
client: DbtAdminAPIClient instance for fetching artifacts
admin_api_config: Admin API configuration
"""
self.run_id = run_id
self.run_details = run_details
self.client = client
self.admin_api_config = admin_api_config
async def analyze_run_errors(self) -> dict[str, Any]:
"""Parse the run data and return all failed steps with their details."""
try:
run_details = RunDetailsSchema.model_validate(self.run_details)
failed_steps = self._find_all_failed_steps(run_details)
if run_details.is_cancelled:
error_result = self._create_error_result(
message="Job run was cancelled",
finished_at=run_details.finished_at,
)
return {"failed_steps": [error_result]}
if not failed_steps:
error_result = self._create_error_result("No failed step found")
return {"failed_steps": [error_result]}
processed_steps = []
for step in failed_steps:
step_result = await self._get_failure_details(step)
processed_steps.append(step_result)
return {"failed_steps": processed_steps}
except ValidationError as e:
logger.error(f"Schema validation failed for run {self.run_id}: {e}")
error_result = self._create_error_result(f"Validation failed: {e!s}")
return {"failed_steps": [error_result]}
except Exception as e:
logger.error(f"Error analyzing run {self.run_id}: {e}")
error_result = self._create_error_result(str(e))
return {"failed_steps": [error_result]}
def _find_all_failed_steps(
self, run_details: RunDetailsSchema
) -> list[RunStepSchema]:
"""Find all failed steps in the run."""
failed_steps = []
for step in run_details.run_steps:
if step.status == STATUS_MAP[JobRunStatus.ERROR]:
failed_steps.append(step)
return failed_steps
async def _get_failure_details(self, failed_step: RunStepSchema) -> dict[str, Any]:
"""Get simplified failure information from failed step."""
run_results_content = await self._fetch_run_results_artifact(failed_step)
if not run_results_content:
return self._handle_artifact_error(failed_step)
return self._parse_run_results(run_results_content, failed_step)
async def _fetch_run_results_artifact(
self, failed_step: RunStepSchema
) -> str | None:
"""Fetch run_results.json artifact for the failed step."""
step_index = failed_step.index
try:
if step_index is not None:
run_results_content = await self.client.get_job_run_artifact(
self.admin_api_config.account_id,
self.run_id,
"run_results.json",
step=step_index,
)
logger.info(f"Got run_results.json from failed step {step_index}")
return run_results_content
else:
raise ArtifactRetrievalError(
"No step index available for artifact retrieval"
)
except Exception as e:
logger.error(f"Failed to get run_results.json from step {step_index}: {e}")
return None
def _parse_run_results(
self, run_results_content: str, failed_step: RunStepSchema
) -> dict[str, Any]:
"""Parse run_results.json content and extract errors."""
try:
run_results = RunResultsArtifactSchema.model_validate_json(
run_results_content
)
errors = self._extract_errors_from_results(run_results.results)
return self._build_error_response(errors, failed_step, run_results.args)
except ValidationError as e:
logger.warning(f"run_results.json validation failed: {e}")
return self._handle_artifact_error(failed_step, e)
except Exception as e:
return self._handle_artifact_error(failed_step, e)
def _extract_errors_from_results(
self, results: list[RunResultSchema]
) -> list[ErrorResultSchema]:
"""Extract error results from run results."""
errors = []
for result in results:
if result.status in [
RunResultsStatus.ERROR.value,
RunResultsStatus.FAIL.value,
]:
relation_name = (
result.relation_name
if result.relation_name is not None
else "No database relation"
)
error = ErrorResultSchema(
unique_id=result.unique_id,
relation_name=relation_name,
message=result.message or "",
compiled_code=result.compiled_code,
)
errors.append(error)
return errors
def _build_error_response(
self,
errors: list[ErrorResultSchema],
failed_step: RunStepSchema,
args: Any | None,
) -> dict[str, Any]:
"""Build the final error response structure."""
target = args.target if args else None
step_name = failed_step.name
finished_at = failed_step.finished_at
truncated_logs = self._truncated_logs(failed_step)
if errors:
return ErrorStepSchema(
errors=errors,
step_name=step_name,
finished_at=finished_at,
target=target,
).model_dump()
message = "No failures found in run_results.json"
return self._create_error_result(
message=message,
target=target,
step_name=step_name,
finished_at=finished_at,
truncated_logs=truncated_logs,
)
def _create_error_result(
self,
message: str,
unique_id: str | None = None,
relation_name: str | None = None,
target: str | None = None,
step_name: str | None = None,
finished_at: str | None = None,
compiled_code: str | None = None,
truncated_logs: str | None = None,
) -> dict[str, Any]:
"""Create a standardized error results using ErrorStepSchema."""
error = ErrorResultSchema(
unique_id=unique_id,
relation_name=relation_name,
message=message,
compiled_code=compiled_code,
truncated_logs=truncated_logs,
)
return ErrorStepSchema(
errors=[error],
step_name=step_name,
finished_at=finished_at,
target=target,
).model_dump()
def _handle_artifact_error(
self, failed_step: RunStepSchema, error: Exception | None = None
) -> dict[str, Any]:
"""Handle cases where run_results.json is not available."""
relation_name = "No database relation"
step_name = failed_step.name
finished_at = failed_step.finished_at
truncated_logs = self._truncated_logs(failed_step)
# Special handling for source freshness steps
if SOURCE_FRESHNESS_STEP_NAME.lower() in step_name.lower():
message = "Source freshness error - returning logs"
else:
message = "run_results.json not available - returning logs"
return self._create_error_result(
message=message,
relation_name=relation_name,
step_name=step_name,
finished_at=finished_at,
truncated_logs=truncated_logs,
)
def _truncated_logs(self, failed_step: RunStepSchema) -> str | None:
"""Truncate logs to the last 50 lines."""
TRUNCATED_LOGS_LENGTH = 50
split_logs = failed_step.logs.splitlines() if failed_step.logs else []
if len(split_logs) > TRUNCATED_LOGS_LENGTH:
split_logs = [
f"Logs truncated to last {TRUNCATED_LOGS_LENGTH} lines..."
] + split_logs[-TRUNCATED_LOGS_LENGTH:]
return "\n".join(split_logs)
```