This is page 41 of 45. Use http://codebase.md/dicklesworthstone/llm_gateway_mcp_server?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .cursorignore
├── .env.example
├── .envrc
├── .gitignore
├── additional_features.md
├── check_api_keys.py
├── completion_support.py
├── comprehensive_test.py
├── docker-compose.yml
├── Dockerfile
├── empirically_measured_model_speeds.json
├── error_handling.py
├── example_structured_tool.py
├── examples
│ ├── __init__.py
│ ├── advanced_agent_flows_using_unified_memory_system_demo.py
│ ├── advanced_extraction_demo.py
│ ├── advanced_unified_memory_system_demo.py
│ ├── advanced_vector_search_demo.py
│ ├── analytics_reporting_demo.py
│ ├── audio_transcription_demo.py
│ ├── basic_completion_demo.py
│ ├── cache_demo.py
│ ├── claude_integration_demo.py
│ ├── compare_synthesize_demo.py
│ ├── cost_optimization.py
│ ├── data
│ │ ├── sample_event.txt
│ │ ├── Steve_Jobs_Introducing_The_iPhone_compressed.md
│ │ └── Steve_Jobs_Introducing_The_iPhone_compressed.mp3
│ ├── docstring_refiner_demo.py
│ ├── document_conversion_and_processing_demo.py
│ ├── entity_relation_graph_demo.py
│ ├── filesystem_operations_demo.py
│ ├── grok_integration_demo.py
│ ├── local_text_tools_demo.py
│ ├── marqo_fused_search_demo.py
│ ├── measure_model_speeds.py
│ ├── meta_api_demo.py
│ ├── multi_provider_demo.py
│ ├── ollama_integration_demo.py
│ ├── prompt_templates_demo.py
│ ├── python_sandbox_demo.py
│ ├── rag_example.py
│ ├── research_workflow_demo.py
│ ├── sample
│ │ ├── article.txt
│ │ ├── backprop_paper.pdf
│ │ ├── buffett.pdf
│ │ ├── contract_link.txt
│ │ ├── legal_contract.txt
│ │ ├── medical_case.txt
│ │ ├── northwind.db
│ │ ├── research_paper.txt
│ │ ├── sample_data.json
│ │ └── text_classification_samples
│ │ ├── email_classification.txt
│ │ ├── news_samples.txt
│ │ ├── product_reviews.txt
│ │ └── support_tickets.txt
│ ├── sample_docs
│ │ └── downloaded
│ │ └── attention_is_all_you_need.pdf
│ ├── sentiment_analysis_demo.py
│ ├── simple_completion_demo.py
│ ├── single_shot_synthesis_demo.py
│ ├── smart_browser_demo.py
│ ├── sql_database_demo.py
│ ├── sse_client_demo.py
│ ├── test_code_extraction.py
│ ├── test_content_detection.py
│ ├── test_ollama.py
│ ├── text_classification_demo.py
│ ├── text_redline_demo.py
│ ├── tool_composition_examples.py
│ ├── tournament_code_demo.py
│ ├── tournament_text_demo.py
│ ├── unified_memory_system_demo.py
│ ├── vector_search_demo.py
│ ├── web_automation_instruction_packs.py
│ └── workflow_delegation_demo.py
├── LICENSE
├── list_models.py
├── marqo_index_config.json.example
├── mcp_protocol_schema_2025-03-25_version.json
├── mcp_python_lib_docs.md
├── mcp_tool_context_estimator.py
├── model_preferences.py
├── pyproject.toml
├── quick_test.py
├── README.md
├── resource_annotations.py
├── run_all_demo_scripts_and_check_for_errors.py
├── storage
│ └── smart_browser_internal
│ ├── locator_cache.db
│ ├── readability.js
│ └── storage_state.enc
├── test_client.py
├── test_connection.py
├── TEST_README.md
├── test_sse_client.py
├── test_stdio_client.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── integration
│ │ ├── __init__.py
│ │ └── test_server.py
│ ├── manual
│ │ ├── test_extraction_advanced.py
│ │ └── test_extraction.py
│ └── unit
│ ├── __init__.py
│ ├── test_cache.py
│ ├── test_providers.py
│ └── test_tools.py
├── TODO.md
├── tool_annotations.py
├── tools_list.json
├── ultimate_mcp_banner.webp
├── ultimate_mcp_logo.webp
├── ultimate_mcp_server
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── commands.py
│ │ ├── helpers.py
│ │ └── typer_cli.py
│ ├── clients
│ │ ├── __init__.py
│ │ ├── completion_client.py
│ │ └── rag_client.py
│ ├── config
│ │ └── examples
│ │ └── filesystem_config.yaml
│ ├── config.py
│ ├── constants.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── evaluation
│ │ │ ├── base.py
│ │ │ └── evaluators.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── anthropic.py
│ │ │ ├── base.py
│ │ │ ├── deepseek.py
│ │ │ ├── gemini.py
│ │ │ ├── grok.py
│ │ │ ├── ollama.py
│ │ │ ├── openai.py
│ │ │ └── openrouter.py
│ │ ├── server.py
│ │ ├── state_store.py
│ │ ├── tournaments
│ │ │ ├── manager.py
│ │ │ ├── tasks.py
│ │ │ └── utils.py
│ │ └── ums_api
│ │ ├── __init__.py
│ │ ├── ums_database.py
│ │ ├── ums_endpoints.py
│ │ ├── ums_models.py
│ │ └── ums_services.py
│ ├── exceptions.py
│ ├── graceful_shutdown.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── analytics
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── reporting.py
│ │ ├── cache
│ │ │ ├── __init__.py
│ │ │ ├── cache_service.py
│ │ │ ├── persistence.py
│ │ │ ├── strategies.py
│ │ │ └── utils.py
│ │ ├── cache.py
│ │ ├── document.py
│ │ ├── knowledge_base
│ │ │ ├── __init__.py
│ │ │ ├── feedback.py
│ │ │ ├── manager.py
│ │ │ ├── rag_engine.py
│ │ │ ├── retriever.py
│ │ │ └── utils.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── repository.py
│ │ │ └── templates.py
│ │ ├── prompts.py
│ │ └── vector
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ └── vector_service.py
│ ├── tool_token_counter.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── audio_transcription.py
│ │ ├── base.py
│ │ ├── completion.py
│ │ ├── docstring_refiner.py
│ │ ├── document_conversion_and_processing.py
│ │ ├── enhanced-ums-lookbook.html
│ │ ├── entity_relation_graph.py
│ │ ├── excel_spreadsheet_automation.py
│ │ ├── extraction.py
│ │ ├── filesystem.py
│ │ ├── html_to_markdown.py
│ │ ├── local_text_tools.py
│ │ ├── marqo_fused_search.py
│ │ ├── meta_api_tool.py
│ │ ├── ocr_tools.py
│ │ ├── optimization.py
│ │ ├── provider.py
│ │ ├── pyodide_boot_template.html
│ │ ├── python_sandbox.py
│ │ ├── rag.py
│ │ ├── redline-compiled.css
│ │ ├── sentiment_analysis.py
│ │ ├── single_shot_synthesis.py
│ │ ├── smart_browser.py
│ │ ├── sql_databases.py
│ │ ├── text_classification.py
│ │ ├── text_redline_tools.py
│ │ ├── tournament.py
│ │ ├── ums_explorer.html
│ │ └── unified_memory_system.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── async_utils.py
│ │ ├── display.py
│ │ ├── logging
│ │ │ ├── __init__.py
│ │ │ ├── console.py
│ │ │ ├── emojis.py
│ │ │ ├── formatter.py
│ │ │ ├── logger.py
│ │ │ ├── panels.py
│ │ │ ├── progress.py
│ │ │ └── themes.py
│ │ ├── parse_yaml.py
│ │ ├── parsing.py
│ │ ├── security.py
│ │ └── text.py
│ └── working_memory_api.py
├── unified_memory_system_technical_analysis.md
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/ultimate_mcp_server/tools/sql_databases.py:
--------------------------------------------------------------------------------
```python
1 | # ultimate_mcp_server/tools/sql_databases.py
2 | from __future__ import annotations
3 |
4 | import asyncio
5 | import datetime as dt
6 | import hashlib
7 | import json
8 | import os
9 | import re
10 | import tempfile
11 | import time
12 | import uuid
13 | from dataclasses import dataclass
14 | from functools import lru_cache
15 | from pathlib import Path
16 |
17 | # --- START: Expanded typing imports ---
18 | from typing import Any, Dict, List, Optional, Set, Tuple, Union
19 |
20 | # --- END: Expanded typing imports ---
21 | # --- Removed BaseTool import ---
22 | # SQLAlchemy imports
23 | from sqlalchemy import inspect as sa_inspect
24 | from sqlalchemy import text
25 | from sqlalchemy.engine.url import make_url
26 | from sqlalchemy.exc import OperationalError, ProgrammingError, SQLAlchemyError
27 | from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
28 |
29 | # Local imports
30 | from ultimate_mcp_server.exceptions import ToolError, ToolInputError
31 |
32 | # --- START: Expanded base imports ---
33 | from ultimate_mcp_server.tools.base import with_error_handling, with_tool_metrics
34 |
35 | # --- END: Expanded base imports ---
36 | from ultimate_mcp_server.tools.completion import generate_completion # For NL→SQL
37 | from ultimate_mcp_server.utils import get_logger
38 |
39 | # Optional imports with graceful fallbacks
40 | try:
41 | import boto3 # For AWS Secrets Manager
42 | except ImportError:
43 | boto3 = None
44 |
45 | try:
46 | import hvac # For HashiCorp Vault
47 | except ImportError:
48 | hvac = None
49 |
50 | try:
51 | import pandas as pd
52 | except ImportError:
53 | pd = None
54 |
55 | try:
56 | import pandera as pa
57 | except ImportError:
58 | pa = None
59 |
60 | try:
61 | import prometheus_client as prom
62 | except ImportError:
63 | prom = None
64 |
65 | logger = get_logger("ultimate_mcp_server.tools.sql_databases")
66 |
67 | # =============================================================================
68 | # Global State and Configuration (Replaces instance variables)
69 | # =============================================================================
70 |
71 |
72 | # --- Connection Management ---
73 | class ConnectionManager:
74 | """Manages database connections with automatic cleanup after inactivity."""
75 |
76 | # (Keep ConnectionManager class as is - it's a helper utility)
77 | def __init__(self, cleanup_interval_seconds=600, check_interval_seconds=60):
78 | self.connections: Dict[str, Tuple[AsyncEngine, float]] = {}
79 | self.cleanup_interval = cleanup_interval_seconds
80 | self.check_interval = check_interval_seconds
81 | self._cleanup_task: Optional[asyncio.Task] = None
82 | self._lock = asyncio.Lock() # Added lock for thread-safe modifications
83 |
84 | async def start_cleanup_task(self):
85 | async with self._lock:
86 | cleanup_task_is_none = self._cleanup_task is None
87 | cleanup_task_is_done = self._cleanup_task is not None and self._cleanup_task.done()
88 | if cleanup_task_is_none or cleanup_task_is_done:
89 | try:
90 | loop = asyncio.get_running_loop()
91 | task_coro = self._cleanup_loop()
92 | self._cleanup_task = loop.create_task(task_coro)
93 | logger.info("Started connection cleanup task.")
94 | except RuntimeError:
95 | logger.warning("No running event loop found, cleanup task not started.")
96 |
97 | async def _cleanup_loop(self):
98 | log_msg = f"Cleanup loop started. Check interval: {self.check_interval}s, Inactivity threshold: {self.cleanup_interval}s"
99 | logger.debug(log_msg)
100 | while True:
101 | await asyncio.sleep(self.check_interval)
102 | try:
103 | await self.cleanup_inactive_connections()
104 | except asyncio.CancelledError:
105 | logger.info("Cleanup loop cancelled.")
106 | break # Exit loop cleanly on cancellation
107 | except Exception as e:
108 | logger.error(f"Error during connection cleanup: {e}", exc_info=True)
109 |
110 | async def cleanup_inactive_connections(self):
111 | current_time = time.time()
112 | conn_ids_to_close = []
113 |
114 | # Need lock here as we iterate over potentially changing dict
115 | async with self._lock:
116 | # Use items() for safe iteration while potentially modifying dict later
117 | # Create a copy to avoid issues if the dict is modified elsewhere concurrently (though unlikely with lock)
118 | current_connections = list(self.connections.items())
119 |
120 | for conn_id, (_engine, last_accessed) in current_connections:
121 | idle_time = current_time - last_accessed
122 | is_inactive = idle_time > self.cleanup_interval
123 | if is_inactive:
124 | log_msg = f"Connection {conn_id} exceeded inactivity timeout ({idle_time:.1f}s > {self.cleanup_interval}s)"
125 | logger.info(log_msg)
126 | conn_ids_to_close.append(conn_id)
127 |
128 | closed_count = 0
129 | for conn_id in conn_ids_to_close:
130 | # close_connection acquires its own lock
131 | closed = await self.close_connection(conn_id)
132 | if closed:
133 | logger.info(f"Auto-closed inactive connection: {conn_id}")
134 | closed_count += 1
135 | if closed_count > 0:
136 | logger.debug(f"Closed {closed_count} inactive connections.")
137 | elif conn_ids_to_close:
138 | num_attempted = len(conn_ids_to_close)
139 | logger.debug(
140 | f"Attempted to close {num_attempted} connections, but they might have been removed already."
141 | )
142 |
143 | async def get_connection(self, conn_id: str) -> AsyncEngine:
144 | async with self._lock:
145 | if conn_id not in self.connections:
146 | details = {"error_type": "CONNECTION_NOT_FOUND"}
147 | raise ToolInputError(
148 | f"Unknown connection_id: {conn_id}", param_name="connection_id", details=details
149 | )
150 |
151 | engine, _ = self.connections[conn_id]
152 | # Update last accessed time
153 | current_time = time.time()
154 | self.connections[conn_id] = (engine, current_time)
155 | logger.debug(f"Accessed connection {conn_id}, updated last accessed time.")
156 | return engine
157 |
158 | async def add_connection(self, conn_id: str, engine: AsyncEngine):
159 | # close_connection handles locking internally
160 | has_existing = conn_id in self.connections
161 | if has_existing:
162 | logger.warning(f"Overwriting existing connection entry for {conn_id}.")
163 | await self.close_connection(conn_id) # Close the old one first
164 |
165 | async with self._lock:
166 | current_time = time.time()
167 | self.connections[conn_id] = (engine, current_time)
168 | url_str = str(engine.url)
169 | url_prefix = url_str.split("@")[0]
170 | log_msg = (
171 | f"Added connection {conn_id} for URL: {url_prefix}..." # Avoid logging credentials
172 | )
173 | logger.info(log_msg)
174 | await self.start_cleanup_task() # Ensure cleanup is running
175 |
176 | async def close_connection(self, conn_id: str) -> bool:
177 | engine = None
178 | async with self._lock:
179 | connection_exists = conn_id in self.connections
180 | if connection_exists:
181 | engine, _ = self.connections.pop(conn_id)
182 | else:
183 | logger.warning(f"Attempted to close non-existent connection ID: {conn_id}")
184 | return False # Not found
185 |
186 | if engine:
187 | logger.info(f"Closing connection {conn_id}...")
188 | try:
189 | await engine.dispose()
190 | logger.info(f"Connection {conn_id} disposed successfully.")
191 | return True
192 | except Exception as e:
193 | log_msg = f"Error disposing engine for connection {conn_id}: {e}"
194 | logger.error(log_msg, exc_info=True)
195 | # Removed from dict, but disposal failed
196 | return False
197 | return False # Should not be reached if found
198 |
199 | async def shutdown(self):
200 | logger.info("Shutting down Connection Manager...")
201 | # Cancel cleanup task first
202 | cleanup_task = None
203 | async with self._lock:
204 | task_exists = self._cleanup_task is not None
205 | task_not_done = task_exists and not self._cleanup_task.done()
206 | if task_exists and task_not_done:
207 | cleanup_task = self._cleanup_task # Get reference before clearing
208 | self._cleanup_task = None # Prevent restarting
209 |
210 | if cleanup_task:
211 | cleanup_task.cancel()
212 | try:
213 | # Add timeout for task cancellation
214 | await asyncio.wait_for(cleanup_task, timeout=2.0)
215 | except asyncio.TimeoutError:
216 | logger.warning("Cleanup task cancellation timed out after 2 seconds")
217 | except asyncio.CancelledError:
218 | logger.info("Cleanup task cancelled.")
219 | except Exception as e:
220 | logger.error(f"Error stopping cleanup task: {e}", exc_info=True)
221 |
222 | # Close remaining connections
223 | async with self._lock:
224 | conn_ids = list(self.connections.keys())
225 |
226 | if conn_ids:
227 | num_conns = len(conn_ids)
228 | logger.info(f"Closing {num_conns} active connections...")
229 | # Call close_connection which handles locking and removal
230 | close_tasks = []
231 | for conn_id in conn_ids:
232 | # Create a task that times out for each connection
233 | async def close_with_timeout(conn_id):
234 | try:
235 | await asyncio.wait_for(self.close_connection(conn_id), timeout=2.0)
236 | return True
237 | except asyncio.TimeoutError:
238 | logger.warning(f"Connection {conn_id} close timed out after 2 seconds")
239 | return False
240 | close_tasks.append(close_with_timeout(conn_id))
241 |
242 | # Wait for all connections to close with an overall timeout
243 | try:
244 | await asyncio.wait_for(asyncio.gather(*close_tasks, return_exceptions=True), timeout=5.0)
245 | except asyncio.TimeoutError:
246 | logger.warning("Some connections did not close within the 5 second timeout")
247 |
248 | async with self._lock:
249 | # Final check
250 | remaining = len(self.connections)
251 | if remaining > 0:
252 | logger.warning(f"{remaining} connections still remain after shutdown attempt.")
253 | self.connections.clear() # Clear the dictionary
254 |
255 | logger.info("Connection Manager shutdown complete.")
256 |
257 |
258 | _connection_manager = ConnectionManager()
259 |
260 | # --- Security and Validation ---
261 | _PROHIBITED_SQL_PATTERN = r"""^\s*(DROP\s+(TABLE|DATABASE|INDEX|VIEW|FUNCTION|PROCEDURE|USER|ROLE)|
262 | TRUNCATE\s+TABLE|
263 | DELETE\s+FROM|
264 | ALTER\s+(TABLE|DATABASE)\s+\S+\s+DROP\s+|
265 | UPDATE\s+|INSERT\s+INTO(?!\s+OR\s+IGNORE)|
266 | GRANT\s+|REVOKE\s+|
267 | CREATE\s+USER|ALTER\s+USER|DROP\s+USER|
268 | CREATE\s+ROLE|ALTER\s+ROLE|DROP\s+ROLE|
269 | SHUTDOWN|REBOOT|RESTART)"""
270 | _PROHIBITED_SQL_REGEX = re.compile(_PROHIBITED_SQL_PATTERN, re.I | re.X)
271 |
272 | _TABLE_RX = re.compile(r"\b(?:FROM|JOIN|UPDATE|INSERT\s+INTO|DELETE\s+FROM)\s+([\w.\"$-]+)", re.I)
273 |
274 |
275 | # --- Masking ---
276 | @dataclass
277 | class MaskRule:
278 | rx: re.Pattern
279 | repl: Union[str, callable]
280 |
281 |
282 | # Helper lambda for credit card masking
283 | def _mask_cc(v: str) -> str:
284 | return f"XXXX-...-{v[-4:]}"
285 |
286 |
287 | # Helper lambda for email masking
288 | def _mask_email(v: str) -> str:
289 | if "@" in v:
290 | parts = v.split("@")
291 | prefix = parts[0][:2] + "***"
292 | domain = parts[-1]
293 | return f"{prefix}@{domain}"
294 | else:
295 | return "***"
296 |
297 |
298 | _MASKING_RULES = [
299 | MaskRule(re.compile(r"^\d{3}-\d{2}-\d{4}$"), "***-**-XXXX"), # SSN
300 | MaskRule(re.compile(r"(\b\d{4}-?){3}\d{4}\b"), _mask_cc), # CC basic mask
301 | MaskRule(re.compile(r"[\w\.-]+@[\w\.-]+\.\w+"), _mask_email), # Email
302 | ]
303 |
304 | # --- ACLs ---
305 | _RESTRICTED_TABLES: Set[str] = set()
306 | _RESTRICTED_COLUMNS: Set[str] = set()
307 |
308 | # --- Auditing ---
309 | _AUDIT_LOG: List[Dict[str, Any]] = []
310 | _AUDIT_ID_COUNTER: int = 0
311 | _audit_lock = asyncio.Lock() # Lock for modifying audit counter and log
312 |
313 | # --- Schema Drift Detection ---
314 | _LINEAGE: List[Dict[str, Any]] = []
315 | _SCHEMA_VERSIONS: Dict[str, str] = {} # connection_id -> schema_hash
316 |
317 | # --- Prometheus Metrics ---
318 | # Initialized as None, populated in initialize function if prom is available
319 | _Q_CNT: Optional[Any] = None
320 | _Q_LAT: Optional[Any] = None
321 | _CONN_GAUGE: Optional[Any] = None
322 |
323 |
324 | # =============================================================================
325 | # Initialization and Shutdown Functions
326 | # =============================================================================
327 |
328 | # Flag to track if metrics have been initialized
329 | _sql_metrics_initialized = False
330 |
331 | async def initialize_sql_tools():
332 | """Initialize global state for SQL tools, like starting the cleanup task and metrics."""
333 | global _sql_metrics_initialized
334 | global _Q_CNT, _Q_LAT, _CONN_GAUGE # Ensure globals are declared for assignment
335 |
336 | # Initialize metrics only once
337 | if not _sql_metrics_initialized:
338 | logger.info("Initializing SQL Tools module metrics...")
339 | if prom:
340 | try:
341 | # Define metrics
342 | _Q_CNT = prom.Counter("mcp_sqltool_calls", "SQL tool calls", ["tool", "action", "db"])
343 | latency_buckets = (0.01, 0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10, 30, 60)
344 | _Q_LAT = prom.Histogram(
345 | "mcp_sqltool_latency_seconds",
346 | "SQL latency",
347 | ["tool", "action", "db"],
348 | buckets=latency_buckets,
349 | )
350 | _CONN_GAUGE = prom.Gauge(
351 | "mcp_sqltool_active_connections",
352 | "Number of active SQL connections"
353 | )
354 |
355 | # Define the gauge function referencing the global manager
356 | # Wrap in try-except as accessing length during shutdown might be tricky
357 | def _get_active_connections():
358 | try:
359 | # Access length directly if manager state is simple enough
360 | # If complex state, acquire lock if necessary (_connection_manager._lock)
361 | # For just length, direct access is usually okay unless adding/removing heavily concurrent
362 | return len(_connection_manager.connections)
363 | except Exception:
364 | logger.exception("Error getting active connection count for Prometheus.")
365 | return 0 # Default to 0 if error accessing
366 |
367 | _CONN_GAUGE.set_function(_get_active_connections)
368 | logger.info("Prometheus metrics initialized for SQL tools.")
369 | _sql_metrics_initialized = True # Set flag only after successful initialization
370 |
371 | except ValueError as e:
372 | # Catch the specific duplicate error and log nicely, but don't crash
373 | if "Duplicated timeseries" in str(e):
374 | logger.warning(f"Prometheus metrics already registered: {e}. Skipping re-initialization.")
375 | _sql_metrics_initialized = True # Assume they are initialized if duplicate error occurs
376 | else:
377 | # Re-raise other ValueErrors
378 | logger.error(f"ValueError during Prometheus metric initialization: {e}", exc_info=True)
379 | raise # Re-raise unexpected ValueError
380 | except Exception as e:
381 | logger.error(f"Failed to initialize Prometheus metrics for SQL tools: {e}", exc_info=True)
382 | # Continue without metrics if initialization fails? Or raise? Let's continue for now.
383 |
384 | else:
385 | logger.info("Prometheus client not available, metrics disabled for SQL tools.")
386 | _sql_metrics_initialized = True # Mark as "initialized" (i.e., done trying) even if prom not present
387 | else:
388 | logger.debug("SQL tools metrics already initialized, skipping metric creation.")
389 |
390 | # Always try to start the cleanup task (it's internally idempotent)
391 | # Ensure this happens *after* logging initialization attempt
392 | logger.info("Ensuring SQL connection cleanup task is running...")
393 | await _connection_manager.start_cleanup_task()
394 |
395 |
396 | async def shutdown_sql_tools():
397 | """Gracefully shut down SQL tool resources, like the connection manager."""
398 | logger.info("Shutting down SQL Tools module...")
399 | try:
400 | # Add timeout to connection manager shutdown
401 | await asyncio.wait_for(_connection_manager.shutdown(), timeout=8.0)
402 | except asyncio.TimeoutError:
403 | logger.warning("Connection Manager shutdown timed out after 8 seconds")
404 | # Clear other global state if necessary (e.g., save audit log)
405 | logger.info("SQL Tools module shutdown complete.")
406 |
407 |
408 | # =============================================================================
409 | # Helper Functions (Private module-level functions)
410 | # =============================================================================
411 |
412 |
413 | @lru_cache(maxsize=64)
414 | def _pull_secret_from_sources(name: str) -> str:
415 | """Retrieve a secret from various sources."""
416 | # (Implementation remains the same as in the original class)
417 | if boto3:
418 | try:
419 | client = boto3.client("secretsmanager")
420 | # Consider region_name=os.getenv("AWS_REGION") or similar config
421 | secret_value_response = client.get_secret_value(SecretId=name)
422 | # Handle binary vs string secrets
423 | if "SecretString" in secret_value_response:
424 | secret = secret_value_response["SecretString"]
425 | return secret
426 | elif "SecretBinary" in secret_value_response:
427 | # Decode binary appropriately if needed, default to utf-8
428 | secret_bytes = secret_value_response["SecretBinary"]
429 | secret = secret_bytes.decode("utf-8")
430 | return secret
431 | except Exception as aws_err:
432 | logger.debug(f"Secret '{name}' not found or error in AWS Secrets Manager: {aws_err}")
433 | pass
434 |
435 | if hvac:
436 | try:
437 | vault_url = os.getenv("VAULT_ADDR")
438 | vault_token = os.getenv("VAULT_TOKEN")
439 | if vault_url and vault_token:
440 | vault_client = hvac.Client(
441 | url=vault_url, token=vault_token, timeout=2
442 | ) # Short timeout
443 | is_auth = vault_client.is_authenticated()
444 | if is_auth:
445 | mount_point = os.getenv("VAULT_KV_MOUNT_POINT", "secret")
446 | secret_path = name
447 | read_response = vault_client.secrets.kv.v2.read_secret_version(
448 | path=secret_path, mount_point=mount_point
449 | )
450 | # Standard KV v2 structure: response['data']['data'] is the dict of secrets
451 | has_outer_data = "data" in read_response
452 | has_inner_data = has_outer_data and "data" in read_response["data"]
453 | if has_inner_data:
454 | # Try common key names 'value' or the secret name itself
455 | secret_data = read_response["data"]["data"]
456 | if "value" in secret_data:
457 | value = secret_data["value"]
458 | return value
459 | elif name in secret_data:
460 | value = secret_data[name]
461 | return value
462 | else:
463 | log_msg = f"Secret keys 'value' or '{name}' not found at path '{secret_path}' in Vault."
464 | logger.debug(log_msg)
465 | else:
466 | logger.warning(f"Vault authentication failed for address: {vault_url}")
467 |
468 | except Exception as e:
469 | logger.debug(f"Error accessing Vault for secret '{name}': {e}")
470 | pass
471 |
472 | # Try environment variables
473 | env_val_direct = os.getenv(name)
474 | if env_val_direct:
475 | return env_val_direct
476 |
477 | mcp_secret_name = f"MCP_SECRET_{name.upper()}"
478 | env_val_prefixed = os.getenv(mcp_secret_name)
479 | if env_val_prefixed:
480 | logger.debug(f"Found secret '{name}' using prefixed env var '{mcp_secret_name}'.")
481 | return env_val_prefixed
482 |
483 | error_msg = (
484 | f"Secret '{name}' not found in any source (AWS, Vault, Env: {name}, Env: {mcp_secret_name})"
485 | )
486 | details = {"secret_name": name, "error_type": "SECRET_NOT_FOUND"}
487 | raise ToolError(error_msg, http_status_code=404, details=details)
488 |
489 |
490 | async def _sql_get_engine(cid: str) -> AsyncEngine:
491 | """Get engine by connection ID using the global ConnectionManager."""
492 | engine = await _connection_manager.get_connection(cid)
493 | return engine
494 |
495 |
496 | def _sql_get_next_audit_id() -> str:
497 | """Generate the next sequential audit ID (thread-safe)."""
498 | # Locking happens in _sql_audit where this is called
499 | global _AUDIT_ID_COUNTER
500 | _AUDIT_ID_COUNTER += 1
501 | audit_id_str = f"a{_AUDIT_ID_COUNTER:09d}"
502 | return audit_id_str
503 |
504 |
505 | def _sql_now() -> str:
506 | """Get current UTC timestamp in ISO format."""
507 | now_utc = dt.datetime.now(dt.timezone.utc)
508 | iso_str = now_utc.isoformat(timespec="seconds")
509 | return iso_str
510 |
511 |
512 | async def _sql_audit(
513 | *,
514 | tool_name: str,
515 | action: str,
516 | connection_id: Optional[str],
517 | sql: Optional[str],
518 | tables: Optional[List[str]],
519 | row_count: Optional[int],
520 | success: bool,
521 | error: Optional[str],
522 | user_id: Optional[str],
523 | session_id: Optional[str],
524 | **extra_data: Any,
525 | ) -> None:
526 | """Record an audit trail entry (thread-safe)."""
527 | global _AUDIT_LOG
528 | async with _audit_lock:
529 | audit_id = _sql_get_next_audit_id() # Get ID while locked
530 | timestamp = _sql_now()
531 | log_entry = {}
532 | log_entry["audit_id"] = audit_id
533 | log_entry["timestamp"] = timestamp
534 | log_entry["tool_name"] = tool_name
535 | log_entry["action"] = action
536 | log_entry["user_id"] = user_id
537 | log_entry["session_id"] = session_id
538 | log_entry["connection_id"] = connection_id
539 | log_entry["sql"] = sql
540 | log_entry["tables"] = tables
541 | log_entry["row_count"] = row_count
542 | log_entry["success"] = success
543 | log_entry["error"] = error
544 | log_entry.update(extra_data) # Add extra data
545 |
546 | _AUDIT_LOG.append(log_entry)
547 |
548 | # Optional: Log to logger (outside lock)
549 | log_base = f"Audit[{audit_id}]: Tool={tool_name}, Action={action}, Conn={connection_id}, Success={success}"
550 | log_error = f", Error={error}" if error else ""
551 | logger.info(log_base + log_error)
552 |
553 |
554 | def _sql_update_acl(
555 | *, tables: Optional[List[str]] = None, columns: Optional[List[str]] = None
556 | ) -> None:
557 | """Update the global ACL lists."""
558 | global _RESTRICTED_TABLES, _RESTRICTED_COLUMNS
559 | if tables is not None:
560 | lowered_tables = {t.lower() for t in tables}
561 | _RESTRICTED_TABLES = lowered_tables
562 | logger.info(f"Updated restricted tables ACL: {_RESTRICTED_TABLES}")
563 | if columns is not None:
564 | lowered_columns = {c.lower() for c in columns}
565 | _RESTRICTED_COLUMNS = lowered_columns
566 | logger.info(f"Updated restricted columns ACL: {_RESTRICTED_COLUMNS}")
567 |
568 |
569 | def _sql_check_acl(sql: str) -> None:
570 | """Check if SQL contains any restricted tables or columns using global ACLs."""
571 | # (Implementation remains the same, uses global _RESTRICTED_TABLES/_COLUMNS)
572 | raw_toks = re.findall(r'[\w$"\'.]+', sql.lower())
573 | toks = set(raw_toks)
574 | normalized_toks = set()
575 | for tok in toks:
576 | tok_norm = tok.strip("\"`'[]")
577 | normalized_toks.add(tok_norm)
578 | has_dot = "." in tok_norm
579 | if has_dot:
580 | last_part = tok_norm.split(".")[-1]
581 | normalized_toks.add(last_part)
582 |
583 | restricted_tables_found_set = _RESTRICTED_TABLES.intersection(normalized_toks)
584 | restricted_tables_found = list(restricted_tables_found_set)
585 | if restricted_tables_found:
586 | tables_str = ", ".join(restricted_tables_found)
587 | logger.warning(
588 | f"ACL Violation: Restricted table(s) found in query: {restricted_tables_found}"
589 | )
590 | details = {
591 | "restricted_tables": restricted_tables_found,
592 | "error_type": "ACL_TABLE_VIOLATION",
593 | }
594 | raise ToolError(
595 | f"Access denied: Query involves restricted table(s): {tables_str}",
596 | http_status_code=403,
597 | details=details,
598 | )
599 |
600 | restricted_columns_found_set = _RESTRICTED_COLUMNS.intersection(normalized_toks)
601 | restricted_columns_found = list(restricted_columns_found_set)
602 | if restricted_columns_found:
603 | columns_str = ", ".join(restricted_columns_found)
604 | logger.warning(
605 | f"ACL Violation: Restricted column(s) found in query: {restricted_columns_found}"
606 | )
607 | details = {
608 | "restricted_columns": restricted_columns_found,
609 | "error_type": "ACL_COLUMN_VIOLATION",
610 | }
611 | raise ToolError(
612 | f"Access denied: Query involves restricted column(s): {columns_str}",
613 | http_status_code=403,
614 | details=details,
615 | )
616 |
617 |
618 | def _sql_resolve_conn(raw: str) -> str:
619 | """Resolve connection string, handling secret references."""
620 | # (Implementation remains the same)
621 | is_secret_ref = raw.startswith("secrets://")
622 | if is_secret_ref:
623 | secret_name = raw[10:]
624 | logger.info(f"Resolving secret reference: '{secret_name}'")
625 | resolved_secret = _pull_secret_from_sources(secret_name)
626 | return resolved_secret
627 | return raw
628 |
629 |
630 | def _sql_mask_val(v: Any) -> Any:
631 | """Apply masking rules to a single value using global rules."""
632 | # (Implementation remains the same, uses global _MASKING_RULES)
633 | is_string = isinstance(v, str)
634 | is_not_empty = bool(v)
635 | if not is_string or not is_not_empty:
636 | return v
637 | for rule in _MASKING_RULES:
638 | matches = rule.rx.fullmatch(v)
639 | if matches:
640 | is_callable = callable(rule.repl)
641 | if is_callable:
642 | try:
643 | masked_value = rule.repl(v)
644 | return masked_value
645 | except Exception as e:
646 | log_msg = f"Error applying dynamic mask rule {rule.rx.pattern}: {e}"
647 | logger.error(log_msg)
648 | return "MASKING_ERROR"
649 | else:
650 | return rule.repl
651 | return v
652 |
653 |
654 | def _sql_mask_row(row: Dict[str, Any]) -> Dict[str, Any]:
655 | """Apply masking rules to an entire row of data."""
656 | masked_dict = {}
657 | for k, v in row.items():
658 | masked_val = _sql_mask_val(v)
659 | masked_dict[k] = masked_val
660 | return masked_dict
661 | # return {k: _sql_mask_val(v) for k, v in row.items()} # Keep single-line comprehension
662 |
663 |
664 | def _sql_driver_url(conn_str: str) -> Tuple[str, str]:
665 | """Convert generic connection string to dialect-specific async URL."""
666 | # Check if it looks like a path (no ://) and exists or is :memory:
667 | has_protocol = "://" in conn_str
668 | looks_like_path = not has_protocol
669 | path_obj = Path(conn_str)
670 | path_exists = path_obj.exists()
671 | is_memory = conn_str == ":memory:"
672 | is_file_path = looks_like_path and (path_exists or is_memory)
673 |
674 | if is_file_path:
675 | if is_memory:
676 | url_str = "sqlite+aiosqlite:///:memory:"
677 | logger.info("Using in-memory SQLite database.")
678 | else:
679 | sqlite_path = path_obj.expanduser().resolve()
680 | parent_dir = sqlite_path.parent
681 | parent_exists = parent_dir.exists()
682 | if not parent_exists:
683 | try:
684 | parent_dir.mkdir(parents=True, exist_ok=True)
685 | logger.info(f"Created directory for SQLite DB: {parent_dir}")
686 | except OSError as e:
687 | details = {"path": str(parent_dir)}
688 | raise ToolError(
689 | f"Failed to create directory for SQLite DB '{parent_dir}': {e}",
690 | http_status_code=500,
691 | details=details,
692 | ) from e
693 | url_str = f"sqlite+aiosqlite:///{sqlite_path}"
694 | logger.info(f"Using SQLite database file: {sqlite_path}")
695 | url = make_url(url_str)
696 | final_url_str = str(url)
697 | return final_url_str, "sqlite"
698 | else:
699 | url_str = conn_str
700 | try:
701 | url = make_url(url_str)
702 | except Exception as e:
703 | details = {"value": conn_str}
704 | raise ToolInputError(
705 | f"Invalid connection string format: {e}",
706 | param_name="connection_string",
707 | details=details,
708 | ) from e
709 |
710 | drv = url.drivername.lower()
711 | drivername = url.drivername # Preserve original case for setting later if needed
712 |
713 | if drv.startswith("sqlite"):
714 | new_url = url.set(drivername="sqlite+aiosqlite")
715 | return str(new_url), "sqlite"
716 | if drv.startswith("postgresql") or drv == "postgres":
717 | new_url = url.set(drivername="postgresql+asyncpg")
718 | return str(new_url), "postgresql"
719 | if drv.startswith("mysql") or drv == "mariadb":
720 | query = dict(url.query)
721 | query.setdefault("charset", "utf8mb4")
722 | new_url = url.set(drivername="mysql+aiomysql", query=query)
723 | return str(new_url), "mysql"
724 | if drv.startswith("mssql") or drv == "sqlserver":
725 | odbc_driver = url.query.get("driver")
726 | if not odbc_driver:
727 | logger.warning(
728 | "MSSQL connection string lacks 'driver' parameter. Ensure a valid ODBC driver (e.g., 'ODBC Driver 17 for SQL Server') is installed and specified."
729 | )
730 | new_url = url.set(drivername="mssql+aioodbc")
731 | return str(new_url), "sqlserver"
732 | if drv.startswith("snowflake"):
733 | # Keep original snowflake driver
734 | new_url = url.set(drivername=drivername)
735 | return str(new_url), "snowflake"
736 |
737 | logger.error(f"Unsupported database dialect: {drv}")
738 | details = {"dialect": drv}
739 | raise ToolInputError(
740 | f"Unsupported database dialect: '{drv}'. Supported: sqlite, postgresql, mysql, mssql, snowflake",
741 | param_name="connection_string",
742 | details=details,
743 | )
744 |
745 |
746 | def _sql_auto_pool(db_type: str) -> Dict[str, Any]:
747 | """Provide sensible default connection pool settings."""
748 | # (Implementation remains the same)
749 | # Single-line dict return is acceptable
750 | defaults = {
751 | "pool_size": 5,
752 | "max_overflow": 10,
753 | "pool_recycle": 1800,
754 | "pool_pre_ping": True,
755 | "pool_timeout": 30,
756 | }
757 | if db_type == "sqlite":
758 | return {"pool_pre_ping": True}
759 | if db_type == "postgresql":
760 | return {
761 | "pool_size": 10,
762 | "max_overflow": 20,
763 | "pool_recycle": 900,
764 | "pool_pre_ping": True,
765 | "pool_timeout": 30,
766 | }
767 | if db_type == "mysql":
768 | return {
769 | "pool_size": 10,
770 | "max_overflow": 20,
771 | "pool_recycle": 900,
772 | "pool_pre_ping": True,
773 | "pool_timeout": 30,
774 | }
775 | if db_type == "sqlserver":
776 | return {
777 | "pool_size": 10,
778 | "max_overflow": 20,
779 | "pool_recycle": 900,
780 | "pool_pre_ping": True,
781 | "pool_timeout": 30,
782 | }
783 | if db_type == "snowflake":
784 | return {"pool_size": 5, "max_overflow": 5, "pool_pre_ping": True, "pool_timeout": 60}
785 | logger.warning(f"Using default pool settings for unknown db_type: {db_type}")
786 | return defaults
787 |
788 |
789 | def _sql_extract_tables(sql: str) -> List[str]:
790 | """Extract table names referenced in a SQL query."""
791 | matches = _TABLE_RX.findall(sql)
792 | tables = set()
793 | for match in matches:
794 | # Chained strip is one expression
795 | table_stripped = match.strip()
796 | table = table_stripped.strip("\"`'[]")
797 | has_dot = "." in table
798 | if has_dot:
799 | # table.split('.')[-1].strip('"`\'[]') # Original combined
800 | parts = table.split(".")
801 | last_part = parts[-1]
802 | table = last_part.strip("\"`'[]")
803 | if table:
804 | tables.add(table)
805 | sorted_tables = sorted(list(tables))
806 | return sorted_tables
807 |
808 |
809 | def _sql_check_safe(sql: str, read_only: bool = True) -> None:
810 | """Validate SQL for safety using global patterns and ACLs."""
811 | # Check ACLs first
812 | _sql_check_acl(sql)
813 |
814 | # Check prohibited statements
815 | normalized_sql = sql.lstrip().upper()
816 | check_sql_part = normalized_sql # Default part to check
817 |
818 | starts_with_with = normalized_sql.startswith("WITH")
819 | if starts_with_with:
820 | try:
821 | # Regex remains single-line expression assignment
822 | search_regex = r"\)\s*(SELECT|INSERT|UPDATE|DELETE|MERGE)"
823 | search_flags = re.IGNORECASE | re.DOTALL
824 | main_statement_match = re.search(search_regex, normalized_sql, search_flags)
825 | if main_statement_match:
826 | # Chained calls okay on one line
827 | main_statement_group = main_statement_match.group(0)
828 | check_sql_part = main_statement_group.lstrip(") \t\n\r")
829 | # else: keep check_sql_part as normalized_sql
830 | except Exception:
831 | # Ignore regex errors, fallback to checking whole normalized_sql
832 | pass
833 |
834 | prohibited_match_obj = _PROHIBITED_SQL_REGEX.match(check_sql_part)
835 | if prohibited_match_obj:
836 | # Chained calls okay on one line
837 | prohibited_match = prohibited_match_obj.group(1)
838 | prohibited_statement = prohibited_match.strip()
839 | logger.warning(f"Security Violation: Prohibited statement detected: {prohibited_statement}")
840 | details = {"statement": prohibited_statement, "error_type": "PROHIBITED_STATEMENT"}
841 | raise ToolInputError(
842 | f"Prohibited statement type detected: '{prohibited_statement}'",
843 | param_name="query",
844 | details=details,
845 | )
846 |
847 | # Check read-only constraint
848 | if read_only:
849 | allowed_starts = ("SELECT", "SHOW", "EXPLAIN", "DESCRIBE", "PRAGMA")
850 | is_read_query = check_sql_part.startswith(allowed_starts)
851 | if not is_read_query:
852 | query_preview = sql[:100]
853 | logger.warning(
854 | f"Security Violation: Write operation attempted in read-only mode: {query_preview}..."
855 | )
856 | details = {"error_type": "READ_ONLY_VIOLATION"}
857 | raise ToolInputError(
858 | "Write operation attempted in read-only mode", param_name="query", details=details
859 | )
860 |
861 |
862 | async def _sql_exec(
863 | eng: AsyncEngine,
864 | sql: str,
865 | params: Optional[Dict[str, Any]],
866 | *,
867 | limit: Optional[int],
868 | tool_name: str,
869 | action_name: str,
870 | timeout: float = 30.0,
871 | ) -> Tuple[List[str], List[Dict[str, Any]], int]:
872 | """Core async SQL executor helper."""
873 | db_dialect = eng.dialect.name
874 | start_time = time.perf_counter()
875 |
876 | if _Q_CNT:
877 | # Chained call okay
878 | _Q_CNT.labels(tool=tool_name, action=action_name, db=db_dialect).inc()
879 |
880 | cols: List[str] = []
881 | rows_raw: List[Any] = []
882 | row_count: int = 0
883 | masked_rows: List[Dict[str, Any]] = []
884 |
885 | async def _run(conn: AsyncConnection):
886 | nonlocal cols, rows_raw, row_count, masked_rows
887 | statement = text(sql)
888 | query_params = params or {}
889 | try:
890 | res = await conn.execute(statement, query_params)
891 | has_cursor = res.cursor is not None
892 | has_description = has_cursor and res.cursor.description is not None
893 | if not has_cursor or not has_description:
894 | logger.debug(f"Query did not return rows or description. Action: {action_name}")
895 | # Ternary okay
896 | res_rowcount = res.rowcount if res.rowcount >= 0 else 0
897 | row_count = res_rowcount
898 | masked_rows = [] # Ensure it's an empty list
899 | empty_cols: List[str] = []
900 | empty_rows: List[Dict[str, Any]] = []
901 | return empty_cols, empty_rows, row_count # Return empty lists for cols/rows
902 |
903 | cols = list(res.keys())
904 | try:
905 | # --- START: Restored SQLite Handling ---
906 | is_sqlite = db_dialect == "sqlite"
907 | if is_sqlite:
908 | # aiosqlite fetchall/fetchmany might not work reliably with async iteration or limits in all cases
909 | # Fetch all as mappings (dicts) directly
910 | # Lambda okay if single line
911 | def sync_lambda(sync_conn):
912 | return list(sync_conn.execute(statement, query_params).mappings())
913 |
914 | all_rows_mapped = await conn.run_sync(sync_lambda)
915 | rows_raw = all_rows_mapped # Keep the dict list format
916 | needs_limit = limit is not None and limit >= 0
917 | if needs_limit:
918 | rows_raw = rows_raw[:limit] # Apply limit in Python
919 | else:
920 | # Standard async fetching for other dialects
921 | needs_limit = limit is not None and limit >= 0
922 | if needs_limit:
923 | fetched_rows = await res.fetchmany(limit) # Returns Row objects
924 | rows_raw = fetched_rows
925 | else:
926 | fetched_rows = await res.fetchall() # Returns Row objects
927 | rows_raw = fetched_rows
928 | # --- END: Restored SQLite Handling ---
929 |
930 | row_count = len(rows_raw) # Count based on fetched/limited rows
931 |
932 | except Exception as fetch_err:
933 | log_msg = f"Error fetching rows for {tool_name}/{action_name}: {fetch_err}"
934 | logger.error(log_msg, exc_info=True)
935 | query_preview = sql[:100] + "..."
936 | details = {"query": query_preview}
937 | raise ToolError(
938 | f"Error fetching results: {fetch_err}", http_status_code=500, details=details
939 | ) from fetch_err
940 |
941 | # Apply masking using _sql_mask_row which uses global rules
942 | # Adjust masking based on fetched format
943 | if is_sqlite:
944 | # List comprehension okay
945 | masked_rows_list = [_sql_mask_row(r) for r in rows_raw] # Already dicts
946 | masked_rows = masked_rows_list
947 | else:
948 | # List comprehension okay
949 | masked_rows_list = [
950 | _sql_mask_row(r._mapping) for r in rows_raw
951 | ] # Convert Row objects
952 | masked_rows = masked_rows_list
953 |
954 | return cols, masked_rows, row_count
955 |
956 | except (ProgrammingError, OperationalError) as db_err:
957 | err_type_name = type(db_err).__name__
958 | log_msg = f"Database execution error ({err_type_name}) for {tool_name}/{action_name} on {db_dialect}: {db_err}"
959 | logger.error(log_msg, exc_info=True)
960 | query_preview = sql[:100] + "..."
961 | details = {"db_error": str(db_err), "query": query_preview}
962 | raise ToolError(
963 | f"Database Error: {db_err}", http_status_code=400, details=details
964 | ) from db_err
965 | except SQLAlchemyError as sa_err:
966 | err_type_name = type(sa_err).__name__
967 | log_msg = f"SQLAlchemy error ({err_type_name}) for {tool_name}/{action_name} on {db_dialect}: {sa_err}"
968 | logger.error(log_msg, exc_info=True)
969 | query_preview = sql[:100] + "..."
970 | details = {"sqlalchemy_error": str(sa_err), "query": query_preview}
971 | raise ToolError(
972 | f"SQLAlchemy Error: {sa_err}", http_status_code=500, details=details
973 | ) from sa_err
974 | except Exception as e: # Catch other potential errors within _run
975 | log_msg = f"Unexpected error within _run for {tool_name}/{action_name}: {e}"
976 | logger.error(log_msg, exc_info=True)
977 | raise ToolError(
978 | f"Unexpected error during query execution step: {e}", http_status_code=500
979 | ) from e
980 |
981 | try:
982 | async with eng.connect() as conn:
983 | # Run within timeout
984 | # Call okay
985 | run_coro = _run(conn)
986 | cols_res, masked_rows_res, cnt_res = await asyncio.wait_for(run_coro, timeout=timeout)
987 | cols = cols_res
988 | masked_rows = masked_rows_res
989 | cnt = cnt_res
990 |
991 | latency = time.perf_counter() - start_time
992 | if _Q_LAT:
993 | # Chained call okay
994 | _Q_LAT.labels(tool=tool_name, action=action_name, db=db_dialect).observe(latency)
995 | log_msg = f"Execution successful for {tool_name}/{action_name}. Latency: {latency:.3f}s, Rows fetched: {cnt}"
996 | logger.debug(log_msg)
997 | return cols, masked_rows, cnt
998 |
999 | except asyncio.TimeoutError:
1000 | log_msg = (
1001 | f"Query timeout ({timeout}s) exceeded for {tool_name}/{action_name} on {db_dialect}."
1002 | )
1003 | logger.warning(log_msg)
1004 | query_preview = sql[:100] + "..."
1005 | details = {"timeout": timeout, "query": query_preview}
1006 | raise ToolError(
1007 | f"Query timed out after {timeout} seconds", http_status_code=504, details=details
1008 | ) from None
1009 | except ToolError:
1010 | # Re-raise known ToolErrors
1011 | raise
1012 | except Exception as e:
1013 | log_msg = f"Unexpected error during _sql_exec for {tool_name}/{action_name}: {e}"
1014 | logger.error(log_msg, exc_info=True)
1015 | details = {"error_type": type(e).__name__}
1016 | raise ToolError(
1017 | f"An unexpected error occurred: {e}", http_status_code=500, details=details
1018 | ) from e
1019 |
1020 |
1021 | def _sql_export_rows(
1022 | cols: List[str],
1023 | rows: List[Dict[str, Any]],
1024 | export_format: str,
1025 | export_path: Optional[str] = None,
1026 | ) -> Tuple[Any | None, str | None]:
1027 | """Export query results helper."""
1028 | if not export_format:
1029 | return None, None
1030 | export_format_lower = export_format.lower()
1031 | supported_formats = ["pandas", "excel", "csv"]
1032 | if export_format_lower not in supported_formats:
1033 | details = {"format": export_format}
1034 | msg = f"Unsupported export format: '{export_format}'. Use 'pandas', 'excel', or 'csv'."
1035 | raise ToolInputError(msg, param_name="export.format", details=details)
1036 | if pd is None:
1037 | details = {"library": "pandas"}
1038 | msg = f"Pandas library is not installed. Cannot export to '{export_format_lower}'."
1039 | raise ToolError(msg, http_status_code=501, details=details)
1040 |
1041 | try:
1042 | # Ternary okay
1043 | df = pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
1044 | logger.info(f"Created DataFrame with shape {df.shape} for export.")
1045 | except Exception as e:
1046 | logger.error(f"Error creating Pandas DataFrame: {e}", exc_info=True)
1047 | raise ToolError(f"Failed to create DataFrame for export: {e}", http_status_code=500) from e
1048 |
1049 | if export_format_lower == "pandas":
1050 | logger.debug("Returning raw Pandas DataFrame.")
1051 | return df, None
1052 |
1053 | final_path: str
1054 | temp_file_created = False
1055 | if export_path:
1056 | try:
1057 | # Chained calls okay
1058 | path_obj = Path(export_path)
1059 | path_expanded = path_obj.expanduser()
1060 | path_resolved = path_expanded.resolve()
1061 | parent_dir = path_resolved.parent
1062 | parent_dir.mkdir(parents=True, exist_ok=True)
1063 | final_path = str(path_resolved)
1064 | logger.info(f"Using specified export path: {final_path}")
1065 | except OSError as e:
1066 | details = {"path": export_path}
1067 | raise ToolError(
1068 | f"Cannot create directory for export path '{export_path}': {e}",
1069 | http_status_code=500,
1070 | details=details,
1071 | ) from e
1072 | except Exception as e: # Catch other path errors
1073 | details = {"path": export_path}
1074 | msg = f"Invalid export path provided: {export_path}. Error: {e}"
1075 | raise ToolInputError(msg, param_name="export.path", details=details) from e
1076 | else:
1077 | # Ternary okay
1078 | suffix = ".xlsx" if export_format_lower == "excel" else ".csv"
1079 | try:
1080 | prefix = f"mcp_export_{export_format_lower}_"
1081 | fd, final_path_temp = tempfile.mkstemp(suffix=suffix, prefix=prefix)
1082 | final_path = final_path_temp
1083 | os.close(fd)
1084 | temp_file_created = True
1085 | logger.info(f"Created temporary file for export: {final_path}")
1086 | except Exception as e:
1087 | logger.error(f"Failed to create temporary file for export: {e}", exc_info=True)
1088 | raise ToolError(f"Failed to create temporary file: {e}", http_status_code=500) from e
1089 |
1090 | try:
1091 | if export_format_lower == "excel":
1092 | df.to_excel(final_path, index=False, engine="xlsxwriter")
1093 | elif export_format_lower == "csv":
1094 | df.to_csv(final_path, index=False)
1095 | log_msg = f"Exported data to {export_format_lower.upper()} file: {final_path}"
1096 | logger.info(log_msg)
1097 | return None, final_path
1098 | except Exception as e:
1099 | log_msg = f"Error exporting DataFrame to {export_format_lower} file '{final_path}': {e}"
1100 | logger.error(log_msg, exc_info=True)
1101 | path_exists = Path(final_path).exists()
1102 | if temp_file_created and path_exists:
1103 | try:
1104 | Path(final_path).unlink()
1105 | except OSError:
1106 | logger.warning(f"Could not clean up temporary export file: {final_path}")
1107 | raise ToolError(
1108 | f"Failed to export data to {export_format_lower}: {e}", http_status_code=500
1109 | ) from e
1110 |
1111 |
1112 | async def _sql_validate_df(df: Any, schema: Any | None) -> None:
1113 | """Validate DataFrame against Pandera schema helper."""
1114 | if schema is None:
1115 | logger.debug("No Pandera schema provided for validation.")
1116 | return
1117 | if pa is None:
1118 | logger.warning("Pandera library not installed, skipping validation.")
1119 | return
1120 | is_pandas_df = pd is not None and isinstance(df, pd.DataFrame)
1121 | if not is_pandas_df:
1122 | logger.warning("Pandas DataFrame not available for validation.")
1123 | return
1124 |
1125 | logger.info(f"Validating DataFrame (shape {df.shape}) against provided Pandera schema.")
1126 | try:
1127 | schema.validate(df, lazy=True)
1128 | logger.info("Pandera validation successful.")
1129 | except pa.errors.SchemaErrors as se:
1130 | # Ternary okay
1131 | error_details_df = se.failure_cases
1132 | can_dict = hasattr(error_details_df, "to_dict")
1133 | error_details = (
1134 | error_details_df.to_dict(orient="records") if can_dict else str(error_details_df)
1135 | )
1136 | # Ternary okay
1137 | can_len = hasattr(error_details_df, "__len__")
1138 | error_count = len(error_details_df) if can_len else "multiple"
1139 |
1140 | log_msg = f"Pandera validation failed with {error_count} errors. Details: {error_details}"
1141 | logger.warning(log_msg)
1142 |
1143 | # Break down error message construction
1144 | error_msg_base = f"Pandera validation failed ({error_count} errors):\n"
1145 | error_msg_lines = []
1146 | error_details_list = error_details if isinstance(error_details, list) else []
1147 | errors_to_show = error_details_list[:5]
1148 |
1149 | for err in errors_to_show:
1150 | col = err.get("column", "N/A")
1151 | check = err.get("check", "N/A")
1152 | index = err.get("index", "N/A")
1153 | fail_case_raw = err.get("failure_case", "N/A")
1154 | fail_case_str = str(fail_case_raw)[:50]
1155 | line = f"- Column '{col}': {check} failed for index {index}. Data: {fail_case_str}..."
1156 | error_msg_lines.append(line)
1157 |
1158 | error_msg = error_msg_base + "\n".join(error_msg_lines)
1159 |
1160 | num_errors = error_count if isinstance(error_count, int) else 0
1161 | if num_errors > 5:
1162 | more_errors_count = num_errors - 5
1163 | error_msg += f"\n... and {more_errors_count} more errors."
1164 |
1165 | validation_errors = error_details # Pass the original structure
1166 | details = {"error_type": "VALIDATION_ERROR"}
1167 | raise ToolError(
1168 | error_msg, http_status_code=422, validation_errors=validation_errors, details=details
1169 | ) from se
1170 | except Exception as e:
1171 | logger.error(f"Unexpected error during Pandera validation: {e}", exc_info=True)
1172 | raise ToolError(
1173 | f"An unexpected error occurred during schema validation: {e}", http_status_code=500
1174 | ) from e
1175 |
1176 |
1177 | async def _sql_convert_nl_to_sql(
1178 | connection_id: str,
1179 | natural_language: str,
1180 | confidence_threshold: float = 0.6,
1181 | user_id: Optional[str] = None, # Added for lineage
1182 | session_id: Optional[str] = None, # Added for lineage
1183 | ) -> Dict[str, Any]:
1184 | """Helper method to convert natural language to SQL."""
1185 | # (Implementation largely the same, uses _sql_get_engine, _sql_check_safe, global state _SCHEMA_VERSIONS, _LINEAGE)
1186 | nl_preview = natural_language[:100]
1187 | logger.info(f"Converting NL to SQL for connection {connection_id}. Query: '{nl_preview}...'")
1188 | eng = await _sql_get_engine(connection_id)
1189 |
1190 | def _get_schema_fingerprint_sync(sync_conn) -> str:
1191 | # (Schema fingerprint sync helper implementation is the same)
1192 | try:
1193 | sync_inspector = sa_inspect(sync_conn)
1194 | tbls = []
1195 | schema_names = sync_inspector.get_schema_names()
1196 | default_schema = sync_inspector.default_schema_name
1197 | # List comprehension okay
1198 | other_schemas = [s for s in schema_names if s != default_schema]
1199 | schemas_to_inspect = [default_schema] + other_schemas
1200 |
1201 | for schema_name in schemas_to_inspect:
1202 | # Ternary okay
1203 | prefix = f"{schema_name}." if schema_name and schema_name != default_schema else ""
1204 | table_names_in_schema = sync_inspector.get_table_names(schema=schema_name)
1205 | for t in table_names_in_schema:
1206 | try:
1207 | cols = sync_inspector.get_columns(t, schema=schema_name)
1208 | # List comprehension okay
1209 | col_defs = [f"{c['name']}:{str(c['type']).split('(')[0]}" for c in cols]
1210 | col_defs_str = ",".join(col_defs)
1211 | tbl_def = f"{prefix}{t}({col_defs_str})"
1212 | tbls.append(tbl_def)
1213 | except Exception as col_err:
1214 | logger.warning(f"Could not get columns for table {prefix}{t}: {col_err}")
1215 | tbl_def_err = f"{prefix}{t}(...)"
1216 | tbls.append(tbl_def_err)
1217 | # Call okay
1218 | fp = "; ".join(sorted(tbls))
1219 | if not fp:
1220 | logger.warning("Schema fingerprint generation resulted in empty string.")
1221 | return "Error: Could not retrieve schema."
1222 | return fp
1223 | except Exception as e:
1224 | logger.error(f"Error in _get_schema_fingerprint_sync: {e}", exc_info=True)
1225 | return "Error: Could not retrieve schema."
1226 |
1227 | async def _get_schema_fingerprint(conn: AsyncConnection) -> str:
1228 | logger.debug("Generating schema fingerprint for NL->SQL...")
1229 | try:
1230 | # Lambda okay
1231 | def sync_func(sync_conn):
1232 | return _get_schema_fingerprint_sync(sync_conn)
1233 |
1234 | fingerprint = await conn.run_sync(sync_func)
1235 | return fingerprint
1236 | except Exception as e:
1237 | logger.error(f"Error generating schema fingerprint: {e}", exc_info=True)
1238 | return "Error: Could not retrieve schema."
1239 |
1240 | async with eng.connect() as conn:
1241 | schema_fingerprint = await _get_schema_fingerprint(conn)
1242 |
1243 | # Multi-line string assignment okay
1244 | prompt = (
1245 | "You are a highly specialized AI assistant that translates natural language questions into SQL queries.\n"
1246 | "You must adhere STRICTLY to the following rules:\n"
1247 | "1. Generate only a SINGLE, executable SQL query for the given database schema and question.\n"
1248 | "2. Use the exact table and column names provided in the schema fingerprint.\n"
1249 | "3. Do NOT generate any explanatory text, comments, or markdown formatting.\n"
1250 | "4. The output MUST be a valid JSON object containing two keys: 'sql' (the generated SQL query as a string) and 'confidence' (a float between 0.0 and 1.0 indicating your confidence in the generated SQL).\n"
1251 | "5. If the question cannot be answered from the schema or is ambiguous, set confidence to 0.0 and provide a minimal, safe query like 'SELECT 1;' in the 'sql' field.\n"
1252 | "6. Prioritize safety: Avoid generating queries that could modify data (UPDATE, INSERT, DELETE, DROP, etc.). Generate SELECT statements ONLY.\n\n" # Stricter rule
1253 | f"Database Schema Fingerprint:\n```\n{schema_fingerprint}\n```\n\n"
1254 | f"Natural Language Question:\n```\n{natural_language}\n```\n\n"
1255 | "JSON Output:"
1256 | )
1257 |
1258 | try:
1259 | logger.debug("Sending prompt to LLM for NL->SQL conversion.")
1260 | # Call okay
1261 | completion_result = await generate_completion(
1262 | prompt=prompt, max_tokens=350, temperature=0.2
1263 | )
1264 | llm_response_dict = completion_result
1265 |
1266 | # Ternary okay
1267 | is_dict_response = isinstance(llm_response_dict, dict)
1268 | llm_response = llm_response_dict.get("text", "") if is_dict_response else ""
1269 |
1270 | llm_response_preview = llm_response[:300]
1271 | logger.debug(f"LLM Response received: {llm_response_preview}...")
1272 | if not llm_response:
1273 | raise ToolError("LLM returned empty response for NL->SQL.", http_status_code=502)
1274 |
1275 | except Exception as llm_err:
1276 | logger.error(f"LLM completion failed for NL->SQL: {llm_err}", exc_info=True)
1277 | details = {"error_type": "LLM_ERROR"}
1278 | raise ToolError(
1279 | f"Failed to get response from LLM: {llm_err}", http_status_code=502, details=details
1280 | ) from llm_err
1281 |
1282 | try:
1283 | data = {}
1284 | try:
1285 | # Try parsing the whole response as JSON first
1286 | data = json.loads(llm_response)
1287 | except json.JSONDecodeError as e:
1288 | # If that fails, look for a JSON block within the text
1289 | # Call okay
1290 | search_regex = r"\{.*\}"
1291 | search_flags = re.DOTALL | re.MULTILINE
1292 | json_match = re.search(search_regex, llm_response, search_flags)
1293 | if not json_match:
1294 | raise ValueError("No JSON object found in the LLM response.") from e
1295 | json_str = json_match.group(0)
1296 | data = json.loads(json_str)
1297 |
1298 | is_dict_data = isinstance(data, dict)
1299 | has_sql = "sql" in data
1300 | has_confidence = "confidence" in data
1301 | if not is_dict_data or not has_sql or not has_confidence:
1302 | raise ValueError("LLM response JSON is missing required keys ('sql', 'confidence').")
1303 |
1304 | sql = data["sql"]
1305 | conf_raw = data["confidence"]
1306 | conf = float(conf_raw)
1307 |
1308 | is_sql_str = isinstance(sql, str)
1309 | is_conf_valid = 0.0 <= conf <= 1.0
1310 | if not is_sql_str or not is_conf_valid:
1311 | raise ValueError("LLM response has invalid types for 'sql' or 'confidence'.")
1312 |
1313 | sql_preview = sql[:150]
1314 | logger.info(f"LLM generated SQL with confidence {conf:.2f}: {sql_preview}...")
1315 |
1316 | except (json.JSONDecodeError, ValueError, TypeError) as e:
1317 | response_preview = str(llm_response)[:200]
1318 | error_detail = (
1319 | f"LLM returned invalid or malformed JSON: {e}. Response: '{response_preview}...'"
1320 | )
1321 | logger.error(error_detail)
1322 | details = {"error_type": "LLM_RESPONSE_INVALID"}
1323 | raise ToolError(error_detail, http_status_code=500, details=details) from e
1324 |
1325 | is_below_threshold = conf < confidence_threshold
1326 | if is_below_threshold:
1327 | nl_query_preview = natural_language
1328 | low_conf_msg = f"LLM confidence ({conf:.2f}) is below the required threshold ({confidence_threshold}). NL Query: '{nl_query_preview}'"
1329 | logger.warning(low_conf_msg)
1330 | details = {"error_type": "LOW_CONFIDENCE"}
1331 | raise ToolError(
1332 | low_conf_msg, http_status_code=400, generated_sql=sql, confidence=conf, details=details
1333 | ) from None
1334 |
1335 | try:
1336 | _sql_check_safe(sql, read_only=True) # Enforce read-only for generated SQL
1337 | # Call okay
1338 | sql_upper = sql.upper()
1339 | sql_stripped = sql_upper.lstrip()
1340 | is_valid_start = sql_stripped.startswith(("SELECT", "WITH"))
1341 | if not is_valid_start:
1342 | details = {"error_type": "INVALID_GENERATED_SQL"}
1343 | raise ToolError(
1344 | "Generated query does not appear to be a valid SELECT statement.",
1345 | http_status_code=400,
1346 | details=details,
1347 | )
1348 | # Basic table check (optional, as before)
1349 | except ToolInputError as safety_err:
1350 | logger.error(f"Generated SQL failed safety check: {safety_err}. SQL: {sql}")
1351 | details = {"error_type": "SAFETY_VIOLATION"}
1352 | raise ToolError(
1353 | f"Generated SQL failed validation: {safety_err}",
1354 | http_status_code=400,
1355 | generated_sql=sql,
1356 | confidence=conf,
1357 | details=details,
1358 | ) from safety_err
1359 |
1360 | result_dict = {"sql": sql, "confidence": conf}
1361 | return result_dict
1362 |
1363 |
1364 | # =============================================================================
1365 | # Public Tool Functions (Standalone replacements for SQLTool methods)
1366 | # =============================================================================
1367 |
1368 |
1369 | @with_tool_metrics
1370 | @with_error_handling
1371 | async def manage_database(
1372 | action: str,
1373 | connection_string: Optional[str] = None,
1374 | connection_id: Optional[str] = None,
1375 | echo: bool = False,
1376 | user_id: Optional[str] = None,
1377 | session_id: Optional[str] = None,
1378 | ctx: Optional[Dict] = None, # Added ctx for potential future use
1379 | **options: Any,
1380 | ) -> Dict[str, Any]:
1381 | """
1382 | Unified database connection management tool.
1383 |
1384 | Args:
1385 | action: The action to perform: "connect", "disconnect", "test", or "status".
1386 | connection_string: Database connection string or secrets:// reference. (Required for "connect").
1387 | connection_id: An existing connection ID (Required for "disconnect", "test"). Can be provided for "connect" to suggest an ID.
1388 | echo: Enable SQLAlchemy engine logging (For "connect" action, default: False).
1389 | user_id: Optional user identifier for audit logging.
1390 | session_id: Optional session identifier for audit logging.
1391 | ctx: Optional context from MCP server (not used currently).
1392 | **options: Additional options:
1393 | - For "connect": Passed directly to SQLAlchemy's `create_async_engine`.
1394 | - Can include custom audit context.
1395 |
1396 | Returns:
1397 | Dict with action results and metadata. Varies based on action.
1398 | """
1399 | tool_name = "manage_database"
1400 | db_dialect = "unknown"
1401 | # Dict comprehension okay
1402 | audit_extras_all = {k: v for k, v in options.items()}
1403 | audit_extras = {k: v for k, v in audit_extras_all.items() if k not in ["echo"]}
1404 |
1405 | try:
1406 | if action == "connect":
1407 | if not connection_string:
1408 | raise ToolInputError(
1409 | "connection_string is required for 'connect'", param_name="connection_string"
1410 | )
1411 | # Ternary okay
1412 | cid = connection_id or str(uuid.uuid4())
1413 | logger.info(f"Attempting to connect with connection_id: {cid}")
1414 | resolved_conn_str = _sql_resolve_conn(connection_string)
1415 | url, db_type = _sql_driver_url(resolved_conn_str)
1416 | db_dialect = db_type # Update dialect for potential error logging
1417 | pool_opts = _sql_auto_pool(db_type)
1418 | # Dict unpacking okay
1419 | engine_opts = {**pool_opts, **options}
1420 | # Dict comprehension okay
1421 | log_opts = {k: v for k, v in engine_opts.items() if k != "password"}
1422 | logger.debug(f"Creating engine for {db_type} with options: {log_opts}")
1423 | connect_args = engine_opts.pop("connect_args", {})
1424 | # Ternary okay
1425 | exec_opts_base = {"async_execution": True} if db_type == "snowflake" else {}
1426 | # Pass other engine options directly
1427 | execution_options = {**exec_opts_base, **engine_opts.pop("execution_options", {})}
1428 |
1429 | # Separate create_async_engine call
1430 | eng = create_async_engine(
1431 | url,
1432 | echo=echo,
1433 | connect_args=connect_args,
1434 | execution_options=execution_options,
1435 | **engine_opts, # Pass remaining options like pool settings
1436 | )
1437 |
1438 | try:
1439 | # Ternary okay
1440 | test_sql = "SELECT CURRENT_TIMESTAMP" if db_type != "sqlite" else "SELECT 1"
1441 | # Call okay
1442 | await _sql_exec(
1443 | eng,
1444 | test_sql,
1445 | None,
1446 | limit=1,
1447 | tool_name=tool_name,
1448 | action_name="connect_test",
1449 | timeout=15,
1450 | )
1451 | logger.info(f"Connection test successful for {cid} ({db_type}).")
1452 | except ToolError as test_err:
1453 | logger.error(f"Connection test failed for {cid} ({db_type}): {test_err}")
1454 | await eng.dispose()
1455 | # Get details okay
1456 | err_details = getattr(test_err, "details", None)
1457 | raise ToolError(
1458 | f"Connection test failed: {test_err}", http_status_code=400, details=err_details
1459 | ) from test_err
1460 | except Exception as e:
1461 | logger.error(
1462 | f"Unexpected error during connection test for {cid} ({db_type}): {e}",
1463 | exc_info=True,
1464 | )
1465 | await eng.dispose()
1466 | raise ToolError(
1467 | f"Unexpected error during connection test: {e}", http_status_code=500
1468 | ) from e
1469 |
1470 | await _connection_manager.add_connection(cid, eng)
1471 | # Call okay
1472 | await _sql_audit(
1473 | tool_name=tool_name,
1474 | action="connect",
1475 | connection_id=cid,
1476 | sql=None,
1477 | tables=None,
1478 | row_count=None,
1479 | success=True,
1480 | error=None,
1481 | user_id=user_id,
1482 | session_id=session_id,
1483 | database_type=db_type,
1484 | echo=echo,
1485 | **audit_extras,
1486 | )
1487 | # Return dict okay
1488 | return {
1489 | "action": "connect",
1490 | "connection_id": cid,
1491 | "database_type": db_type,
1492 | "success": True,
1493 | }
1494 |
1495 | elif action == "disconnect":
1496 | if not connection_id:
1497 | raise ToolInputError(
1498 | "connection_id is required for 'disconnect'", param_name="connection_id"
1499 | )
1500 | logger.info(f"Attempting to disconnect connection_id: {connection_id}")
1501 | db_dialect_for_audit = "unknown" # Default if engine retrieval fails
1502 | try:
1503 | # Needs await before get_connection
1504 | engine_to_close = await _connection_manager.get_connection(connection_id)
1505 | db_dialect_for_audit = engine_to_close.dialect.name
1506 | except ToolInputError:
1507 | # This error means connection_id wasn't found by get_connection
1508 | logger.warning(f"Disconnect requested for unknown connection_id: {connection_id}")
1509 | # Call okay
1510 | await _sql_audit(
1511 | tool_name=tool_name,
1512 | action="disconnect",
1513 | connection_id=connection_id,
1514 | sql=None,
1515 | tables=None,
1516 | row_count=None,
1517 | success=False,
1518 | error="Connection ID not found",
1519 | user_id=user_id,
1520 | session_id=session_id,
1521 | **audit_extras,
1522 | )
1523 | # Return dict okay
1524 | return {
1525 | "action": "disconnect",
1526 | "connection_id": connection_id,
1527 | "success": False,
1528 | "message": "Connection ID not found",
1529 | }
1530 | except Exception as e:
1531 | # Catch other errors during engine retrieval itself
1532 | logger.error(f"Error retrieving engine for disconnect ({connection_id}): {e}")
1533 | # Proceed to attempt close, but audit will likely show failure or non-existence
1534 |
1535 | # Attempt closing even if retrieval had issues (it might have been removed between check and close)
1536 | success = await _connection_manager.close_connection(connection_id)
1537 | # Ternary okay
1538 | error_msg = None if success else "Failed to close or already closed/not found"
1539 | # Call okay
1540 | await _sql_audit(
1541 | tool_name=tool_name,
1542 | action="disconnect",
1543 | connection_id=connection_id,
1544 | sql=None,
1545 | tables=None,
1546 | row_count=None,
1547 | success=success,
1548 | error=error_msg,
1549 | user_id=user_id,
1550 | session_id=session_id,
1551 | database_type=db_dialect_for_audit,
1552 | **audit_extras,
1553 | )
1554 | # Return dict okay
1555 | return {"action": "disconnect", "connection_id": connection_id, "success": success}
1556 |
1557 | elif action == "test":
1558 | if not connection_id:
1559 | raise ToolInputError(
1560 | "connection_id is required for 'test'", param_name="connection_id"
1561 | )
1562 | logger.info(f"Testing connection_id: {connection_id}")
1563 | eng = await _sql_get_engine(connection_id)
1564 | db_dialect = eng.dialect.name # Now dialect is known for sure
1565 | t0 = time.perf_counter()
1566 | # Ternary conditions okay
1567 | vsql = (
1568 | "SELECT sqlite_version()"
1569 | if db_dialect == "sqlite"
1570 | else "SELECT CURRENT_VERSION()"
1571 | if db_dialect == "snowflake"
1572 | else "SELECT version()"
1573 | )
1574 | # Call okay
1575 | cols, rows, _ = await _sql_exec(
1576 | eng, vsql, None, limit=1, tool_name=tool_name, action_name="test", timeout=10
1577 | )
1578 | latency = time.perf_counter() - t0
1579 | # Ternary okay
1580 | has_rows_and_cols = rows and cols
1581 | version_info = rows[0].get(cols[0], "N/A") if has_rows_and_cols else "N/A"
1582 | log_msg = f"Connection test successful for {connection_id}. Version: {version_info}, Latency: {latency:.3f}s"
1583 | logger.info(log_msg)
1584 | # Return dict okay
1585 | return {
1586 | "action": "test",
1587 | "connection_id": connection_id,
1588 | "response_time_seconds": round(latency, 3),
1589 | "version": version_info,
1590 | "database_type": db_dialect,
1591 | "success": True,
1592 | }
1593 |
1594 | elif action == "status":
1595 | logger.info("Retrieving connection status.")
1596 | connections_info = {}
1597 | current_time = time.time()
1598 | # Access connections safely using async with lock if needed, or make copy
1599 | conn_items = []
1600 | async with _connection_manager._lock: # Access lock directly for iteration safety
1601 | # Call okay
1602 | conn_items = list(_connection_manager.connections.items())
1603 |
1604 | for conn_id, (eng, last_access) in conn_items:
1605 | try:
1606 | url_display_raw = str(eng.url)
1607 | parsed_url = make_url(url_display_raw)
1608 | url_display = url_display_raw # Default
1609 | if parsed_url.password:
1610 | # Call okay
1611 | url_masked = parsed_url.set(password="***")
1612 | url_display = str(url_masked)
1613 | # Break down dict assignment
1614 | conn_info_dict = {}
1615 | conn_info_dict["url_summary"] = url_display
1616 | conn_info_dict["dialect"] = eng.dialect.name
1617 | last_access_dt = dt.datetime.fromtimestamp(last_access)
1618 | conn_info_dict["last_accessed"] = last_access_dt.isoformat()
1619 | idle_seconds = current_time - last_access
1620 | conn_info_dict["idle_time_seconds"] = round(idle_seconds, 1)
1621 | connections_info[conn_id] = conn_info_dict
1622 | except Exception as status_err:
1623 | logger.error(f"Error retrieving status for connection {conn_id}: {status_err}")
1624 | connections_info[conn_id] = {"error": str(status_err)}
1625 | # Return dict okay
1626 | return {
1627 | "action": "status",
1628 | "active_connections_count": len(connections_info),
1629 | "connections": connections_info,
1630 | "cleanup_interval_seconds": _connection_manager.cleanup_interval,
1631 | "success": True,
1632 | }
1633 |
1634 | else:
1635 | logger.error(f"Invalid action specified for manage_database: {action}")
1636 | details = {"action": action}
1637 | msg = f"Unknown action: '{action}'. Valid actions: connect, disconnect, test, status"
1638 | raise ToolInputError(msg, param_name="action", details=details)
1639 |
1640 | except ToolInputError as tie:
1641 | # Call okay
1642 | await _sql_audit(
1643 | tool_name=tool_name,
1644 | action=action,
1645 | connection_id=connection_id,
1646 | sql=None,
1647 | tables=None,
1648 | row_count=None,
1649 | success=False,
1650 | error=str(tie),
1651 | user_id=user_id,
1652 | session_id=session_id,
1653 | database_type=db_dialect,
1654 | **audit_extras,
1655 | )
1656 | raise tie
1657 | except ToolError as te:
1658 | # Call okay
1659 | await _sql_audit(
1660 | tool_name=tool_name,
1661 | action=action,
1662 | connection_id=connection_id,
1663 | sql=None,
1664 | tables=None,
1665 | row_count=None,
1666 | success=False,
1667 | error=str(te),
1668 | user_id=user_id,
1669 | session_id=session_id,
1670 | database_type=db_dialect,
1671 | **audit_extras,
1672 | )
1673 | raise te
1674 | except Exception as e:
1675 | log_msg = f"Unexpected error in manage_database (action: {action}): {e}"
1676 | logger.error(log_msg, exc_info=True)
1677 | error_str = f"Unexpected error: {e}"
1678 | # Call okay
1679 | await _sql_audit(
1680 | tool_name=tool_name,
1681 | action=action,
1682 | connection_id=connection_id,
1683 | sql=None,
1684 | tables=None,
1685 | row_count=None,
1686 | success=False,
1687 | error=error_str,
1688 | user_id=user_id,
1689 | session_id=session_id,
1690 | database_type=db_dialect,
1691 | **audit_extras,
1692 | )
1693 | raise ToolError(
1694 | f"An unexpected error occurred in manage_database: {e}", http_status_code=500
1695 | ) from e
1696 |
1697 |
1698 | @with_tool_metrics
1699 | @with_error_handling
1700 | async def execute_sql(
1701 | connection_id: str,
1702 | query: Optional[str] = None,
1703 | natural_language: Optional[str] = None,
1704 | parameters: Optional[Dict[str, Any]] = None,
1705 | pagination: Optional[Dict[str, int]] = None,
1706 | read_only: bool = True,
1707 | export: Optional[Dict[str, Any]] = None,
1708 | timeout: float = 60.0,
1709 | validate_schema: Optional[Any] = None,
1710 | max_rows: Optional[int] = 1000,
1711 | confidence_threshold: float = 0.6,
1712 | user_id: Optional[str] = None,
1713 | session_id: Optional[str] = None,
1714 | ctx: Optional[Dict] = None, # Added ctx
1715 | **options: Any,
1716 | ) -> Dict[str, Any]:
1717 | """
1718 | Unified SQL query execution tool.
1719 |
1720 | Handles direct SQL execution, NL-to-SQL conversion, pagination,
1721 | result masking, safety checks, validation, and export.
1722 |
1723 | Args:
1724 | connection_id: The ID of the database connection to use.
1725 | query: The SQL query string to execute. (Use instead of natural_language).
1726 | natural_language: A natural language question to convert to SQL. (Use instead of query).
1727 | parameters: Dictionary of parameters for parameterized queries.
1728 | pagination: Dict with "page" (>=1) and "page_size" (>=1) for paginated results.
1729 | Cannot be used with max_rows clipping if the dialect requires LIMIT/OFFSET.
1730 | read_only: If True (default), enforces safety checks against write operations (UPDATE, DELETE, etc.). Set to False only if writes are explicitly intended and allowed.
1731 | export: Dictionary with "format" ('pandas', 'excel', 'csv') and optional "path" (string) for exporting results.
1732 | timeout: Maximum execution time in seconds (default: 60.0).
1733 | validate_schema: A Pandera schema object to validate the results DataFrame against.
1734 | max_rows: Maximum number of rows to return in the result (default: 1000). Set to None or -1 for unlimited (potentially dangerous).
1735 | confidence_threshold: Minimum confidence score (0.0-1.0) required from the LLM for NL-to-SQL conversion (default: 0.6).
1736 | user_id: Optional user identifier for audit logging.
1737 | session_id: Optional session identifier for audit logging.
1738 | ctx: Optional context from MCP server.
1739 | **options: Additional options for audit logging or future extensions.
1740 |
1741 | Returns:
1742 | A dictionary containing:
1743 | - columns (List[str]): List of column names.
1744 | - rows (List[Dict[str, Any]]): List of data rows (masked).
1745 | - row_count (int): Number of rows returned in this batch/page.
1746 | - truncated (bool): True if max_rows limited the results.
1747 | - pagination (Optional[Dict]): Info about the current page if pagination was used.
1748 | - generated_sql (Optional[str]): The SQL query generated from natural language, if applicable.
1749 | - confidence (Optional[float]): The confidence score from the NL-to-SQL conversion, if applicable.
1750 | - validation_status (Optional[str]): 'success', 'failed', 'skipped'.
1751 | - validation_errors (Optional[Any]): Details if validation failed.
1752 | - export_status (Optional[str]): Status message if export was attempted.
1753 | - <format>_path (Optional[str]): Path to the exported file if export to file was successful.
1754 | - dataframe (Optional[pd.DataFrame]): The raw Pandas DataFrame if export format was 'pandas'.
1755 | - success (bool): Always True if no exception was raised.
1756 | """
1757 | tool_name = "execute_sql"
1758 | action_name = "query" # Default, may change
1759 | original_query_input = query # Keep track of original SQL input
1760 | original_nl_input = natural_language # Keep track of NL input
1761 | generated_sql = None
1762 | confidence = None
1763 | final_query: str
1764 | # Dict unpacking okay
1765 | final_params = parameters or {}
1766 | result: Dict[str, Any] = {}
1767 | tables: List[str] = []
1768 | # Dict unpacking okay
1769 | audit_extras = {**options}
1770 |
1771 | try:
1772 | # 1. Determine Query
1773 | use_nl = natural_language and not query
1774 | use_sql = query and not natural_language
1775 | is_ambiguous = natural_language and query
1776 | no_input = not natural_language and not query
1777 |
1778 | if is_ambiguous:
1779 | msg = "Provide either 'query' or 'natural_language', not both."
1780 | raise ToolInputError(msg, param_name="query/natural_language")
1781 | if no_input:
1782 | msg = "Either 'query' or 'natural_language' must be provided."
1783 | raise ToolInputError(msg, param_name="query/natural_language")
1784 |
1785 | if use_nl:
1786 | action_name = "nl_to_sql_exec"
1787 | nl_preview = natural_language[:100]
1788 | log_msg = (
1789 | f"Received natural language query for connection {connection_id}: '{nl_preview}...'"
1790 | )
1791 | logger.info(log_msg)
1792 | try:
1793 | # Pass user_id/session_id to NL converter for lineage/audit trail consistency if needed
1794 | # Call okay
1795 | nl_result = await _sql_convert_nl_to_sql(
1796 | connection_id, natural_language, confidence_threshold, user_id, session_id
1797 | )
1798 | final_query = nl_result["sql"]
1799 | generated_sql = final_query
1800 | confidence = nl_result["confidence"]
1801 | # original_query remains None, original_nl_input has the NL
1802 | audit_extras["generated_sql"] = generated_sql
1803 | audit_extras["confidence"] = confidence
1804 | query_preview = final_query[:150]
1805 | log_msg = f"Successfully converted NL to SQL (Confidence: {confidence:.2f}): {query_preview}..."
1806 | logger.info(log_msg)
1807 | read_only = True # Ensure read-only for generated SQL
1808 | except ToolError as nl_err:
1809 | # Audit NL failure
1810 | await _sql_audit(
1811 | tool_name=tool_name,
1812 | action="nl_to_sql_fail",
1813 | connection_id=connection_id,
1814 | sql=natural_language, # Log the NL query that failed
1815 | tables=None,
1816 | row_count=None,
1817 | success=False,
1818 | error=str(nl_err),
1819 | user_id=user_id,
1820 | session_id=session_id,
1821 | **audit_extras,
1822 | )
1823 | raise nl_err # Re-raise the error
1824 | elif use_sql:
1825 | # Action name remains 'query'
1826 | final_query = query
1827 | query_preview = final_query[:150]
1828 | logger.info(f"Executing direct SQL query on {connection_id}: {query_preview}...")
1829 | # original_query_input has the SQL, original_nl_input is None
1830 | # else case already handled by initial checks
1831 |
1832 | # 2. Check Safety
1833 | _sql_check_safe(final_query, read_only)
1834 | tables = _sql_extract_tables(final_query)
1835 | logger.debug(f"Query targets tables: {tables}")
1836 |
1837 | # 3. Get Engine
1838 | eng = await _sql_get_engine(connection_id)
1839 |
1840 | # 4. Handle Pagination or Standard Execution
1841 | if pagination:
1842 | action_name = "query_paginated"
1843 | page = pagination.get("page", 1)
1844 | page_size = pagination.get("page_size", 100)
1845 | is_page_valid = isinstance(page, int) and page >= 1
1846 | is_page_size_valid = isinstance(page_size, int) and page_size >= 1
1847 | if not is_page_valid:
1848 | raise ToolInputError(
1849 | "Pagination 'page' must be an integer >= 1.", param_name="pagination.page"
1850 | )
1851 | if not is_page_size_valid:
1852 | raise ToolInputError(
1853 | "Pagination 'page_size' must be an integer >= 1.",
1854 | param_name="pagination.page_size",
1855 | )
1856 |
1857 | offset = (page - 1) * page_size
1858 | db_dialect = eng.dialect.name
1859 | paginated_query: str
1860 | if db_dialect == "sqlserver":
1861 | query_lower = final_query.lower()
1862 | has_order_by = "order by" in query_lower
1863 | if not has_order_by:
1864 | raise ToolInputError(
1865 | "SQL Server pagination requires an ORDER BY clause in the query.",
1866 | param_name="query",
1867 | )
1868 | paginated_query = (
1869 | f"{final_query} OFFSET :_page_offset ROWS FETCH NEXT :_page_size ROWS ONLY"
1870 | )
1871 | elif db_dialect == "oracle":
1872 | paginated_query = (
1873 | f"{final_query} OFFSET :_page_offset ROWS FETCH NEXT :_page_size ROWS ONLY"
1874 | )
1875 | else: # Default LIMIT/OFFSET for others (MySQL, PostgreSQL, SQLite)
1876 | paginated_query = f"{final_query} LIMIT :_page_size OFFSET :_page_offset"
1877 |
1878 | # Fetch one extra row to check for next page
1879 | fetch_size = page_size + 1
1880 | # Dict unpacking okay
1881 | paginated_params = {**final_params, "_page_size": fetch_size, "_page_offset": offset}
1882 | log_msg = (
1883 | f"Executing paginated query (Page: {page}, Size: {page_size}): {paginated_query}"
1884 | )
1885 | logger.debug(log_msg)
1886 | # Call okay
1887 | cols, rows_with_extra, fetched_count_paged = await _sql_exec(
1888 | eng,
1889 | paginated_query,
1890 | paginated_params,
1891 | limit=None, # Limit is applied in SQL for pagination
1892 | tool_name=tool_name,
1893 | action_name=action_name,
1894 | timeout=timeout,
1895 | )
1896 |
1897 | # Check if more rows exist than requested page size
1898 | has_next_page = len(rows_with_extra) > page_size
1899 | returned_rows = rows_with_extra[:page_size]
1900 | returned_row_count = len(returned_rows)
1901 |
1902 | # Build result dict piece by piece
1903 | pagination_info = {}
1904 | pagination_info["page"] = page
1905 | pagination_info["page_size"] = page_size
1906 | pagination_info["has_next_page"] = has_next_page
1907 | pagination_info["has_previous_page"] = page > 1
1908 |
1909 | result = {}
1910 | result["columns"] = cols
1911 | result["rows"] = returned_rows
1912 | result["row_count"] = returned_row_count
1913 | result["pagination"] = pagination_info
1914 | result["truncated"] = False # Not truncated by max_rows in pagination mode
1915 | result["success"] = True
1916 |
1917 | else: # Standard execution (no pagination dict)
1918 | action_name = "query_standard"
1919 | # Ternary okay
1920 | needs_limit = max_rows is not None and max_rows >= 0
1921 | fetch_limit = (max_rows + 1) if needs_limit else None
1922 |
1923 | query_preview = final_query[:150]
1924 | log_msg = f"Executing standard query (Max rows: {max_rows}): {query_preview}..."
1925 | logger.debug(log_msg)
1926 | # Call okay
1927 | cols, rows_maybe_extra, fetched_count = await _sql_exec(
1928 | eng,
1929 | final_query,
1930 | final_params,
1931 | limit=fetch_limit, # Use fetch_limit (max_rows + 1 or None)
1932 | tool_name=tool_name,
1933 | action_name=action_name,
1934 | timeout=timeout,
1935 | )
1936 |
1937 | # Determine truncation based on fetch_limit
1938 | truncated = fetch_limit is not None and fetched_count >= fetch_limit
1939 | # Apply actual max_rows limit to returned data
1940 | # Ternary okay
1941 | returned_rows = rows_maybe_extra[:max_rows] if needs_limit else rows_maybe_extra
1942 | returned_row_count = len(returned_rows)
1943 |
1944 | # Build result dict piece by piece
1945 | result = {}
1946 | result["columns"] = cols
1947 | result["rows"] = returned_rows
1948 | result["row_count"] = returned_row_count
1949 | result["truncated"] = truncated
1950 | result["success"] = True
1951 | # No pagination key in standard mode
1952 |
1953 | # Add NL->SQL info if applicable
1954 | if generated_sql:
1955 | result["generated_sql"] = generated_sql
1956 | result["confidence"] = confidence
1957 |
1958 | # 5. Handle Validation
1959 | if validate_schema:
1960 | temp_df = None
1961 | validation_status = "skipped (unknown reason)"
1962 | validation_errors = None
1963 | if pd:
1964 | try:
1965 | # Ternary okay
1966 | df_data = result["rows"]
1967 | df_cols = result["columns"]
1968 | temp_df = (
1969 | pd.DataFrame(df_data, columns=df_cols)
1970 | if df_data
1971 | else pd.DataFrame(columns=df_cols)
1972 | )
1973 | try:
1974 | # Call okay
1975 | await _sql_validate_df(temp_df, validate_schema)
1976 | validation_status = "success"
1977 | logger.info("Pandera validation passed.")
1978 | except ToolError as val_err:
1979 | logger.warning(f"Pandera validation failed: {val_err}")
1980 | validation_status = "failed"
1981 | # Get validation errors okay
1982 | validation_errors = getattr(val_err, "validation_errors", str(val_err))
1983 | except Exception as df_err:
1984 | logger.error(f"Error creating DataFrame for validation: {df_err}")
1985 | validation_status = f"skipped (Failed to create DataFrame: {df_err})"
1986 | else:
1987 | logger.warning("Pandas not installed, skipping Pandera validation.")
1988 | validation_status = "skipped (Pandas not installed)"
1989 |
1990 | result["validation_status"] = validation_status
1991 | if validation_errors:
1992 | result["validation_errors"] = validation_errors
1993 |
1994 | # 6. Handle Export
1995 | export_requested = export and export.get("format")
1996 | if export_requested:
1997 | export_format = export["format"] # Keep original case for path key
1998 | export_format_lower = export_format.lower()
1999 | req_path = export.get("path")
2000 | log_msg = f"Export requested: Format={export_format}, Path={req_path or 'Temporary'}"
2001 | logger.info(log_msg)
2002 | export_status = "failed (unknown reason)"
2003 | try:
2004 | # Call okay
2005 | dataframe, export_path = _sql_export_rows(
2006 | result["columns"], result["rows"], export_format_lower, req_path
2007 | )
2008 | export_status = "success"
2009 | if dataframe is not None: # Only if format was 'pandas'
2010 | result["dataframe"] = dataframe
2011 | if export_path: # If file was created
2012 | path_key = f"{export_format_lower}_path" # Use lowercase format for key
2013 | result[path_key] = export_path
2014 | log_msg = f"Export successful. Format: {export_format}, Path: {export_path or 'In-memory DataFrame'}"
2015 | logger.info(log_msg)
2016 | audit_extras["export_format"] = export_format
2017 | audit_extras["export_path"] = export_path
2018 | except (ToolError, ToolInputError) as export_err:
2019 | logger.error(f"Export failed: {export_err}")
2020 | export_status = f"Failed: {export_err}"
2021 | result["export_status"] = export_status
2022 |
2023 | # 7. Audit Success
2024 | # Determine which query to log based on input
2025 | audit_sql = original_nl_input if use_nl else original_query_input
2026 | audit_row_count = result.get("row_count", 0)
2027 | audit_val_status = result.get("validation_status")
2028 | audit_exp_status = result.get("export_status", "not requested")
2029 | # Call okay
2030 | await _sql_audit(
2031 | tool_name=tool_name,
2032 | action=action_name,
2033 | connection_id=connection_id,
2034 | sql=audit_sql,
2035 | tables=tables,
2036 | row_count=audit_row_count,
2037 | success=True,
2038 | error=None,
2039 | user_id=user_id,
2040 | session_id=session_id,
2041 | read_only=read_only,
2042 | pagination_used=bool(pagination),
2043 | validation_status=audit_val_status,
2044 | export_status=audit_exp_status,
2045 | **audit_extras,
2046 | )
2047 | return result
2048 |
2049 | except ToolInputError as tie:
2050 | # Audit failure, use original inputs for logging context
2051 | audit_sql = original_nl_input if original_nl_input else original_query_input
2052 | # Call okay
2053 | await _sql_audit(
2054 | tool_name=tool_name,
2055 | action=action_name + "_fail",
2056 | connection_id=connection_id,
2057 | sql=audit_sql,
2058 | tables=tables,
2059 | row_count=0,
2060 | success=False,
2061 | error=str(tie),
2062 | user_id=user_id,
2063 | session_id=session_id,
2064 | **audit_extras,
2065 | )
2066 | raise tie
2067 | except ToolError as te:
2068 | # Audit failure
2069 | audit_sql = original_nl_input if original_nl_input else original_query_input
2070 | # Call okay
2071 | await _sql_audit(
2072 | tool_name=tool_name,
2073 | action=action_name + "_fail",
2074 | connection_id=connection_id,
2075 | sql=audit_sql,
2076 | tables=tables,
2077 | row_count=0,
2078 | success=False,
2079 | error=str(te),
2080 | user_id=user_id,
2081 | session_id=session_id,
2082 | **audit_extras,
2083 | )
2084 | raise te
2085 | except Exception as e:
2086 | log_msg = f"Unexpected error in execute_sql (action: {action_name}): {e}"
2087 | logger.error(log_msg, exc_info=True)
2088 | # Audit failure
2089 | audit_sql = original_nl_input if original_nl_input else original_query_input
2090 | error_str = f"Unexpected error: {e}"
2091 | # Call okay
2092 | await _sql_audit(
2093 | tool_name=tool_name,
2094 | action=action_name + "_fail",
2095 | connection_id=connection_id,
2096 | sql=audit_sql,
2097 | tables=tables,
2098 | row_count=0,
2099 | success=False,
2100 | error=error_str,
2101 | user_id=user_id,
2102 | session_id=session_id,
2103 | **audit_extras,
2104 | )
2105 | raise ToolError(
2106 | f"An unexpected error occurred during SQL execution: {e}", http_status_code=500
2107 | ) from e
2108 |
2109 |
2110 | @with_tool_metrics
2111 | @with_error_handling
2112 | async def explore_database(
2113 | connection_id: str,
2114 | action: str,
2115 | table_name: Optional[str] = None,
2116 | column_name: Optional[str] = None,
2117 | schema_name: Optional[str] = None,
2118 | user_id: Optional[str] = None,
2119 | session_id: Optional[str] = None,
2120 | ctx: Optional[Dict] = None, # Added ctx
2121 | **options: Any,
2122 | ) -> Dict[str, Any]:
2123 | """
2124 | Unified database schema exploration and documentation tool.
2125 |
2126 | Performs actions like listing schemas, tables, views, columns,
2127 | getting table/column details, finding relationships, and generating documentation.
2128 |
2129 | Args:
2130 | connection_id: The ID of the database connection to use.
2131 | action: The exploration action:
2132 | - "schema": Get full schema details (tables, views, columns, relationships).
2133 | - "table": Get details for a specific table (columns, PK, FKs, indexes, optionally sample data/stats). Requires `table_name`.
2134 | - "column": Get statistics for a specific column (nulls, distinct, optionally histogram). Requires `table_name` and `column_name`.
2135 | - "relationships": Find related tables via foreign keys up to a certain depth. Requires `table_name`.
2136 | - "documentation": Generate schema documentation (markdown or JSON).
2137 | table_name: Name of the table for 'table', 'column', 'relationships' actions.
2138 | column_name: Name of the column for 'column' action.
2139 | schema_name: Specific schema to inspect (if supported by dialect and needed). Defaults to connection's default schema.
2140 | user_id: Optional user identifier for audit logging.
2141 | session_id: Optional session identifier for audit logging.
2142 | ctx: Optional context from MCP server.
2143 | **options: Additional options depending on the action:
2144 | - schema: include_indexes (bool), include_foreign_keys (bool), detailed (bool)
2145 | - table: include_sample_data (bool), sample_size (int), include_statistics (bool)
2146 | - column: histogram (bool), num_buckets (int)
2147 | - relationships: depth (int)
2148 | - documentation: output_format ('markdown'|'json'), include_indexes(bool), include_foreign_keys(bool)
2149 |
2150 | Returns:
2151 | A dictionary containing the results of the exploration action and a 'success' flag.
2152 | Structure varies significantly based on the action.
2153 | """
2154 | tool_name = "explore_database"
2155 | # Break down audit_extras creation
2156 | audit_extras = {}
2157 | audit_extras.update(options)
2158 | audit_extras["table_name"] = table_name
2159 | audit_extras["column_name"] = column_name
2160 | audit_extras["schema_name"] = schema_name
2161 |
2162 | try:
2163 | log_msg = f"Exploring database for connection {connection_id}. Action: {action}, Table: {table_name}, Column: {column_name}, Schema: {schema_name}"
2164 | logger.info(log_msg)
2165 | eng = await _sql_get_engine(connection_id)
2166 | db_dialect = eng.dialect.name
2167 | audit_extras["database_type"] = db_dialect
2168 |
2169 | # Define sync inspection helper (runs within connect block)
2170 | def _run_sync_inspection(
2171 | inspector_target: Union[AsyncConnection, AsyncEngine], func_to_run: callable
2172 | ):
2173 | # Call okay
2174 | sync_inspector = sa_inspect(inspector_target)
2175 | return func_to_run(sync_inspector)
2176 |
2177 | async with eng.connect() as conn:
2178 | # --- Action: schema ---
2179 | if action == "schema":
2180 | include_indexes = options.get("include_indexes", True)
2181 | include_foreign_keys = options.get("include_foreign_keys", True)
2182 | detailed = options.get("detailed", False)
2183 | filter_schema = schema_name # Use provided schema or None for default
2184 |
2185 | def _get_full_schema(sync_conn) -> Dict[str, Any]:
2186 | # Separate inspector and target schema assignment
2187 | insp = sa_inspect(sync_conn)
2188 | target_schema = filter_schema or getattr(insp, "default_schema_name", None)
2189 |
2190 | log_msg = f"Inspecting schema: {target_schema or 'Default'}. Detailed: {detailed}, Indexes: {include_indexes}, FKs: {include_foreign_keys}"
2191 | logger.info(log_msg)
2192 | tables_data: List[Dict[str, Any]] = []
2193 | views_data: List[Dict[str, Any]] = []
2194 | relationships: List[Dict[str, Any]] = []
2195 | try:
2196 | table_names = insp.get_table_names(schema=target_schema)
2197 | view_names = insp.get_view_names(schema=target_schema)
2198 | except Exception as inspect_err:
2199 | msg = f"Failed to list tables/views for schema '{target_schema}': {inspect_err}"
2200 | raise ToolError(msg, http_status_code=500) from inspect_err
2201 |
2202 | for tbl_name in table_names:
2203 | try:
2204 | # Build t_info dict step-by-step
2205 | t_info: Dict[str, Any] = {}
2206 | t_info["name"] = tbl_name
2207 | t_info["columns"] = []
2208 | if target_schema:
2209 | t_info["schema"] = target_schema
2210 |
2211 | columns_raw = insp.get_columns(tbl_name, schema=target_schema)
2212 | for c in columns_raw:
2213 | # Build col_info dict step-by-step
2214 | col_info = {}
2215 | col_info["name"] = c["name"]
2216 | col_info["type"] = str(c["type"])
2217 | col_info["nullable"] = c["nullable"]
2218 | col_info["primary_key"] = bool(c.get("primary_key"))
2219 | if detailed:
2220 | col_info["default"] = c.get("default")
2221 | col_info["comment"] = c.get("comment")
2222 | col_info["autoincrement"] = c.get("autoincrement", "auto")
2223 | t_info["columns"].append(col_info)
2224 |
2225 | if include_indexes:
2226 | try:
2227 | idxs_raw = insp.get_indexes(tbl_name, schema=target_schema)
2228 | # List comprehension okay
2229 | t_info["indexes"] = [
2230 | {
2231 | "name": i["name"],
2232 | "columns": i["column_names"],
2233 | "unique": i.get("unique", False),
2234 | }
2235 | for i in idxs_raw
2236 | ]
2237 | except Exception as idx_err:
2238 | logger.warning(
2239 | f"Could not retrieve indexes for table {tbl_name}: {idx_err}"
2240 | )
2241 | t_info["indexes"] = []
2242 | if include_foreign_keys:
2243 | try:
2244 | fks_raw = insp.get_foreign_keys(tbl_name, schema=target_schema)
2245 | if fks_raw:
2246 | t_info["foreign_keys"] = []
2247 | for fk in fks_raw:
2248 | # Build fk_info dict step-by-step
2249 | fk_info = {}
2250 | fk_info["name"] = fk.get("name")
2251 | fk_info["constrained_columns"] = fk[
2252 | "constrained_columns"
2253 | ]
2254 | fk_info["referred_schema"] = fk.get("referred_schema")
2255 | fk_info["referred_table"] = fk["referred_table"]
2256 | fk_info["referred_columns"] = fk["referred_columns"]
2257 | t_info["foreign_keys"].append(fk_info)
2258 |
2259 | # Build relationship dict step-by-step
2260 | rel_info = {}
2261 | rel_info["source_schema"] = target_schema
2262 | rel_info["source_table"] = tbl_name
2263 | rel_info["source_columns"] = fk["constrained_columns"]
2264 | rel_info["target_schema"] = fk.get("referred_schema")
2265 | rel_info["target_table"] = fk["referred_table"]
2266 | rel_info["target_columns"] = fk["referred_columns"]
2267 | relationships.append(rel_info)
2268 | except Exception as fk_err:
2269 | logger.warning(
2270 | f"Could not retrieve foreign keys for table {tbl_name}: {fk_err}"
2271 | )
2272 | tables_data.append(t_info)
2273 | except Exception as tbl_err:
2274 | log_msg = f"Failed to inspect table '{tbl_name}' in schema '{target_schema}': {tbl_err}"
2275 | logger.error(log_msg, exc_info=True)
2276 | # Append error dict
2277 | error_entry = {
2278 | "name": tbl_name,
2279 | "schema": target_schema,
2280 | "error": f"Failed to inspect: {tbl_err}",
2281 | }
2282 | tables_data.append(error_entry)
2283 |
2284 | for view_name in view_names:
2285 | try:
2286 | # Build view_info dict step-by-step
2287 | view_info: Dict[str, Any] = {}
2288 | view_info["name"] = view_name
2289 | if target_schema:
2290 | view_info["schema"] = target_schema
2291 | try:
2292 | view_def_raw = insp.get_view_definition(
2293 | view_name, schema=target_schema
2294 | )
2295 | # Ternary okay
2296 | view_def = view_def_raw or ""
2297 | view_info["definition"] = view_def
2298 | except Exception as view_def_err:
2299 | log_msg = f"Could not retrieve definition for view {view_name}: {view_def_err}"
2300 | logger.warning(log_msg)
2301 | view_info["definition"] = "Error retrieving definition"
2302 | try:
2303 | view_cols_raw = insp.get_columns(view_name, schema=target_schema)
2304 | # List comprehension okay
2305 | view_info["columns"] = [
2306 | {"name": vc["name"], "type": str(vc["type"])}
2307 | for vc in view_cols_raw
2308 | ]
2309 | except Exception:
2310 | pass # Ignore column errors for views if definition failed etc.
2311 | views_data.append(view_info)
2312 | except Exception as view_err:
2313 | log_msg = f"Failed to inspect view '{view_name}' in schema '{target_schema}': {view_err}"
2314 | logger.error(log_msg, exc_info=True)
2315 | # Append error dict
2316 | error_entry = {
2317 | "name": view_name,
2318 | "schema": target_schema,
2319 | "error": f"Failed to inspect: {view_err}",
2320 | }
2321 | views_data.append(error_entry)
2322 |
2323 | # Build schema_result dict step-by-step
2324 | schema_result: Dict[str, Any] = {}
2325 | schema_result["action"] = "schema"
2326 | schema_result["database_type"] = db_dialect
2327 | schema_result["inspected_schema"] = target_schema or "Default"
2328 | schema_result["tables"] = tables_data
2329 | schema_result["views"] = views_data
2330 | schema_result["relationships"] = relationships
2331 | schema_result["success"] = True
2332 |
2333 | # Schema Hashing and Lineage
2334 | try:
2335 | # Call okay
2336 | schema_json = json.dumps(schema_result, sort_keys=True, default=str)
2337 | schema_bytes = schema_json.encode()
2338 | # Call okay
2339 | schema_hash = hashlib.sha256(schema_bytes).hexdigest()
2340 |
2341 | timestamp = _sql_now()
2342 | last_hash = _SCHEMA_VERSIONS.get(connection_id)
2343 | schema_changed = last_hash != schema_hash
2344 |
2345 | if schema_changed:
2346 | _SCHEMA_VERSIONS[connection_id] = schema_hash
2347 | # Build lineage_entry dict step-by-step
2348 | lineage_entry = {}
2349 | lineage_entry["connection_id"] = connection_id
2350 | lineage_entry["timestamp"] = timestamp
2351 | lineage_entry["schema_hash"] = schema_hash
2352 | lineage_entry["previous_hash"] = last_hash
2353 | lineage_entry["user_id"] = user_id # Include user from outer scope
2354 | lineage_entry["tables_count"] = len(tables_data)
2355 | lineage_entry["views_count"] = len(views_data)
2356 | lineage_entry["action_source"] = f"{tool_name}/{action}"
2357 | _LINEAGE.append(lineage_entry)
2358 |
2359 | hash_preview = schema_hash[:8]
2360 | prev_hash_preview = last_hash[:8] if last_hash else "None"
2361 | log_msg = f"Schema change detected or initial capture for {connection_id}. New hash: {hash_preview}..., Previous: {prev_hash_preview}"
2362 | logger.info(log_msg)
2363 | schema_result["schema_hash"] = schema_hash
2364 | # Boolean conversion okay
2365 | schema_result["schema_change_detected"] = bool(last_hash)
2366 | except Exception as hash_err:
2367 | log_msg = f"Error generating schema hash or recording lineage: {hash_err}"
2368 | logger.error(log_msg, exc_info=True)
2369 |
2370 | return schema_result
2371 |
2372 | # Call okay
2373 | def sync_func(sync_conn_arg):
2374 | return _get_full_schema(sync_conn_arg)
2375 |
2376 | result = await conn.run_sync(sync_func) # Pass sync connection
2377 |
2378 | # --- Action: table ---
2379 | elif action == "table":
2380 | if not table_name:
2381 | raise ToolInputError(
2382 | "`table_name` is required for 'table'", param_name="table_name"
2383 | )
2384 | include_sample = options.get("include_sample_data", False)
2385 | sample_size_raw = options.get("sample_size", 5)
2386 | sample_size = int(sample_size_raw)
2387 | include_stats = options.get("include_statistics", False)
2388 | if sample_size < 0:
2389 | sample_size = 0
2390 |
2391 | def _get_basic_table_meta(sync_conn) -> Dict[str, Any]:
2392 | # Assign inspector and schema
2393 | insp = sa_inspect(sync_conn)
2394 | target_schema = schema_name or getattr(insp, "default_schema_name", None)
2395 | logger.info(f"Inspecting table details: {target_schema}.{table_name}")
2396 | try:
2397 | all_tables = insp.get_table_names(schema=target_schema)
2398 | if table_name not in all_tables:
2399 | msg = f"Table '{table_name}' not found in schema '{target_schema}'."
2400 | raise ToolInputError(msg, param_name="table_name")
2401 | except Exception as list_err:
2402 | msg = f"Could not verify if table '{table_name}' exists: {list_err}"
2403 | raise ToolError(msg, http_status_code=500) from list_err
2404 |
2405 | # Initialize meta parts
2406 | cols = []
2407 | idx = []
2408 | fks = []
2409 | pk_constraint = {}
2410 | table_comment_text = None
2411 |
2412 | cols = insp.get_columns(table_name, schema=target_schema)
2413 | try:
2414 | idx = insp.get_indexes(table_name, schema=target_schema)
2415 | except Exception as idx_err:
2416 | logger.warning(f"Could not get indexes for table {table_name}: {idx_err}")
2417 | try:
2418 | fks = insp.get_foreign_keys(table_name, schema=target_schema)
2419 | except Exception as fk_err:
2420 | logger.warning(
2421 | f"Could not get foreign keys for table {table_name}: {fk_err}"
2422 | )
2423 | try:
2424 | pk_info = insp.get_pk_constraint(table_name, schema=target_schema)
2425 | # Split pk_constraint assignment
2426 | if pk_info and pk_info.get("constrained_columns"):
2427 | pk_constraint = {
2428 | "name": pk_info.get("name"),
2429 | "columns": pk_info["constrained_columns"],
2430 | }
2431 | # else pk_constraint remains {}
2432 | except Exception as pk_err:
2433 | logger.warning(f"Could not get PK constraint for {table_name}: {pk_err}")
2434 | try:
2435 | table_comment_raw = insp.get_table_comment(table_name, schema=target_schema)
2436 | # Ternary okay
2437 | table_comment_text = (
2438 | table_comment_raw.get("text") if table_comment_raw else None
2439 | )
2440 | except Exception as cmt_err:
2441 | logger.warning(f"Could not get table comment for {table_name}: {cmt_err}")
2442 |
2443 | # Build return dict step-by-step
2444 | meta_result = {}
2445 | meta_result["columns"] = cols
2446 | meta_result["indexes"] = idx
2447 | meta_result["foreign_keys"] = fks
2448 | meta_result["pk_constraint"] = pk_constraint
2449 | meta_result["table_comment"] = table_comment_text
2450 | meta_result["schema_name"] = target_schema # Add schema name for reference
2451 | return meta_result
2452 |
2453 | # Call okay
2454 | def sync_func_meta(sync_conn_arg):
2455 | return _get_basic_table_meta(sync_conn_arg)
2456 |
2457 | meta = await conn.run_sync(sync_func_meta) # Pass sync connection
2458 |
2459 | # Build result dict step-by-step
2460 | result = {}
2461 | result["action"] = "table"
2462 | result["table_name"] = table_name
2463 | # Use schema name returned from meta function
2464 | result["schema_name"] = meta.get("schema_name")
2465 | result["comment"] = meta.get("table_comment")
2466 | # List comprehension okay
2467 | result["columns"] = [
2468 | {
2469 | "name": c["name"],
2470 | "type": str(c["type"]),
2471 | "nullable": c["nullable"],
2472 | "primary_key": bool(c.get("primary_key")),
2473 | "default": c.get("default"),
2474 | "comment": c.get("comment"),
2475 | }
2476 | for c in meta["columns"]
2477 | ]
2478 | result["primary_key"] = meta.get("pk_constraint")
2479 | result["indexes"] = meta.get("indexes", [])
2480 | result["foreign_keys"] = meta.get("foreign_keys", [])
2481 | result["success"] = True
2482 |
2483 | # Quote identifiers
2484 | id_prep = eng.dialect.identifier_preparer
2485 | quoted_table_name = id_prep.quote(table_name)
2486 | quoted_schema_name = id_prep.quote(schema_name) if schema_name else None
2487 | # Ternary okay
2488 | full_table_name = (
2489 | f"{quoted_schema_name}.{quoted_table_name}"
2490 | if quoted_schema_name
2491 | else quoted_table_name
2492 | )
2493 |
2494 | # Row count
2495 | try:
2496 | # Call okay
2497 | _, count_rows, _ = await _sql_exec(
2498 | eng,
2499 | f"SELECT COUNT(*) AS row_count FROM {full_table_name}",
2500 | None,
2501 | limit=1,
2502 | tool_name=tool_name,
2503 | action_name="table_count",
2504 | timeout=30,
2505 | )
2506 | # Ternary okay
2507 | result["row_count"] = count_rows[0]["row_count"] if count_rows else 0
2508 | except Exception as count_err:
2509 | logger.warning(
2510 | f"Could not get row count for table {full_table_name}: {count_err}"
2511 | )
2512 | result["row_count"] = "Error"
2513 |
2514 | # Sample data
2515 | if include_sample and sample_size > 0:
2516 | try:
2517 | # Call okay
2518 | sample_cols, sample_rows, _ = await _sql_exec(
2519 | eng,
2520 | f"SELECT * FROM {full_table_name} LIMIT :n",
2521 | {"n": sample_size},
2522 | limit=sample_size,
2523 | tool_name=tool_name,
2524 | action_name="table_sample",
2525 | timeout=30,
2526 | )
2527 | # Assign sample data dict okay
2528 | result["sample_data"] = {"columns": sample_cols, "rows": sample_rows}
2529 | except Exception as sample_err:
2530 | logger.warning(
2531 | f"Could not get sample data for table {full_table_name}: {sample_err}"
2532 | )
2533 | # Assign error dict okay
2534 | result["sample_data"] = {
2535 | "error": f"Failed to retrieve sample data: {sample_err}"
2536 | }
2537 |
2538 | # Statistics
2539 | if include_stats:
2540 | stats = {}
2541 | logger.debug(f"Calculating basic statistics for columns in {full_table_name}")
2542 | columns_to_stat = result.get("columns", [])
2543 | for c in columns_to_stat:
2544 | col_name = c["name"]
2545 | quoted_col = id_prep.quote(col_name)
2546 | col_stat_data = {}
2547 | try:
2548 | # Null count
2549 | # Call okay
2550 | _, null_rows, _ = await _sql_exec(
2551 | eng,
2552 | f"SELECT COUNT(*) AS null_count FROM {full_table_name} WHERE {quoted_col} IS NULL",
2553 | None,
2554 | limit=1,
2555 | tool_name=tool_name,
2556 | action_name="col_stat_null",
2557 | timeout=20,
2558 | )
2559 | # Ternary okay
2560 | null_count = null_rows[0]["null_count"] if null_rows else "Error"
2561 |
2562 | # Distinct count
2563 | # Call okay
2564 | _, distinct_rows, _ = await _sql_exec(
2565 | eng,
2566 | f"SELECT COUNT(DISTINCT {quoted_col}) AS distinct_count FROM {full_table_name}",
2567 | None,
2568 | limit=1,
2569 | tool_name=tool_name,
2570 | action_name="col_stat_distinct",
2571 | timeout=45,
2572 | )
2573 | # Ternary okay
2574 | distinct_count = (
2575 | distinct_rows[0]["distinct_count"] if distinct_rows else "Error"
2576 | )
2577 |
2578 | # Assign stats dict okay
2579 | col_stat_data = {
2580 | "null_count": null_count,
2581 | "distinct_count": distinct_count,
2582 | }
2583 | except Exception as stat_err:
2584 | log_msg = f"Could not calculate statistics for column {col_name} in {full_table_name}: {stat_err}"
2585 | logger.warning(log_msg)
2586 | # Assign error dict okay
2587 | col_stat_data = {"error": f"Failed: {stat_err}"}
2588 | stats[col_name] = col_stat_data
2589 | result["statistics"] = stats
2590 |
2591 | # --- Action: column ---
2592 | elif action == "column":
2593 | if not table_name:
2594 | raise ToolInputError(
2595 | "`table_name` required for 'column'", param_name="table_name"
2596 | )
2597 | if not column_name:
2598 | raise ToolInputError(
2599 | "`column_name` required for 'column'", param_name="column_name"
2600 | )
2601 |
2602 | generate_histogram = options.get("histogram", False)
2603 | num_buckets_raw = options.get("num_buckets", 10)
2604 | num_buckets = int(num_buckets_raw)
2605 | num_buckets = max(1, num_buckets) # Ensure at least one bucket
2606 |
2607 | # Quote identifiers
2608 | id_prep = eng.dialect.identifier_preparer
2609 | quoted_table = id_prep.quote(table_name)
2610 | quoted_column = id_prep.quote(column_name)
2611 | quoted_schema = id_prep.quote(schema_name) if schema_name else None
2612 | # Ternary okay
2613 | full_table_name = (
2614 | f"{quoted_schema}.{quoted_table}" if quoted_schema else quoted_table
2615 | )
2616 | logger.info(f"Analyzing column {full_table_name}.{quoted_column}")
2617 |
2618 | stats_data: Dict[str, Any] = {}
2619 | try:
2620 | # Total Rows
2621 | # Call okay
2622 | _, total_rows_res, _ = await _sql_exec(
2623 | eng,
2624 | f"SELECT COUNT(*) as cnt FROM {full_table_name}",
2625 | None,
2626 | limit=1,
2627 | tool_name=tool_name,
2628 | action_name="col_stat_total",
2629 | timeout=30,
2630 | )
2631 | # Ternary okay
2632 | total_rows_count = total_rows_res[0]["cnt"] if total_rows_res else 0
2633 | stats_data["total_rows"] = total_rows_count
2634 |
2635 | # Null Count
2636 | # Call okay
2637 | _, null_rows_res, _ = await _sql_exec(
2638 | eng,
2639 | f"SELECT COUNT(*) as cnt FROM {full_table_name} WHERE {quoted_column} IS NULL",
2640 | None,
2641 | limit=1,
2642 | tool_name=tool_name,
2643 | action_name="col_stat_null",
2644 | timeout=30,
2645 | )
2646 | # Ternary okay
2647 | null_count = null_rows_res[0]["cnt"] if null_rows_res else 0
2648 | stats_data["null_count"] = null_count
2649 | # Ternary okay
2650 | null_perc = (
2651 | round((null_count / total_rows_count) * 100, 2) if total_rows_count else 0
2652 | )
2653 | stats_data["null_percentage"] = null_perc
2654 |
2655 | # Distinct Count
2656 | # Call okay
2657 | _, distinct_rows_res, _ = await _sql_exec(
2658 | eng,
2659 | f"SELECT COUNT(DISTINCT {quoted_column}) as cnt FROM {full_table_name}",
2660 | None,
2661 | limit=1,
2662 | tool_name=tool_name,
2663 | action_name="col_stat_distinct",
2664 | timeout=60,
2665 | )
2666 | # Ternary okay
2667 | distinct_count = distinct_rows_res[0]["cnt"] if distinct_rows_res else 0
2668 | stats_data["distinct_count"] = distinct_count
2669 | # Ternary okay
2670 | distinct_perc = (
2671 | round((distinct_count / total_rows_count) * 100, 2)
2672 | if total_rows_count
2673 | else 0
2674 | )
2675 | stats_data["distinct_percentage"] = distinct_perc
2676 | except Exception as stat_err:
2677 | log_msg = f"Failed to get basic statistics for column {full_table_name}.{quoted_column}: {stat_err}"
2678 | logger.error(log_msg, exc_info=True)
2679 | stats_data["error"] = f"Failed to retrieve some statistics: {stat_err}"
2680 |
2681 | # Build result dict step-by-step
2682 | result = {}
2683 | result["action"] = "column"
2684 | result["table_name"] = table_name
2685 | result["column_name"] = column_name
2686 | result["schema_name"] = schema_name
2687 | result["statistics"] = stats_data
2688 | result["success"] = True
2689 |
2690 | if generate_histogram:
2691 | logger.debug(f"Generating histogram for {full_table_name}.{quoted_column}")
2692 | histogram_data: Optional[Dict[str, Any]] = None
2693 | try:
2694 | hist_query = f"SELECT {quoted_column} FROM {full_table_name} WHERE {quoted_column} IS NOT NULL"
2695 | # Call okay
2696 | _, value_rows, _ = await _sql_exec(
2697 | eng,
2698 | hist_query,
2699 | None,
2700 | limit=None, # Fetch all non-null values
2701 | tool_name=tool_name,
2702 | action_name="col_hist_fetch",
2703 | timeout=90,
2704 | )
2705 | # List comprehension okay
2706 | values = [r[column_name] for r in value_rows]
2707 |
2708 | if not values:
2709 | histogram_data = {"type": "empty", "buckets": []}
2710 | else:
2711 | first_val = values[0]
2712 | # Check type okay
2713 | is_numeric = isinstance(first_val, (int, float))
2714 |
2715 | if is_numeric:
2716 | try:
2717 | min_val = min(values)
2718 | max_val = max(values)
2719 | buckets = []
2720 | if min_val == max_val:
2721 | # Single bucket dict okay
2722 | bucket = {"range": f"{min_val}", "count": len(values)}
2723 | buckets.append(bucket)
2724 | else:
2725 | # Calculate bin width okay
2726 | val_range = max_val - min_val
2727 | bin_width = val_range / num_buckets
2728 | # List comprehension okay
2729 | bucket_ranges_raw = [
2730 | (min_val + i * bin_width, min_val + (i + 1) * bin_width)
2731 | for i in range(num_buckets)
2732 | ]
2733 | # Adjust last bucket range okay
2734 | last_bucket_idx = num_buckets - 1
2735 | last_bucket_start = bucket_ranges_raw[last_bucket_idx][0]
2736 | bucket_ranges_raw[last_bucket_idx] = (
2737 | last_bucket_start,
2738 | max_val,
2739 | )
2740 | bucket_ranges = bucket_ranges_raw
2741 |
2742 | # List comprehension for bucket init okay
2743 | buckets = [
2744 | {"range": f"{r[0]:.4g} - {r[1]:.4g}", "count": 0}
2745 | for r in bucket_ranges
2746 | ]
2747 | for v in values:
2748 | # Ternary okay
2749 | idx_float = (
2750 | (v - min_val) / bin_width if bin_width > 0 else 0
2751 | )
2752 | idx_int = int(idx_float)
2753 | # Ensure index is within bounds
2754 | idx = min(idx_int, num_buckets - 1)
2755 | # Handle max value potentially falling into last bucket due to precision
2756 | if v == max_val:
2757 | idx = num_buckets - 1
2758 | buckets[idx]["count"] += 1
2759 |
2760 | # Assign numeric histogram dict okay
2761 | histogram_data = {
2762 | "type": "numeric",
2763 | "min": min_val,
2764 | "max": max_val,
2765 | "buckets": buckets,
2766 | }
2767 | except Exception as num_hist_err:
2768 | log_msg = f"Error generating numeric histogram: {num_hist_err}"
2769 | logger.error(log_msg, exc_info=True)
2770 | # Assign error dict okay
2771 | histogram_data = {
2772 | "error": f"Failed to generate numeric histogram: {num_hist_err}"
2773 | }
2774 | else: # Categorical / Frequency
2775 | try:
2776 | # Import okay
2777 | from collections import Counter
2778 |
2779 | # Call okay
2780 | str_values = map(str, values)
2781 | value_counts = Counter(str_values)
2782 | # Call okay
2783 | top_buckets_raw = value_counts.most_common(num_buckets)
2784 | # List comprehension okay
2785 | buckets_data = [
2786 | {"value": str(k)[:100], "count": v} # Limit value length
2787 | for k, v in top_buckets_raw
2788 | ]
2789 | # Sum okay
2790 | top_n_count = sum(b["count"] for b in buckets_data)
2791 | other_count = len(values) - top_n_count
2792 |
2793 | # Assign frequency histogram dict okay
2794 | histogram_data = {
2795 | "type": "frequency",
2796 | "top_n": num_buckets,
2797 | "buckets": buckets_data,
2798 | }
2799 | if other_count > 0:
2800 | histogram_data["other_values_count"] = other_count
2801 | except Exception as freq_hist_err:
2802 | log_msg = (
2803 | f"Error generating frequency histogram: {freq_hist_err}"
2804 | )
2805 | logger.error(log_msg, exc_info=True)
2806 | # Assign error dict okay
2807 | histogram_data = {
2808 | "error": f"Failed to generate frequency histogram: {freq_hist_err}"
2809 | }
2810 | except Exception as hist_err:
2811 | log_msg = f"Failed to generate histogram for column {full_table_name}.{quoted_column}: {hist_err}"
2812 | logger.error(log_msg, exc_info=True)
2813 | # Assign error dict okay
2814 | histogram_data = {"error": f"Histogram generation failed: {hist_err}"}
2815 | result["histogram"] = histogram_data
2816 |
2817 | # --- Action: relationships ---
2818 | elif action == "relationships":
2819 | if not table_name:
2820 | raise ToolInputError(
2821 | "`table_name` required for 'relationships'", param_name="table_name"
2822 | )
2823 | depth_raw = options.get("depth", 1)
2824 | depth_int = int(depth_raw)
2825 | # Clamp depth
2826 | depth = max(1, min(depth_int, 5))
2827 |
2828 | log_msg = f"Finding relationships for table '{table_name}' (depth: {depth}, schema: {schema_name})"
2829 | logger.info(log_msg)
2830 | # Call explore_database for schema info - this recursive call is okay
2831 | schema_info = await explore_database(
2832 | connection_id=connection_id,
2833 | action="schema",
2834 | schema_name=schema_name,
2835 | include_indexes=False, # Don't need indexes for relationships
2836 | include_foreign_keys=True, # Need FKs
2837 | detailed=False, # Don't need detailed column info
2838 | )
2839 | # Check success okay
2840 | schema_success = schema_info.get("success", False)
2841 | if not schema_success:
2842 | raise ToolError(
2843 | "Failed to retrieve schema information needed to find relationships."
2844 | )
2845 |
2846 | # Dict comprehension okay
2847 | tables_list = schema_info.get("tables", [])
2848 | tables_by_name: Dict[str, Dict] = {t["name"]: t for t in tables_list}
2849 |
2850 | if table_name not in tables_by_name:
2851 | msg = f"Starting table '{table_name}' not found in schema '{schema_name}'."
2852 | raise ToolInputError(msg, param_name="table_name")
2853 |
2854 | visited_nodes = set() # Track visited nodes to prevent cycles
2855 |
2856 | # Define the recursive helper function *inside* this action block
2857 | # so it has access to tables_by_name and visited_nodes
2858 | def _build_relationship_graph_standalone(
2859 | current_table: str, current_depth: int
2860 | ) -> Dict[str, Any]:
2861 | # Build node_id string okay
2862 | current_schema = schema_name or "default"
2863 | node_id = f"{current_schema}.{current_table}"
2864 |
2865 | is_max_depth = current_depth >= depth
2866 | is_visited = node_id in visited_nodes
2867 | if is_max_depth or is_visited:
2868 | # Return dict okay
2869 | return {
2870 | "table": current_table,
2871 | "schema": schema_name, # Use original schema_name context
2872 | "max_depth_reached": is_max_depth,
2873 | "cyclic_reference": is_visited,
2874 | }
2875 |
2876 | visited_nodes.add(node_id)
2877 | node_info = tables_by_name.get(current_table)
2878 |
2879 | if not node_info:
2880 | visited_nodes.remove(node_id) # Backtrack
2881 | # Return dict okay
2882 | return {
2883 | "table": current_table,
2884 | "schema": schema_name,
2885 | "error": "Table info not found",
2886 | }
2887 |
2888 | # Build graph_node dict step-by-step
2889 | graph_node: Dict[str, Any] = {}
2890 | graph_node["table"] = current_table
2891 | graph_node["schema"] = schema_name
2892 | graph_node["children"] = []
2893 | graph_node["parents"] = []
2894 |
2895 | # Find Parents (current table's FKs point to parents)
2896 | foreign_keys_list = node_info.get("foreign_keys", [])
2897 | for fk in foreign_keys_list:
2898 | ref_table = fk["referred_table"]
2899 | ref_schema = fk.get(
2900 | "referred_schema", schema_name
2901 | ) # Assume same schema if not specified
2902 |
2903 | if ref_table in tables_by_name:
2904 | # Recursive call okay
2905 | parent_node = _build_relationship_graph_standalone(
2906 | ref_table, current_depth + 1
2907 | )
2908 | else:
2909 | # Return dict okay for outside scope
2910 | parent_node = {
2911 | "table": ref_table,
2912 | "schema": ref_schema,
2913 | "outside_scope": True,
2914 | }
2915 |
2916 | # Build relationship string okay
2917 | constrained_cols_str = ",".join(fk["constrained_columns"])
2918 | referred_cols_str = ",".join(fk["referred_columns"])
2919 | rel_str = f"{current_table}.({constrained_cols_str}) -> {ref_table}.({referred_cols_str})"
2920 | # Append parent relationship dict okay
2921 | graph_node["parents"].append(
2922 | {"relationship": rel_str, "target": parent_node}
2923 | )
2924 |
2925 | # Find Children (other tables' FKs point to current table)
2926 | for other_table_name, other_table_info in tables_by_name.items():
2927 | if other_table_name == current_table:
2928 | continue # Skip self-reference check here
2929 |
2930 | other_fks = other_table_info.get("foreign_keys", [])
2931 | for fk in other_fks:
2932 | points_to_current = fk["referred_table"] == current_table
2933 | # Check schema match (use original schema_name context)
2934 | referred_schema_matches = (
2935 | fk.get("referred_schema", schema_name) == schema_name
2936 | )
2937 | if points_to_current and referred_schema_matches:
2938 | # Recursive call okay
2939 | child_node = _build_relationship_graph_standalone(
2940 | other_table_name, current_depth + 1
2941 | )
2942 | # Build relationship string okay
2943 | constrained_cols_str = ",".join(fk["constrained_columns"])
2944 | referred_cols_str = ",".join(fk["referred_columns"])
2945 | rel_str = f"{other_table_name}.({constrained_cols_str}) -> {current_table}.({referred_cols_str})"
2946 | # Append child relationship dict okay
2947 | graph_node["children"].append(
2948 | {"relationship": rel_str, "source": child_node}
2949 | )
2950 |
2951 | visited_nodes.remove(node_id) # Backtrack visited state
2952 | return graph_node
2953 |
2954 | # Initial call to the recursive function
2955 | relationship_graph = _build_relationship_graph_standalone(table_name, 0)
2956 | # Build result dict step-by-step
2957 | result = {}
2958 | result["action"] = "relationships"
2959 | result["source_table"] = table_name
2960 | result["schema_name"] = schema_name
2961 | result["max_depth"] = depth
2962 | result["relationship_graph"] = relationship_graph
2963 | result["success"] = True
2964 |
2965 | # --- Action: documentation ---
2966 | elif action == "documentation":
2967 | output_format_raw = options.get("output_format", "markdown")
2968 | output_format = output_format_raw.lower()
2969 | valid_formats = ["markdown", "json"]
2970 | if output_format not in valid_formats:
2971 | msg = "Invalid 'output_format'. Use 'markdown' or 'json'."
2972 | raise ToolInputError(msg, param_name="output_format")
2973 |
2974 | doc_include_indexes = options.get("include_indexes", True)
2975 | doc_include_fks = options.get("include_foreign_keys", True)
2976 | log_msg = f"Generating database documentation (Format: {output_format}, Schema: {schema_name})"
2977 | logger.info(log_msg)
2978 |
2979 | # Call explore_database for schema info (recursive call okay)
2980 | schema_data = await explore_database(
2981 | connection_id=connection_id,
2982 | action="schema",
2983 | schema_name=schema_name,
2984 | include_indexes=doc_include_indexes,
2985 | include_foreign_keys=doc_include_fks,
2986 | detailed=True, # Need details for documentation
2987 | )
2988 | schema_success = schema_data.get("success", False)
2989 | if not schema_success:
2990 | raise ToolError(
2991 | "Failed to retrieve schema information needed for documentation."
2992 | )
2993 |
2994 | if output_format == "json":
2995 | # Build result dict step-by-step
2996 | result = {}
2997 | result["action"] = "documentation"
2998 | result["format"] = "json"
2999 | result["documentation"] = schema_data # Embed the schema result directly
3000 | result["success"] = True
3001 | else: # Markdown
3002 | # --- Markdown Generation ---
3003 | lines = []
3004 | lines.append(f"# Database Documentation ({db_dialect})")
3005 | db_schema_name = schema_data.get("inspected_schema", "Default Schema")
3006 | lines.append(f"Schema: **{db_schema_name}**")
3007 | now_str = _sql_now()
3008 | lines.append(f"Generated: {now_str}")
3009 | schema_hash_val = schema_data.get("schema_hash")
3010 | if schema_hash_val:
3011 | hash_preview = schema_hash_val[:12]
3012 | lines.append(f"Schema Version (Hash): `{hash_preview}`")
3013 | lines.append("") # Blank line
3014 |
3015 | lines.append("## Tables")
3016 | lines.append("")
3017 | # Sort okay
3018 | tables_list_raw = schema_data.get("tables", [])
3019 | tables = sorted(tables_list_raw, key=lambda x: x["name"])
3020 |
3021 | if not tables:
3022 | lines.append("*No tables found in this schema.*")
3023 |
3024 | for t in tables:
3025 | table_name_doc = t["name"]
3026 | if t.get("error"):
3027 | lines.append(f"### {table_name_doc} (Error)")
3028 | lines.append(f"```\n{t['error']}\n```")
3029 | lines.append("")
3030 | continue # Skip rest for this table
3031 |
3032 | lines.append(f"### {table_name_doc}")
3033 | lines.append("")
3034 | table_comment = t.get("comment")
3035 | if table_comment:
3036 | lines.append(f"> {table_comment}")
3037 | lines.append("")
3038 |
3039 | # Column Header
3040 | lines.append("| Column | Type | Nullable | PK | Default | Comment |")
3041 | lines.append("|--------|------|----------|----|---------|---------|")
3042 | columns_list = t.get("columns", [])
3043 | for c in columns_list:
3044 | # Ternary okay
3045 | pk_flag = "✅" if c["primary_key"] else ""
3046 | null_flag = "✅" if c["nullable"] else ""
3047 | default_raw = c.get("default")
3048 | # Ternary okay
3049 | default_val_str = f"`{default_raw}`" if default_raw is not None else ""
3050 | comment_val = c.get("comment") or ""
3051 | col_name_str = f"`{c['name']}`"
3052 | col_type_str = f"`{c['type']}`"
3053 | # Build line okay
3054 | line = f"| {col_name_str} | {col_type_str} | {null_flag} | {pk_flag} | {default_val_str} | {comment_val} |"
3055 | lines.append(line)
3056 | lines.append("") # Blank line after table
3057 |
3058 | # Primary Key section
3059 | pk_info = t.get("primary_key")
3060 | pk_cols = pk_info.get("columns") if pk_info else None
3061 | if pk_info and pk_cols:
3062 | pk_name = pk_info.get("name", "PK")
3063 | # List comprehension okay
3064 | pk_cols_formatted = [f"`{c}`" for c in pk_cols]
3065 | pk_cols_str = ", ".join(pk_cols_formatted)
3066 | lines.append(f"**Primary Key:** `{pk_name}` ({pk_cols_str})")
3067 | lines.append("")
3068 |
3069 | # Indexes section
3070 | indexes_list = t.get("indexes")
3071 | if doc_include_indexes and indexes_list:
3072 | lines.append("**Indexes:**")
3073 | lines.append("")
3074 | lines.append("| Name | Columns | Unique |")
3075 | lines.append("|------|---------|--------|")
3076 | for idx in indexes_list:
3077 | # Ternary okay
3078 | unique_flag = "✅" if idx["unique"] else ""
3079 | # List comprehension okay
3080 | idx_cols_formatted = [f"`{c}`" for c in idx["columns"]]
3081 | cols_str = ", ".join(idx_cols_formatted)
3082 | idx_name_str = f"`{idx['name']}`"
3083 | # Build line okay
3084 | line = f"| {idx_name_str} | {cols_str} | {unique_flag} |"
3085 | lines.append(line)
3086 | lines.append("")
3087 |
3088 | # Foreign Keys section
3089 | fks_list = t.get("foreign_keys")
3090 | if doc_include_fks and fks_list:
3091 | lines.append("**Foreign Keys:**")
3092 | lines.append("")
3093 | lines.append("| Name | Column(s) | References |")
3094 | lines.append("|------|-----------|------------|")
3095 | for fk in fks_list:
3096 | # List comprehension okay
3097 | constrained_cols_fmt = [f"`{c}`" for c in fk["constrained_columns"]]
3098 | constrained_cols_str = ", ".join(constrained_cols_fmt)
3099 |
3100 | ref_schema = fk.get("referred_schema", db_schema_name)
3101 | ref_table_name = fk["referred_table"]
3102 | ref_table_str = f"`{ref_schema}`.`{ref_table_name}`"
3103 |
3104 | # List comprehension okay
3105 | ref_cols_fmt = [f"`{c}`" for c in fk["referred_columns"]]
3106 | ref_cols_str = ", ".join(ref_cols_fmt)
3107 |
3108 | fk_name = fk.get("name", "FK")
3109 | fk_name_str = f"`{fk_name}`"
3110 | ref_full_str = f"{ref_table_str} ({ref_cols_str})"
3111 | # Build line okay
3112 | line = (
3113 | f"| {fk_name_str} | {constrained_cols_str} | {ref_full_str} |"
3114 | )
3115 | lines.append(line)
3116 | lines.append("")
3117 |
3118 | # Views Section
3119 | views_list_raw = schema_data.get("views", [])
3120 | views = sorted(views_list_raw, key=lambda x: x["name"])
3121 | if views:
3122 | lines.append("## Views")
3123 | lines.append("")
3124 | for v in views:
3125 | view_name_doc = v["name"]
3126 | if v.get("error"):
3127 | lines.append(f"### {view_name_doc} (Error)")
3128 | lines.append(f"```\n{v['error']}\n```")
3129 | lines.append("")
3130 | continue # Skip rest for this view
3131 |
3132 | lines.append(f"### {view_name_doc}")
3133 | lines.append("")
3134 | view_columns = v.get("columns")
3135 | if view_columns:
3136 | # List comprehension okay
3137 | view_cols_fmt = [
3138 | f"`{vc['name']}` ({vc['type']})" for vc in view_columns
3139 | ]
3140 | view_cols_str = ", ".join(view_cols_fmt)
3141 | lines.append(f"**Columns:** {view_cols_str}")
3142 | lines.append("")
3143 |
3144 | view_def = v.get("definition")
3145 | # Check for valid definition string
3146 | is_valid_def = (
3147 | view_def and view_def != "N/A (Not Implemented by Dialect)"
3148 | )
3149 | if is_valid_def:
3150 | lines.append("**Definition:**")
3151 | lines.append("```sql")
3152 | lines.append(view_def)
3153 | lines.append("```")
3154 | lines.append("")
3155 | else:
3156 | lines.append(
3157 | "**Definition:** *Not available or not implemented by dialect.*"
3158 | )
3159 | lines.append("")
3160 | # --- End Markdown Generation ---
3161 |
3162 | # Join lines okay
3163 | markdown_output = "\n".join(lines)
3164 | # Build result dict step-by-step
3165 | result = {}
3166 | result["action"] = "documentation"
3167 | result["format"] = "markdown"
3168 | result["documentation"] = markdown_output
3169 | result["success"] = True
3170 |
3171 | else:
3172 | logger.error(f"Invalid action specified for explore_database: {action}")
3173 | details = {"action": action}
3174 | valid_actions = "schema, table, column, relationships, documentation"
3175 | msg = f"Unknown action: '{action}'. Valid actions: {valid_actions}"
3176 | raise ToolInputError(msg, param_name="action", details=details)
3177 |
3178 | # Audit success for all successful actions
3179 | # Ternary okay
3180 | audit_table = [table_name] if table_name else None
3181 | # Call okay
3182 | await _sql_audit(
3183 | tool_name=tool_name,
3184 | action=action,
3185 | connection_id=connection_id,
3186 | sql=None,
3187 | tables=audit_table,
3188 | row_count=None,
3189 | success=True,
3190 | error=None,
3191 | user_id=user_id,
3192 | session_id=session_id,
3193 | **audit_extras,
3194 | )
3195 | return result # Return the constructed result dict
3196 |
3197 | except ToolInputError as tie:
3198 | # Audit failure
3199 | # Ternary okay
3200 | audit_table = [table_name] if table_name else None
3201 | action_fail = action + "_fail"
3202 | # Call okay
3203 | await _sql_audit(
3204 | tool_name=tool_name,
3205 | action=action_fail,
3206 | connection_id=connection_id,
3207 | sql=None,
3208 | tables=audit_table,
3209 | row_count=None,
3210 | success=False,
3211 | error=str(tie),
3212 | user_id=user_id,
3213 | session_id=session_id,
3214 | **audit_extras,
3215 | )
3216 | raise tie
3217 | except ToolError as te:
3218 | # Audit failure
3219 | # Ternary okay
3220 | audit_table = [table_name] if table_name else None
3221 | action_fail = action + "_fail"
3222 | # Call okay
3223 | await _sql_audit(
3224 | tool_name=tool_name,
3225 | action=action_fail,
3226 | connection_id=connection_id,
3227 | sql=None,
3228 | tables=audit_table,
3229 | row_count=None,
3230 | success=False,
3231 | error=str(te),
3232 | user_id=user_id,
3233 | session_id=session_id,
3234 | **audit_extras,
3235 | )
3236 | raise te
3237 | except Exception as e:
3238 | log_msg = f"Unexpected error in explore_database (action: {action}): {e}"
3239 | logger.error(log_msg, exc_info=True)
3240 | # Audit failure
3241 | # Ternary okay
3242 | audit_table = [table_name] if table_name else None
3243 | action_fail = action + "_fail"
3244 | error_str = f"Unexpected error: {e}"
3245 | # Call okay
3246 | await _sql_audit(
3247 | tool_name=tool_name,
3248 | action=action_fail,
3249 | connection_id=connection_id,
3250 | sql=None,
3251 | tables=audit_table,
3252 | row_count=None,
3253 | success=False,
3254 | error=error_str,
3255 | user_id=user_id,
3256 | session_id=session_id,
3257 | **audit_extras,
3258 | )
3259 | raise ToolError(
3260 | f"An unexpected error occurred during database exploration: {e}", http_status_code=500
3261 | ) from e
3262 |
3263 |
3264 | @with_tool_metrics
3265 | @with_error_handling
3266 | async def access_audit_log(
3267 | action: str = "view",
3268 | export_format: Optional[str] = None,
3269 | limit: Optional[int] = 100,
3270 | user_id: Optional[str] = None,
3271 | connection_id: Optional[str] = None,
3272 | ctx: Optional[Dict] = None, # Added ctx
3273 | ) -> Dict[str, Any]:
3274 | """
3275 | Access and export the in-memory SQL audit log.
3276 |
3277 | Allows viewing recent log entries or exporting them to a file.
3278 | Note: The audit log is currently stored only in memory and will be lost on server restart.
3279 |
3280 | Args:
3281 | action: "view" (default) or "export".
3282 | export_format: Required if action is "export". Supports "json", "excel", "csv".
3283 | limit: For "view", the maximum number of most recent records to return (default: 100). Use None or -1 for all.
3284 | user_id: Filter log entries by this user ID.
3285 | connection_id: Filter log entries by this connection ID.
3286 | ctx: Optional context from MCP server.
3287 |
3288 | Returns:
3289 | Dict containing results:
3290 | - For "view": {action: "view", records: List[Dict], filtered_record_count: int, total_records_in_log: int, filters_applied: Dict, success: True}
3291 | - For "export": {action: "export", path: str, format: str, record_count: int, success: True} or {action: "export", message: str, record_count: 0, success: True} if no records.
3292 | """
3293 | tool_name = "access_audit_log" # noqa: F841
3294 |
3295 | # Apply filters using global _AUDIT_LOG
3296 | async with _audit_lock: # Need lock to safely read/copy log
3297 | # Call okay
3298 | full_log_copy = list(_AUDIT_LOG)
3299 | total_records_in_log = len(full_log_copy)
3300 |
3301 | # Start with the full copy
3302 | filtered_log = full_log_copy
3303 |
3304 | # Apply filters sequentially
3305 | if user_id:
3306 | # List comprehension okay
3307 | filtered_log = [r for r in filtered_log if r.get("user_id") == user_id]
3308 | if connection_id:
3309 | # List comprehension okay
3310 | filtered_log = [r for r in filtered_log if r.get("connection_id") == connection_id]
3311 | filtered_record_count = len(filtered_log)
3312 |
3313 | if action == "view":
3314 | # Ternary okay
3315 | needs_limit = limit is not None and limit >= 0
3316 | records_to_return = filtered_log[-limit:] if needs_limit else filtered_log
3317 | num_returned = len(records_to_return)
3318 | log_msg = f"View audit log requested. Returning {num_returned}/{filtered_record_count} filtered records (Total in log: {total_records_in_log})."
3319 | logger.info(log_msg)
3320 |
3321 | # Build filters applied dict okay
3322 | filters_applied = {"user_id": user_id, "connection_id": connection_id}
3323 | # Build result dict step-by-step
3324 | result = {}
3325 | result["action"] = "view"
3326 | result["records"] = records_to_return
3327 | result["filtered_record_count"] = filtered_record_count
3328 | result["total_records_in_log"] = total_records_in_log
3329 | result["filters_applied"] = filters_applied
3330 | result["success"] = True
3331 | return result
3332 |
3333 | elif action == "export":
3334 | if not export_format:
3335 | raise ToolInputError(
3336 | "`export_format` is required for 'export'", param_name="export_format"
3337 | )
3338 | export_format_lower = export_format.lower()
3339 | log_msg = f"Export audit log requested. Format: {export_format_lower}. Records to export: {filtered_record_count}"
3340 | logger.info(log_msg)
3341 |
3342 | if not filtered_log:
3343 | logger.warning("Audit log is empty or filtered log is empty, nothing to export.")
3344 | # Return dict okay
3345 | return {
3346 | "action": "export",
3347 | "message": "No audit records found matching filters to export.",
3348 | "record_count": 0,
3349 | "success": True,
3350 | }
3351 |
3352 | if export_format_lower == "json":
3353 | path = "" # Initialize path
3354 | try:
3355 | # Call okay
3356 | fd, temp_path = tempfile.mkstemp(suffix=".json", prefix="mcp_audit_export_")
3357 | path = temp_path # Assign path now we know mkstemp succeeded
3358 | os.close(fd)
3359 | # Use sync write for simplicity here
3360 | with open(path, "w", encoding="utf-8") as f:
3361 | # Call okay
3362 | json.dump(filtered_log, f, indent=2, default=str)
3363 | log_msg = (
3364 | f"Successfully exported {filtered_record_count} audit records to JSON: {path}"
3365 | )
3366 | logger.info(log_msg)
3367 | # Return dict okay
3368 | return {
3369 | "action": "export",
3370 | "path": path,
3371 | "format": "json",
3372 | "record_count": filtered_record_count,
3373 | "success": True,
3374 | }
3375 | except Exception as e:
3376 | log_msg = f"Failed to export audit log to JSON: {e}"
3377 | logger.error(log_msg, exc_info=True)
3378 | # Clean up temp file if created
3379 | if path and Path(path).exists():
3380 | try:
3381 | Path(path).unlink()
3382 | except OSError:
3383 | logger.warning(f"Could not clean up failed JSON export file: {path}")
3384 | raise ToolError(
3385 | f"Failed to export audit log to JSON: {e}", http_status_code=500
3386 | ) from e
3387 |
3388 | elif export_format_lower in ["excel", "csv"]:
3389 | if pd is None:
3390 | details = {"library": "pandas"}
3391 | msg = f"Pandas library not installed, cannot export audit log to '{export_format_lower}'."
3392 | raise ToolError(msg, http_status_code=501, details=details)
3393 | path = "" # Initialize path
3394 | try:
3395 | # Call okay
3396 | df = pd.DataFrame(filtered_log)
3397 | # Ternary okay for suffix/writer/engine
3398 | is_excel = export_format_lower == "excel"
3399 | suffix = ".xlsx" if is_excel else ".csv"
3400 | writer_func = df.to_excel if is_excel else df.to_csv
3401 | engine = "xlsxwriter" if is_excel else None
3402 |
3403 | # Call okay
3404 | fd, temp_path = tempfile.mkstemp(suffix=suffix, prefix="mcp_audit_export_")
3405 | path = temp_path # Assign path
3406 | os.close(fd)
3407 |
3408 | # Build export args dict okay
3409 | export_kwargs: Dict[str, Any] = {"index": False}
3410 | if engine:
3411 | export_kwargs["engine"] = engine
3412 |
3413 | # Call writer function
3414 | writer_func(path, **export_kwargs)
3415 |
3416 | log_msg = f"Successfully exported {filtered_record_count} audit records to {export_format_lower.upper()}: {path}"
3417 | logger.info(log_msg)
3418 | # Return dict okay
3419 | return {
3420 | "action": "export",
3421 | "path": path,
3422 | "format": export_format_lower,
3423 | "record_count": filtered_record_count,
3424 | "success": True,
3425 | }
3426 | except Exception as e:
3427 | log_msg = f"Failed to export audit log to {export_format_lower}: {e}"
3428 | logger.error(log_msg, exc_info=True)
3429 | # Clean up temp file if created
3430 | if path and Path(path).exists():
3431 | try:
3432 | Path(path).unlink()
3433 | except OSError:
3434 | logger.warning(f"Could not clean up temporary export file: {path}")
3435 | msg = f"Failed to export audit log to {export_format_lower}: {e}"
3436 | raise ToolError(msg, http_status_code=500) from e
3437 | else:
3438 | details = {"format": export_format}
3439 | valid_formats = "'excel', 'csv', or 'json'"
3440 | msg = f"Unsupported export format: '{export_format}'. Use {valid_formats}."
3441 | raise ToolInputError(msg, param_name="export_format", details=details)
3442 | else:
3443 | details = {"action": action}
3444 | msg = f"Unknown action: '{action}'. Use 'view' or 'export'."
3445 | raise ToolInputError(msg, param_name="action", details=details)
3446 |
```