This is page 31 of 35. Use http://codebase.md/dicklesworthstone/llm_gateway_mcp_server?page={x} to view the full context.
# Directory Structure
```
├── .cursorignore
├── .env.example
├── .envrc
├── .gitignore
├── additional_features.md
├── check_api_keys.py
├── completion_support.py
├── comprehensive_test.py
├── docker-compose.yml
├── Dockerfile
├── empirically_measured_model_speeds.json
├── error_handling.py
├── example_structured_tool.py
├── examples
│ ├── __init__.py
│ ├── advanced_agent_flows_using_unified_memory_system_demo.py
│ ├── advanced_extraction_demo.py
│ ├── advanced_unified_memory_system_demo.py
│ ├── advanced_vector_search_demo.py
│ ├── analytics_reporting_demo.py
│ ├── audio_transcription_demo.py
│ ├── basic_completion_demo.py
│ ├── cache_demo.py
│ ├── claude_integration_demo.py
│ ├── compare_synthesize_demo.py
│ ├── cost_optimization.py
│ ├── data
│ │ ├── sample_event.txt
│ │ ├── Steve_Jobs_Introducing_The_iPhone_compressed.md
│ │ └── Steve_Jobs_Introducing_The_iPhone_compressed.mp3
│ ├── docstring_refiner_demo.py
│ ├── document_conversion_and_processing_demo.py
│ ├── entity_relation_graph_demo.py
│ ├── filesystem_operations_demo.py
│ ├── grok_integration_demo.py
│ ├── local_text_tools_demo.py
│ ├── marqo_fused_search_demo.py
│ ├── measure_model_speeds.py
│ ├── meta_api_demo.py
│ ├── multi_provider_demo.py
│ ├── ollama_integration_demo.py
│ ├── prompt_templates_demo.py
│ ├── python_sandbox_demo.py
│ ├── rag_example.py
│ ├── research_workflow_demo.py
│ ├── sample
│ │ ├── article.txt
│ │ ├── backprop_paper.pdf
│ │ ├── buffett.pdf
│ │ ├── contract_link.txt
│ │ ├── legal_contract.txt
│ │ ├── medical_case.txt
│ │ ├── northwind.db
│ │ ├── research_paper.txt
│ │ ├── sample_data.json
│ │ └── text_classification_samples
│ │ ├── email_classification.txt
│ │ ├── news_samples.txt
│ │ ├── product_reviews.txt
│ │ └── support_tickets.txt
│ ├── sample_docs
│ │ └── downloaded
│ │ └── attention_is_all_you_need.pdf
│ ├── sentiment_analysis_demo.py
│ ├── simple_completion_demo.py
│ ├── single_shot_synthesis_demo.py
│ ├── smart_browser_demo.py
│ ├── sql_database_demo.py
│ ├── sse_client_demo.py
│ ├── test_code_extraction.py
│ ├── test_content_detection.py
│ ├── test_ollama.py
│ ├── text_classification_demo.py
│ ├── text_redline_demo.py
│ ├── tool_composition_examples.py
│ ├── tournament_code_demo.py
│ ├── tournament_text_demo.py
│ ├── unified_memory_system_demo.py
│ ├── vector_search_demo.py
│ ├── web_automation_instruction_packs.py
│ └── workflow_delegation_demo.py
├── LICENSE
├── list_models.py
├── marqo_index_config.json.example
├── mcp_protocol_schema_2025-03-25_version.json
├── mcp_python_lib_docs.md
├── mcp_tool_context_estimator.py
├── model_preferences.py
├── pyproject.toml
├── quick_test.py
├── README.md
├── resource_annotations.py
├── run_all_demo_scripts_and_check_for_errors.py
├── storage
│ └── smart_browser_internal
│ ├── locator_cache.db
│ ├── readability.js
│ └── storage_state.enc
├── test_client.py
├── test_connection.py
├── TEST_README.md
├── test_sse_client.py
├── test_stdio_client.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── integration
│ │ ├── __init__.py
│ │ └── test_server.py
│ ├── manual
│ │ ├── test_extraction_advanced.py
│ │ └── test_extraction.py
│ └── unit
│ ├── __init__.py
│ ├── test_cache.py
│ ├── test_providers.py
│ └── test_tools.py
├── TODO.md
├── tool_annotations.py
├── tools_list.json
├── ultimate_mcp_banner.webp
├── ultimate_mcp_logo.webp
├── ultimate_mcp_server
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── commands.py
│ │ ├── helpers.py
│ │ └── typer_cli.py
│ ├── clients
│ │ ├── __init__.py
│ │ ├── completion_client.py
│ │ └── rag_client.py
│ ├── config
│ │ └── examples
│ │ └── filesystem_config.yaml
│ ├── config.py
│ ├── constants.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── evaluation
│ │ │ ├── base.py
│ │ │ └── evaluators.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── anthropic.py
│ │ │ ├── base.py
│ │ │ ├── deepseek.py
│ │ │ ├── gemini.py
│ │ │ ├── grok.py
│ │ │ ├── ollama.py
│ │ │ ├── openai.py
│ │ │ └── openrouter.py
│ │ ├── server.py
│ │ ├── state_store.py
│ │ ├── tournaments
│ │ │ ├── manager.py
│ │ │ ├── tasks.py
│ │ │ └── utils.py
│ │ └── ums_api
│ │ ├── __init__.py
│ │ ├── ums_database.py
│ │ ├── ums_endpoints.py
│ │ ├── ums_models.py
│ │ └── ums_services.py
│ ├── exceptions.py
│ ├── graceful_shutdown.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── analytics
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── reporting.py
│ │ ├── cache
│ │ │ ├── __init__.py
│ │ │ ├── cache_service.py
│ │ │ ├── persistence.py
│ │ │ ├── strategies.py
│ │ │ └── utils.py
│ │ ├── cache.py
│ │ ├── document.py
│ │ ├── knowledge_base
│ │ │ ├── __init__.py
│ │ │ ├── feedback.py
│ │ │ ├── manager.py
│ │ │ ├── rag_engine.py
│ │ │ ├── retriever.py
│ │ │ └── utils.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── repository.py
│ │ │ └── templates.py
│ │ ├── prompts.py
│ │ └── vector
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ └── vector_service.py
│ ├── tool_token_counter.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── audio_transcription.py
│ │ ├── base.py
│ │ ├── completion.py
│ │ ├── docstring_refiner.py
│ │ ├── document_conversion_and_processing.py
│ │ ├── enhanced-ums-lookbook.html
│ │ ├── entity_relation_graph.py
│ │ ├── excel_spreadsheet_automation.py
│ │ ├── extraction.py
│ │ ├── filesystem.py
│ │ ├── html_to_markdown.py
│ │ ├── local_text_tools.py
│ │ ├── marqo_fused_search.py
│ │ ├── meta_api_tool.py
│ │ ├── ocr_tools.py
│ │ ├── optimization.py
│ │ ├── provider.py
│ │ ├── pyodide_boot_template.html
│ │ ├── python_sandbox.py
│ │ ├── rag.py
│ │ ├── redline-compiled.css
│ │ ├── sentiment_analysis.py
│ │ ├── single_shot_synthesis.py
│ │ ├── smart_browser.py
│ │ ├── sql_databases.py
│ │ ├── text_classification.py
│ │ ├── text_redline_tools.py
│ │ ├── tournament.py
│ │ ├── ums_explorer.html
│ │ └── unified_memory_system.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── async_utils.py
│ │ ├── display.py
│ │ ├── logging
│ │ │ ├── __init__.py
│ │ │ ├── console.py
│ │ │ ├── emojis.py
│ │ │ ├── formatter.py
│ │ │ ├── logger.py
│ │ │ ├── panels.py
│ │ │ ├── progress.py
│ │ │ └── themes.py
│ │ ├── parse_yaml.py
│ │ ├── parsing.py
│ │ ├── security.py
│ │ └── text.py
│ └── working_memory_api.py
├── unified_memory_system_technical_analysis.md
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/ultimate_mcp_server/tools/sql_databases.py:
--------------------------------------------------------------------------------
```python
# ultimate_mcp_server/tools/sql_databases.py
from __future__ import annotations
import asyncio
import datetime as dt
import hashlib
import json
import os
import re
import tempfile
import time
import uuid
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
# --- START: Expanded typing imports ---
from typing import Any, Dict, List, Optional, Set, Tuple, Union
# --- END: Expanded typing imports ---
# --- Removed BaseTool import ---
# SQLAlchemy imports
from sqlalchemy import inspect as sa_inspect
from sqlalchemy import text
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError, ProgrammingError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
# Local imports
from ultimate_mcp_server.exceptions import ToolError, ToolInputError
# --- START: Expanded base imports ---
from ultimate_mcp_server.tools.base import with_error_handling, with_tool_metrics
# --- END: Expanded base imports ---
from ultimate_mcp_server.tools.completion import generate_completion # For NL→SQL
from ultimate_mcp_server.utils import get_logger
# Optional imports with graceful fallbacks
try:
import boto3 # For AWS Secrets Manager
except ImportError:
boto3 = None
try:
import hvac # For HashiCorp Vault
except ImportError:
hvac = None
try:
import pandas as pd
except ImportError:
pd = None
try:
import pandera as pa
except ImportError:
pa = None
try:
import prometheus_client as prom
except ImportError:
prom = None
logger = get_logger("ultimate_mcp_server.tools.sql_databases")
# =============================================================================
# Global State and Configuration (Replaces instance variables)
# =============================================================================
# --- Connection Management ---
class ConnectionManager:
"""Manages database connections with automatic cleanup after inactivity."""
# (Keep ConnectionManager class as is - it's a helper utility)
def __init__(self, cleanup_interval_seconds=600, check_interval_seconds=60):
self.connections: Dict[str, Tuple[AsyncEngine, float]] = {}
self.cleanup_interval = cleanup_interval_seconds
self.check_interval = check_interval_seconds
self._cleanup_task: Optional[asyncio.Task] = None
self._lock = asyncio.Lock() # Added lock for thread-safe modifications
async def start_cleanup_task(self):
async with self._lock:
cleanup_task_is_none = self._cleanup_task is None
cleanup_task_is_done = self._cleanup_task is not None and self._cleanup_task.done()
if cleanup_task_is_none or cleanup_task_is_done:
try:
loop = asyncio.get_running_loop()
task_coro = self._cleanup_loop()
self._cleanup_task = loop.create_task(task_coro)
logger.info("Started connection cleanup task.")
except RuntimeError:
logger.warning("No running event loop found, cleanup task not started.")
async def _cleanup_loop(self):
log_msg = f"Cleanup loop started. Check interval: {self.check_interval}s, Inactivity threshold: {self.cleanup_interval}s"
logger.debug(log_msg)
while True:
await asyncio.sleep(self.check_interval)
try:
await self.cleanup_inactive_connections()
except asyncio.CancelledError:
logger.info("Cleanup loop cancelled.")
break # Exit loop cleanly on cancellation
except Exception as e:
logger.error(f"Error during connection cleanup: {e}", exc_info=True)
async def cleanup_inactive_connections(self):
current_time = time.time()
conn_ids_to_close = []
# Need lock here as we iterate over potentially changing dict
async with self._lock:
# Use items() for safe iteration while potentially modifying dict later
# Create a copy to avoid issues if the dict is modified elsewhere concurrently (though unlikely with lock)
current_connections = list(self.connections.items())
for conn_id, (_engine, last_accessed) in current_connections:
idle_time = current_time - last_accessed
is_inactive = idle_time > self.cleanup_interval
if is_inactive:
log_msg = f"Connection {conn_id} exceeded inactivity timeout ({idle_time:.1f}s > {self.cleanup_interval}s)"
logger.info(log_msg)
conn_ids_to_close.append(conn_id)
closed_count = 0
for conn_id in conn_ids_to_close:
# close_connection acquires its own lock
closed = await self.close_connection(conn_id)
if closed:
logger.info(f"Auto-closed inactive connection: {conn_id}")
closed_count += 1
if closed_count > 0:
logger.debug(f"Closed {closed_count} inactive connections.")
elif conn_ids_to_close:
num_attempted = len(conn_ids_to_close)
logger.debug(
f"Attempted to close {num_attempted} connections, but they might have been removed already."
)
async def get_connection(self, conn_id: str) -> AsyncEngine:
async with self._lock:
if conn_id not in self.connections:
details = {"error_type": "CONNECTION_NOT_FOUND"}
raise ToolInputError(
f"Unknown connection_id: {conn_id}", param_name="connection_id", details=details
)
engine, _ = self.connections[conn_id]
# Update last accessed time
current_time = time.time()
self.connections[conn_id] = (engine, current_time)
logger.debug(f"Accessed connection {conn_id}, updated last accessed time.")
return engine
async def add_connection(self, conn_id: str, engine: AsyncEngine):
# close_connection handles locking internally
has_existing = conn_id in self.connections
if has_existing:
logger.warning(f"Overwriting existing connection entry for {conn_id}.")
await self.close_connection(conn_id) # Close the old one first
async with self._lock:
current_time = time.time()
self.connections[conn_id] = (engine, current_time)
url_str = str(engine.url)
url_prefix = url_str.split("@")[0]
log_msg = (
f"Added connection {conn_id} for URL: {url_prefix}..." # Avoid logging credentials
)
logger.info(log_msg)
await self.start_cleanup_task() # Ensure cleanup is running
async def close_connection(self, conn_id: str) -> bool:
engine = None
async with self._lock:
connection_exists = conn_id in self.connections
if connection_exists:
engine, _ = self.connections.pop(conn_id)
else:
logger.warning(f"Attempted to close non-existent connection ID: {conn_id}")
return False # Not found
if engine:
logger.info(f"Closing connection {conn_id}...")
try:
await engine.dispose()
logger.info(f"Connection {conn_id} disposed successfully.")
return True
except Exception as e:
log_msg = f"Error disposing engine for connection {conn_id}: {e}"
logger.error(log_msg, exc_info=True)
# Removed from dict, but disposal failed
return False
return False # Should not be reached if found
async def shutdown(self):
logger.info("Shutting down Connection Manager...")
# Cancel cleanup task first
cleanup_task = None
async with self._lock:
task_exists = self._cleanup_task is not None
task_not_done = task_exists and not self._cleanup_task.done()
if task_exists and task_not_done:
cleanup_task = self._cleanup_task # Get reference before clearing
self._cleanup_task = None # Prevent restarting
if cleanup_task:
cleanup_task.cancel()
try:
# Add timeout for task cancellation
await asyncio.wait_for(cleanup_task, timeout=2.0)
except asyncio.TimeoutError:
logger.warning("Cleanup task cancellation timed out after 2 seconds")
except asyncio.CancelledError:
logger.info("Cleanup task cancelled.")
except Exception as e:
logger.error(f"Error stopping cleanup task: {e}", exc_info=True)
# Close remaining connections
async with self._lock:
conn_ids = list(self.connections.keys())
if conn_ids:
num_conns = len(conn_ids)
logger.info(f"Closing {num_conns} active connections...")
# Call close_connection which handles locking and removal
close_tasks = []
for conn_id in conn_ids:
# Create a task that times out for each connection
async def close_with_timeout(conn_id):
try:
await asyncio.wait_for(self.close_connection(conn_id), timeout=2.0)
return True
except asyncio.TimeoutError:
logger.warning(f"Connection {conn_id} close timed out after 2 seconds")
return False
close_tasks.append(close_with_timeout(conn_id))
# Wait for all connections to close with an overall timeout
try:
await asyncio.wait_for(asyncio.gather(*close_tasks, return_exceptions=True), timeout=5.0)
except asyncio.TimeoutError:
logger.warning("Some connections did not close within the 5 second timeout")
async with self._lock:
# Final check
remaining = len(self.connections)
if remaining > 0:
logger.warning(f"{remaining} connections still remain after shutdown attempt.")
self.connections.clear() # Clear the dictionary
logger.info("Connection Manager shutdown complete.")
_connection_manager = ConnectionManager()
# --- Security and Validation ---
_PROHIBITED_SQL_PATTERN = r"""^\s*(DROP\s+(TABLE|DATABASE|INDEX|VIEW|FUNCTION|PROCEDURE|USER|ROLE)|
TRUNCATE\s+TABLE|
DELETE\s+FROM|
ALTER\s+(TABLE|DATABASE)\s+\S+\s+DROP\s+|
UPDATE\s+|INSERT\s+INTO(?!\s+OR\s+IGNORE)|
GRANT\s+|REVOKE\s+|
CREATE\s+USER|ALTER\s+USER|DROP\s+USER|
CREATE\s+ROLE|ALTER\s+ROLE|DROP\s+ROLE|
SHUTDOWN|REBOOT|RESTART)"""
_PROHIBITED_SQL_REGEX = re.compile(_PROHIBITED_SQL_PATTERN, re.I | re.X)
_TABLE_RX = re.compile(r"\b(?:FROM|JOIN|UPDATE|INSERT\s+INTO|DELETE\s+FROM)\s+([\w.\"$-]+)", re.I)
# --- Masking ---
@dataclass
class MaskRule:
rx: re.Pattern
repl: Union[str, callable]
# Helper lambda for credit card masking
def _mask_cc(v: str) -> str:
return f"XXXX-...-{v[-4:]}"
# Helper lambda for email masking
def _mask_email(v: str) -> str:
if "@" in v:
parts = v.split("@")
prefix = parts[0][:2] + "***"
domain = parts[-1]
return f"{prefix}@{domain}"
else:
return "***"
_MASKING_RULES = [
MaskRule(re.compile(r"^\d{3}-\d{2}-\d{4}$"), "***-**-XXXX"), # SSN
MaskRule(re.compile(r"(\b\d{4}-?){3}\d{4}\b"), _mask_cc), # CC basic mask
MaskRule(re.compile(r"[\w\.-]+@[\w\.-]+\.\w+"), _mask_email), # Email
]
# --- ACLs ---
_RESTRICTED_TABLES: Set[str] = set()
_RESTRICTED_COLUMNS: Set[str] = set()
# --- Auditing ---
_AUDIT_LOG: List[Dict[str, Any]] = []
_AUDIT_ID_COUNTER: int = 0
_audit_lock = asyncio.Lock() # Lock for modifying audit counter and log
# --- Schema Drift Detection ---
_LINEAGE: List[Dict[str, Any]] = []
_SCHEMA_VERSIONS: Dict[str, str] = {} # connection_id -> schema_hash
# --- Prometheus Metrics ---
# Initialized as None, populated in initialize function if prom is available
_Q_CNT: Optional[Any] = None
_Q_LAT: Optional[Any] = None
_CONN_GAUGE: Optional[Any] = None
# =============================================================================
# Initialization and Shutdown Functions
# =============================================================================
# Flag to track if metrics have been initialized
_sql_metrics_initialized = False
async def initialize_sql_tools():
"""Initialize global state for SQL tools, like starting the cleanup task and metrics."""
global _sql_metrics_initialized
global _Q_CNT, _Q_LAT, _CONN_GAUGE # Ensure globals are declared for assignment
# Initialize metrics only once
if not _sql_metrics_initialized:
logger.info("Initializing SQL Tools module metrics...")
if prom:
try:
# Define metrics
_Q_CNT = prom.Counter("mcp_sqltool_calls", "SQL tool calls", ["tool", "action", "db"])
latency_buckets = (0.01, 0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10, 30, 60)
_Q_LAT = prom.Histogram(
"mcp_sqltool_latency_seconds",
"SQL latency",
["tool", "action", "db"],
buckets=latency_buckets,
)
_CONN_GAUGE = prom.Gauge(
"mcp_sqltool_active_connections",
"Number of active SQL connections"
)
# Define the gauge function referencing the global manager
# Wrap in try-except as accessing length during shutdown might be tricky
def _get_active_connections():
try:
# Access length directly if manager state is simple enough
# If complex state, acquire lock if necessary (_connection_manager._lock)
# For just length, direct access is usually okay unless adding/removing heavily concurrent
return len(_connection_manager.connections)
except Exception:
logger.exception("Error getting active connection count for Prometheus.")
return 0 # Default to 0 if error accessing
_CONN_GAUGE.set_function(_get_active_connections)
logger.info("Prometheus metrics initialized for SQL tools.")
_sql_metrics_initialized = True # Set flag only after successful initialization
except ValueError as e:
# Catch the specific duplicate error and log nicely, but don't crash
if "Duplicated timeseries" in str(e):
logger.warning(f"Prometheus metrics already registered: {e}. Skipping re-initialization.")
_sql_metrics_initialized = True # Assume they are initialized if duplicate error occurs
else:
# Re-raise other ValueErrors
logger.error(f"ValueError during Prometheus metric initialization: {e}", exc_info=True)
raise # Re-raise unexpected ValueError
except Exception as e:
logger.error(f"Failed to initialize Prometheus metrics for SQL tools: {e}", exc_info=True)
# Continue without metrics if initialization fails? Or raise? Let's continue for now.
else:
logger.info("Prometheus client not available, metrics disabled for SQL tools.")
_sql_metrics_initialized = True # Mark as "initialized" (i.e., done trying) even if prom not present
else:
logger.debug("SQL tools metrics already initialized, skipping metric creation.")
# Always try to start the cleanup task (it's internally idempotent)
# Ensure this happens *after* logging initialization attempt
logger.info("Ensuring SQL connection cleanup task is running...")
await _connection_manager.start_cleanup_task()
async def shutdown_sql_tools():
"""Gracefully shut down SQL tool resources, like the connection manager."""
logger.info("Shutting down SQL Tools module...")
try:
# Add timeout to connection manager shutdown
await asyncio.wait_for(_connection_manager.shutdown(), timeout=8.0)
except asyncio.TimeoutError:
logger.warning("Connection Manager shutdown timed out after 8 seconds")
# Clear other global state if necessary (e.g., save audit log)
logger.info("SQL Tools module shutdown complete.")
# =============================================================================
# Helper Functions (Private module-level functions)
# =============================================================================
@lru_cache(maxsize=64)
def _pull_secret_from_sources(name: str) -> str:
"""Retrieve a secret from various sources."""
# (Implementation remains the same as in the original class)
if boto3:
try:
client = boto3.client("secretsmanager")
# Consider region_name=os.getenv("AWS_REGION") or similar config
secret_value_response = client.get_secret_value(SecretId=name)
# Handle binary vs string secrets
if "SecretString" in secret_value_response:
secret = secret_value_response["SecretString"]
return secret
elif "SecretBinary" in secret_value_response:
# Decode binary appropriately if needed, default to utf-8
secret_bytes = secret_value_response["SecretBinary"]
secret = secret_bytes.decode("utf-8")
return secret
except Exception as aws_err:
logger.debug(f"Secret '{name}' not found or error in AWS Secrets Manager: {aws_err}")
pass
if hvac:
try:
vault_url = os.getenv("VAULT_ADDR")
vault_token = os.getenv("VAULT_TOKEN")
if vault_url and vault_token:
vault_client = hvac.Client(
url=vault_url, token=vault_token, timeout=2
) # Short timeout
is_auth = vault_client.is_authenticated()
if is_auth:
mount_point = os.getenv("VAULT_KV_MOUNT_POINT", "secret")
secret_path = name
read_response = vault_client.secrets.kv.v2.read_secret_version(
path=secret_path, mount_point=mount_point
)
# Standard KV v2 structure: response['data']['data'] is the dict of secrets
has_outer_data = "data" in read_response
has_inner_data = has_outer_data and "data" in read_response["data"]
if has_inner_data:
# Try common key names 'value' or the secret name itself
secret_data = read_response["data"]["data"]
if "value" in secret_data:
value = secret_data["value"]
return value
elif name in secret_data:
value = secret_data[name]
return value
else:
log_msg = f"Secret keys 'value' or '{name}' not found at path '{secret_path}' in Vault."
logger.debug(log_msg)
else:
logger.warning(f"Vault authentication failed for address: {vault_url}")
except Exception as e:
logger.debug(f"Error accessing Vault for secret '{name}': {e}")
pass
# Try environment variables
env_val_direct = os.getenv(name)
if env_val_direct:
return env_val_direct
mcp_secret_name = f"MCP_SECRET_{name.upper()}"
env_val_prefixed = os.getenv(mcp_secret_name)
if env_val_prefixed:
logger.debug(f"Found secret '{name}' using prefixed env var '{mcp_secret_name}'.")
return env_val_prefixed
error_msg = (
f"Secret '{name}' not found in any source (AWS, Vault, Env: {name}, Env: {mcp_secret_name})"
)
details = {"secret_name": name, "error_type": "SECRET_NOT_FOUND"}
raise ToolError(error_msg, http_status_code=404, details=details)
async def _sql_get_engine(cid: str) -> AsyncEngine:
"""Get engine by connection ID using the global ConnectionManager."""
engine = await _connection_manager.get_connection(cid)
return engine
def _sql_get_next_audit_id() -> str:
"""Generate the next sequential audit ID (thread-safe)."""
# Locking happens in _sql_audit where this is called
global _AUDIT_ID_COUNTER
_AUDIT_ID_COUNTER += 1
audit_id_str = f"a{_AUDIT_ID_COUNTER:09d}"
return audit_id_str
def _sql_now() -> str:
"""Get current UTC timestamp in ISO format."""
now_utc = dt.datetime.now(dt.timezone.utc)
iso_str = now_utc.isoformat(timespec="seconds")
return iso_str
async def _sql_audit(
*,
tool_name: str,
action: str,
connection_id: Optional[str],
sql: Optional[str],
tables: Optional[List[str]],
row_count: Optional[int],
success: bool,
error: Optional[str],
user_id: Optional[str],
session_id: Optional[str],
**extra_data: Any,
) -> None:
"""Record an audit trail entry (thread-safe)."""
global _AUDIT_LOG
async with _audit_lock:
audit_id = _sql_get_next_audit_id() # Get ID while locked
timestamp = _sql_now()
log_entry = {}
log_entry["audit_id"] = audit_id
log_entry["timestamp"] = timestamp
log_entry["tool_name"] = tool_name
log_entry["action"] = action
log_entry["user_id"] = user_id
log_entry["session_id"] = session_id
log_entry["connection_id"] = connection_id
log_entry["sql"] = sql
log_entry["tables"] = tables
log_entry["row_count"] = row_count
log_entry["success"] = success
log_entry["error"] = error
log_entry.update(extra_data) # Add extra data
_AUDIT_LOG.append(log_entry)
# Optional: Log to logger (outside lock)
log_base = f"Audit[{audit_id}]: Tool={tool_name}, Action={action}, Conn={connection_id}, Success={success}"
log_error = f", Error={error}" if error else ""
logger.info(log_base + log_error)
def _sql_update_acl(
*, tables: Optional[List[str]] = None, columns: Optional[List[str]] = None
) -> None:
"""Update the global ACL lists."""
global _RESTRICTED_TABLES, _RESTRICTED_COLUMNS
if tables is not None:
lowered_tables = {t.lower() for t in tables}
_RESTRICTED_TABLES = lowered_tables
logger.info(f"Updated restricted tables ACL: {_RESTRICTED_TABLES}")
if columns is not None:
lowered_columns = {c.lower() for c in columns}
_RESTRICTED_COLUMNS = lowered_columns
logger.info(f"Updated restricted columns ACL: {_RESTRICTED_COLUMNS}")
def _sql_check_acl(sql: str) -> None:
"""Check if SQL contains any restricted tables or columns using global ACLs."""
# (Implementation remains the same, uses global _RESTRICTED_TABLES/_COLUMNS)
raw_toks = re.findall(r'[\w$"\'.]+', sql.lower())
toks = set(raw_toks)
normalized_toks = set()
for tok in toks:
tok_norm = tok.strip("\"`'[]")
normalized_toks.add(tok_norm)
has_dot = "." in tok_norm
if has_dot:
last_part = tok_norm.split(".")[-1]
normalized_toks.add(last_part)
restricted_tables_found_set = _RESTRICTED_TABLES.intersection(normalized_toks)
restricted_tables_found = list(restricted_tables_found_set)
if restricted_tables_found:
tables_str = ", ".join(restricted_tables_found)
logger.warning(
f"ACL Violation: Restricted table(s) found in query: {restricted_tables_found}"
)
details = {
"restricted_tables": restricted_tables_found,
"error_type": "ACL_TABLE_VIOLATION",
}
raise ToolError(
f"Access denied: Query involves restricted table(s): {tables_str}",
http_status_code=403,
details=details,
)
restricted_columns_found_set = _RESTRICTED_COLUMNS.intersection(normalized_toks)
restricted_columns_found = list(restricted_columns_found_set)
if restricted_columns_found:
columns_str = ", ".join(restricted_columns_found)
logger.warning(
f"ACL Violation: Restricted column(s) found in query: {restricted_columns_found}"
)
details = {
"restricted_columns": restricted_columns_found,
"error_type": "ACL_COLUMN_VIOLATION",
}
raise ToolError(
f"Access denied: Query involves restricted column(s): {columns_str}",
http_status_code=403,
details=details,
)
def _sql_resolve_conn(raw: str) -> str:
"""Resolve connection string, handling secret references."""
# (Implementation remains the same)
is_secret_ref = raw.startswith("secrets://")
if is_secret_ref:
secret_name = raw[10:]
logger.info(f"Resolving secret reference: '{secret_name}'")
resolved_secret = _pull_secret_from_sources(secret_name)
return resolved_secret
return raw
def _sql_mask_val(v: Any) -> Any:
"""Apply masking rules to a single value using global rules."""
# (Implementation remains the same, uses global _MASKING_RULES)
is_string = isinstance(v, str)
is_not_empty = bool(v)
if not is_string or not is_not_empty:
return v
for rule in _MASKING_RULES:
matches = rule.rx.fullmatch(v)
if matches:
is_callable = callable(rule.repl)
if is_callable:
try:
masked_value = rule.repl(v)
return masked_value
except Exception as e:
log_msg = f"Error applying dynamic mask rule {rule.rx.pattern}: {e}"
logger.error(log_msg)
return "MASKING_ERROR"
else:
return rule.repl
return v
def _sql_mask_row(row: Dict[str, Any]) -> Dict[str, Any]:
"""Apply masking rules to an entire row of data."""
masked_dict = {}
for k, v in row.items():
masked_val = _sql_mask_val(v)
masked_dict[k] = masked_val
return masked_dict
# return {k: _sql_mask_val(v) for k, v in row.items()} # Keep single-line comprehension
def _sql_driver_url(conn_str: str) -> Tuple[str, str]:
"""Convert generic connection string to dialect-specific async URL."""
# Check if it looks like a path (no ://) and exists or is :memory:
has_protocol = "://" in conn_str
looks_like_path = not has_protocol
path_obj = Path(conn_str)
path_exists = path_obj.exists()
is_memory = conn_str == ":memory:"
is_file_path = looks_like_path and (path_exists or is_memory)
if is_file_path:
if is_memory:
url_str = "sqlite+aiosqlite:///:memory:"
logger.info("Using in-memory SQLite database.")
else:
sqlite_path = path_obj.expanduser().resolve()
parent_dir = sqlite_path.parent
parent_exists = parent_dir.exists()
if not parent_exists:
try:
parent_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Created directory for SQLite DB: {parent_dir}")
except OSError as e:
details = {"path": str(parent_dir)}
raise ToolError(
f"Failed to create directory for SQLite DB '{parent_dir}': {e}",
http_status_code=500,
details=details,
) from e
url_str = f"sqlite+aiosqlite:///{sqlite_path}"
logger.info(f"Using SQLite database file: {sqlite_path}")
url = make_url(url_str)
final_url_str = str(url)
return final_url_str, "sqlite"
else:
url_str = conn_str
try:
url = make_url(url_str)
except Exception as e:
details = {"value": conn_str}
raise ToolInputError(
f"Invalid connection string format: {e}",
param_name="connection_string",
details=details,
) from e
drv = url.drivername.lower()
drivername = url.drivername # Preserve original case for setting later if needed
if drv.startswith("sqlite"):
new_url = url.set(drivername="sqlite+aiosqlite")
return str(new_url), "sqlite"
if drv.startswith("postgresql") or drv == "postgres":
new_url = url.set(drivername="postgresql+asyncpg")
return str(new_url), "postgresql"
if drv.startswith("mysql") or drv == "mariadb":
query = dict(url.query)
query.setdefault("charset", "utf8mb4")
new_url = url.set(drivername="mysql+aiomysql", query=query)
return str(new_url), "mysql"
if drv.startswith("mssql") or drv == "sqlserver":
odbc_driver = url.query.get("driver")
if not odbc_driver:
logger.warning(
"MSSQL connection string lacks 'driver' parameter. Ensure a valid ODBC driver (e.g., 'ODBC Driver 17 for SQL Server') is installed and specified."
)
new_url = url.set(drivername="mssql+aioodbc")
return str(new_url), "sqlserver"
if drv.startswith("snowflake"):
# Keep original snowflake driver
new_url = url.set(drivername=drivername)
return str(new_url), "snowflake"
logger.error(f"Unsupported database dialect: {drv}")
details = {"dialect": drv}
raise ToolInputError(
f"Unsupported database dialect: '{drv}'. Supported: sqlite, postgresql, mysql, mssql, snowflake",
param_name="connection_string",
details=details,
)
def _sql_auto_pool(db_type: str) -> Dict[str, Any]:
"""Provide sensible default connection pool settings."""
# (Implementation remains the same)
# Single-line dict return is acceptable
defaults = {
"pool_size": 5,
"max_overflow": 10,
"pool_recycle": 1800,
"pool_pre_ping": True,
"pool_timeout": 30,
}
if db_type == "sqlite":
return {"pool_pre_ping": True}
if db_type == "postgresql":
return {
"pool_size": 10,
"max_overflow": 20,
"pool_recycle": 900,
"pool_pre_ping": True,
"pool_timeout": 30,
}
if db_type == "mysql":
return {
"pool_size": 10,
"max_overflow": 20,
"pool_recycle": 900,
"pool_pre_ping": True,
"pool_timeout": 30,
}
if db_type == "sqlserver":
return {
"pool_size": 10,
"max_overflow": 20,
"pool_recycle": 900,
"pool_pre_ping": True,
"pool_timeout": 30,
}
if db_type == "snowflake":
return {"pool_size": 5, "max_overflow": 5, "pool_pre_ping": True, "pool_timeout": 60}
logger.warning(f"Using default pool settings for unknown db_type: {db_type}")
return defaults
def _sql_extract_tables(sql: str) -> List[str]:
"""Extract table names referenced in a SQL query."""
matches = _TABLE_RX.findall(sql)
tables = set()
for match in matches:
# Chained strip is one expression
table_stripped = match.strip()
table = table_stripped.strip("\"`'[]")
has_dot = "." in table
if has_dot:
# table.split('.')[-1].strip('"`\'[]') # Original combined
parts = table.split(".")
last_part = parts[-1]
table = last_part.strip("\"`'[]")
if table:
tables.add(table)
sorted_tables = sorted(list(tables))
return sorted_tables
def _sql_check_safe(sql: str, read_only: bool = True) -> None:
"""Validate SQL for safety using global patterns and ACLs."""
# Check ACLs first
_sql_check_acl(sql)
# Check prohibited statements
normalized_sql = sql.lstrip().upper()
check_sql_part = normalized_sql # Default part to check
starts_with_with = normalized_sql.startswith("WITH")
if starts_with_with:
try:
# Regex remains single-line expression assignment
search_regex = r"\)\s*(SELECT|INSERT|UPDATE|DELETE|MERGE)"
search_flags = re.IGNORECASE | re.DOTALL
main_statement_match = re.search(search_regex, normalized_sql, search_flags)
if main_statement_match:
# Chained calls okay on one line
main_statement_group = main_statement_match.group(0)
check_sql_part = main_statement_group.lstrip(") \t\n\r")
# else: keep check_sql_part as normalized_sql
except Exception:
# Ignore regex errors, fallback to checking whole normalized_sql
pass
prohibited_match_obj = _PROHIBITED_SQL_REGEX.match(check_sql_part)
if prohibited_match_obj:
# Chained calls okay on one line
prohibited_match = prohibited_match_obj.group(1)
prohibited_statement = prohibited_match.strip()
logger.warning(f"Security Violation: Prohibited statement detected: {prohibited_statement}")
details = {"statement": prohibited_statement, "error_type": "PROHIBITED_STATEMENT"}
raise ToolInputError(
f"Prohibited statement type detected: '{prohibited_statement}'",
param_name="query",
details=details,
)
# Check read-only constraint
if read_only:
allowed_starts = ("SELECT", "SHOW", "EXPLAIN", "DESCRIBE", "PRAGMA")
is_read_query = check_sql_part.startswith(allowed_starts)
if not is_read_query:
query_preview = sql[:100]
logger.warning(
f"Security Violation: Write operation attempted in read-only mode: {query_preview}..."
)
details = {"error_type": "READ_ONLY_VIOLATION"}
raise ToolInputError(
"Write operation attempted in read-only mode", param_name="query", details=details
)
async def _sql_exec(
eng: AsyncEngine,
sql: str,
params: Optional[Dict[str, Any]],
*,
limit: Optional[int],
tool_name: str,
action_name: str,
timeout: float = 30.0,
) -> Tuple[List[str], List[Dict[str, Any]], int]:
"""Core async SQL executor helper."""
db_dialect = eng.dialect.name
start_time = time.perf_counter()
if _Q_CNT:
# Chained call okay
_Q_CNT.labels(tool=tool_name, action=action_name, db=db_dialect).inc()
cols: List[str] = []
rows_raw: List[Any] = []
row_count: int = 0
masked_rows: List[Dict[str, Any]] = []
async def _run(conn: AsyncConnection):
nonlocal cols, rows_raw, row_count, masked_rows
statement = text(sql)
query_params = params or {}
try:
res = await conn.execute(statement, query_params)
has_cursor = res.cursor is not None
has_description = has_cursor and res.cursor.description is not None
if not has_cursor or not has_description:
logger.debug(f"Query did not return rows or description. Action: {action_name}")
# Ternary okay
res_rowcount = res.rowcount if res.rowcount >= 0 else 0
row_count = res_rowcount
masked_rows = [] # Ensure it's an empty list
empty_cols: List[str] = []
empty_rows: List[Dict[str, Any]] = []
return empty_cols, empty_rows, row_count # Return empty lists for cols/rows
cols = list(res.keys())
try:
# --- START: Restored SQLite Handling ---
is_sqlite = db_dialect == "sqlite"
if is_sqlite:
# aiosqlite fetchall/fetchmany might not work reliably with async iteration or limits in all cases
# Fetch all as mappings (dicts) directly
# Lambda okay if single line
def sync_lambda(sync_conn):
return list(sync_conn.execute(statement, query_params).mappings())
all_rows_mapped = await conn.run_sync(sync_lambda)
rows_raw = all_rows_mapped # Keep the dict list format
needs_limit = limit is not None and limit >= 0
if needs_limit:
rows_raw = rows_raw[:limit] # Apply limit in Python
else:
# Standard async fetching for other dialects
needs_limit = limit is not None and limit >= 0
if needs_limit:
fetched_rows = await res.fetchmany(limit) # Returns Row objects
rows_raw = fetched_rows
else:
fetched_rows = await res.fetchall() # Returns Row objects
rows_raw = fetched_rows
# --- END: Restored SQLite Handling ---
row_count = len(rows_raw) # Count based on fetched/limited rows
except Exception as fetch_err:
log_msg = f"Error fetching rows for {tool_name}/{action_name}: {fetch_err}"
logger.error(log_msg, exc_info=True)
query_preview = sql[:100] + "..."
details = {"query": query_preview}
raise ToolError(
f"Error fetching results: {fetch_err}", http_status_code=500, details=details
) from fetch_err
# Apply masking using _sql_mask_row which uses global rules
# Adjust masking based on fetched format
if is_sqlite:
# List comprehension okay
masked_rows_list = [_sql_mask_row(r) for r in rows_raw] # Already dicts
masked_rows = masked_rows_list
else:
# List comprehension okay
masked_rows_list = [
_sql_mask_row(r._mapping) for r in rows_raw
] # Convert Row objects
masked_rows = masked_rows_list
return cols, masked_rows, row_count
except (ProgrammingError, OperationalError) as db_err:
err_type_name = type(db_err).__name__
log_msg = f"Database execution error ({err_type_name}) for {tool_name}/{action_name} on {db_dialect}: {db_err}"
logger.error(log_msg, exc_info=True)
query_preview = sql[:100] + "..."
details = {"db_error": str(db_err), "query": query_preview}
raise ToolError(
f"Database Error: {db_err}", http_status_code=400, details=details
) from db_err
except SQLAlchemyError as sa_err:
err_type_name = type(sa_err).__name__
log_msg = f"SQLAlchemy error ({err_type_name}) for {tool_name}/{action_name} on {db_dialect}: {sa_err}"
logger.error(log_msg, exc_info=True)
query_preview = sql[:100] + "..."
details = {"sqlalchemy_error": str(sa_err), "query": query_preview}
raise ToolError(
f"SQLAlchemy Error: {sa_err}", http_status_code=500, details=details
) from sa_err
except Exception as e: # Catch other potential errors within _run
log_msg = f"Unexpected error within _run for {tool_name}/{action_name}: {e}"
logger.error(log_msg, exc_info=True)
raise ToolError(
f"Unexpected error during query execution step: {e}", http_status_code=500
) from e
try:
async with eng.connect() as conn:
# Run within timeout
# Call okay
run_coro = _run(conn)
cols_res, masked_rows_res, cnt_res = await asyncio.wait_for(run_coro, timeout=timeout)
cols = cols_res
masked_rows = masked_rows_res
cnt = cnt_res
latency = time.perf_counter() - start_time
if _Q_LAT:
# Chained call okay
_Q_LAT.labels(tool=tool_name, action=action_name, db=db_dialect).observe(latency)
log_msg = f"Execution successful for {tool_name}/{action_name}. Latency: {latency:.3f}s, Rows fetched: {cnt}"
logger.debug(log_msg)
return cols, masked_rows, cnt
except asyncio.TimeoutError:
log_msg = (
f"Query timeout ({timeout}s) exceeded for {tool_name}/{action_name} on {db_dialect}."
)
logger.warning(log_msg)
query_preview = sql[:100] + "..."
details = {"timeout": timeout, "query": query_preview}
raise ToolError(
f"Query timed out after {timeout} seconds", http_status_code=504, details=details
) from None
except ToolError:
# Re-raise known ToolErrors
raise
except Exception as e:
log_msg = f"Unexpected error during _sql_exec for {tool_name}/{action_name}: {e}"
logger.error(log_msg, exc_info=True)
details = {"error_type": type(e).__name__}
raise ToolError(
f"An unexpected error occurred: {e}", http_status_code=500, details=details
) from e
def _sql_export_rows(
cols: List[str],
rows: List[Dict[str, Any]],
export_format: str,
export_path: Optional[str] = None,
) -> Tuple[Any | None, str | None]:
"""Export query results helper."""
if not export_format:
return None, None
export_format_lower = export_format.lower()
supported_formats = ["pandas", "excel", "csv"]
if export_format_lower not in supported_formats:
details = {"format": export_format}
msg = f"Unsupported export format: '{export_format}'. Use 'pandas', 'excel', or 'csv'."
raise ToolInputError(msg, param_name="export.format", details=details)
if pd is None:
details = {"library": "pandas"}
msg = f"Pandas library is not installed. Cannot export to '{export_format_lower}'."
raise ToolError(msg, http_status_code=501, details=details)
try:
# Ternary okay
df = pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
logger.info(f"Created DataFrame with shape {df.shape} for export.")
except Exception as e:
logger.error(f"Error creating Pandas DataFrame: {e}", exc_info=True)
raise ToolError(f"Failed to create DataFrame for export: {e}", http_status_code=500) from e
if export_format_lower == "pandas":
logger.debug("Returning raw Pandas DataFrame.")
return df, None
final_path: str
temp_file_created = False
if export_path:
try:
# Chained calls okay
path_obj = Path(export_path)
path_expanded = path_obj.expanduser()
path_resolved = path_expanded.resolve()
parent_dir = path_resolved.parent
parent_dir.mkdir(parents=True, exist_ok=True)
final_path = str(path_resolved)
logger.info(f"Using specified export path: {final_path}")
except OSError as e:
details = {"path": export_path}
raise ToolError(
f"Cannot create directory for export path '{export_path}': {e}",
http_status_code=500,
details=details,
) from e
except Exception as e: # Catch other path errors
details = {"path": export_path}
msg = f"Invalid export path provided: {export_path}. Error: {e}"
raise ToolInputError(msg, param_name="export.path", details=details) from e
else:
# Ternary okay
suffix = ".xlsx" if export_format_lower == "excel" else ".csv"
try:
prefix = f"mcp_export_{export_format_lower}_"
fd, final_path_temp = tempfile.mkstemp(suffix=suffix, prefix=prefix)
final_path = final_path_temp
os.close(fd)
temp_file_created = True
logger.info(f"Created temporary file for export: {final_path}")
except Exception as e:
logger.error(f"Failed to create temporary file for export: {e}", exc_info=True)
raise ToolError(f"Failed to create temporary file: {e}", http_status_code=500) from e
try:
if export_format_lower == "excel":
df.to_excel(final_path, index=False, engine="xlsxwriter")
elif export_format_lower == "csv":
df.to_csv(final_path, index=False)
log_msg = f"Exported data to {export_format_lower.upper()} file: {final_path}"
logger.info(log_msg)
return None, final_path
except Exception as e:
log_msg = f"Error exporting DataFrame to {export_format_lower} file '{final_path}': {e}"
logger.error(log_msg, exc_info=True)
path_exists = Path(final_path).exists()
if temp_file_created and path_exists:
try:
Path(final_path).unlink()
except OSError:
logger.warning(f"Could not clean up temporary export file: {final_path}")
raise ToolError(
f"Failed to export data to {export_format_lower}: {e}", http_status_code=500
) from e
async def _sql_validate_df(df: Any, schema: Any | None) -> None:
"""Validate DataFrame against Pandera schema helper."""
if schema is None:
logger.debug("No Pandera schema provided for validation.")
return
if pa is None:
logger.warning("Pandera library not installed, skipping validation.")
return
is_pandas_df = pd is not None and isinstance(df, pd.DataFrame)
if not is_pandas_df:
logger.warning("Pandas DataFrame not available for validation.")
return
logger.info(f"Validating DataFrame (shape {df.shape}) against provided Pandera schema.")
try:
schema.validate(df, lazy=True)
logger.info("Pandera validation successful.")
except pa.errors.SchemaErrors as se:
# Ternary okay
error_details_df = se.failure_cases
can_dict = hasattr(error_details_df, "to_dict")
error_details = (
error_details_df.to_dict(orient="records") if can_dict else str(error_details_df)
)
# Ternary okay
can_len = hasattr(error_details_df, "__len__")
error_count = len(error_details_df) if can_len else "multiple"
log_msg = f"Pandera validation failed with {error_count} errors. Details: {error_details}"
logger.warning(log_msg)
# Break down error message construction
error_msg_base = f"Pandera validation failed ({error_count} errors):\n"
error_msg_lines = []
error_details_list = error_details if isinstance(error_details, list) else []
errors_to_show = error_details_list[:5]
for err in errors_to_show:
col = err.get("column", "N/A")
check = err.get("check", "N/A")
index = err.get("index", "N/A")
fail_case_raw = err.get("failure_case", "N/A")
fail_case_str = str(fail_case_raw)[:50]
line = f"- Column '{col}': {check} failed for index {index}. Data: {fail_case_str}..."
error_msg_lines.append(line)
error_msg = error_msg_base + "\n".join(error_msg_lines)
num_errors = error_count if isinstance(error_count, int) else 0
if num_errors > 5:
more_errors_count = num_errors - 5
error_msg += f"\n... and {more_errors_count} more errors."
validation_errors = error_details # Pass the original structure
details = {"error_type": "VALIDATION_ERROR"}
raise ToolError(
error_msg, http_status_code=422, validation_errors=validation_errors, details=details
) from se
except Exception as e:
logger.error(f"Unexpected error during Pandera validation: {e}", exc_info=True)
raise ToolError(
f"An unexpected error occurred during schema validation: {e}", http_status_code=500
) from e
async def _sql_convert_nl_to_sql(
connection_id: str,
natural_language: str,
confidence_threshold: float = 0.6,
user_id: Optional[str] = None, # Added for lineage
session_id: Optional[str] = None, # Added for lineage
) -> Dict[str, Any]:
"""Helper method to convert natural language to SQL."""
# (Implementation largely the same, uses _sql_get_engine, _sql_check_safe, global state _SCHEMA_VERSIONS, _LINEAGE)
nl_preview = natural_language[:100]
logger.info(f"Converting NL to SQL for connection {connection_id}. Query: '{nl_preview}...'")
eng = await _sql_get_engine(connection_id)
def _get_schema_fingerprint_sync(sync_conn) -> str:
# (Schema fingerprint sync helper implementation is the same)
try:
sync_inspector = sa_inspect(sync_conn)
tbls = []
schema_names = sync_inspector.get_schema_names()
default_schema = sync_inspector.default_schema_name
# List comprehension okay
other_schemas = [s for s in schema_names if s != default_schema]
schemas_to_inspect = [default_schema] + other_schemas
for schema_name in schemas_to_inspect:
# Ternary okay
prefix = f"{schema_name}." if schema_name and schema_name != default_schema else ""
table_names_in_schema = sync_inspector.get_table_names(schema=schema_name)
for t in table_names_in_schema:
try:
cols = sync_inspector.get_columns(t, schema=schema_name)
# List comprehension okay
col_defs = [f"{c['name']}:{str(c['type']).split('(')[0]}" for c in cols]
col_defs_str = ",".join(col_defs)
tbl_def = f"{prefix}{t}({col_defs_str})"
tbls.append(tbl_def)
except Exception as col_err:
logger.warning(f"Could not get columns for table {prefix}{t}: {col_err}")
tbl_def_err = f"{prefix}{t}(...)"
tbls.append(tbl_def_err)
# Call okay
fp = "; ".join(sorted(tbls))
if not fp:
logger.warning("Schema fingerprint generation resulted in empty string.")
return "Error: Could not retrieve schema."
return fp
except Exception as e:
logger.error(f"Error in _get_schema_fingerprint_sync: {e}", exc_info=True)
return "Error: Could not retrieve schema."
async def _get_schema_fingerprint(conn: AsyncConnection) -> str:
logger.debug("Generating schema fingerprint for NL->SQL...")
try:
# Lambda okay
def sync_func(sync_conn):
return _get_schema_fingerprint_sync(sync_conn)
fingerprint = await conn.run_sync(sync_func)
return fingerprint
except Exception as e:
logger.error(f"Error generating schema fingerprint: {e}", exc_info=True)
return "Error: Could not retrieve schema."
async with eng.connect() as conn:
schema_fingerprint = await _get_schema_fingerprint(conn)
# Multi-line string assignment okay
prompt = (
"You are a highly specialized AI assistant that translates natural language questions into SQL queries.\n"
"You must adhere STRICTLY to the following rules:\n"
"1. Generate only a SINGLE, executable SQL query for the given database schema and question.\n"
"2. Use the exact table and column names provided in the schema fingerprint.\n"
"3. Do NOT generate any explanatory text, comments, or markdown formatting.\n"
"4. The output MUST be a valid JSON object containing two keys: 'sql' (the generated SQL query as a string) and 'confidence' (a float between 0.0 and 1.0 indicating your confidence in the generated SQL).\n"
"5. If the question cannot be answered from the schema or is ambiguous, set confidence to 0.0 and provide a minimal, safe query like 'SELECT 1;' in the 'sql' field.\n"
"6. Prioritize safety: Avoid generating queries that could modify data (UPDATE, INSERT, DELETE, DROP, etc.). Generate SELECT statements ONLY.\n\n" # Stricter rule
f"Database Schema Fingerprint:\n```\n{schema_fingerprint}\n```\n\n"
f"Natural Language Question:\n```\n{natural_language}\n```\n\n"
"JSON Output:"
)
try:
logger.debug("Sending prompt to LLM for NL->SQL conversion.")
# Call okay
completion_result = await generate_completion(
prompt=prompt, max_tokens=350, temperature=0.2
)
llm_response_dict = completion_result
# Ternary okay
is_dict_response = isinstance(llm_response_dict, dict)
llm_response = llm_response_dict.get("text", "") if is_dict_response else ""
llm_response_preview = llm_response[:300]
logger.debug(f"LLM Response received: {llm_response_preview}...")
if not llm_response:
raise ToolError("LLM returned empty response for NL->SQL.", http_status_code=502)
except Exception as llm_err:
logger.error(f"LLM completion failed for NL->SQL: {llm_err}", exc_info=True)
details = {"error_type": "LLM_ERROR"}
raise ToolError(
f"Failed to get response from LLM: {llm_err}", http_status_code=502, details=details
) from llm_err
try:
data = {}
try:
# Try parsing the whole response as JSON first
data = json.loads(llm_response)
except json.JSONDecodeError as e:
# If that fails, look for a JSON block within the text
# Call okay
search_regex = r"\{.*\}"
search_flags = re.DOTALL | re.MULTILINE
json_match = re.search(search_regex, llm_response, search_flags)
if not json_match:
raise ValueError("No JSON object found in the LLM response.") from e
json_str = json_match.group(0)
data = json.loads(json_str)
is_dict_data = isinstance(data, dict)
has_sql = "sql" in data
has_confidence = "confidence" in data
if not is_dict_data or not has_sql or not has_confidence:
raise ValueError("LLM response JSON is missing required keys ('sql', 'confidence').")
sql = data["sql"]
conf_raw = data["confidence"]
conf = float(conf_raw)
is_sql_str = isinstance(sql, str)
is_conf_valid = 0.0 <= conf <= 1.0
if not is_sql_str or not is_conf_valid:
raise ValueError("LLM response has invalid types for 'sql' or 'confidence'.")
sql_preview = sql[:150]
logger.info(f"LLM generated SQL with confidence {conf:.2f}: {sql_preview}...")
except (json.JSONDecodeError, ValueError, TypeError) as e:
response_preview = str(llm_response)[:200]
error_detail = (
f"LLM returned invalid or malformed JSON: {e}. Response: '{response_preview}...'"
)
logger.error(error_detail)
details = {"error_type": "LLM_RESPONSE_INVALID"}
raise ToolError(error_detail, http_status_code=500, details=details) from e
is_below_threshold = conf < confidence_threshold
if is_below_threshold:
nl_query_preview = natural_language
low_conf_msg = f"LLM confidence ({conf:.2f}) is below the required threshold ({confidence_threshold}). NL Query: '{nl_query_preview}'"
logger.warning(low_conf_msg)
details = {"error_type": "LOW_CONFIDENCE"}
raise ToolError(
low_conf_msg, http_status_code=400, generated_sql=sql, confidence=conf, details=details
) from None
try:
_sql_check_safe(sql, read_only=True) # Enforce read-only for generated SQL
# Call okay
sql_upper = sql.upper()
sql_stripped = sql_upper.lstrip()
is_valid_start = sql_stripped.startswith(("SELECT", "WITH"))
if not is_valid_start:
details = {"error_type": "INVALID_GENERATED_SQL"}
raise ToolError(
"Generated query does not appear to be a valid SELECT statement.",
http_status_code=400,
details=details,
)
# Basic table check (optional, as before)
except ToolInputError as safety_err:
logger.error(f"Generated SQL failed safety check: {safety_err}. SQL: {sql}")
details = {"error_type": "SAFETY_VIOLATION"}
raise ToolError(
f"Generated SQL failed validation: {safety_err}",
http_status_code=400,
generated_sql=sql,
confidence=conf,
details=details,
) from safety_err
result_dict = {"sql": sql, "confidence": conf}
return result_dict
# =============================================================================
# Public Tool Functions (Standalone replacements for SQLTool methods)
# =============================================================================
@with_tool_metrics
@with_error_handling
async def manage_database(
action: str,
connection_string: Optional[str] = None,
connection_id: Optional[str] = None,
echo: bool = False,
user_id: Optional[str] = None,
session_id: Optional[str] = None,
ctx: Optional[Dict] = None, # Added ctx for potential future use
**options: Any,
) -> Dict[str, Any]:
"""
Unified database connection management tool.
Args:
action: The action to perform: "connect", "disconnect", "test", or "status".
connection_string: Database connection string or secrets:// reference. (Required for "connect").
connection_id: An existing connection ID (Required for "disconnect", "test"). Can be provided for "connect" to suggest an ID.
echo: Enable SQLAlchemy engine logging (For "connect" action, default: False).
user_id: Optional user identifier for audit logging.
session_id: Optional session identifier for audit logging.
ctx: Optional context from MCP server (not used currently).
**options: Additional options:
- For "connect": Passed directly to SQLAlchemy's `create_async_engine`.
- Can include custom audit context.
Returns:
Dict with action results and metadata. Varies based on action.
"""
tool_name = "manage_database"
db_dialect = "unknown"
# Dict comprehension okay
audit_extras_all = {k: v for k, v in options.items()}
audit_extras = {k: v for k, v in audit_extras_all.items() if k not in ["echo"]}
try:
if action == "connect":
if not connection_string:
raise ToolInputError(
"connection_string is required for 'connect'", param_name="connection_string"
)
# Ternary okay
cid = connection_id or str(uuid.uuid4())
logger.info(f"Attempting to connect with connection_id: {cid}")
resolved_conn_str = _sql_resolve_conn(connection_string)
url, db_type = _sql_driver_url(resolved_conn_str)
db_dialect = db_type # Update dialect for potential error logging
pool_opts = _sql_auto_pool(db_type)
# Dict unpacking okay
engine_opts = {**pool_opts, **options}
# Dict comprehension okay
log_opts = {k: v for k, v in engine_opts.items() if k != "password"}
logger.debug(f"Creating engine for {db_type} with options: {log_opts}")
connect_args = engine_opts.pop("connect_args", {})
# Ternary okay
exec_opts_base = {"async_execution": True} if db_type == "snowflake" else {}
# Pass other engine options directly
execution_options = {**exec_opts_base, **engine_opts.pop("execution_options", {})}
# Separate create_async_engine call
eng = create_async_engine(
url,
echo=echo,
connect_args=connect_args,
execution_options=execution_options,
**engine_opts, # Pass remaining options like pool settings
)
try:
# Ternary okay
test_sql = "SELECT CURRENT_TIMESTAMP" if db_type != "sqlite" else "SELECT 1"
# Call okay
await _sql_exec(
eng,
test_sql,
None,
limit=1,
tool_name=tool_name,
action_name="connect_test",
timeout=15,
)
logger.info(f"Connection test successful for {cid} ({db_type}).")
except ToolError as test_err:
logger.error(f"Connection test failed for {cid} ({db_type}): {test_err}")
await eng.dispose()
# Get details okay
err_details = getattr(test_err, "details", None)
raise ToolError(
f"Connection test failed: {test_err}", http_status_code=400, details=err_details
) from test_err
except Exception as e:
logger.error(
f"Unexpected error during connection test for {cid} ({db_type}): {e}",
exc_info=True,
)
await eng.dispose()
raise ToolError(
f"Unexpected error during connection test: {e}", http_status_code=500
) from e
await _connection_manager.add_connection(cid, eng)
# Call okay
await _sql_audit(
tool_name=tool_name,
action="connect",
connection_id=cid,
sql=None,
tables=None,
row_count=None,
success=True,
error=None,
user_id=user_id,
session_id=session_id,
database_type=db_type,
echo=echo,
**audit_extras,
)
# Return dict okay
return {
"action": "connect",
"connection_id": cid,
"database_type": db_type,
"success": True,
}
elif action == "disconnect":
if not connection_id:
raise ToolInputError(
"connection_id is required for 'disconnect'", param_name="connection_id"
)
logger.info(f"Attempting to disconnect connection_id: {connection_id}")
db_dialect_for_audit = "unknown" # Default if engine retrieval fails
try:
# Needs await before get_connection
engine_to_close = await _connection_manager.get_connection(connection_id)
db_dialect_for_audit = engine_to_close.dialect.name
except ToolInputError:
# This error means connection_id wasn't found by get_connection
logger.warning(f"Disconnect requested for unknown connection_id: {connection_id}")
# Call okay
await _sql_audit(
tool_name=tool_name,
action="disconnect",
connection_id=connection_id,
sql=None,
tables=None,
row_count=None,
success=False,
error="Connection ID not found",
user_id=user_id,
session_id=session_id,
**audit_extras,
)
# Return dict okay
return {
"action": "disconnect",
"connection_id": connection_id,
"success": False,
"message": "Connection ID not found",
}
except Exception as e:
# Catch other errors during engine retrieval itself
logger.error(f"Error retrieving engine for disconnect ({connection_id}): {e}")
# Proceed to attempt close, but audit will likely show failure or non-existence
# Attempt closing even if retrieval had issues (it might have been removed between check and close)
success = await _connection_manager.close_connection(connection_id)
# Ternary okay
error_msg = None if success else "Failed to close or already closed/not found"
# Call okay
await _sql_audit(
tool_name=tool_name,
action="disconnect",
connection_id=connection_id,
sql=None,
tables=None,
row_count=None,
success=success,
error=error_msg,
user_id=user_id,
session_id=session_id,
database_type=db_dialect_for_audit,
**audit_extras,
)
# Return dict okay
return {"action": "disconnect", "connection_id": connection_id, "success": success}
elif action == "test":
if not connection_id:
raise ToolInputError(
"connection_id is required for 'test'", param_name="connection_id"
)
logger.info(f"Testing connection_id: {connection_id}")
eng = await _sql_get_engine(connection_id)
db_dialect = eng.dialect.name # Now dialect is known for sure
t0 = time.perf_counter()
# Ternary conditions okay
vsql = (
"SELECT sqlite_version()"
if db_dialect == "sqlite"
else "SELECT CURRENT_VERSION()"
if db_dialect == "snowflake"
else "SELECT version()"
)
# Call okay
cols, rows, _ = await _sql_exec(
eng, vsql, None, limit=1, tool_name=tool_name, action_name="test", timeout=10
)
latency = time.perf_counter() - t0
# Ternary okay
has_rows_and_cols = rows and cols
version_info = rows[0].get(cols[0], "N/A") if has_rows_and_cols else "N/A"
log_msg = f"Connection test successful for {connection_id}. Version: {version_info}, Latency: {latency:.3f}s"
logger.info(log_msg)
# Return dict okay
return {
"action": "test",
"connection_id": connection_id,
"response_time_seconds": round(latency, 3),
"version": version_info,
"database_type": db_dialect,
"success": True,
}
elif action == "status":
logger.info("Retrieving connection status.")
connections_info = {}
current_time = time.time()
# Access connections safely using async with lock if needed, or make copy
conn_items = []
async with _connection_manager._lock: # Access lock directly for iteration safety
# Call okay
conn_items = list(_connection_manager.connections.items())
for conn_id, (eng, last_access) in conn_items:
try:
url_display_raw = str(eng.url)
parsed_url = make_url(url_display_raw)
url_display = url_display_raw # Default
if parsed_url.password:
# Call okay
url_masked = parsed_url.set(password="***")
url_display = str(url_masked)
# Break down dict assignment
conn_info_dict = {}
conn_info_dict["url_summary"] = url_display
conn_info_dict["dialect"] = eng.dialect.name
last_access_dt = dt.datetime.fromtimestamp(last_access)
conn_info_dict["last_accessed"] = last_access_dt.isoformat()
idle_seconds = current_time - last_access
conn_info_dict["idle_time_seconds"] = round(idle_seconds, 1)
connections_info[conn_id] = conn_info_dict
except Exception as status_err:
logger.error(f"Error retrieving status for connection {conn_id}: {status_err}")
connections_info[conn_id] = {"error": str(status_err)}
# Return dict okay
return {
"action": "status",
"active_connections_count": len(connections_info),
"connections": connections_info,
"cleanup_interval_seconds": _connection_manager.cleanup_interval,
"success": True,
}
else:
logger.error(f"Invalid action specified for manage_database: {action}")
details = {"action": action}
msg = f"Unknown action: '{action}'. Valid actions: connect, disconnect, test, status"
raise ToolInputError(msg, param_name="action", details=details)
except ToolInputError as tie:
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action,
connection_id=connection_id,
sql=None,
tables=None,
row_count=None,
success=False,
error=str(tie),
user_id=user_id,
session_id=session_id,
database_type=db_dialect,
**audit_extras,
)
raise tie
except ToolError as te:
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action,
connection_id=connection_id,
sql=None,
tables=None,
row_count=None,
success=False,
error=str(te),
user_id=user_id,
session_id=session_id,
database_type=db_dialect,
**audit_extras,
)
raise te
except Exception as e:
log_msg = f"Unexpected error in manage_database (action: {action}): {e}"
logger.error(log_msg, exc_info=True)
error_str = f"Unexpected error: {e}"
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action,
connection_id=connection_id,
sql=None,
tables=None,
row_count=None,
success=False,
error=error_str,
user_id=user_id,
session_id=session_id,
database_type=db_dialect,
**audit_extras,
)
raise ToolError(
f"An unexpected error occurred in manage_database: {e}", http_status_code=500
) from e
@with_tool_metrics
@with_error_handling
async def execute_sql(
connection_id: str,
query: Optional[str] = None,
natural_language: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
pagination: Optional[Dict[str, int]] = None,
read_only: bool = True,
export: Optional[Dict[str, Any]] = None,
timeout: float = 60.0,
validate_schema: Optional[Any] = None,
max_rows: Optional[int] = 1000,
confidence_threshold: float = 0.6,
user_id: Optional[str] = None,
session_id: Optional[str] = None,
ctx: Optional[Dict] = None, # Added ctx
**options: Any,
) -> Dict[str, Any]:
"""
Unified SQL query execution tool.
Handles direct SQL execution, NL-to-SQL conversion, pagination,
result masking, safety checks, validation, and export.
Args:
connection_id: The ID of the database connection to use.
query: The SQL query string to execute. (Use instead of natural_language).
natural_language: A natural language question to convert to SQL. (Use instead of query).
parameters: Dictionary of parameters for parameterized queries.
pagination: Dict with "page" (>=1) and "page_size" (>=1) for paginated results.
Cannot be used with max_rows clipping if the dialect requires LIMIT/OFFSET.
read_only: If True (default), enforces safety checks against write operations (UPDATE, DELETE, etc.). Set to False only if writes are explicitly intended and allowed.
export: Dictionary with "format" ('pandas', 'excel', 'csv') and optional "path" (string) for exporting results.
timeout: Maximum execution time in seconds (default: 60.0).
validate_schema: A Pandera schema object to validate the results DataFrame against.
max_rows: Maximum number of rows to return in the result (default: 1000). Set to None or -1 for unlimited (potentially dangerous).
confidence_threshold: Minimum confidence score (0.0-1.0) required from the LLM for NL-to-SQL conversion (default: 0.6).
user_id: Optional user identifier for audit logging.
session_id: Optional session identifier for audit logging.
ctx: Optional context from MCP server.
**options: Additional options for audit logging or future extensions.
Returns:
A dictionary containing:
- columns (List[str]): List of column names.
- rows (List[Dict[str, Any]]): List of data rows (masked).
- row_count (int): Number of rows returned in this batch/page.
- truncated (bool): True if max_rows limited the results.
- pagination (Optional[Dict]): Info about the current page if pagination was used.
- generated_sql (Optional[str]): The SQL query generated from natural language, if applicable.
- confidence (Optional[float]): The confidence score from the NL-to-SQL conversion, if applicable.
- validation_status (Optional[str]): 'success', 'failed', 'skipped'.
- validation_errors (Optional[Any]): Details if validation failed.
- export_status (Optional[str]): Status message if export was attempted.
- <format>_path (Optional[str]): Path to the exported file if export to file was successful.
- dataframe (Optional[pd.DataFrame]): The raw Pandas DataFrame if export format was 'pandas'.
- success (bool): Always True if no exception was raised.
"""
tool_name = "execute_sql"
action_name = "query" # Default, may change
original_query_input = query # Keep track of original SQL input
original_nl_input = natural_language # Keep track of NL input
generated_sql = None
confidence = None
final_query: str
# Dict unpacking okay
final_params = parameters or {}
result: Dict[str, Any] = {}
tables: List[str] = []
# Dict unpacking okay
audit_extras = {**options}
try:
# 1. Determine Query
use_nl = natural_language and not query
use_sql = query and not natural_language
is_ambiguous = natural_language and query
no_input = not natural_language and not query
if is_ambiguous:
msg = "Provide either 'query' or 'natural_language', not both."
raise ToolInputError(msg, param_name="query/natural_language")
if no_input:
msg = "Either 'query' or 'natural_language' must be provided."
raise ToolInputError(msg, param_name="query/natural_language")
if use_nl:
action_name = "nl_to_sql_exec"
nl_preview = natural_language[:100]
log_msg = (
f"Received natural language query for connection {connection_id}: '{nl_preview}...'"
)
logger.info(log_msg)
try:
# Pass user_id/session_id to NL converter for lineage/audit trail consistency if needed
# Call okay
nl_result = await _sql_convert_nl_to_sql(
connection_id, natural_language, confidence_threshold, user_id, session_id
)
final_query = nl_result["sql"]
generated_sql = final_query
confidence = nl_result["confidence"]
# original_query remains None, original_nl_input has the NL
audit_extras["generated_sql"] = generated_sql
audit_extras["confidence"] = confidence
query_preview = final_query[:150]
log_msg = f"Successfully converted NL to SQL (Confidence: {confidence:.2f}): {query_preview}..."
logger.info(log_msg)
read_only = True # Ensure read-only for generated SQL
except ToolError as nl_err:
# Audit NL failure
await _sql_audit(
tool_name=tool_name,
action="nl_to_sql_fail",
connection_id=connection_id,
sql=natural_language, # Log the NL query that failed
tables=None,
row_count=None,
success=False,
error=str(nl_err),
user_id=user_id,
session_id=session_id,
**audit_extras,
)
raise nl_err # Re-raise the error
elif use_sql:
# Action name remains 'query'
final_query = query
query_preview = final_query[:150]
logger.info(f"Executing direct SQL query on {connection_id}: {query_preview}...")
# original_query_input has the SQL, original_nl_input is None
# else case already handled by initial checks
# 2. Check Safety
_sql_check_safe(final_query, read_only)
tables = _sql_extract_tables(final_query)
logger.debug(f"Query targets tables: {tables}")
# 3. Get Engine
eng = await _sql_get_engine(connection_id)
# 4. Handle Pagination or Standard Execution
if pagination:
action_name = "query_paginated"
page = pagination.get("page", 1)
page_size = pagination.get("page_size", 100)
is_page_valid = isinstance(page, int) and page >= 1
is_page_size_valid = isinstance(page_size, int) and page_size >= 1
if not is_page_valid:
raise ToolInputError(
"Pagination 'page' must be an integer >= 1.", param_name="pagination.page"
)
if not is_page_size_valid:
raise ToolInputError(
"Pagination 'page_size' must be an integer >= 1.",
param_name="pagination.page_size",
)
offset = (page - 1) * page_size
db_dialect = eng.dialect.name
paginated_query: str
if db_dialect == "sqlserver":
query_lower = final_query.lower()
has_order_by = "order by" in query_lower
if not has_order_by:
raise ToolInputError(
"SQL Server pagination requires an ORDER BY clause in the query.",
param_name="query",
)
paginated_query = (
f"{final_query} OFFSET :_page_offset ROWS FETCH NEXT :_page_size ROWS ONLY"
)
elif db_dialect == "oracle":
paginated_query = (
f"{final_query} OFFSET :_page_offset ROWS FETCH NEXT :_page_size ROWS ONLY"
)
else: # Default LIMIT/OFFSET for others (MySQL, PostgreSQL, SQLite)
paginated_query = f"{final_query} LIMIT :_page_size OFFSET :_page_offset"
# Fetch one extra row to check for next page
fetch_size = page_size + 1
# Dict unpacking okay
paginated_params = {**final_params, "_page_size": fetch_size, "_page_offset": offset}
log_msg = (
f"Executing paginated query (Page: {page}, Size: {page_size}): {paginated_query}"
)
logger.debug(log_msg)
# Call okay
cols, rows_with_extra, fetched_count_paged = await _sql_exec(
eng,
paginated_query,
paginated_params,
limit=None, # Limit is applied in SQL for pagination
tool_name=tool_name,
action_name=action_name,
timeout=timeout,
)
# Check if more rows exist than requested page size
has_next_page = len(rows_with_extra) > page_size
returned_rows = rows_with_extra[:page_size]
returned_row_count = len(returned_rows)
# Build result dict piece by piece
pagination_info = {}
pagination_info["page"] = page
pagination_info["page_size"] = page_size
pagination_info["has_next_page"] = has_next_page
pagination_info["has_previous_page"] = page > 1
result = {}
result["columns"] = cols
result["rows"] = returned_rows
result["row_count"] = returned_row_count
result["pagination"] = pagination_info
result["truncated"] = False # Not truncated by max_rows in pagination mode
result["success"] = True
else: # Standard execution (no pagination dict)
action_name = "query_standard"
# Ternary okay
needs_limit = max_rows is not None and max_rows >= 0
fetch_limit = (max_rows + 1) if needs_limit else None
query_preview = final_query[:150]
log_msg = f"Executing standard query (Max rows: {max_rows}): {query_preview}..."
logger.debug(log_msg)
# Call okay
cols, rows_maybe_extra, fetched_count = await _sql_exec(
eng,
final_query,
final_params,
limit=fetch_limit, # Use fetch_limit (max_rows + 1 or None)
tool_name=tool_name,
action_name=action_name,
timeout=timeout,
)
# Determine truncation based on fetch_limit
truncated = fetch_limit is not None and fetched_count >= fetch_limit
# Apply actual max_rows limit to returned data
# Ternary okay
returned_rows = rows_maybe_extra[:max_rows] if needs_limit else rows_maybe_extra
returned_row_count = len(returned_rows)
# Build result dict piece by piece
result = {}
result["columns"] = cols
result["rows"] = returned_rows
result["row_count"] = returned_row_count
result["truncated"] = truncated
result["success"] = True
# No pagination key in standard mode
# Add NL->SQL info if applicable
if generated_sql:
result["generated_sql"] = generated_sql
result["confidence"] = confidence
# 5. Handle Validation
if validate_schema:
temp_df = None
validation_status = "skipped (unknown reason)"
validation_errors = None
if pd:
try:
# Ternary okay
df_data = result["rows"]
df_cols = result["columns"]
temp_df = (
pd.DataFrame(df_data, columns=df_cols)
if df_data
else pd.DataFrame(columns=df_cols)
)
try:
# Call okay
await _sql_validate_df(temp_df, validate_schema)
validation_status = "success"
logger.info("Pandera validation passed.")
except ToolError as val_err:
logger.warning(f"Pandera validation failed: {val_err}")
validation_status = "failed"
# Get validation errors okay
validation_errors = getattr(val_err, "validation_errors", str(val_err))
except Exception as df_err:
logger.error(f"Error creating DataFrame for validation: {df_err}")
validation_status = f"skipped (Failed to create DataFrame: {df_err})"
else:
logger.warning("Pandas not installed, skipping Pandera validation.")
validation_status = "skipped (Pandas not installed)"
result["validation_status"] = validation_status
if validation_errors:
result["validation_errors"] = validation_errors
# 6. Handle Export
export_requested = export and export.get("format")
if export_requested:
export_format = export["format"] # Keep original case for path key
export_format_lower = export_format.lower()
req_path = export.get("path")
log_msg = f"Export requested: Format={export_format}, Path={req_path or 'Temporary'}"
logger.info(log_msg)
export_status = "failed (unknown reason)"
try:
# Call okay
dataframe, export_path = _sql_export_rows(
result["columns"], result["rows"], export_format_lower, req_path
)
export_status = "success"
if dataframe is not None: # Only if format was 'pandas'
result["dataframe"] = dataframe
if export_path: # If file was created
path_key = f"{export_format_lower}_path" # Use lowercase format for key
result[path_key] = export_path
log_msg = f"Export successful. Format: {export_format}, Path: {export_path or 'In-memory DataFrame'}"
logger.info(log_msg)
audit_extras["export_format"] = export_format
audit_extras["export_path"] = export_path
except (ToolError, ToolInputError) as export_err:
logger.error(f"Export failed: {export_err}")
export_status = f"Failed: {export_err}"
result["export_status"] = export_status
# 7. Audit Success
# Determine which query to log based on input
audit_sql = original_nl_input if use_nl else original_query_input
audit_row_count = result.get("row_count", 0)
audit_val_status = result.get("validation_status")
audit_exp_status = result.get("export_status", "not requested")
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action_name,
connection_id=connection_id,
sql=audit_sql,
tables=tables,
row_count=audit_row_count,
success=True,
error=None,
user_id=user_id,
session_id=session_id,
read_only=read_only,
pagination_used=bool(pagination),
validation_status=audit_val_status,
export_status=audit_exp_status,
**audit_extras,
)
return result
except ToolInputError as tie:
# Audit failure, use original inputs for logging context
audit_sql = original_nl_input if original_nl_input else original_query_input
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action_name + "_fail",
connection_id=connection_id,
sql=audit_sql,
tables=tables,
row_count=0,
success=False,
error=str(tie),
user_id=user_id,
session_id=session_id,
**audit_extras,
)
raise tie
except ToolError as te:
# Audit failure
audit_sql = original_nl_input if original_nl_input else original_query_input
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action_name + "_fail",
connection_id=connection_id,
sql=audit_sql,
tables=tables,
row_count=0,
success=False,
error=str(te),
user_id=user_id,
session_id=session_id,
**audit_extras,
)
raise te
except Exception as e:
log_msg = f"Unexpected error in execute_sql (action: {action_name}): {e}"
logger.error(log_msg, exc_info=True)
# Audit failure
audit_sql = original_nl_input if original_nl_input else original_query_input
error_str = f"Unexpected error: {e}"
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action_name + "_fail",
connection_id=connection_id,
sql=audit_sql,
tables=tables,
row_count=0,
success=False,
error=error_str,
user_id=user_id,
session_id=session_id,
**audit_extras,
)
raise ToolError(
f"An unexpected error occurred during SQL execution: {e}", http_status_code=500
) from e
@with_tool_metrics
@with_error_handling
async def explore_database(
connection_id: str,
action: str,
table_name: Optional[str] = None,
column_name: Optional[str] = None,
schema_name: Optional[str] = None,
user_id: Optional[str] = None,
session_id: Optional[str] = None,
ctx: Optional[Dict] = None, # Added ctx
**options: Any,
) -> Dict[str, Any]:
"""
Unified database schema exploration and documentation tool.
Performs actions like listing schemas, tables, views, columns,
getting table/column details, finding relationships, and generating documentation.
Args:
connection_id: The ID of the database connection to use.
action: The exploration action:
- "schema": Get full schema details (tables, views, columns, relationships).
- "table": Get details for a specific table (columns, PK, FKs, indexes, optionally sample data/stats). Requires `table_name`.
- "column": Get statistics for a specific column (nulls, distinct, optionally histogram). Requires `table_name` and `column_name`.
- "relationships": Find related tables via foreign keys up to a certain depth. Requires `table_name`.
- "documentation": Generate schema documentation (markdown or JSON).
table_name: Name of the table for 'table', 'column', 'relationships' actions.
column_name: Name of the column for 'column' action.
schema_name: Specific schema to inspect (if supported by dialect and needed). Defaults to connection's default schema.
user_id: Optional user identifier for audit logging.
session_id: Optional session identifier for audit logging.
ctx: Optional context from MCP server.
**options: Additional options depending on the action:
- schema: include_indexes (bool), include_foreign_keys (bool), detailed (bool)
- table: include_sample_data (bool), sample_size (int), include_statistics (bool)
- column: histogram (bool), num_buckets (int)
- relationships: depth (int)
- documentation: output_format ('markdown'|'json'), include_indexes(bool), include_foreign_keys(bool)
Returns:
A dictionary containing the results of the exploration action and a 'success' flag.
Structure varies significantly based on the action.
"""
tool_name = "explore_database"
# Break down audit_extras creation
audit_extras = {}
audit_extras.update(options)
audit_extras["table_name"] = table_name
audit_extras["column_name"] = column_name
audit_extras["schema_name"] = schema_name
try:
log_msg = f"Exploring database for connection {connection_id}. Action: {action}, Table: {table_name}, Column: {column_name}, Schema: {schema_name}"
logger.info(log_msg)
eng = await _sql_get_engine(connection_id)
db_dialect = eng.dialect.name
audit_extras["database_type"] = db_dialect
# Define sync inspection helper (runs within connect block)
def _run_sync_inspection(
inspector_target: Union[AsyncConnection, AsyncEngine], func_to_run: callable
):
# Call okay
sync_inspector = sa_inspect(inspector_target)
return func_to_run(sync_inspector)
async with eng.connect() as conn:
# --- Action: schema ---
if action == "schema":
include_indexes = options.get("include_indexes", True)
include_foreign_keys = options.get("include_foreign_keys", True)
detailed = options.get("detailed", False)
filter_schema = schema_name # Use provided schema or None for default
def _get_full_schema(sync_conn) -> Dict[str, Any]:
# Separate inspector and target schema assignment
insp = sa_inspect(sync_conn)
target_schema = filter_schema or getattr(insp, "default_schema_name", None)
log_msg = f"Inspecting schema: {target_schema or 'Default'}. Detailed: {detailed}, Indexes: {include_indexes}, FKs: {include_foreign_keys}"
logger.info(log_msg)
tables_data: List[Dict[str, Any]] = []
views_data: List[Dict[str, Any]] = []
relationships: List[Dict[str, Any]] = []
try:
table_names = insp.get_table_names(schema=target_schema)
view_names = insp.get_view_names(schema=target_schema)
except Exception as inspect_err:
msg = f"Failed to list tables/views for schema '{target_schema}': {inspect_err}"
raise ToolError(msg, http_status_code=500) from inspect_err
for tbl_name in table_names:
try:
# Build t_info dict step-by-step
t_info: Dict[str, Any] = {}
t_info["name"] = tbl_name
t_info["columns"] = []
if target_schema:
t_info["schema"] = target_schema
columns_raw = insp.get_columns(tbl_name, schema=target_schema)
for c in columns_raw:
# Build col_info dict step-by-step
col_info = {}
col_info["name"] = c["name"]
col_info["type"] = str(c["type"])
col_info["nullable"] = c["nullable"]
col_info["primary_key"] = bool(c.get("primary_key"))
if detailed:
col_info["default"] = c.get("default")
col_info["comment"] = c.get("comment")
col_info["autoincrement"] = c.get("autoincrement", "auto")
t_info["columns"].append(col_info)
if include_indexes:
try:
idxs_raw = insp.get_indexes(tbl_name, schema=target_schema)
# List comprehension okay
t_info["indexes"] = [
{
"name": i["name"],
"columns": i["column_names"],
"unique": i.get("unique", False),
}
for i in idxs_raw
]
except Exception as idx_err:
logger.warning(
f"Could not retrieve indexes for table {tbl_name}: {idx_err}"
)
t_info["indexes"] = []
if include_foreign_keys:
try:
fks_raw = insp.get_foreign_keys(tbl_name, schema=target_schema)
if fks_raw:
t_info["foreign_keys"] = []
for fk in fks_raw:
# Build fk_info dict step-by-step
fk_info = {}
fk_info["name"] = fk.get("name")
fk_info["constrained_columns"] = fk[
"constrained_columns"
]
fk_info["referred_schema"] = fk.get("referred_schema")
fk_info["referred_table"] = fk["referred_table"]
fk_info["referred_columns"] = fk["referred_columns"]
t_info["foreign_keys"].append(fk_info)
# Build relationship dict step-by-step
rel_info = {}
rel_info["source_schema"] = target_schema
rel_info["source_table"] = tbl_name
rel_info["source_columns"] = fk["constrained_columns"]
rel_info["target_schema"] = fk.get("referred_schema")
rel_info["target_table"] = fk["referred_table"]
rel_info["target_columns"] = fk["referred_columns"]
relationships.append(rel_info)
except Exception as fk_err:
logger.warning(
f"Could not retrieve foreign keys for table {tbl_name}: {fk_err}"
)
tables_data.append(t_info)
except Exception as tbl_err:
log_msg = f"Failed to inspect table '{tbl_name}' in schema '{target_schema}': {tbl_err}"
logger.error(log_msg, exc_info=True)
# Append error dict
error_entry = {
"name": tbl_name,
"schema": target_schema,
"error": f"Failed to inspect: {tbl_err}",
}
tables_data.append(error_entry)
for view_name in view_names:
try:
# Build view_info dict step-by-step
view_info: Dict[str, Any] = {}
view_info["name"] = view_name
if target_schema:
view_info["schema"] = target_schema
try:
view_def_raw = insp.get_view_definition(
view_name, schema=target_schema
)
# Ternary okay
view_def = view_def_raw or ""
view_info["definition"] = view_def
except Exception as view_def_err:
log_msg = f"Could not retrieve definition for view {view_name}: {view_def_err}"
logger.warning(log_msg)
view_info["definition"] = "Error retrieving definition"
try:
view_cols_raw = insp.get_columns(view_name, schema=target_schema)
# List comprehension okay
view_info["columns"] = [
{"name": vc["name"], "type": str(vc["type"])}
for vc in view_cols_raw
]
except Exception:
pass # Ignore column errors for views if definition failed etc.
views_data.append(view_info)
except Exception as view_err:
log_msg = f"Failed to inspect view '{view_name}' in schema '{target_schema}': {view_err}"
logger.error(log_msg, exc_info=True)
# Append error dict
error_entry = {
"name": view_name,
"schema": target_schema,
"error": f"Failed to inspect: {view_err}",
}
views_data.append(error_entry)
# Build schema_result dict step-by-step
schema_result: Dict[str, Any] = {}
schema_result["action"] = "schema"
schema_result["database_type"] = db_dialect
schema_result["inspected_schema"] = target_schema or "Default"
schema_result["tables"] = tables_data
schema_result["views"] = views_data
schema_result["relationships"] = relationships
schema_result["success"] = True
# Schema Hashing and Lineage
try:
# Call okay
schema_json = json.dumps(schema_result, sort_keys=True, default=str)
schema_bytes = schema_json.encode()
# Call okay
schema_hash = hashlib.sha256(schema_bytes).hexdigest()
timestamp = _sql_now()
last_hash = _SCHEMA_VERSIONS.get(connection_id)
schema_changed = last_hash != schema_hash
if schema_changed:
_SCHEMA_VERSIONS[connection_id] = schema_hash
# Build lineage_entry dict step-by-step
lineage_entry = {}
lineage_entry["connection_id"] = connection_id
lineage_entry["timestamp"] = timestamp
lineage_entry["schema_hash"] = schema_hash
lineage_entry["previous_hash"] = last_hash
lineage_entry["user_id"] = user_id # Include user from outer scope
lineage_entry["tables_count"] = len(tables_data)
lineage_entry["views_count"] = len(views_data)
lineage_entry["action_source"] = f"{tool_name}/{action}"
_LINEAGE.append(lineage_entry)
hash_preview = schema_hash[:8]
prev_hash_preview = last_hash[:8] if last_hash else "None"
log_msg = f"Schema change detected or initial capture for {connection_id}. New hash: {hash_preview}..., Previous: {prev_hash_preview}"
logger.info(log_msg)
schema_result["schema_hash"] = schema_hash
# Boolean conversion okay
schema_result["schema_change_detected"] = bool(last_hash)
except Exception as hash_err:
log_msg = f"Error generating schema hash or recording lineage: {hash_err}"
logger.error(log_msg, exc_info=True)
return schema_result
# Call okay
def sync_func(sync_conn_arg):
return _get_full_schema(sync_conn_arg)
result = await conn.run_sync(sync_func) # Pass sync connection
# --- Action: table ---
elif action == "table":
if not table_name:
raise ToolInputError(
"`table_name` is required for 'table'", param_name="table_name"
)
include_sample = options.get("include_sample_data", False)
sample_size_raw = options.get("sample_size", 5)
sample_size = int(sample_size_raw)
include_stats = options.get("include_statistics", False)
if sample_size < 0:
sample_size = 0
def _get_basic_table_meta(sync_conn) -> Dict[str, Any]:
# Assign inspector and schema
insp = sa_inspect(sync_conn)
target_schema = schema_name or getattr(insp, "default_schema_name", None)
logger.info(f"Inspecting table details: {target_schema}.{table_name}")
try:
all_tables = insp.get_table_names(schema=target_schema)
if table_name not in all_tables:
msg = f"Table '{table_name}' not found in schema '{target_schema}'."
raise ToolInputError(msg, param_name="table_name")
except Exception as list_err:
msg = f"Could not verify if table '{table_name}' exists: {list_err}"
raise ToolError(msg, http_status_code=500) from list_err
# Initialize meta parts
cols = []
idx = []
fks = []
pk_constraint = {}
table_comment_text = None
cols = insp.get_columns(table_name, schema=target_schema)
try:
idx = insp.get_indexes(table_name, schema=target_schema)
except Exception as idx_err:
logger.warning(f"Could not get indexes for table {table_name}: {idx_err}")
try:
fks = insp.get_foreign_keys(table_name, schema=target_schema)
except Exception as fk_err:
logger.warning(
f"Could not get foreign keys for table {table_name}: {fk_err}"
)
try:
pk_info = insp.get_pk_constraint(table_name, schema=target_schema)
# Split pk_constraint assignment
if pk_info and pk_info.get("constrained_columns"):
pk_constraint = {
"name": pk_info.get("name"),
"columns": pk_info["constrained_columns"],
}
# else pk_constraint remains {}
except Exception as pk_err:
logger.warning(f"Could not get PK constraint for {table_name}: {pk_err}")
try:
table_comment_raw = insp.get_table_comment(table_name, schema=target_schema)
# Ternary okay
table_comment_text = (
table_comment_raw.get("text") if table_comment_raw else None
)
except Exception as cmt_err:
logger.warning(f"Could not get table comment for {table_name}: {cmt_err}")
# Build return dict step-by-step
meta_result = {}
meta_result["columns"] = cols
meta_result["indexes"] = idx
meta_result["foreign_keys"] = fks
meta_result["pk_constraint"] = pk_constraint
meta_result["table_comment"] = table_comment_text
meta_result["schema_name"] = target_schema # Add schema name for reference
return meta_result
# Call okay
def sync_func_meta(sync_conn_arg):
return _get_basic_table_meta(sync_conn_arg)
meta = await conn.run_sync(sync_func_meta) # Pass sync connection
# Build result dict step-by-step
result = {}
result["action"] = "table"
result["table_name"] = table_name
# Use schema name returned from meta function
result["schema_name"] = meta.get("schema_name")
result["comment"] = meta.get("table_comment")
# List comprehension okay
result["columns"] = [
{
"name": c["name"],
"type": str(c["type"]),
"nullable": c["nullable"],
"primary_key": bool(c.get("primary_key")),
"default": c.get("default"),
"comment": c.get("comment"),
}
for c in meta["columns"]
]
result["primary_key"] = meta.get("pk_constraint")
result["indexes"] = meta.get("indexes", [])
result["foreign_keys"] = meta.get("foreign_keys", [])
result["success"] = True
# Quote identifiers
id_prep = eng.dialect.identifier_preparer
quoted_table_name = id_prep.quote(table_name)
quoted_schema_name = id_prep.quote(schema_name) if schema_name else None
# Ternary okay
full_table_name = (
f"{quoted_schema_name}.{quoted_table_name}"
if quoted_schema_name
else quoted_table_name
)
# Row count
try:
# Call okay
_, count_rows, _ = await _sql_exec(
eng,
f"SELECT COUNT(*) AS row_count FROM {full_table_name}",
None,
limit=1,
tool_name=tool_name,
action_name="table_count",
timeout=30,
)
# Ternary okay
result["row_count"] = count_rows[0]["row_count"] if count_rows else 0
except Exception as count_err:
logger.warning(
f"Could not get row count for table {full_table_name}: {count_err}"
)
result["row_count"] = "Error"
# Sample data
if include_sample and sample_size > 0:
try:
# Call okay
sample_cols, sample_rows, _ = await _sql_exec(
eng,
f"SELECT * FROM {full_table_name} LIMIT :n",
{"n": sample_size},
limit=sample_size,
tool_name=tool_name,
action_name="table_sample",
timeout=30,
)
# Assign sample data dict okay
result["sample_data"] = {"columns": sample_cols, "rows": sample_rows}
except Exception as sample_err:
logger.warning(
f"Could not get sample data for table {full_table_name}: {sample_err}"
)
# Assign error dict okay
result["sample_data"] = {
"error": f"Failed to retrieve sample data: {sample_err}"
}
# Statistics
if include_stats:
stats = {}
logger.debug(f"Calculating basic statistics for columns in {full_table_name}")
columns_to_stat = result.get("columns", [])
for c in columns_to_stat:
col_name = c["name"]
quoted_col = id_prep.quote(col_name)
col_stat_data = {}
try:
# Null count
# Call okay
_, null_rows, _ = await _sql_exec(
eng,
f"SELECT COUNT(*) AS null_count FROM {full_table_name} WHERE {quoted_col} IS NULL",
None,
limit=1,
tool_name=tool_name,
action_name="col_stat_null",
timeout=20,
)
# Ternary okay
null_count = null_rows[0]["null_count"] if null_rows else "Error"
# Distinct count
# Call okay
_, distinct_rows, _ = await _sql_exec(
eng,
f"SELECT COUNT(DISTINCT {quoted_col}) AS distinct_count FROM {full_table_name}",
None,
limit=1,
tool_name=tool_name,
action_name="col_stat_distinct",
timeout=45,
)
# Ternary okay
distinct_count = (
distinct_rows[0]["distinct_count"] if distinct_rows else "Error"
)
# Assign stats dict okay
col_stat_data = {
"null_count": null_count,
"distinct_count": distinct_count,
}
except Exception as stat_err:
log_msg = f"Could not calculate statistics for column {col_name} in {full_table_name}: {stat_err}"
logger.warning(log_msg)
# Assign error dict okay
col_stat_data = {"error": f"Failed: {stat_err}"}
stats[col_name] = col_stat_data
result["statistics"] = stats
# --- Action: column ---
elif action == "column":
if not table_name:
raise ToolInputError(
"`table_name` required for 'column'", param_name="table_name"
)
if not column_name:
raise ToolInputError(
"`column_name` required for 'column'", param_name="column_name"
)
generate_histogram = options.get("histogram", False)
num_buckets_raw = options.get("num_buckets", 10)
num_buckets = int(num_buckets_raw)
num_buckets = max(1, num_buckets) # Ensure at least one bucket
# Quote identifiers
id_prep = eng.dialect.identifier_preparer
quoted_table = id_prep.quote(table_name)
quoted_column = id_prep.quote(column_name)
quoted_schema = id_prep.quote(schema_name) if schema_name else None
# Ternary okay
full_table_name = (
f"{quoted_schema}.{quoted_table}" if quoted_schema else quoted_table
)
logger.info(f"Analyzing column {full_table_name}.{quoted_column}")
stats_data: Dict[str, Any] = {}
try:
# Total Rows
# Call okay
_, total_rows_res, _ = await _sql_exec(
eng,
f"SELECT COUNT(*) as cnt FROM {full_table_name}",
None,
limit=1,
tool_name=tool_name,
action_name="col_stat_total",
timeout=30,
)
# Ternary okay
total_rows_count = total_rows_res[0]["cnt"] if total_rows_res else 0
stats_data["total_rows"] = total_rows_count
# Null Count
# Call okay
_, null_rows_res, _ = await _sql_exec(
eng,
f"SELECT COUNT(*) as cnt FROM {full_table_name} WHERE {quoted_column} IS NULL",
None,
limit=1,
tool_name=tool_name,
action_name="col_stat_null",
timeout=30,
)
# Ternary okay
null_count = null_rows_res[0]["cnt"] if null_rows_res else 0
stats_data["null_count"] = null_count
# Ternary okay
null_perc = (
round((null_count / total_rows_count) * 100, 2) if total_rows_count else 0
)
stats_data["null_percentage"] = null_perc
# Distinct Count
# Call okay
_, distinct_rows_res, _ = await _sql_exec(
eng,
f"SELECT COUNT(DISTINCT {quoted_column}) as cnt FROM {full_table_name}",
None,
limit=1,
tool_name=tool_name,
action_name="col_stat_distinct",
timeout=60,
)
# Ternary okay
distinct_count = distinct_rows_res[0]["cnt"] if distinct_rows_res else 0
stats_data["distinct_count"] = distinct_count
# Ternary okay
distinct_perc = (
round((distinct_count / total_rows_count) * 100, 2)
if total_rows_count
else 0
)
stats_data["distinct_percentage"] = distinct_perc
except Exception as stat_err:
log_msg = f"Failed to get basic statistics for column {full_table_name}.{quoted_column}: {stat_err}"
logger.error(log_msg, exc_info=True)
stats_data["error"] = f"Failed to retrieve some statistics: {stat_err}"
# Build result dict step-by-step
result = {}
result["action"] = "column"
result["table_name"] = table_name
result["column_name"] = column_name
result["schema_name"] = schema_name
result["statistics"] = stats_data
result["success"] = True
if generate_histogram:
logger.debug(f"Generating histogram for {full_table_name}.{quoted_column}")
histogram_data: Optional[Dict[str, Any]] = None
try:
hist_query = f"SELECT {quoted_column} FROM {full_table_name} WHERE {quoted_column} IS NOT NULL"
# Call okay
_, value_rows, _ = await _sql_exec(
eng,
hist_query,
None,
limit=None, # Fetch all non-null values
tool_name=tool_name,
action_name="col_hist_fetch",
timeout=90,
)
# List comprehension okay
values = [r[column_name] for r in value_rows]
if not values:
histogram_data = {"type": "empty", "buckets": []}
else:
first_val = values[0]
# Check type okay
is_numeric = isinstance(first_val, (int, float))
if is_numeric:
try:
min_val = min(values)
max_val = max(values)
buckets = []
if min_val == max_val:
# Single bucket dict okay
bucket = {"range": f"{min_val}", "count": len(values)}
buckets.append(bucket)
else:
# Calculate bin width okay
val_range = max_val - min_val
bin_width = val_range / num_buckets
# List comprehension okay
bucket_ranges_raw = [
(min_val + i * bin_width, min_val + (i + 1) * bin_width)
for i in range(num_buckets)
]
# Adjust last bucket range okay
last_bucket_idx = num_buckets - 1
last_bucket_start = bucket_ranges_raw[last_bucket_idx][0]
bucket_ranges_raw[last_bucket_idx] = (
last_bucket_start,
max_val,
)
bucket_ranges = bucket_ranges_raw
# List comprehension for bucket init okay
buckets = [
{"range": f"{r[0]:.4g} - {r[1]:.4g}", "count": 0}
for r in bucket_ranges
]
for v in values:
# Ternary okay
idx_float = (
(v - min_val) / bin_width if bin_width > 0 else 0
)
idx_int = int(idx_float)
# Ensure index is within bounds
idx = min(idx_int, num_buckets - 1)
# Handle max value potentially falling into last bucket due to precision
if v == max_val:
idx = num_buckets - 1
buckets[idx]["count"] += 1
# Assign numeric histogram dict okay
histogram_data = {
"type": "numeric",
"min": min_val,
"max": max_val,
"buckets": buckets,
}
except Exception as num_hist_err:
log_msg = f"Error generating numeric histogram: {num_hist_err}"
logger.error(log_msg, exc_info=True)
# Assign error dict okay
histogram_data = {
"error": f"Failed to generate numeric histogram: {num_hist_err}"
}
else: # Categorical / Frequency
try:
# Import okay
from collections import Counter
# Call okay
str_values = map(str, values)
value_counts = Counter(str_values)
# Call okay
top_buckets_raw = value_counts.most_common(num_buckets)
# List comprehension okay
buckets_data = [
{"value": str(k)[:100], "count": v} # Limit value length
for k, v in top_buckets_raw
]
# Sum okay
top_n_count = sum(b["count"] for b in buckets_data)
other_count = len(values) - top_n_count
# Assign frequency histogram dict okay
histogram_data = {
"type": "frequency",
"top_n": num_buckets,
"buckets": buckets_data,
}
if other_count > 0:
histogram_data["other_values_count"] = other_count
except Exception as freq_hist_err:
log_msg = (
f"Error generating frequency histogram: {freq_hist_err}"
)
logger.error(log_msg, exc_info=True)
# Assign error dict okay
histogram_data = {
"error": f"Failed to generate frequency histogram: {freq_hist_err}"
}
except Exception as hist_err:
log_msg = f"Failed to generate histogram for column {full_table_name}.{quoted_column}: {hist_err}"
logger.error(log_msg, exc_info=True)
# Assign error dict okay
histogram_data = {"error": f"Histogram generation failed: {hist_err}"}
result["histogram"] = histogram_data
# --- Action: relationships ---
elif action == "relationships":
if not table_name:
raise ToolInputError(
"`table_name` required for 'relationships'", param_name="table_name"
)
depth_raw = options.get("depth", 1)
depth_int = int(depth_raw)
# Clamp depth
depth = max(1, min(depth_int, 5))
log_msg = f"Finding relationships for table '{table_name}' (depth: {depth}, schema: {schema_name})"
logger.info(log_msg)
# Call explore_database for schema info - this recursive call is okay
schema_info = await explore_database(
connection_id=connection_id,
action="schema",
schema_name=schema_name,
include_indexes=False, # Don't need indexes for relationships
include_foreign_keys=True, # Need FKs
detailed=False, # Don't need detailed column info
)
# Check success okay
schema_success = schema_info.get("success", False)
if not schema_success:
raise ToolError(
"Failed to retrieve schema information needed to find relationships."
)
# Dict comprehension okay
tables_list = schema_info.get("tables", [])
tables_by_name: Dict[str, Dict] = {t["name"]: t for t in tables_list}
if table_name not in tables_by_name:
msg = f"Starting table '{table_name}' not found in schema '{schema_name}'."
raise ToolInputError(msg, param_name="table_name")
visited_nodes = set() # Track visited nodes to prevent cycles
# Define the recursive helper function *inside* this action block
# so it has access to tables_by_name and visited_nodes
def _build_relationship_graph_standalone(
current_table: str, current_depth: int
) -> Dict[str, Any]:
# Build node_id string okay
current_schema = schema_name or "default"
node_id = f"{current_schema}.{current_table}"
is_max_depth = current_depth >= depth
is_visited = node_id in visited_nodes
if is_max_depth or is_visited:
# Return dict okay
return {
"table": current_table,
"schema": schema_name, # Use original schema_name context
"max_depth_reached": is_max_depth,
"cyclic_reference": is_visited,
}
visited_nodes.add(node_id)
node_info = tables_by_name.get(current_table)
if not node_info:
visited_nodes.remove(node_id) # Backtrack
# Return dict okay
return {
"table": current_table,
"schema": schema_name,
"error": "Table info not found",
}
# Build graph_node dict step-by-step
graph_node: Dict[str, Any] = {}
graph_node["table"] = current_table
graph_node["schema"] = schema_name
graph_node["children"] = []
graph_node["parents"] = []
# Find Parents (current table's FKs point to parents)
foreign_keys_list = node_info.get("foreign_keys", [])
for fk in foreign_keys_list:
ref_table = fk["referred_table"]
ref_schema = fk.get(
"referred_schema", schema_name
) # Assume same schema if not specified
if ref_table in tables_by_name:
# Recursive call okay
parent_node = _build_relationship_graph_standalone(
ref_table, current_depth + 1
)
else:
# Return dict okay for outside scope
parent_node = {
"table": ref_table,
"schema": ref_schema,
"outside_scope": True,
}
# Build relationship string okay
constrained_cols_str = ",".join(fk["constrained_columns"])
referred_cols_str = ",".join(fk["referred_columns"])
rel_str = f"{current_table}.({constrained_cols_str}) -> {ref_table}.({referred_cols_str})"
# Append parent relationship dict okay
graph_node["parents"].append(
{"relationship": rel_str, "target": parent_node}
)
# Find Children (other tables' FKs point to current table)
for other_table_name, other_table_info in tables_by_name.items():
if other_table_name == current_table:
continue # Skip self-reference check here
other_fks = other_table_info.get("foreign_keys", [])
for fk in other_fks:
points_to_current = fk["referred_table"] == current_table
# Check schema match (use original schema_name context)
referred_schema_matches = (
fk.get("referred_schema", schema_name) == schema_name
)
if points_to_current and referred_schema_matches:
# Recursive call okay
child_node = _build_relationship_graph_standalone(
other_table_name, current_depth + 1
)
# Build relationship string okay
constrained_cols_str = ",".join(fk["constrained_columns"])
referred_cols_str = ",".join(fk["referred_columns"])
rel_str = f"{other_table_name}.({constrained_cols_str}) -> {current_table}.({referred_cols_str})"
# Append child relationship dict okay
graph_node["children"].append(
{"relationship": rel_str, "source": child_node}
)
visited_nodes.remove(node_id) # Backtrack visited state
return graph_node
# Initial call to the recursive function
relationship_graph = _build_relationship_graph_standalone(table_name, 0)
# Build result dict step-by-step
result = {}
result["action"] = "relationships"
result["source_table"] = table_name
result["schema_name"] = schema_name
result["max_depth"] = depth
result["relationship_graph"] = relationship_graph
result["success"] = True
# --- Action: documentation ---
elif action == "documentation":
output_format_raw = options.get("output_format", "markdown")
output_format = output_format_raw.lower()
valid_formats = ["markdown", "json"]
if output_format not in valid_formats:
msg = "Invalid 'output_format'. Use 'markdown' or 'json'."
raise ToolInputError(msg, param_name="output_format")
doc_include_indexes = options.get("include_indexes", True)
doc_include_fks = options.get("include_foreign_keys", True)
log_msg = f"Generating database documentation (Format: {output_format}, Schema: {schema_name})"
logger.info(log_msg)
# Call explore_database for schema info (recursive call okay)
schema_data = await explore_database(
connection_id=connection_id,
action="schema",
schema_name=schema_name,
include_indexes=doc_include_indexes,
include_foreign_keys=doc_include_fks,
detailed=True, # Need details for documentation
)
schema_success = schema_data.get("success", False)
if not schema_success:
raise ToolError(
"Failed to retrieve schema information needed for documentation."
)
if output_format == "json":
# Build result dict step-by-step
result = {}
result["action"] = "documentation"
result["format"] = "json"
result["documentation"] = schema_data # Embed the schema result directly
result["success"] = True
else: # Markdown
# --- Markdown Generation ---
lines = []
lines.append(f"# Database Documentation ({db_dialect})")
db_schema_name = schema_data.get("inspected_schema", "Default Schema")
lines.append(f"Schema: **{db_schema_name}**")
now_str = _sql_now()
lines.append(f"Generated: {now_str}")
schema_hash_val = schema_data.get("schema_hash")
if schema_hash_val:
hash_preview = schema_hash_val[:12]
lines.append(f"Schema Version (Hash): `{hash_preview}`")
lines.append("") # Blank line
lines.append("## Tables")
lines.append("")
# Sort okay
tables_list_raw = schema_data.get("tables", [])
tables = sorted(tables_list_raw, key=lambda x: x["name"])
if not tables:
lines.append("*No tables found in this schema.*")
for t in tables:
table_name_doc = t["name"]
if t.get("error"):
lines.append(f"### {table_name_doc} (Error)")
lines.append(f"```\n{t['error']}\n```")
lines.append("")
continue # Skip rest for this table
lines.append(f"### {table_name_doc}")
lines.append("")
table_comment = t.get("comment")
if table_comment:
lines.append(f"> {table_comment}")
lines.append("")
# Column Header
lines.append("| Column | Type | Nullable | PK | Default | Comment |")
lines.append("|--------|------|----------|----|---------|---------|")
columns_list = t.get("columns", [])
for c in columns_list:
# Ternary okay
pk_flag = "✅" if c["primary_key"] else ""
null_flag = "✅" if c["nullable"] else ""
default_raw = c.get("default")
# Ternary okay
default_val_str = f"`{default_raw}`" if default_raw is not None else ""
comment_val = c.get("comment") or ""
col_name_str = f"`{c['name']}`"
col_type_str = f"`{c['type']}`"
# Build line okay
line = f"| {col_name_str} | {col_type_str} | {null_flag} | {pk_flag} | {default_val_str} | {comment_val} |"
lines.append(line)
lines.append("") # Blank line after table
# Primary Key section
pk_info = t.get("primary_key")
pk_cols = pk_info.get("columns") if pk_info else None
if pk_info and pk_cols:
pk_name = pk_info.get("name", "PK")
# List comprehension okay
pk_cols_formatted = [f"`{c}`" for c in pk_cols]
pk_cols_str = ", ".join(pk_cols_formatted)
lines.append(f"**Primary Key:** `{pk_name}` ({pk_cols_str})")
lines.append("")
# Indexes section
indexes_list = t.get("indexes")
if doc_include_indexes and indexes_list:
lines.append("**Indexes:**")
lines.append("")
lines.append("| Name | Columns | Unique |")
lines.append("|------|---------|--------|")
for idx in indexes_list:
# Ternary okay
unique_flag = "✅" if idx["unique"] else ""
# List comprehension okay
idx_cols_formatted = [f"`{c}`" for c in idx["columns"]]
cols_str = ", ".join(idx_cols_formatted)
idx_name_str = f"`{idx['name']}`"
# Build line okay
line = f"| {idx_name_str} | {cols_str} | {unique_flag} |"
lines.append(line)
lines.append("")
# Foreign Keys section
fks_list = t.get("foreign_keys")
if doc_include_fks and fks_list:
lines.append("**Foreign Keys:**")
lines.append("")
lines.append("| Name | Column(s) | References |")
lines.append("|------|-----------|------------|")
for fk in fks_list:
# List comprehension okay
constrained_cols_fmt = [f"`{c}`" for c in fk["constrained_columns"]]
constrained_cols_str = ", ".join(constrained_cols_fmt)
ref_schema = fk.get("referred_schema", db_schema_name)
ref_table_name = fk["referred_table"]
ref_table_str = f"`{ref_schema}`.`{ref_table_name}`"
# List comprehension okay
ref_cols_fmt = [f"`{c}`" for c in fk["referred_columns"]]
ref_cols_str = ", ".join(ref_cols_fmt)
fk_name = fk.get("name", "FK")
fk_name_str = f"`{fk_name}`"
ref_full_str = f"{ref_table_str} ({ref_cols_str})"
# Build line okay
line = (
f"| {fk_name_str} | {constrained_cols_str} | {ref_full_str} |"
)
lines.append(line)
lines.append("")
# Views Section
views_list_raw = schema_data.get("views", [])
views = sorted(views_list_raw, key=lambda x: x["name"])
if views:
lines.append("## Views")
lines.append("")
for v in views:
view_name_doc = v["name"]
if v.get("error"):
lines.append(f"### {view_name_doc} (Error)")
lines.append(f"```\n{v['error']}\n```")
lines.append("")
continue # Skip rest for this view
lines.append(f"### {view_name_doc}")
lines.append("")
view_columns = v.get("columns")
if view_columns:
# List comprehension okay
view_cols_fmt = [
f"`{vc['name']}` ({vc['type']})" for vc in view_columns
]
view_cols_str = ", ".join(view_cols_fmt)
lines.append(f"**Columns:** {view_cols_str}")
lines.append("")
view_def = v.get("definition")
# Check for valid definition string
is_valid_def = (
view_def and view_def != "N/A (Not Implemented by Dialect)"
)
if is_valid_def:
lines.append("**Definition:**")
lines.append("```sql")
lines.append(view_def)
lines.append("```")
lines.append("")
else:
lines.append(
"**Definition:** *Not available or not implemented by dialect.*"
)
lines.append("")
# --- End Markdown Generation ---
# Join lines okay
markdown_output = "\n".join(lines)
# Build result dict step-by-step
result = {}
result["action"] = "documentation"
result["format"] = "markdown"
result["documentation"] = markdown_output
result["success"] = True
else:
logger.error(f"Invalid action specified for explore_database: {action}")
details = {"action": action}
valid_actions = "schema, table, column, relationships, documentation"
msg = f"Unknown action: '{action}'. Valid actions: {valid_actions}"
raise ToolInputError(msg, param_name="action", details=details)
# Audit success for all successful actions
# Ternary okay
audit_table = [table_name] if table_name else None
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action,
connection_id=connection_id,
sql=None,
tables=audit_table,
row_count=None,
success=True,
error=None,
user_id=user_id,
session_id=session_id,
**audit_extras,
)
return result # Return the constructed result dict
except ToolInputError as tie:
# Audit failure
# Ternary okay
audit_table = [table_name] if table_name else None
action_fail = action + "_fail"
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action_fail,
connection_id=connection_id,
sql=None,
tables=audit_table,
row_count=None,
success=False,
error=str(tie),
user_id=user_id,
session_id=session_id,
**audit_extras,
)
raise tie
except ToolError as te:
# Audit failure
# Ternary okay
audit_table = [table_name] if table_name else None
action_fail = action + "_fail"
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action_fail,
connection_id=connection_id,
sql=None,
tables=audit_table,
row_count=None,
success=False,
error=str(te),
user_id=user_id,
session_id=session_id,
**audit_extras,
)
raise te
except Exception as e:
log_msg = f"Unexpected error in explore_database (action: {action}): {e}"
logger.error(log_msg, exc_info=True)
# Audit failure
# Ternary okay
audit_table = [table_name] if table_name else None
action_fail = action + "_fail"
error_str = f"Unexpected error: {e}"
# Call okay
await _sql_audit(
tool_name=tool_name,
action=action_fail,
connection_id=connection_id,
sql=None,
tables=audit_table,
row_count=None,
success=False,
error=error_str,
user_id=user_id,
session_id=session_id,
**audit_extras,
)
raise ToolError(
f"An unexpected error occurred during database exploration: {e}", http_status_code=500
) from e
@with_tool_metrics
@with_error_handling
async def access_audit_log(
action: str = "view",
export_format: Optional[str] = None,
limit: Optional[int] = 100,
user_id: Optional[str] = None,
connection_id: Optional[str] = None,
ctx: Optional[Dict] = None, # Added ctx
) -> Dict[str, Any]:
"""
Access and export the in-memory SQL audit log.
Allows viewing recent log entries or exporting them to a file.
Note: The audit log is currently stored only in memory and will be lost on server restart.
Args:
action: "view" (default) or "export".
export_format: Required if action is "export". Supports "json", "excel", "csv".
limit: For "view", the maximum number of most recent records to return (default: 100). Use None or -1 for all.
user_id: Filter log entries by this user ID.
connection_id: Filter log entries by this connection ID.
ctx: Optional context from MCP server.
Returns:
Dict containing results:
- For "view": {action: "view", records: List[Dict], filtered_record_count: int, total_records_in_log: int, filters_applied: Dict, success: True}
- For "export": {action: "export", path: str, format: str, record_count: int, success: True} or {action: "export", message: str, record_count: 0, success: True} if no records.
"""
tool_name = "access_audit_log" # noqa: F841
# Apply filters using global _AUDIT_LOG
async with _audit_lock: # Need lock to safely read/copy log
# Call okay
full_log_copy = list(_AUDIT_LOG)
total_records_in_log = len(full_log_copy)
# Start with the full copy
filtered_log = full_log_copy
# Apply filters sequentially
if user_id:
# List comprehension okay
filtered_log = [r for r in filtered_log if r.get("user_id") == user_id]
if connection_id:
# List comprehension okay
filtered_log = [r for r in filtered_log if r.get("connection_id") == connection_id]
filtered_record_count = len(filtered_log)
if action == "view":
# Ternary okay
needs_limit = limit is not None and limit >= 0
records_to_return = filtered_log[-limit:] if needs_limit else filtered_log
num_returned = len(records_to_return)
log_msg = f"View audit log requested. Returning {num_returned}/{filtered_record_count} filtered records (Total in log: {total_records_in_log})."
logger.info(log_msg)
# Build filters applied dict okay
filters_applied = {"user_id": user_id, "connection_id": connection_id}
# Build result dict step-by-step
result = {}
result["action"] = "view"
result["records"] = records_to_return
result["filtered_record_count"] = filtered_record_count
result["total_records_in_log"] = total_records_in_log
result["filters_applied"] = filters_applied
result["success"] = True
return result
elif action == "export":
if not export_format:
raise ToolInputError(
"`export_format` is required for 'export'", param_name="export_format"
)
export_format_lower = export_format.lower()
log_msg = f"Export audit log requested. Format: {export_format_lower}. Records to export: {filtered_record_count}"
logger.info(log_msg)
if not filtered_log:
logger.warning("Audit log is empty or filtered log is empty, nothing to export.")
# Return dict okay
return {
"action": "export",
"message": "No audit records found matching filters to export.",
"record_count": 0,
"success": True,
}
if export_format_lower == "json":
path = "" # Initialize path
try:
# Call okay
fd, temp_path = tempfile.mkstemp(suffix=".json", prefix="mcp_audit_export_")
path = temp_path # Assign path now we know mkstemp succeeded
os.close(fd)
# Use sync write for simplicity here
with open(path, "w", encoding="utf-8") as f:
# Call okay
json.dump(filtered_log, f, indent=2, default=str)
log_msg = (
f"Successfully exported {filtered_record_count} audit records to JSON: {path}"
)
logger.info(log_msg)
# Return dict okay
return {
"action": "export",
"path": path,
"format": "json",
"record_count": filtered_record_count,
"success": True,
}
except Exception as e:
log_msg = f"Failed to export audit log to JSON: {e}"
logger.error(log_msg, exc_info=True)
# Clean up temp file if created
if path and Path(path).exists():
try:
Path(path).unlink()
except OSError:
logger.warning(f"Could not clean up failed JSON export file: {path}")
raise ToolError(
f"Failed to export audit log to JSON: {e}", http_status_code=500
) from e
elif export_format_lower in ["excel", "csv"]:
if pd is None:
details = {"library": "pandas"}
msg = f"Pandas library not installed, cannot export audit log to '{export_format_lower}'."
raise ToolError(msg, http_status_code=501, details=details)
path = "" # Initialize path
try:
# Call okay
df = pd.DataFrame(filtered_log)
# Ternary okay for suffix/writer/engine
is_excel = export_format_lower == "excel"
suffix = ".xlsx" if is_excel else ".csv"
writer_func = df.to_excel if is_excel else df.to_csv
engine = "xlsxwriter" if is_excel else None
# Call okay
fd, temp_path = tempfile.mkstemp(suffix=suffix, prefix="mcp_audit_export_")
path = temp_path # Assign path
os.close(fd)
# Build export args dict okay
export_kwargs: Dict[str, Any] = {"index": False}
if engine:
export_kwargs["engine"] = engine
# Call writer function
writer_func(path, **export_kwargs)
log_msg = f"Successfully exported {filtered_record_count} audit records to {export_format_lower.upper()}: {path}"
logger.info(log_msg)
# Return dict okay
return {
"action": "export",
"path": path,
"format": export_format_lower,
"record_count": filtered_record_count,
"success": True,
}
except Exception as e:
log_msg = f"Failed to export audit log to {export_format_lower}: {e}"
logger.error(log_msg, exc_info=True)
# Clean up temp file if created
if path and Path(path).exists():
try:
Path(path).unlink()
except OSError:
logger.warning(f"Could not clean up temporary export file: {path}")
msg = f"Failed to export audit log to {export_format_lower}: {e}"
raise ToolError(msg, http_status_code=500) from e
else:
details = {"format": export_format}
valid_formats = "'excel', 'csv', or 'json'"
msg = f"Unsupported export format: '{export_format}'. Use {valid_formats}."
raise ToolInputError(msg, param_name="export_format", details=details)
else:
details = {"action": action}
msg = f"Unknown action: '{action}'. Use 'view' or 'export'."
raise ToolInputError(msg, param_name="action", details=details)
```