#
tokens: 43246/50000 1/207 files (page 41/45)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 41 of 45. Use http://codebase.md/dicklesworthstone/llm_gateway_mcp_server?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .cursorignore
├── .env.example
├── .envrc
├── .gitignore
├── additional_features.md
├── check_api_keys.py
├── completion_support.py
├── comprehensive_test.py
├── docker-compose.yml
├── Dockerfile
├── empirically_measured_model_speeds.json
├── error_handling.py
├── example_structured_tool.py
├── examples
│   ├── __init__.py
│   ├── advanced_agent_flows_using_unified_memory_system_demo.py
│   ├── advanced_extraction_demo.py
│   ├── advanced_unified_memory_system_demo.py
│   ├── advanced_vector_search_demo.py
│   ├── analytics_reporting_demo.py
│   ├── audio_transcription_demo.py
│   ├── basic_completion_demo.py
│   ├── cache_demo.py
│   ├── claude_integration_demo.py
│   ├── compare_synthesize_demo.py
│   ├── cost_optimization.py
│   ├── data
│   │   ├── sample_event.txt
│   │   ├── Steve_Jobs_Introducing_The_iPhone_compressed.md
│   │   └── Steve_Jobs_Introducing_The_iPhone_compressed.mp3
│   ├── docstring_refiner_demo.py
│   ├── document_conversion_and_processing_demo.py
│   ├── entity_relation_graph_demo.py
│   ├── filesystem_operations_demo.py
│   ├── grok_integration_demo.py
│   ├── local_text_tools_demo.py
│   ├── marqo_fused_search_demo.py
│   ├── measure_model_speeds.py
│   ├── meta_api_demo.py
│   ├── multi_provider_demo.py
│   ├── ollama_integration_demo.py
│   ├── prompt_templates_demo.py
│   ├── python_sandbox_demo.py
│   ├── rag_example.py
│   ├── research_workflow_demo.py
│   ├── sample
│   │   ├── article.txt
│   │   ├── backprop_paper.pdf
│   │   ├── buffett.pdf
│   │   ├── contract_link.txt
│   │   ├── legal_contract.txt
│   │   ├── medical_case.txt
│   │   ├── northwind.db
│   │   ├── research_paper.txt
│   │   ├── sample_data.json
│   │   └── text_classification_samples
│   │       ├── email_classification.txt
│   │       ├── news_samples.txt
│   │       ├── product_reviews.txt
│   │       └── support_tickets.txt
│   ├── sample_docs
│   │   └── downloaded
│   │       └── attention_is_all_you_need.pdf
│   ├── sentiment_analysis_demo.py
│   ├── simple_completion_demo.py
│   ├── single_shot_synthesis_demo.py
│   ├── smart_browser_demo.py
│   ├── sql_database_demo.py
│   ├── sse_client_demo.py
│   ├── test_code_extraction.py
│   ├── test_content_detection.py
│   ├── test_ollama.py
│   ├── text_classification_demo.py
│   ├── text_redline_demo.py
│   ├── tool_composition_examples.py
│   ├── tournament_code_demo.py
│   ├── tournament_text_demo.py
│   ├── unified_memory_system_demo.py
│   ├── vector_search_demo.py
│   ├── web_automation_instruction_packs.py
│   └── workflow_delegation_demo.py
├── LICENSE
├── list_models.py
├── marqo_index_config.json.example
├── mcp_protocol_schema_2025-03-25_version.json
├── mcp_python_lib_docs.md
├── mcp_tool_context_estimator.py
├── model_preferences.py
├── pyproject.toml
├── quick_test.py
├── README.md
├── resource_annotations.py
├── run_all_demo_scripts_and_check_for_errors.py
├── storage
│   └── smart_browser_internal
│       ├── locator_cache.db
│       ├── readability.js
│       └── storage_state.enc
├── test_client.py
├── test_connection.py
├── TEST_README.md
├── test_sse_client.py
├── test_stdio_client.py
├── tests
│   ├── __init__.py
│   ├── conftest.py
│   ├── integration
│   │   ├── __init__.py
│   │   └── test_server.py
│   ├── manual
│   │   ├── test_extraction_advanced.py
│   │   └── test_extraction.py
│   └── unit
│       ├── __init__.py
│       ├── test_cache.py
│       ├── test_providers.py
│       └── test_tools.py
├── TODO.md
├── tool_annotations.py
├── tools_list.json
├── ultimate_mcp_banner.webp
├── ultimate_mcp_logo.webp
├── ultimate_mcp_server
│   ├── __init__.py
│   ├── __main__.py
│   ├── cli
│   │   ├── __init__.py
│   │   ├── __main__.py
│   │   ├── commands.py
│   │   ├── helpers.py
│   │   └── typer_cli.py
│   ├── clients
│   │   ├── __init__.py
│   │   ├── completion_client.py
│   │   └── rag_client.py
│   ├── config
│   │   └── examples
│   │       └── filesystem_config.yaml
│   ├── config.py
│   ├── constants.py
│   ├── core
│   │   ├── __init__.py
│   │   ├── evaluation
│   │   │   ├── base.py
│   │   │   └── evaluators.py
│   │   ├── providers
│   │   │   ├── __init__.py
│   │   │   ├── anthropic.py
│   │   │   ├── base.py
│   │   │   ├── deepseek.py
│   │   │   ├── gemini.py
│   │   │   ├── grok.py
│   │   │   ├── ollama.py
│   │   │   ├── openai.py
│   │   │   └── openrouter.py
│   │   ├── server.py
│   │   ├── state_store.py
│   │   ├── tournaments
│   │   │   ├── manager.py
│   │   │   ├── tasks.py
│   │   │   └── utils.py
│   │   └── ums_api
│   │       ├── __init__.py
│   │       ├── ums_database.py
│   │       ├── ums_endpoints.py
│   │       ├── ums_models.py
│   │       └── ums_services.py
│   ├── exceptions.py
│   ├── graceful_shutdown.py
│   ├── services
│   │   ├── __init__.py
│   │   ├── analytics
│   │   │   ├── __init__.py
│   │   │   ├── metrics.py
│   │   │   └── reporting.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── cache_service.py
│   │   │   ├── persistence.py
│   │   │   ├── strategies.py
│   │   │   └── utils.py
│   │   ├── cache.py
│   │   ├── document.py
│   │   ├── knowledge_base
│   │   │   ├── __init__.py
│   │   │   ├── feedback.py
│   │   │   ├── manager.py
│   │   │   ├── rag_engine.py
│   │   │   ├── retriever.py
│   │   │   └── utils.py
│   │   ├── prompts
│   │   │   ├── __init__.py
│   │   │   ├── repository.py
│   │   │   └── templates.py
│   │   ├── prompts.py
│   │   └── vector
│   │       ├── __init__.py
│   │       ├── embeddings.py
│   │       └── vector_service.py
│   ├── tool_token_counter.py
│   ├── tools
│   │   ├── __init__.py
│   │   ├── audio_transcription.py
│   │   ├── base.py
│   │   ├── completion.py
│   │   ├── docstring_refiner.py
│   │   ├── document_conversion_and_processing.py
│   │   ├── enhanced-ums-lookbook.html
│   │   ├── entity_relation_graph.py
│   │   ├── excel_spreadsheet_automation.py
│   │   ├── extraction.py
│   │   ├── filesystem.py
│   │   ├── html_to_markdown.py
│   │   ├── local_text_tools.py
│   │   ├── marqo_fused_search.py
│   │   ├── meta_api_tool.py
│   │   ├── ocr_tools.py
│   │   ├── optimization.py
│   │   ├── provider.py
│   │   ├── pyodide_boot_template.html
│   │   ├── python_sandbox.py
│   │   ├── rag.py
│   │   ├── redline-compiled.css
│   │   ├── sentiment_analysis.py
│   │   ├── single_shot_synthesis.py
│   │   ├── smart_browser.py
│   │   ├── sql_databases.py
│   │   ├── text_classification.py
│   │   ├── text_redline_tools.py
│   │   ├── tournament.py
│   │   ├── ums_explorer.html
│   │   └── unified_memory_system.py
│   ├── utils
│   │   ├── __init__.py
│   │   ├── async_utils.py
│   │   ├── display.py
│   │   ├── logging
│   │   │   ├── __init__.py
│   │   │   ├── console.py
│   │   │   ├── emojis.py
│   │   │   ├── formatter.py
│   │   │   ├── logger.py
│   │   │   ├── panels.py
│   │   │   ├── progress.py
│   │   │   └── themes.py
│   │   ├── parse_yaml.py
│   │   ├── parsing.py
│   │   ├── security.py
│   │   └── text.py
│   └── working_memory_api.py
├── unified_memory_system_technical_analysis.md
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/ultimate_mcp_server/tools/sql_databases.py:
--------------------------------------------------------------------------------

```python
   1 | # ultimate_mcp_server/tools/sql_databases.py
   2 | from __future__ import annotations
   3 | 
   4 | import asyncio
   5 | import datetime as dt
   6 | import hashlib
   7 | import json
   8 | import os
   9 | import re
  10 | import tempfile
  11 | import time
  12 | import uuid
  13 | from dataclasses import dataclass
  14 | from functools import lru_cache
  15 | from pathlib import Path
  16 | 
  17 | # --- START: Expanded typing imports ---
  18 | from typing import Any, Dict, List, Optional, Set, Tuple, Union
  19 | 
  20 | # --- END: Expanded typing imports ---
  21 | # --- Removed BaseTool import ---
  22 | # SQLAlchemy imports
  23 | from sqlalchemy import inspect as sa_inspect
  24 | from sqlalchemy import text
  25 | from sqlalchemy.engine.url import make_url
  26 | from sqlalchemy.exc import OperationalError, ProgrammingError, SQLAlchemyError
  27 | from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
  28 | 
  29 | # Local imports
  30 | from ultimate_mcp_server.exceptions import ToolError, ToolInputError
  31 | 
  32 | # --- START: Expanded base imports ---
  33 | from ultimate_mcp_server.tools.base import with_error_handling, with_tool_metrics
  34 | 
  35 | # --- END: Expanded base imports ---
  36 | from ultimate_mcp_server.tools.completion import generate_completion  # For NL→SQL
  37 | from ultimate_mcp_server.utils import get_logger
  38 | 
  39 | # Optional imports with graceful fallbacks
  40 | try:
  41 |     import boto3  # For AWS Secrets Manager
  42 | except ImportError:
  43 |     boto3 = None
  44 | 
  45 | try:
  46 |     import hvac  # For HashiCorp Vault
  47 | except ImportError:
  48 |     hvac = None
  49 | 
  50 | try:
  51 |     import pandas as pd
  52 | except ImportError:
  53 |     pd = None
  54 | 
  55 | try:
  56 |     import pandera as pa
  57 | except ImportError:
  58 |     pa = None
  59 | 
  60 | try:
  61 |     import prometheus_client as prom
  62 | except ImportError:
  63 |     prom = None
  64 | 
  65 | logger = get_logger("ultimate_mcp_server.tools.sql_databases")
  66 | 
  67 | # =============================================================================
  68 | # Global State and Configuration (Replaces instance variables)
  69 | # =============================================================================
  70 | 
  71 | 
  72 | # --- Connection Management ---
  73 | class ConnectionManager:
  74 |     """Manages database connections with automatic cleanup after inactivity."""
  75 | 
  76 |     # (Keep ConnectionManager class as is - it's a helper utility)
  77 |     def __init__(self, cleanup_interval_seconds=600, check_interval_seconds=60):
  78 |         self.connections: Dict[str, Tuple[AsyncEngine, float]] = {}
  79 |         self.cleanup_interval = cleanup_interval_seconds
  80 |         self.check_interval = check_interval_seconds
  81 |         self._cleanup_task: Optional[asyncio.Task] = None
  82 |         self._lock = asyncio.Lock()  # Added lock for thread-safe modifications
  83 | 
  84 |     async def start_cleanup_task(self):
  85 |         async with self._lock:
  86 |             cleanup_task_is_none = self._cleanup_task is None
  87 |             cleanup_task_is_done = self._cleanup_task is not None and self._cleanup_task.done()
  88 |             if cleanup_task_is_none or cleanup_task_is_done:
  89 |                 try:
  90 |                     loop = asyncio.get_running_loop()
  91 |                     task_coro = self._cleanup_loop()
  92 |                     self._cleanup_task = loop.create_task(task_coro)
  93 |                     logger.info("Started connection cleanup task.")
  94 |                 except RuntimeError:
  95 |                     logger.warning("No running event loop found, cleanup task not started.")
  96 | 
  97 |     async def _cleanup_loop(self):
  98 |         log_msg = f"Cleanup loop started. Check interval: {self.check_interval}s, Inactivity threshold: {self.cleanup_interval}s"
  99 |         logger.debug(log_msg)
 100 |         while True:
 101 |             await asyncio.sleep(self.check_interval)
 102 |             try:
 103 |                 await self.cleanup_inactive_connections()
 104 |             except asyncio.CancelledError:
 105 |                 logger.info("Cleanup loop cancelled.")
 106 |                 break  # Exit loop cleanly on cancellation
 107 |             except Exception as e:
 108 |                 logger.error(f"Error during connection cleanup: {e}", exc_info=True)
 109 | 
 110 |     async def cleanup_inactive_connections(self):
 111 |         current_time = time.time()
 112 |         conn_ids_to_close = []
 113 | 
 114 |         # Need lock here as we iterate over potentially changing dict
 115 |         async with self._lock:
 116 |             # Use items() for safe iteration while potentially modifying dict later
 117 |             # Create a copy to avoid issues if the dict is modified elsewhere concurrently (though unlikely with lock)
 118 |             current_connections = list(self.connections.items())
 119 | 
 120 |         for conn_id, (_engine, last_accessed) in current_connections:
 121 |             idle_time = current_time - last_accessed
 122 |             is_inactive = idle_time > self.cleanup_interval
 123 |             if is_inactive:
 124 |                 log_msg = f"Connection {conn_id} exceeded inactivity timeout ({idle_time:.1f}s > {self.cleanup_interval}s)"
 125 |                 logger.info(log_msg)
 126 |                 conn_ids_to_close.append(conn_id)
 127 | 
 128 |         closed_count = 0
 129 |         for conn_id in conn_ids_to_close:
 130 |             # close_connection acquires its own lock
 131 |             closed = await self.close_connection(conn_id)
 132 |             if closed:
 133 |                 logger.info(f"Auto-closed inactive connection: {conn_id}")
 134 |                 closed_count += 1
 135 |         if closed_count > 0:
 136 |             logger.debug(f"Closed {closed_count} inactive connections.")
 137 |         elif conn_ids_to_close:
 138 |             num_attempted = len(conn_ids_to_close)
 139 |             logger.debug(
 140 |                 f"Attempted to close {num_attempted} connections, but they might have been removed already."
 141 |             )
 142 | 
 143 |     async def get_connection(self, conn_id: str) -> AsyncEngine:
 144 |         async with self._lock:
 145 |             if conn_id not in self.connections:
 146 |                 details = {"error_type": "CONNECTION_NOT_FOUND"}
 147 |                 raise ToolInputError(
 148 |                     f"Unknown connection_id: {conn_id}", param_name="connection_id", details=details
 149 |                 )
 150 | 
 151 |             engine, _ = self.connections[conn_id]
 152 |             # Update last accessed time
 153 |             current_time = time.time()
 154 |             self.connections[conn_id] = (engine, current_time)
 155 |             logger.debug(f"Accessed connection {conn_id}, updated last accessed time.")
 156 |             return engine
 157 | 
 158 |     async def add_connection(self, conn_id: str, engine: AsyncEngine):
 159 |         # close_connection handles locking internally
 160 |         has_existing = conn_id in self.connections
 161 |         if has_existing:
 162 |             logger.warning(f"Overwriting existing connection entry for {conn_id}.")
 163 |             await self.close_connection(conn_id)  # Close the old one first
 164 | 
 165 |         async with self._lock:
 166 |             current_time = time.time()
 167 |             self.connections[conn_id] = (engine, current_time)
 168 |         url_str = str(engine.url)
 169 |         url_prefix = url_str.split("@")[0]
 170 |         log_msg = (
 171 |             f"Added connection {conn_id} for URL: {url_prefix}..."  # Avoid logging credentials
 172 |         )
 173 |         logger.info(log_msg)
 174 |         await self.start_cleanup_task()  # Ensure cleanup is running
 175 | 
 176 |     async def close_connection(self, conn_id: str) -> bool:
 177 |         engine = None
 178 |         async with self._lock:
 179 |             connection_exists = conn_id in self.connections
 180 |             if connection_exists:
 181 |                 engine, _ = self.connections.pop(conn_id)
 182 |             else:
 183 |                 logger.warning(f"Attempted to close non-existent connection ID: {conn_id}")
 184 |                 return False  # Not found
 185 | 
 186 |         if engine:
 187 |             logger.info(f"Closing connection {conn_id}...")
 188 |             try:
 189 |                 await engine.dispose()
 190 |                 logger.info(f"Connection {conn_id} disposed successfully.")
 191 |                 return True
 192 |             except Exception as e:
 193 |                 log_msg = f"Error disposing engine for connection {conn_id}: {e}"
 194 |                 logger.error(log_msg, exc_info=True)
 195 |                 # Removed from dict, but disposal failed
 196 |                 return False
 197 |         return False  # Should not be reached if found
 198 | 
 199 |     async def shutdown(self):
 200 |         logger.info("Shutting down Connection Manager...")
 201 |         # Cancel cleanup task first
 202 |         cleanup_task = None
 203 |         async with self._lock:
 204 |             task_exists = self._cleanup_task is not None
 205 |             task_not_done = task_exists and not self._cleanup_task.done()
 206 |             if task_exists and task_not_done:
 207 |                 cleanup_task = self._cleanup_task  # Get reference before clearing
 208 |                 self._cleanup_task = None  # Prevent restarting
 209 | 
 210 |         if cleanup_task:
 211 |             cleanup_task.cancel()
 212 |             try:
 213 |                 # Add timeout for task cancellation
 214 |                 await asyncio.wait_for(cleanup_task, timeout=2.0)
 215 |             except asyncio.TimeoutError:
 216 |                 logger.warning("Cleanup task cancellation timed out after 2 seconds")
 217 |             except asyncio.CancelledError:
 218 |                 logger.info("Cleanup task cancelled.")
 219 |             except Exception as e:
 220 |                 logger.error(f"Error stopping cleanup task: {e}", exc_info=True)
 221 | 
 222 |         # Close remaining connections
 223 |         async with self._lock:
 224 |             conn_ids = list(self.connections.keys())
 225 | 
 226 |         if conn_ids:
 227 |             num_conns = len(conn_ids)
 228 |             logger.info(f"Closing {num_conns} active connections...")
 229 |             # Call close_connection which handles locking and removal
 230 |             close_tasks = []
 231 |             for conn_id in conn_ids:
 232 |                 # Create a task that times out for each connection
 233 |                 async def close_with_timeout(conn_id):
 234 |                     try:
 235 |                         await asyncio.wait_for(self.close_connection(conn_id), timeout=2.0)
 236 |                         return True
 237 |                     except asyncio.TimeoutError:
 238 |                         logger.warning(f"Connection {conn_id} close timed out after 2 seconds")
 239 |                         return False
 240 |                 close_tasks.append(close_with_timeout(conn_id))
 241 |             
 242 |             # Wait for all connections to close with an overall timeout
 243 |             try:
 244 |                 await asyncio.wait_for(asyncio.gather(*close_tasks, return_exceptions=True), timeout=5.0)
 245 |             except asyncio.TimeoutError:
 246 |                 logger.warning("Some connections did not close within the 5 second timeout")
 247 | 
 248 |         async with self._lock:
 249 |             # Final check
 250 |             remaining = len(self.connections)
 251 |             if remaining > 0:
 252 |                 logger.warning(f"{remaining} connections still remain after shutdown attempt.")
 253 |             self.connections.clear()  # Clear the dictionary
 254 | 
 255 |         logger.info("Connection Manager shutdown complete.")
 256 | 
 257 | 
 258 | _connection_manager = ConnectionManager()
 259 | 
 260 | # --- Security and Validation ---
 261 | _PROHIBITED_SQL_PATTERN = r"""^\s*(DROP\s+(TABLE|DATABASE|INDEX|VIEW|FUNCTION|PROCEDURE|USER|ROLE)|
 262 |              TRUNCATE\s+TABLE|
 263 |              DELETE\s+FROM|
 264 |              ALTER\s+(TABLE|DATABASE)\s+\S+\s+DROP\s+|
 265 |              UPDATE\s+|INSERT\s+INTO(?!\s+OR\s+IGNORE)|
 266 |              GRANT\s+|REVOKE\s+|
 267 |              CREATE\s+USER|ALTER\s+USER|DROP\s+USER|
 268 |              CREATE\s+ROLE|ALTER\s+ROLE|DROP\s+ROLE|
 269 |              SHUTDOWN|REBOOT|RESTART)"""
 270 | _PROHIBITED_SQL_REGEX = re.compile(_PROHIBITED_SQL_PATTERN, re.I | re.X)
 271 | 
 272 | _TABLE_RX = re.compile(r"\b(?:FROM|JOIN|UPDATE|INSERT\s+INTO|DELETE\s+FROM)\s+([\w.\"$-]+)", re.I)
 273 | 
 274 | 
 275 | # --- Masking ---
 276 | @dataclass
 277 | class MaskRule:
 278 |     rx: re.Pattern
 279 |     repl: Union[str, callable]
 280 | 
 281 | 
 282 | # Helper lambda for credit card masking
 283 | def _mask_cc(v: str) -> str:
 284 |     return f"XXXX-...-{v[-4:]}"
 285 | 
 286 | 
 287 | # Helper lambda for email masking
 288 | def _mask_email(v: str) -> str:
 289 |     if "@" in v:
 290 |         parts = v.split("@")
 291 |         prefix = parts[0][:2] + "***"
 292 |         domain = parts[-1]
 293 |         return f"{prefix}@{domain}"
 294 |     else:
 295 |         return "***"
 296 | 
 297 | 
 298 | _MASKING_RULES = [
 299 |     MaskRule(re.compile(r"^\d{3}-\d{2}-\d{4}$"), "***-**-XXXX"),  # SSN
 300 |     MaskRule(re.compile(r"(\b\d{4}-?){3}\d{4}\b"), _mask_cc),  # CC basic mask
 301 |     MaskRule(re.compile(r"[\w\.-]+@[\w\.-]+\.\w+"), _mask_email),  # Email
 302 | ]
 303 | 
 304 | # --- ACLs ---
 305 | _RESTRICTED_TABLES: Set[str] = set()
 306 | _RESTRICTED_COLUMNS: Set[str] = set()
 307 | 
 308 | # --- Auditing ---
 309 | _AUDIT_LOG: List[Dict[str, Any]] = []
 310 | _AUDIT_ID_COUNTER: int = 0
 311 | _audit_lock = asyncio.Lock()  # Lock for modifying audit counter and log
 312 | 
 313 | # --- Schema Drift Detection ---
 314 | _LINEAGE: List[Dict[str, Any]] = []
 315 | _SCHEMA_VERSIONS: Dict[str, str] = {}  # connection_id -> schema_hash
 316 | 
 317 | # --- Prometheus Metrics ---
 318 | # Initialized as None, populated in initialize function if prom is available
 319 | _Q_CNT: Optional[Any] = None
 320 | _Q_LAT: Optional[Any] = None
 321 | _CONN_GAUGE: Optional[Any] = None
 322 | 
 323 | 
 324 | # =============================================================================
 325 | # Initialization and Shutdown Functions
 326 | # =============================================================================
 327 | 
 328 | # Flag to track if metrics have been initialized
 329 | _sql_metrics_initialized = False
 330 | 
 331 | async def initialize_sql_tools():
 332 |     """Initialize global state for SQL tools, like starting the cleanup task and metrics."""
 333 |     global _sql_metrics_initialized
 334 |     global _Q_CNT, _Q_LAT, _CONN_GAUGE # Ensure globals are declared for assignment
 335 | 
 336 |     # Initialize metrics only once
 337 |     if not _sql_metrics_initialized:
 338 |         logger.info("Initializing SQL Tools module metrics...")
 339 |         if prom:
 340 |             try:
 341 |                 # Define metrics
 342 |                 _Q_CNT = prom.Counter("mcp_sqltool_calls", "SQL tool calls", ["tool", "action", "db"])
 343 |                 latency_buckets = (0.01, 0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10, 30, 60)
 344 |                 _Q_LAT = prom.Histogram(
 345 |                     "mcp_sqltool_latency_seconds",
 346 |                     "SQL latency",
 347 |                     ["tool", "action", "db"],
 348 |                     buckets=latency_buckets,
 349 |                 )
 350 |                 _CONN_GAUGE = prom.Gauge(
 351 |                     "mcp_sqltool_active_connections",
 352 |                     "Number of active SQL connections"
 353 |                 )
 354 | 
 355 |                 # Define the gauge function referencing the global manager
 356 |                 # Wrap in try-except as accessing length during shutdown might be tricky
 357 |                 def _get_active_connections():
 358 |                     try:
 359 |                         # Access length directly if manager state is simple enough
 360 |                         # If complex state, acquire lock if necessary (_connection_manager._lock)
 361 |                         # For just length, direct access is usually okay unless adding/removing heavily concurrent
 362 |                         return len(_connection_manager.connections)
 363 |                     except Exception:
 364 |                         logger.exception("Error getting active connection count for Prometheus.")
 365 |                         return 0 # Default to 0 if error accessing
 366 | 
 367 |                 _CONN_GAUGE.set_function(_get_active_connections)
 368 |                 logger.info("Prometheus metrics initialized for SQL tools.")
 369 |                 _sql_metrics_initialized = True # Set flag only after successful initialization
 370 | 
 371 |             except ValueError as e:
 372 |                 # Catch the specific duplicate error and log nicely, but don't crash
 373 |                 if "Duplicated timeseries" in str(e):
 374 |                     logger.warning(f"Prometheus metrics already registered: {e}. Skipping re-initialization.")
 375 |                     _sql_metrics_initialized = True # Assume they are initialized if duplicate error occurs
 376 |                 else:
 377 |                     # Re-raise other ValueErrors
 378 |                     logger.error(f"ValueError during Prometheus metric initialization: {e}", exc_info=True)
 379 |                     raise # Re-raise unexpected ValueError
 380 |             except Exception as e:
 381 |                  logger.error(f"Failed to initialize Prometheus metrics for SQL tools: {e}", exc_info=True)
 382 |                  # Continue without metrics if initialization fails? Or raise? Let's continue for now.
 383 | 
 384 |         else:
 385 |             logger.info("Prometheus client not available, metrics disabled for SQL tools.")
 386 |             _sql_metrics_initialized = True # Mark as "initialized" (i.e., done trying) even if prom not present
 387 |     else:
 388 |         logger.debug("SQL tools metrics already initialized, skipping metric creation.")
 389 | 
 390 |     # Always try to start the cleanup task (it's internally idempotent)
 391 |     # Ensure this happens *after* logging initialization attempt
 392 |     logger.info("Ensuring SQL connection cleanup task is running...")
 393 |     await _connection_manager.start_cleanup_task()
 394 | 
 395 | 
 396 | async def shutdown_sql_tools():
 397 |     """Gracefully shut down SQL tool resources, like the connection manager."""
 398 |     logger.info("Shutting down SQL Tools module...")
 399 |     try:
 400 |         # Add timeout to connection manager shutdown
 401 |         await asyncio.wait_for(_connection_manager.shutdown(), timeout=8.0)
 402 |     except asyncio.TimeoutError:
 403 |         logger.warning("Connection Manager shutdown timed out after 8 seconds")
 404 |     # Clear other global state if necessary (e.g., save audit log)
 405 |     logger.info("SQL Tools module shutdown complete.")
 406 | 
 407 | 
 408 | # =============================================================================
 409 | # Helper Functions (Private module-level functions)
 410 | # =============================================================================
 411 | 
 412 | 
 413 | @lru_cache(maxsize=64)
 414 | def _pull_secret_from_sources(name: str) -> str:
 415 |     """Retrieve a secret from various sources."""
 416 |     # (Implementation remains the same as in the original class)
 417 |     if boto3:
 418 |         try:
 419 |             client = boto3.client("secretsmanager")
 420 |             # Consider region_name=os.getenv("AWS_REGION") or similar config
 421 |             secret_value_response = client.get_secret_value(SecretId=name)
 422 |             # Handle binary vs string secrets
 423 |             if "SecretString" in secret_value_response:
 424 |                 secret = secret_value_response["SecretString"]
 425 |                 return secret
 426 |             elif "SecretBinary" in secret_value_response:
 427 |                 # Decode binary appropriately if needed, default to utf-8
 428 |                 secret_bytes = secret_value_response["SecretBinary"]
 429 |                 secret = secret_bytes.decode("utf-8")
 430 |                 return secret
 431 |         except Exception as aws_err:
 432 |             logger.debug(f"Secret '{name}' not found or error in AWS Secrets Manager: {aws_err}")
 433 |             pass
 434 | 
 435 |     if hvac:
 436 |         try:
 437 |             vault_url = os.getenv("VAULT_ADDR")
 438 |             vault_token = os.getenv("VAULT_TOKEN")
 439 |             if vault_url and vault_token:
 440 |                 vault_client = hvac.Client(
 441 |                     url=vault_url, token=vault_token, timeout=2
 442 |                 )  # Short timeout
 443 |                 is_auth = vault_client.is_authenticated()
 444 |                 if is_auth:
 445 |                     mount_point = os.getenv("VAULT_KV_MOUNT_POINT", "secret")
 446 |                     secret_path = name
 447 |                     read_response = vault_client.secrets.kv.v2.read_secret_version(
 448 |                         path=secret_path, mount_point=mount_point
 449 |                     )
 450 |                     # Standard KV v2 structure: response['data']['data'] is the dict of secrets
 451 |                     has_outer_data = "data" in read_response
 452 |                     has_inner_data = has_outer_data and "data" in read_response["data"]
 453 |                     if has_inner_data:
 454 |                         # Try common key names 'value' or the secret name itself
 455 |                         secret_data = read_response["data"]["data"]
 456 |                         if "value" in secret_data:
 457 |                             value = secret_data["value"]
 458 |                             return value
 459 |                         elif name in secret_data:
 460 |                             value = secret_data[name]
 461 |                             return value
 462 |                         else:
 463 |                             log_msg = f"Secret keys 'value' or '{name}' not found at path '{secret_path}' in Vault."
 464 |                             logger.debug(log_msg)
 465 |                 else:
 466 |                     logger.warning(f"Vault authentication failed for address: {vault_url}")
 467 | 
 468 |         except Exception as e:
 469 |             logger.debug(f"Error accessing Vault for secret '{name}': {e}")
 470 |             pass
 471 | 
 472 |     # Try environment variables
 473 |     env_val_direct = os.getenv(name)
 474 |     if env_val_direct:
 475 |         return env_val_direct
 476 | 
 477 |     mcp_secret_name = f"MCP_SECRET_{name.upper()}"
 478 |     env_val_prefixed = os.getenv(mcp_secret_name)
 479 |     if env_val_prefixed:
 480 |         logger.debug(f"Found secret '{name}' using prefixed env var '{mcp_secret_name}'.")
 481 |         return env_val_prefixed
 482 | 
 483 |     error_msg = (
 484 |         f"Secret '{name}' not found in any source (AWS, Vault, Env: {name}, Env: {mcp_secret_name})"
 485 |     )
 486 |     details = {"secret_name": name, "error_type": "SECRET_NOT_FOUND"}
 487 |     raise ToolError(error_msg, http_status_code=404, details=details)
 488 | 
 489 | 
 490 | async def _sql_get_engine(cid: str) -> AsyncEngine:
 491 |     """Get engine by connection ID using the global ConnectionManager."""
 492 |     engine = await _connection_manager.get_connection(cid)
 493 |     return engine
 494 | 
 495 | 
 496 | def _sql_get_next_audit_id() -> str:
 497 |     """Generate the next sequential audit ID (thread-safe)."""
 498 |     # Locking happens in _sql_audit where this is called
 499 |     global _AUDIT_ID_COUNTER
 500 |     _AUDIT_ID_COUNTER += 1
 501 |     audit_id_str = f"a{_AUDIT_ID_COUNTER:09d}"
 502 |     return audit_id_str
 503 | 
 504 | 
 505 | def _sql_now() -> str:
 506 |     """Get current UTC timestamp in ISO format."""
 507 |     now_utc = dt.datetime.now(dt.timezone.utc)
 508 |     iso_str = now_utc.isoformat(timespec="seconds")
 509 |     return iso_str
 510 | 
 511 | 
 512 | async def _sql_audit(
 513 |     *,
 514 |     tool_name: str,
 515 |     action: str,
 516 |     connection_id: Optional[str],
 517 |     sql: Optional[str],
 518 |     tables: Optional[List[str]],
 519 |     row_count: Optional[int],
 520 |     success: bool,
 521 |     error: Optional[str],
 522 |     user_id: Optional[str],
 523 |     session_id: Optional[str],
 524 |     **extra_data: Any,
 525 | ) -> None:
 526 |     """Record an audit trail entry (thread-safe)."""
 527 |     global _AUDIT_LOG
 528 |     async with _audit_lock:
 529 |         audit_id = _sql_get_next_audit_id()  # Get ID while locked
 530 |         timestamp = _sql_now()
 531 |         log_entry = {}
 532 |         log_entry["audit_id"] = audit_id
 533 |         log_entry["timestamp"] = timestamp
 534 |         log_entry["tool_name"] = tool_name
 535 |         log_entry["action"] = action
 536 |         log_entry["user_id"] = user_id
 537 |         log_entry["session_id"] = session_id
 538 |         log_entry["connection_id"] = connection_id
 539 |         log_entry["sql"] = sql
 540 |         log_entry["tables"] = tables
 541 |         log_entry["row_count"] = row_count
 542 |         log_entry["success"] = success
 543 |         log_entry["error"] = error
 544 |         log_entry.update(extra_data)  # Add extra data
 545 | 
 546 |         _AUDIT_LOG.append(log_entry)
 547 | 
 548 |     # Optional: Log to logger (outside lock)
 549 |     log_base = f"Audit[{audit_id}]: Tool={tool_name}, Action={action}, Conn={connection_id}, Success={success}"
 550 |     log_error = f", Error={error}" if error else ""
 551 |     logger.info(log_base + log_error)
 552 | 
 553 | 
 554 | def _sql_update_acl(
 555 |     *, tables: Optional[List[str]] = None, columns: Optional[List[str]] = None
 556 | ) -> None:
 557 |     """Update the global ACL lists."""
 558 |     global _RESTRICTED_TABLES, _RESTRICTED_COLUMNS
 559 |     if tables is not None:
 560 |         lowered_tables = {t.lower() for t in tables}
 561 |         _RESTRICTED_TABLES = lowered_tables
 562 |         logger.info(f"Updated restricted tables ACL: {_RESTRICTED_TABLES}")
 563 |     if columns is not None:
 564 |         lowered_columns = {c.lower() for c in columns}
 565 |         _RESTRICTED_COLUMNS = lowered_columns
 566 |         logger.info(f"Updated restricted columns ACL: {_RESTRICTED_COLUMNS}")
 567 | 
 568 | 
 569 | def _sql_check_acl(sql: str) -> None:
 570 |     """Check if SQL contains any restricted tables or columns using global ACLs."""
 571 |     # (Implementation remains the same, uses global _RESTRICTED_TABLES/_COLUMNS)
 572 |     raw_toks = re.findall(r'[\w$"\'.]+', sql.lower())
 573 |     toks = set(raw_toks)
 574 |     normalized_toks = set()
 575 |     for tok in toks:
 576 |         tok_norm = tok.strip("\"`'[]")
 577 |         normalized_toks.add(tok_norm)
 578 |         has_dot = "." in tok_norm
 579 |         if has_dot:
 580 |             last_part = tok_norm.split(".")[-1]
 581 |             normalized_toks.add(last_part)
 582 | 
 583 |     restricted_tables_found_set = _RESTRICTED_TABLES.intersection(normalized_toks)
 584 |     restricted_tables_found = list(restricted_tables_found_set)
 585 |     if restricted_tables_found:
 586 |         tables_str = ", ".join(restricted_tables_found)
 587 |         logger.warning(
 588 |             f"ACL Violation: Restricted table(s) found in query: {restricted_tables_found}"
 589 |         )
 590 |         details = {
 591 |             "restricted_tables": restricted_tables_found,
 592 |             "error_type": "ACL_TABLE_VIOLATION",
 593 |         }
 594 |         raise ToolError(
 595 |             f"Access denied: Query involves restricted table(s): {tables_str}",
 596 |             http_status_code=403,
 597 |             details=details,
 598 |         )
 599 | 
 600 |     restricted_columns_found_set = _RESTRICTED_COLUMNS.intersection(normalized_toks)
 601 |     restricted_columns_found = list(restricted_columns_found_set)
 602 |     if restricted_columns_found:
 603 |         columns_str = ", ".join(restricted_columns_found)
 604 |         logger.warning(
 605 |             f"ACL Violation: Restricted column(s) found in query: {restricted_columns_found}"
 606 |         )
 607 |         details = {
 608 |             "restricted_columns": restricted_columns_found,
 609 |             "error_type": "ACL_COLUMN_VIOLATION",
 610 |         }
 611 |         raise ToolError(
 612 |             f"Access denied: Query involves restricted column(s): {columns_str}",
 613 |             http_status_code=403,
 614 |             details=details,
 615 |         )
 616 | 
 617 | 
 618 | def _sql_resolve_conn(raw: str) -> str:
 619 |     """Resolve connection string, handling secret references."""
 620 |     # (Implementation remains the same)
 621 |     is_secret_ref = raw.startswith("secrets://")
 622 |     if is_secret_ref:
 623 |         secret_name = raw[10:]
 624 |         logger.info(f"Resolving secret reference: '{secret_name}'")
 625 |         resolved_secret = _pull_secret_from_sources(secret_name)
 626 |         return resolved_secret
 627 |     return raw
 628 | 
 629 | 
 630 | def _sql_mask_val(v: Any) -> Any:
 631 |     """Apply masking rules to a single value using global rules."""
 632 |     # (Implementation remains the same, uses global _MASKING_RULES)
 633 |     is_string = isinstance(v, str)
 634 |     is_not_empty = bool(v)
 635 |     if not is_string or not is_not_empty:
 636 |         return v
 637 |     for rule in _MASKING_RULES:
 638 |         matches = rule.rx.fullmatch(v)
 639 |         if matches:
 640 |             is_callable = callable(rule.repl)
 641 |             if is_callable:
 642 |                 try:
 643 |                     masked_value = rule.repl(v)
 644 |                     return masked_value
 645 |                 except Exception as e:
 646 |                     log_msg = f"Error applying dynamic mask rule {rule.rx.pattern}: {e}"
 647 |                     logger.error(log_msg)
 648 |                     return "MASKING_ERROR"
 649 |             else:
 650 |                 return rule.repl
 651 |     return v
 652 | 
 653 | 
 654 | def _sql_mask_row(row: Dict[str, Any]) -> Dict[str, Any]:
 655 |     """Apply masking rules to an entire row of data."""
 656 |     masked_dict = {}
 657 |     for k, v in row.items():
 658 |         masked_val = _sql_mask_val(v)
 659 |         masked_dict[k] = masked_val
 660 |     return masked_dict
 661 |     # return {k: _sql_mask_val(v) for k, v in row.items()} # Keep single-line comprehension
 662 | 
 663 | 
 664 | def _sql_driver_url(conn_str: str) -> Tuple[str, str]:
 665 |     """Convert generic connection string to dialect-specific async URL."""
 666 |     # Check if it looks like a path (no ://) and exists or is :memory:
 667 |     has_protocol = "://" in conn_str
 668 |     looks_like_path = not has_protocol
 669 |     path_obj = Path(conn_str)
 670 |     path_exists = path_obj.exists()
 671 |     is_memory = conn_str == ":memory:"
 672 |     is_file_path = looks_like_path and (path_exists or is_memory)
 673 | 
 674 |     if is_file_path:
 675 |         if is_memory:
 676 |             url_str = "sqlite+aiosqlite:///:memory:"
 677 |             logger.info("Using in-memory SQLite database.")
 678 |         else:
 679 |             sqlite_path = path_obj.expanduser().resolve()
 680 |             parent_dir = sqlite_path.parent
 681 |             parent_exists = parent_dir.exists()
 682 |             if not parent_exists:
 683 |                 try:
 684 |                     parent_dir.mkdir(parents=True, exist_ok=True)
 685 |                     logger.info(f"Created directory for SQLite DB: {parent_dir}")
 686 |                 except OSError as e:
 687 |                     details = {"path": str(parent_dir)}
 688 |                     raise ToolError(
 689 |                         f"Failed to create directory for SQLite DB '{parent_dir}': {e}",
 690 |                         http_status_code=500,
 691 |                         details=details,
 692 |                     ) from e
 693 |             url_str = f"sqlite+aiosqlite:///{sqlite_path}"
 694 |             logger.info(f"Using SQLite database file: {sqlite_path}")
 695 |         url = make_url(url_str)
 696 |         final_url_str = str(url)
 697 |         return final_url_str, "sqlite"
 698 |     else:
 699 |         url_str = conn_str
 700 |         try:
 701 |             url = make_url(url_str)
 702 |         except Exception as e:
 703 |             details = {"value": conn_str}
 704 |             raise ToolInputError(
 705 |                 f"Invalid connection string format: {e}",
 706 |                 param_name="connection_string",
 707 |                 details=details,
 708 |             ) from e
 709 | 
 710 |     drv = url.drivername.lower()
 711 |     drivername = url.drivername  # Preserve original case for setting later if needed
 712 | 
 713 |     if drv.startswith("sqlite"):
 714 |         new_url = url.set(drivername="sqlite+aiosqlite")
 715 |         return str(new_url), "sqlite"
 716 |     if drv.startswith("postgresql") or drv == "postgres":
 717 |         new_url = url.set(drivername="postgresql+asyncpg")
 718 |         return str(new_url), "postgresql"
 719 |     if drv.startswith("mysql") or drv == "mariadb":
 720 |         query = dict(url.query)
 721 |         query.setdefault("charset", "utf8mb4")
 722 |         new_url = url.set(drivername="mysql+aiomysql", query=query)
 723 |         return str(new_url), "mysql"
 724 |     if drv.startswith("mssql") or drv == "sqlserver":
 725 |         odbc_driver = url.query.get("driver")
 726 |         if not odbc_driver:
 727 |             logger.warning(
 728 |                 "MSSQL connection string lacks 'driver' parameter. Ensure a valid ODBC driver (e.g., 'ODBC Driver 17 for SQL Server') is installed and specified."
 729 |             )
 730 |         new_url = url.set(drivername="mssql+aioodbc")
 731 |         return str(new_url), "sqlserver"
 732 |     if drv.startswith("snowflake"):
 733 |         # Keep original snowflake driver
 734 |         new_url = url.set(drivername=drivername)
 735 |         return str(new_url), "snowflake"
 736 | 
 737 |     logger.error(f"Unsupported database dialect: {drv}")
 738 |     details = {"dialect": drv}
 739 |     raise ToolInputError(
 740 |         f"Unsupported database dialect: '{drv}'. Supported: sqlite, postgresql, mysql, mssql, snowflake",
 741 |         param_name="connection_string",
 742 |         details=details,
 743 |     )
 744 | 
 745 | 
 746 | def _sql_auto_pool(db_type: str) -> Dict[str, Any]:
 747 |     """Provide sensible default connection pool settings."""
 748 |     # (Implementation remains the same)
 749 |     # Single-line dict return is acceptable
 750 |     defaults = {
 751 |         "pool_size": 5,
 752 |         "max_overflow": 10,
 753 |         "pool_recycle": 1800,
 754 |         "pool_pre_ping": True,
 755 |         "pool_timeout": 30,
 756 |     }
 757 |     if db_type == "sqlite":
 758 |         return {"pool_pre_ping": True}
 759 |     if db_type == "postgresql":
 760 |         return {
 761 |             "pool_size": 10,
 762 |             "max_overflow": 20,
 763 |             "pool_recycle": 900,
 764 |             "pool_pre_ping": True,
 765 |             "pool_timeout": 30,
 766 |         }
 767 |     if db_type == "mysql":
 768 |         return {
 769 |             "pool_size": 10,
 770 |             "max_overflow": 20,
 771 |             "pool_recycle": 900,
 772 |             "pool_pre_ping": True,
 773 |             "pool_timeout": 30,
 774 |         }
 775 |     if db_type == "sqlserver":
 776 |         return {
 777 |             "pool_size": 10,
 778 |             "max_overflow": 20,
 779 |             "pool_recycle": 900,
 780 |             "pool_pre_ping": True,
 781 |             "pool_timeout": 30,
 782 |         }
 783 |     if db_type == "snowflake":
 784 |         return {"pool_size": 5, "max_overflow": 5, "pool_pre_ping": True, "pool_timeout": 60}
 785 |     logger.warning(f"Using default pool settings for unknown db_type: {db_type}")
 786 |     return defaults
 787 | 
 788 | 
 789 | def _sql_extract_tables(sql: str) -> List[str]:
 790 |     """Extract table names referenced in a SQL query."""
 791 |     matches = _TABLE_RX.findall(sql)
 792 |     tables = set()
 793 |     for match in matches:
 794 |         # Chained strip is one expression
 795 |         table_stripped = match.strip()
 796 |         table = table_stripped.strip("\"`'[]")
 797 |         has_dot = "." in table
 798 |         if has_dot:
 799 |             # table.split('.')[-1].strip('"`\'[]') # Original combined
 800 |             parts = table.split(".")
 801 |             last_part = parts[-1]
 802 |             table = last_part.strip("\"`'[]")
 803 |         if table:
 804 |             tables.add(table)
 805 |     sorted_tables = sorted(list(tables))
 806 |     return sorted_tables
 807 | 
 808 | 
 809 | def _sql_check_safe(sql: str, read_only: bool = True) -> None:
 810 |     """Validate SQL for safety using global patterns and ACLs."""
 811 |     # Check ACLs first
 812 |     _sql_check_acl(sql)
 813 | 
 814 |     # Check prohibited statements
 815 |     normalized_sql = sql.lstrip().upper()
 816 |     check_sql_part = normalized_sql  # Default part to check
 817 | 
 818 |     starts_with_with = normalized_sql.startswith("WITH")
 819 |     if starts_with_with:
 820 |         try:
 821 |             # Regex remains single-line expression assignment
 822 |             search_regex = r"\)\s*(SELECT|INSERT|UPDATE|DELETE|MERGE)"
 823 |             search_flags = re.IGNORECASE | re.DOTALL
 824 |             main_statement_match = re.search(search_regex, normalized_sql, search_flags)
 825 |             if main_statement_match:
 826 |                 # Chained calls okay on one line
 827 |                 main_statement_group = main_statement_match.group(0)
 828 |                 check_sql_part = main_statement_group.lstrip(") \t\n\r")
 829 |             # else: keep check_sql_part as normalized_sql
 830 |         except Exception:
 831 |             # Ignore regex errors, fallback to checking whole normalized_sql
 832 |             pass
 833 | 
 834 |     prohibited_match_obj = _PROHIBITED_SQL_REGEX.match(check_sql_part)
 835 |     if prohibited_match_obj:
 836 |         # Chained calls okay on one line
 837 |         prohibited_match = prohibited_match_obj.group(1)
 838 |         prohibited_statement = prohibited_match.strip()
 839 |         logger.warning(f"Security Violation: Prohibited statement detected: {prohibited_statement}")
 840 |         details = {"statement": prohibited_statement, "error_type": "PROHIBITED_STATEMENT"}
 841 |         raise ToolInputError(
 842 |             f"Prohibited statement type detected: '{prohibited_statement}'",
 843 |             param_name="query",
 844 |             details=details,
 845 |         )
 846 | 
 847 |     # Check read-only constraint
 848 |     if read_only:
 849 |         allowed_starts = ("SELECT", "SHOW", "EXPLAIN", "DESCRIBE", "PRAGMA")
 850 |         is_read_query = check_sql_part.startswith(allowed_starts)
 851 |         if not is_read_query:
 852 |             query_preview = sql[:100]
 853 |             logger.warning(
 854 |                 f"Security Violation: Write operation attempted in read-only mode: {query_preview}..."
 855 |             )
 856 |             details = {"error_type": "READ_ONLY_VIOLATION"}
 857 |             raise ToolInputError(
 858 |                 "Write operation attempted in read-only mode", param_name="query", details=details
 859 |             )
 860 | 
 861 | 
 862 | async def _sql_exec(
 863 |     eng: AsyncEngine,
 864 |     sql: str,
 865 |     params: Optional[Dict[str, Any]],
 866 |     *,
 867 |     limit: Optional[int],
 868 |     tool_name: str,
 869 |     action_name: str,
 870 |     timeout: float = 30.0,
 871 | ) -> Tuple[List[str], List[Dict[str, Any]], int]:
 872 |     """Core async SQL executor helper."""
 873 |     db_dialect = eng.dialect.name
 874 |     start_time = time.perf_counter()
 875 | 
 876 |     if _Q_CNT:
 877 |         # Chained call okay
 878 |         _Q_CNT.labels(tool=tool_name, action=action_name, db=db_dialect).inc()
 879 | 
 880 |     cols: List[str] = []
 881 |     rows_raw: List[Any] = []
 882 |     row_count: int = 0
 883 |     masked_rows: List[Dict[str, Any]] = []
 884 | 
 885 |     async def _run(conn: AsyncConnection):
 886 |         nonlocal cols, rows_raw, row_count, masked_rows
 887 |         statement = text(sql)
 888 |         query_params = params or {}
 889 |         try:
 890 |             res = await conn.execute(statement, query_params)
 891 |             has_cursor = res.cursor is not None
 892 |             has_description = has_cursor and res.cursor.description is not None
 893 |             if not has_cursor or not has_description:
 894 |                 logger.debug(f"Query did not return rows or description. Action: {action_name}")
 895 |                 # Ternary okay
 896 |                 res_rowcount = res.rowcount if res.rowcount >= 0 else 0
 897 |                 row_count = res_rowcount
 898 |                 masked_rows = []  # Ensure it's an empty list
 899 |                 empty_cols: List[str] = []
 900 |                 empty_rows: List[Dict[str, Any]] = []
 901 |                 return empty_cols, empty_rows, row_count  # Return empty lists for cols/rows
 902 | 
 903 |             cols = list(res.keys())
 904 |             try:
 905 |                 # --- START: Restored SQLite Handling ---
 906 |                 is_sqlite = db_dialect == "sqlite"
 907 |                 if is_sqlite:
 908 |                     # aiosqlite fetchall/fetchmany might not work reliably with async iteration or limits in all cases
 909 |                     # Fetch all as mappings (dicts) directly
 910 |                     # Lambda okay if single line
 911 |                     def sync_lambda(sync_conn):
 912 |                         return list(sync_conn.execute(statement, query_params).mappings())
 913 | 
 914 |                     all_rows_mapped = await conn.run_sync(sync_lambda)
 915 |                     rows_raw = all_rows_mapped  # Keep the dict list format
 916 |                     needs_limit = limit is not None and limit >= 0
 917 |                     if needs_limit:
 918 |                         rows_raw = rows_raw[:limit]  # Apply limit in Python
 919 |                 else:
 920 |                     # Standard async fetching for other dialects
 921 |                     needs_limit = limit is not None and limit >= 0
 922 |                     if needs_limit:
 923 |                         fetched_rows = await res.fetchmany(limit)  # Returns Row objects
 924 |                         rows_raw = fetched_rows
 925 |                     else:
 926 |                         fetched_rows = await res.fetchall()  # Returns Row objects
 927 |                         rows_raw = fetched_rows
 928 |                 # --- END: Restored SQLite Handling ---
 929 | 
 930 |                 row_count = len(rows_raw)  # Count based on fetched/limited rows
 931 | 
 932 |             except Exception as fetch_err:
 933 |                 log_msg = f"Error fetching rows for {tool_name}/{action_name}: {fetch_err}"
 934 |                 logger.error(log_msg, exc_info=True)
 935 |                 query_preview = sql[:100] + "..."
 936 |                 details = {"query": query_preview}
 937 |                 raise ToolError(
 938 |                     f"Error fetching results: {fetch_err}", http_status_code=500, details=details
 939 |                 ) from fetch_err
 940 | 
 941 |             # Apply masking using _sql_mask_row which uses global rules
 942 |             # Adjust masking based on fetched format
 943 |             if is_sqlite:
 944 |                 # List comprehension okay
 945 |                 masked_rows_list = [_sql_mask_row(r) for r in rows_raw]  # Already dicts
 946 |                 masked_rows = masked_rows_list
 947 |             else:
 948 |                 # List comprehension okay
 949 |                 masked_rows_list = [
 950 |                     _sql_mask_row(r._mapping) for r in rows_raw
 951 |                 ]  # Convert Row objects
 952 |                 masked_rows = masked_rows_list
 953 | 
 954 |             return cols, masked_rows, row_count
 955 | 
 956 |         except (ProgrammingError, OperationalError) as db_err:
 957 |             err_type_name = type(db_err).__name__
 958 |             log_msg = f"Database execution error ({err_type_name}) for {tool_name}/{action_name} on {db_dialect}: {db_err}"
 959 |             logger.error(log_msg, exc_info=True)
 960 |             query_preview = sql[:100] + "..."
 961 |             details = {"db_error": str(db_err), "query": query_preview}
 962 |             raise ToolError(
 963 |                 f"Database Error: {db_err}", http_status_code=400, details=details
 964 |             ) from db_err
 965 |         except SQLAlchemyError as sa_err:
 966 |             err_type_name = type(sa_err).__name__
 967 |             log_msg = f"SQLAlchemy error ({err_type_name}) for {tool_name}/{action_name} on {db_dialect}: {sa_err}"
 968 |             logger.error(log_msg, exc_info=True)
 969 |             query_preview = sql[:100] + "..."
 970 |             details = {"sqlalchemy_error": str(sa_err), "query": query_preview}
 971 |             raise ToolError(
 972 |                 f"SQLAlchemy Error: {sa_err}", http_status_code=500, details=details
 973 |             ) from sa_err
 974 |         except Exception as e:  # Catch other potential errors within _run
 975 |             log_msg = f"Unexpected error within _run for {tool_name}/{action_name}: {e}"
 976 |             logger.error(log_msg, exc_info=True)
 977 |             raise ToolError(
 978 |                 f"Unexpected error during query execution step: {e}", http_status_code=500
 979 |             ) from e
 980 | 
 981 |     try:
 982 |         async with eng.connect() as conn:
 983 |             # Run within timeout
 984 |             # Call okay
 985 |             run_coro = _run(conn)
 986 |             cols_res, masked_rows_res, cnt_res = await asyncio.wait_for(run_coro, timeout=timeout)
 987 |             cols = cols_res
 988 |             masked_rows = masked_rows_res
 989 |             cnt = cnt_res
 990 | 
 991 |             latency = time.perf_counter() - start_time
 992 |             if _Q_LAT:
 993 |                 # Chained call okay
 994 |                 _Q_LAT.labels(tool=tool_name, action=action_name, db=db_dialect).observe(latency)
 995 |             log_msg = f"Execution successful for {tool_name}/{action_name}. Latency: {latency:.3f}s, Rows fetched: {cnt}"
 996 |             logger.debug(log_msg)
 997 |             return cols, masked_rows, cnt
 998 | 
 999 |     except asyncio.TimeoutError:
1000 |         log_msg = (
1001 |             f"Query timeout ({timeout}s) exceeded for {tool_name}/{action_name} on {db_dialect}."
1002 |         )
1003 |         logger.warning(log_msg)
1004 |         query_preview = sql[:100] + "..."
1005 |         details = {"timeout": timeout, "query": query_preview}
1006 |         raise ToolError(
1007 |             f"Query timed out after {timeout} seconds", http_status_code=504, details=details
1008 |         ) from None
1009 |     except ToolError:
1010 |         # Re-raise known ToolErrors
1011 |         raise
1012 |     except Exception as e:
1013 |         log_msg = f"Unexpected error during _sql_exec for {tool_name}/{action_name}: {e}"
1014 |         logger.error(log_msg, exc_info=True)
1015 |         details = {"error_type": type(e).__name__}
1016 |         raise ToolError(
1017 |             f"An unexpected error occurred: {e}", http_status_code=500, details=details
1018 |         ) from e
1019 | 
1020 | 
1021 | def _sql_export_rows(
1022 |     cols: List[str],
1023 |     rows: List[Dict[str, Any]],
1024 |     export_format: str,
1025 |     export_path: Optional[str] = None,
1026 | ) -> Tuple[Any | None, str | None]:
1027 |     """Export query results helper."""
1028 |     if not export_format:
1029 |         return None, None
1030 |     export_format_lower = export_format.lower()
1031 |     supported_formats = ["pandas", "excel", "csv"]
1032 |     if export_format_lower not in supported_formats:
1033 |         details = {"format": export_format}
1034 |         msg = f"Unsupported export format: '{export_format}'. Use 'pandas', 'excel', or 'csv'."
1035 |         raise ToolInputError(msg, param_name="export.format", details=details)
1036 |     if pd is None:
1037 |         details = {"library": "pandas"}
1038 |         msg = f"Pandas library is not installed. Cannot export to '{export_format_lower}'."
1039 |         raise ToolError(msg, http_status_code=501, details=details)
1040 | 
1041 |     try:
1042 |         # Ternary okay
1043 |         df = pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
1044 |         logger.info(f"Created DataFrame with shape {df.shape} for export.")
1045 |     except Exception as e:
1046 |         logger.error(f"Error creating Pandas DataFrame: {e}", exc_info=True)
1047 |         raise ToolError(f"Failed to create DataFrame for export: {e}", http_status_code=500) from e
1048 | 
1049 |     if export_format_lower == "pandas":
1050 |         logger.debug("Returning raw Pandas DataFrame.")
1051 |         return df, None
1052 | 
1053 |     final_path: str
1054 |     temp_file_created = False
1055 |     if export_path:
1056 |         try:
1057 |             # Chained calls okay
1058 |             path_obj = Path(export_path)
1059 |             path_expanded = path_obj.expanduser()
1060 |             path_resolved = path_expanded.resolve()
1061 |             parent_dir = path_resolved.parent
1062 |             parent_dir.mkdir(parents=True, exist_ok=True)
1063 |             final_path = str(path_resolved)
1064 |             logger.info(f"Using specified export path: {final_path}")
1065 |         except OSError as e:
1066 |             details = {"path": export_path}
1067 |             raise ToolError(
1068 |                 f"Cannot create directory for export path '{export_path}': {e}",
1069 |                 http_status_code=500,
1070 |                 details=details,
1071 |             ) from e
1072 |         except Exception as e:  # Catch other path errors
1073 |             details = {"path": export_path}
1074 |             msg = f"Invalid export path provided: {export_path}. Error: {e}"
1075 |             raise ToolInputError(msg, param_name="export.path", details=details) from e
1076 |     else:
1077 |         # Ternary okay
1078 |         suffix = ".xlsx" if export_format_lower == "excel" else ".csv"
1079 |         try:
1080 |             prefix = f"mcp_export_{export_format_lower}_"
1081 |             fd, final_path_temp = tempfile.mkstemp(suffix=suffix, prefix=prefix)
1082 |             final_path = final_path_temp
1083 |             os.close(fd)
1084 |             temp_file_created = True
1085 |             logger.info(f"Created temporary file for export: {final_path}")
1086 |         except Exception as e:
1087 |             logger.error(f"Failed to create temporary file for export: {e}", exc_info=True)
1088 |             raise ToolError(f"Failed to create temporary file: {e}", http_status_code=500) from e
1089 | 
1090 |     try:
1091 |         if export_format_lower == "excel":
1092 |             df.to_excel(final_path, index=False, engine="xlsxwriter")
1093 |         elif export_format_lower == "csv":
1094 |             df.to_csv(final_path, index=False)
1095 |         log_msg = f"Exported data to {export_format_lower.upper()} file: {final_path}"
1096 |         logger.info(log_msg)
1097 |         return None, final_path
1098 |     except Exception as e:
1099 |         log_msg = f"Error exporting DataFrame to {export_format_lower} file '{final_path}': {e}"
1100 |         logger.error(log_msg, exc_info=True)
1101 |         path_exists = Path(final_path).exists()
1102 |         if temp_file_created and path_exists:
1103 |             try:
1104 |                 Path(final_path).unlink()
1105 |             except OSError:
1106 |                 logger.warning(f"Could not clean up temporary export file: {final_path}")
1107 |         raise ToolError(
1108 |             f"Failed to export data to {export_format_lower}: {e}", http_status_code=500
1109 |         ) from e
1110 | 
1111 | 
1112 | async def _sql_validate_df(df: Any, schema: Any | None) -> None:
1113 |     """Validate DataFrame against Pandera schema helper."""
1114 |     if schema is None:
1115 |         logger.debug("No Pandera schema provided for validation.")
1116 |         return
1117 |     if pa is None:
1118 |         logger.warning("Pandera library not installed, skipping validation.")
1119 |         return
1120 |     is_pandas_df = pd is not None and isinstance(df, pd.DataFrame)
1121 |     if not is_pandas_df:
1122 |         logger.warning("Pandas DataFrame not available for validation.")
1123 |         return
1124 | 
1125 |     logger.info(f"Validating DataFrame (shape {df.shape}) against provided Pandera schema.")
1126 |     try:
1127 |         schema.validate(df, lazy=True)
1128 |         logger.info("Pandera validation successful.")
1129 |     except pa.errors.SchemaErrors as se:
1130 |         # Ternary okay
1131 |         error_details_df = se.failure_cases
1132 |         can_dict = hasattr(error_details_df, "to_dict")
1133 |         error_details = (
1134 |             error_details_df.to_dict(orient="records") if can_dict else str(error_details_df)
1135 |         )
1136 |         # Ternary okay
1137 |         can_len = hasattr(error_details_df, "__len__")
1138 |         error_count = len(error_details_df) if can_len else "multiple"
1139 | 
1140 |         log_msg = f"Pandera validation failed with {error_count} errors. Details: {error_details}"
1141 |         logger.warning(log_msg)
1142 | 
1143 |         # Break down error message construction
1144 |         error_msg_base = f"Pandera validation failed ({error_count} errors):\n"
1145 |         error_msg_lines = []
1146 |         error_details_list = error_details if isinstance(error_details, list) else []
1147 |         errors_to_show = error_details_list[:5]
1148 | 
1149 |         for err in errors_to_show:
1150 |             col = err.get("column", "N/A")
1151 |             check = err.get("check", "N/A")
1152 |             index = err.get("index", "N/A")
1153 |             fail_case_raw = err.get("failure_case", "N/A")
1154 |             fail_case_str = str(fail_case_raw)[:50]
1155 |             line = f"- Column '{col}': {check} failed for index {index}. Data: {fail_case_str}..."
1156 |             error_msg_lines.append(line)
1157 | 
1158 |         error_msg = error_msg_base + "\n".join(error_msg_lines)
1159 | 
1160 |         num_errors = error_count if isinstance(error_count, int) else 0
1161 |         if num_errors > 5:
1162 |             more_errors_count = num_errors - 5
1163 |             error_msg += f"\n... and {more_errors_count} more errors."
1164 | 
1165 |         validation_errors = error_details  # Pass the original structure
1166 |         details = {"error_type": "VALIDATION_ERROR"}
1167 |         raise ToolError(
1168 |             error_msg, http_status_code=422, validation_errors=validation_errors, details=details
1169 |         ) from se
1170 |     except Exception as e:
1171 |         logger.error(f"Unexpected error during Pandera validation: {e}", exc_info=True)
1172 |         raise ToolError(
1173 |             f"An unexpected error occurred during schema validation: {e}", http_status_code=500
1174 |         ) from e
1175 | 
1176 | 
1177 | async def _sql_convert_nl_to_sql(
1178 |     connection_id: str,
1179 |     natural_language: str,
1180 |     confidence_threshold: float = 0.6,
1181 |     user_id: Optional[str] = None,  # Added for lineage
1182 |     session_id: Optional[str] = None,  # Added for lineage
1183 | ) -> Dict[str, Any]:
1184 |     """Helper method to convert natural language to SQL."""
1185 |     # (Implementation largely the same, uses _sql_get_engine, _sql_check_safe, global state _SCHEMA_VERSIONS, _LINEAGE)
1186 |     nl_preview = natural_language[:100]
1187 |     logger.info(f"Converting NL to SQL for connection {connection_id}. Query: '{nl_preview}...'")
1188 |     eng = await _sql_get_engine(connection_id)
1189 | 
1190 |     def _get_schema_fingerprint_sync(sync_conn) -> str:
1191 |         # (Schema fingerprint sync helper implementation is the same)
1192 |         try:
1193 |             sync_inspector = sa_inspect(sync_conn)
1194 |             tbls = []
1195 |             schema_names = sync_inspector.get_schema_names()
1196 |             default_schema = sync_inspector.default_schema_name
1197 |             # List comprehension okay
1198 |             other_schemas = [s for s in schema_names if s != default_schema]
1199 |             schemas_to_inspect = [default_schema] + other_schemas
1200 | 
1201 |             for schema_name in schemas_to_inspect:
1202 |                 # Ternary okay
1203 |                 prefix = f"{schema_name}." if schema_name and schema_name != default_schema else ""
1204 |                 table_names_in_schema = sync_inspector.get_table_names(schema=schema_name)
1205 |                 for t in table_names_in_schema:
1206 |                     try:
1207 |                         cols = sync_inspector.get_columns(t, schema=schema_name)
1208 |                         # List comprehension okay
1209 |                         col_defs = [f"{c['name']}:{str(c['type']).split('(')[0]}" for c in cols]
1210 |                         col_defs_str = ",".join(col_defs)
1211 |                         tbl_def = f"{prefix}{t}({col_defs_str})"
1212 |                         tbls.append(tbl_def)
1213 |                     except Exception as col_err:
1214 |                         logger.warning(f"Could not get columns for table {prefix}{t}: {col_err}")
1215 |                         tbl_def_err = f"{prefix}{t}(...)"
1216 |                         tbls.append(tbl_def_err)
1217 |             # Call okay
1218 |             fp = "; ".join(sorted(tbls))
1219 |             if not fp:
1220 |                 logger.warning("Schema fingerprint generation resulted in empty string.")
1221 |                 return "Error: Could not retrieve schema."
1222 |             return fp
1223 |         except Exception as e:
1224 |             logger.error(f"Error in _get_schema_fingerprint_sync: {e}", exc_info=True)
1225 |             return "Error: Could not retrieve schema."
1226 | 
1227 |     async def _get_schema_fingerprint(conn: AsyncConnection) -> str:
1228 |         logger.debug("Generating schema fingerprint for NL->SQL...")
1229 |         try:
1230 |             # Lambda okay
1231 |             def sync_func(sync_conn):
1232 |                 return _get_schema_fingerprint_sync(sync_conn)
1233 | 
1234 |             fingerprint = await conn.run_sync(sync_func)
1235 |             return fingerprint
1236 |         except Exception as e:
1237 |             logger.error(f"Error generating schema fingerprint: {e}", exc_info=True)
1238 |             return "Error: Could not retrieve schema."
1239 | 
1240 |     async with eng.connect() as conn:
1241 |         schema_fingerprint = await _get_schema_fingerprint(conn)
1242 | 
1243 |     # Multi-line string assignment okay
1244 |     prompt = (
1245 |         "You are a highly specialized AI assistant that translates natural language questions into SQL queries.\n"
1246 |         "You must adhere STRICTLY to the following rules:\n"
1247 |         "1. Generate only a SINGLE, executable SQL query for the given database schema and question.\n"
1248 |         "2. Use the exact table and column names provided in the schema fingerprint.\n"
1249 |         "3. Do NOT generate any explanatory text, comments, or markdown formatting.\n"
1250 |         "4. The output MUST be a valid JSON object containing two keys: 'sql' (the generated SQL query as a string) and 'confidence' (a float between 0.0 and 1.0 indicating your confidence in the generated SQL).\n"
1251 |         "5. If the question cannot be answered from the schema or is ambiguous, set confidence to 0.0 and provide a minimal, safe query like 'SELECT 1;' in the 'sql' field.\n"
1252 |         "6. Prioritize safety: Avoid generating queries that could modify data (UPDATE, INSERT, DELETE, DROP, etc.). Generate SELECT statements ONLY.\n\n"  # Stricter rule
1253 |         f"Database Schema Fingerprint:\n```\n{schema_fingerprint}\n```\n\n"
1254 |         f"Natural Language Question:\n```\n{natural_language}\n```\n\n"
1255 |         "JSON Output:"
1256 |     )
1257 | 
1258 |     try:
1259 |         logger.debug("Sending prompt to LLM for NL->SQL conversion.")
1260 |         # Call okay
1261 |         completion_result = await generate_completion(
1262 |             prompt=prompt, max_tokens=350, temperature=0.2
1263 |         )
1264 |         llm_response_dict = completion_result
1265 | 
1266 |         # Ternary okay
1267 |         is_dict_response = isinstance(llm_response_dict, dict)
1268 |         llm_response = llm_response_dict.get("text", "") if is_dict_response else ""
1269 | 
1270 |         llm_response_preview = llm_response[:300]
1271 |         logger.debug(f"LLM Response received: {llm_response_preview}...")
1272 |         if not llm_response:
1273 |             raise ToolError("LLM returned empty response for NL->SQL.", http_status_code=502)
1274 | 
1275 |     except Exception as llm_err:
1276 |         logger.error(f"LLM completion failed for NL->SQL: {llm_err}", exc_info=True)
1277 |         details = {"error_type": "LLM_ERROR"}
1278 |         raise ToolError(
1279 |             f"Failed to get response from LLM: {llm_err}", http_status_code=502, details=details
1280 |         ) from llm_err
1281 | 
1282 |     try:
1283 |         data = {}
1284 |         try:
1285 |             # Try parsing the whole response as JSON first
1286 |             data = json.loads(llm_response)
1287 |         except json.JSONDecodeError as e:
1288 |             # If that fails, look for a JSON block within the text
1289 |             # Call okay
1290 |             search_regex = r"\{.*\}"
1291 |             search_flags = re.DOTALL | re.MULTILINE
1292 |             json_match = re.search(search_regex, llm_response, search_flags)
1293 |             if not json_match:
1294 |                 raise ValueError("No JSON object found in the LLM response.") from e
1295 |             json_str = json_match.group(0)
1296 |             data = json.loads(json_str)
1297 | 
1298 |         is_dict_data = isinstance(data, dict)
1299 |         has_sql = "sql" in data
1300 |         has_confidence = "confidence" in data
1301 |         if not is_dict_data or not has_sql or not has_confidence:
1302 |             raise ValueError("LLM response JSON is missing required keys ('sql', 'confidence').")
1303 | 
1304 |         sql = data["sql"]
1305 |         conf_raw = data["confidence"]
1306 |         conf = float(conf_raw)
1307 | 
1308 |         is_sql_str = isinstance(sql, str)
1309 |         is_conf_valid = 0.0 <= conf <= 1.0
1310 |         if not is_sql_str or not is_conf_valid:
1311 |             raise ValueError("LLM response has invalid types for 'sql' or 'confidence'.")
1312 | 
1313 |         sql_preview = sql[:150]
1314 |         logger.info(f"LLM generated SQL with confidence {conf:.2f}: {sql_preview}...")
1315 | 
1316 |     except (json.JSONDecodeError, ValueError, TypeError) as e:
1317 |         response_preview = str(llm_response)[:200]
1318 |         error_detail = (
1319 |             f"LLM returned invalid or malformed JSON: {e}. Response: '{response_preview}...'"
1320 |         )
1321 |         logger.error(error_detail)
1322 |         details = {"error_type": "LLM_RESPONSE_INVALID"}
1323 |         raise ToolError(error_detail, http_status_code=500, details=details) from e
1324 | 
1325 |     is_below_threshold = conf < confidence_threshold
1326 |     if is_below_threshold:
1327 |         nl_query_preview = natural_language
1328 |         low_conf_msg = f"LLM confidence ({conf:.2f}) is below the required threshold ({confidence_threshold}). NL Query: '{nl_query_preview}'"
1329 |         logger.warning(low_conf_msg)
1330 |         details = {"error_type": "LOW_CONFIDENCE"}
1331 |         raise ToolError(
1332 |             low_conf_msg, http_status_code=400, generated_sql=sql, confidence=conf, details=details
1333 |         ) from None
1334 | 
1335 |     try:
1336 |         _sql_check_safe(sql, read_only=True)  # Enforce read-only for generated SQL
1337 |         # Call okay
1338 |         sql_upper = sql.upper()
1339 |         sql_stripped = sql_upper.lstrip()
1340 |         is_valid_start = sql_stripped.startswith(("SELECT", "WITH"))
1341 |         if not is_valid_start:
1342 |             details = {"error_type": "INVALID_GENERATED_SQL"}
1343 |             raise ToolError(
1344 |                 "Generated query does not appear to be a valid SELECT statement.",
1345 |                 http_status_code=400,
1346 |                 details=details,
1347 |             )
1348 |         # Basic table check (optional, as before)
1349 |     except ToolInputError as safety_err:
1350 |         logger.error(f"Generated SQL failed safety check: {safety_err}. SQL: {sql}")
1351 |         details = {"error_type": "SAFETY_VIOLATION"}
1352 |         raise ToolError(
1353 |             f"Generated SQL failed validation: {safety_err}",
1354 |             http_status_code=400,
1355 |             generated_sql=sql,
1356 |             confidence=conf,
1357 |             details=details,
1358 |         ) from safety_err
1359 | 
1360 |     result_dict = {"sql": sql, "confidence": conf}
1361 |     return result_dict
1362 | 
1363 | 
1364 | # =============================================================================
1365 | # Public Tool Functions (Standalone replacements for SQLTool methods)
1366 | # =============================================================================
1367 | 
1368 | 
1369 | @with_tool_metrics
1370 | @with_error_handling
1371 | async def manage_database(
1372 |     action: str,
1373 |     connection_string: Optional[str] = None,
1374 |     connection_id: Optional[str] = None,
1375 |     echo: bool = False,
1376 |     user_id: Optional[str] = None,
1377 |     session_id: Optional[str] = None,
1378 |     ctx: Optional[Dict] = None,  # Added ctx for potential future use
1379 |     **options: Any,
1380 | ) -> Dict[str, Any]:
1381 |     """
1382 |     Unified database connection management tool.
1383 | 
1384 |     Args:
1385 |         action: The action to perform: "connect", "disconnect", "test", or "status".
1386 |         connection_string: Database connection string or secrets:// reference. (Required for "connect").
1387 |         connection_id: An existing connection ID (Required for "disconnect", "test"). Can be provided for "connect" to suggest an ID.
1388 |         echo: Enable SQLAlchemy engine logging (For "connect" action, default: False).
1389 |         user_id: Optional user identifier for audit logging.
1390 |         session_id: Optional session identifier for audit logging.
1391 |         ctx: Optional context from MCP server (not used currently).
1392 |         **options: Additional options:
1393 |             - For "connect": Passed directly to SQLAlchemy's `create_async_engine`.
1394 |             - Can include custom audit context.
1395 | 
1396 |     Returns:
1397 |         Dict with action results and metadata. Varies based on action.
1398 |     """
1399 |     tool_name = "manage_database"
1400 |     db_dialect = "unknown"
1401 |     # Dict comprehension okay
1402 |     audit_extras_all = {k: v for k, v in options.items()}
1403 |     audit_extras = {k: v for k, v in audit_extras_all.items() if k not in ["echo"]}
1404 | 
1405 |     try:
1406 |         if action == "connect":
1407 |             if not connection_string:
1408 |                 raise ToolInputError(
1409 |                     "connection_string is required for 'connect'", param_name="connection_string"
1410 |                 )
1411 |             # Ternary okay
1412 |             cid = connection_id or str(uuid.uuid4())
1413 |             logger.info(f"Attempting to connect with connection_id: {cid}")
1414 |             resolved_conn_str = _sql_resolve_conn(connection_string)
1415 |             url, db_type = _sql_driver_url(resolved_conn_str)
1416 |             db_dialect = db_type  # Update dialect for potential error logging
1417 |             pool_opts = _sql_auto_pool(db_type)
1418 |             # Dict unpacking okay
1419 |             engine_opts = {**pool_opts, **options}
1420 |             # Dict comprehension okay
1421 |             log_opts = {k: v for k, v in engine_opts.items() if k != "password"}
1422 |             logger.debug(f"Creating engine for {db_type} with options: {log_opts}")
1423 |             connect_args = engine_opts.pop("connect_args", {})
1424 |             # Ternary okay
1425 |             exec_opts_base = {"async_execution": True} if db_type == "snowflake" else {}
1426 |             # Pass other engine options directly
1427 |             execution_options = {**exec_opts_base, **engine_opts.pop("execution_options", {})}
1428 | 
1429 |             # Separate create_async_engine call
1430 |             eng = create_async_engine(
1431 |                 url,
1432 |                 echo=echo,
1433 |                 connect_args=connect_args,
1434 |                 execution_options=execution_options,
1435 |                 **engine_opts,  # Pass remaining options like pool settings
1436 |             )
1437 | 
1438 |             try:
1439 |                 # Ternary okay
1440 |                 test_sql = "SELECT CURRENT_TIMESTAMP" if db_type != "sqlite" else "SELECT 1"
1441 |                 # Call okay
1442 |                 await _sql_exec(
1443 |                     eng,
1444 |                     test_sql,
1445 |                     None,
1446 |                     limit=1,
1447 |                     tool_name=tool_name,
1448 |                     action_name="connect_test",
1449 |                     timeout=15,
1450 |                 )
1451 |                 logger.info(f"Connection test successful for {cid} ({db_type}).")
1452 |             except ToolError as test_err:
1453 |                 logger.error(f"Connection test failed for {cid} ({db_type}): {test_err}")
1454 |                 await eng.dispose()
1455 |                 # Get details okay
1456 |                 err_details = getattr(test_err, "details", None)
1457 |                 raise ToolError(
1458 |                     f"Connection test failed: {test_err}", http_status_code=400, details=err_details
1459 |                 ) from test_err
1460 |             except Exception as e:
1461 |                 logger.error(
1462 |                     f"Unexpected error during connection test for {cid} ({db_type}): {e}",
1463 |                     exc_info=True,
1464 |                 )
1465 |                 await eng.dispose()
1466 |                 raise ToolError(
1467 |                     f"Unexpected error during connection test: {e}", http_status_code=500
1468 |                 ) from e
1469 | 
1470 |             await _connection_manager.add_connection(cid, eng)
1471 |             # Call okay
1472 |             await _sql_audit(
1473 |                 tool_name=tool_name,
1474 |                 action="connect",
1475 |                 connection_id=cid,
1476 |                 sql=None,
1477 |                 tables=None,
1478 |                 row_count=None,
1479 |                 success=True,
1480 |                 error=None,
1481 |                 user_id=user_id,
1482 |                 session_id=session_id,
1483 |                 database_type=db_type,
1484 |                 echo=echo,
1485 |                 **audit_extras,
1486 |             )
1487 |             # Return dict okay
1488 |             return {
1489 |                 "action": "connect",
1490 |                 "connection_id": cid,
1491 |                 "database_type": db_type,
1492 |                 "success": True,
1493 |             }
1494 | 
1495 |         elif action == "disconnect":
1496 |             if not connection_id:
1497 |                 raise ToolInputError(
1498 |                     "connection_id is required for 'disconnect'", param_name="connection_id"
1499 |                 )
1500 |             logger.info(f"Attempting to disconnect connection_id: {connection_id}")
1501 |             db_dialect_for_audit = "unknown"  # Default if engine retrieval fails
1502 |             try:
1503 |                 # Needs await before get_connection
1504 |                 engine_to_close = await _connection_manager.get_connection(connection_id)
1505 |                 db_dialect_for_audit = engine_to_close.dialect.name
1506 |             except ToolInputError:
1507 |                 # This error means connection_id wasn't found by get_connection
1508 |                 logger.warning(f"Disconnect requested for unknown connection_id: {connection_id}")
1509 |                 # Call okay
1510 |                 await _sql_audit(
1511 |                     tool_name=tool_name,
1512 |                     action="disconnect",
1513 |                     connection_id=connection_id,
1514 |                     sql=None,
1515 |                     tables=None,
1516 |                     row_count=None,
1517 |                     success=False,
1518 |                     error="Connection ID not found",
1519 |                     user_id=user_id,
1520 |                     session_id=session_id,
1521 |                     **audit_extras,
1522 |                 )
1523 |                 # Return dict okay
1524 |                 return {
1525 |                     "action": "disconnect",
1526 |                     "connection_id": connection_id,
1527 |                     "success": False,
1528 |                     "message": "Connection ID not found",
1529 |                 }
1530 |             except Exception as e:
1531 |                 # Catch other errors during engine retrieval itself
1532 |                 logger.error(f"Error retrieving engine for disconnect ({connection_id}): {e}")
1533 |                 # Proceed to attempt close, but audit will likely show failure or non-existence
1534 | 
1535 |             # Attempt closing even if retrieval had issues (it might have been removed between check and close)
1536 |             success = await _connection_manager.close_connection(connection_id)
1537 |             # Ternary okay
1538 |             error_msg = None if success else "Failed to close or already closed/not found"
1539 |             # Call okay
1540 |             await _sql_audit(
1541 |                 tool_name=tool_name,
1542 |                 action="disconnect",
1543 |                 connection_id=connection_id,
1544 |                 sql=None,
1545 |                 tables=None,
1546 |                 row_count=None,
1547 |                 success=success,
1548 |                 error=error_msg,
1549 |                 user_id=user_id,
1550 |                 session_id=session_id,
1551 |                 database_type=db_dialect_for_audit,
1552 |                 **audit_extras,
1553 |             )
1554 |             # Return dict okay
1555 |             return {"action": "disconnect", "connection_id": connection_id, "success": success}
1556 | 
1557 |         elif action == "test":
1558 |             if not connection_id:
1559 |                 raise ToolInputError(
1560 |                     "connection_id is required for 'test'", param_name="connection_id"
1561 |                 )
1562 |             logger.info(f"Testing connection_id: {connection_id}")
1563 |             eng = await _sql_get_engine(connection_id)
1564 |             db_dialect = eng.dialect.name  # Now dialect is known for sure
1565 |             t0 = time.perf_counter()
1566 |             # Ternary conditions okay
1567 |             vsql = (
1568 |                 "SELECT sqlite_version()"
1569 |                 if db_dialect == "sqlite"
1570 |                 else "SELECT CURRENT_VERSION()"
1571 |                 if db_dialect == "snowflake"
1572 |                 else "SELECT version()"
1573 |             )
1574 |             # Call okay
1575 |             cols, rows, _ = await _sql_exec(
1576 |                 eng, vsql, None, limit=1, tool_name=tool_name, action_name="test", timeout=10
1577 |             )
1578 |             latency = time.perf_counter() - t0
1579 |             # Ternary okay
1580 |             has_rows_and_cols = rows and cols
1581 |             version_info = rows[0].get(cols[0], "N/A") if has_rows_and_cols else "N/A"
1582 |             log_msg = f"Connection test successful for {connection_id}. Version: {version_info}, Latency: {latency:.3f}s"
1583 |             logger.info(log_msg)
1584 |             # Return dict okay
1585 |             return {
1586 |                 "action": "test",
1587 |                 "connection_id": connection_id,
1588 |                 "response_time_seconds": round(latency, 3),
1589 |                 "version": version_info,
1590 |                 "database_type": db_dialect,
1591 |                 "success": True,
1592 |             }
1593 | 
1594 |         elif action == "status":
1595 |             logger.info("Retrieving connection status.")
1596 |             connections_info = {}
1597 |             current_time = time.time()
1598 |             # Access connections safely using async with lock if needed, or make copy
1599 |             conn_items = []
1600 |             async with _connection_manager._lock:  # Access lock directly for iteration safety
1601 |                 # Call okay
1602 |                 conn_items = list(_connection_manager.connections.items())
1603 | 
1604 |             for conn_id, (eng, last_access) in conn_items:
1605 |                 try:
1606 |                     url_display_raw = str(eng.url)
1607 |                     parsed_url = make_url(url_display_raw)
1608 |                     url_display = url_display_raw  # Default
1609 |                     if parsed_url.password:
1610 |                         # Call okay
1611 |                         url_masked = parsed_url.set(password="***")
1612 |                         url_display = str(url_masked)
1613 |                     # Break down dict assignment
1614 |                     conn_info_dict = {}
1615 |                     conn_info_dict["url_summary"] = url_display
1616 |                     conn_info_dict["dialect"] = eng.dialect.name
1617 |                     last_access_dt = dt.datetime.fromtimestamp(last_access)
1618 |                     conn_info_dict["last_accessed"] = last_access_dt.isoformat()
1619 |                     idle_seconds = current_time - last_access
1620 |                     conn_info_dict["idle_time_seconds"] = round(idle_seconds, 1)
1621 |                     connections_info[conn_id] = conn_info_dict
1622 |                 except Exception as status_err:
1623 |                     logger.error(f"Error retrieving status for connection {conn_id}: {status_err}")
1624 |                     connections_info[conn_id] = {"error": str(status_err)}
1625 |             # Return dict okay
1626 |             return {
1627 |                 "action": "status",
1628 |                 "active_connections_count": len(connections_info),
1629 |                 "connections": connections_info,
1630 |                 "cleanup_interval_seconds": _connection_manager.cleanup_interval,
1631 |                 "success": True,
1632 |             }
1633 | 
1634 |         else:
1635 |             logger.error(f"Invalid action specified for manage_database: {action}")
1636 |             details = {"action": action}
1637 |             msg = f"Unknown action: '{action}'. Valid actions: connect, disconnect, test, status"
1638 |             raise ToolInputError(msg, param_name="action", details=details)
1639 | 
1640 |     except ToolInputError as tie:
1641 |         # Call okay
1642 |         await _sql_audit(
1643 |             tool_name=tool_name,
1644 |             action=action,
1645 |             connection_id=connection_id,
1646 |             sql=None,
1647 |             tables=None,
1648 |             row_count=None,
1649 |             success=False,
1650 |             error=str(tie),
1651 |             user_id=user_id,
1652 |             session_id=session_id,
1653 |             database_type=db_dialect,
1654 |             **audit_extras,
1655 |         )
1656 |         raise tie
1657 |     except ToolError as te:
1658 |         # Call okay
1659 |         await _sql_audit(
1660 |             tool_name=tool_name,
1661 |             action=action,
1662 |             connection_id=connection_id,
1663 |             sql=None,
1664 |             tables=None,
1665 |             row_count=None,
1666 |             success=False,
1667 |             error=str(te),
1668 |             user_id=user_id,
1669 |             session_id=session_id,
1670 |             database_type=db_dialect,
1671 |             **audit_extras,
1672 |         )
1673 |         raise te
1674 |     except Exception as e:
1675 |         log_msg = f"Unexpected error in manage_database (action: {action}): {e}"
1676 |         logger.error(log_msg, exc_info=True)
1677 |         error_str = f"Unexpected error: {e}"
1678 |         # Call okay
1679 |         await _sql_audit(
1680 |             tool_name=tool_name,
1681 |             action=action,
1682 |             connection_id=connection_id,
1683 |             sql=None,
1684 |             tables=None,
1685 |             row_count=None,
1686 |             success=False,
1687 |             error=error_str,
1688 |             user_id=user_id,
1689 |             session_id=session_id,
1690 |             database_type=db_dialect,
1691 |             **audit_extras,
1692 |         )
1693 |         raise ToolError(
1694 |             f"An unexpected error occurred in manage_database: {e}", http_status_code=500
1695 |         ) from e
1696 | 
1697 | 
1698 | @with_tool_metrics
1699 | @with_error_handling
1700 | async def execute_sql(
1701 |     connection_id: str,
1702 |     query: Optional[str] = None,
1703 |     natural_language: Optional[str] = None,
1704 |     parameters: Optional[Dict[str, Any]] = None,
1705 |     pagination: Optional[Dict[str, int]] = None,
1706 |     read_only: bool = True,
1707 |     export: Optional[Dict[str, Any]] = None,
1708 |     timeout: float = 60.0,
1709 |     validate_schema: Optional[Any] = None,
1710 |     max_rows: Optional[int] = 1000,
1711 |     confidence_threshold: float = 0.6,
1712 |     user_id: Optional[str] = None,
1713 |     session_id: Optional[str] = None,
1714 |     ctx: Optional[Dict] = None,  # Added ctx
1715 |     **options: Any,
1716 | ) -> Dict[str, Any]:
1717 |     """
1718 |     Unified SQL query execution tool.
1719 | 
1720 |     Handles direct SQL execution, NL-to-SQL conversion, pagination,
1721 |     result masking, safety checks, validation, and export.
1722 | 
1723 |     Args:
1724 |         connection_id: The ID of the database connection to use.
1725 |         query: The SQL query string to execute. (Use instead of natural_language).
1726 |         natural_language: A natural language question to convert to SQL. (Use instead of query).
1727 |         parameters: Dictionary of parameters for parameterized queries.
1728 |         pagination: Dict with "page" (>=1) and "page_size" (>=1) for paginated results.
1729 |                     Cannot be used with max_rows clipping if the dialect requires LIMIT/OFFSET.
1730 |         read_only: If True (default), enforces safety checks against write operations (UPDATE, DELETE, etc.). Set to False only if writes are explicitly intended and allowed.
1731 |         export: Dictionary with "format" ('pandas', 'excel', 'csv') and optional "path" (string) for exporting results.
1732 |         timeout: Maximum execution time in seconds (default: 60.0).
1733 |         validate_schema: A Pandera schema object to validate the results DataFrame against.
1734 |         max_rows: Maximum number of rows to return in the result (default: 1000). Set to None or -1 for unlimited (potentially dangerous).
1735 |         confidence_threshold: Minimum confidence score (0.0-1.0) required from the LLM for NL-to-SQL conversion (default: 0.6).
1736 |         user_id: Optional user identifier for audit logging.
1737 |         session_id: Optional session identifier for audit logging.
1738 |         ctx: Optional context from MCP server.
1739 |         **options: Additional options for audit logging or future extensions.
1740 | 
1741 |     Returns:
1742 |         A dictionary containing:
1743 |         - columns (List[str]): List of column names.
1744 |         - rows (List[Dict[str, Any]]): List of data rows (masked).
1745 |         - row_count (int): Number of rows returned in this batch/page.
1746 |         - truncated (bool): True if max_rows limited the results.
1747 |         - pagination (Optional[Dict]): Info about the current page if pagination was used.
1748 |         - generated_sql (Optional[str]): The SQL query generated from natural language, if applicable.
1749 |         - confidence (Optional[float]): The confidence score from the NL-to-SQL conversion, if applicable.
1750 |         - validation_status (Optional[str]): 'success', 'failed', 'skipped'.
1751 |         - validation_errors (Optional[Any]): Details if validation failed.
1752 |         - export_status (Optional[str]): Status message if export was attempted.
1753 |         - <format>_path (Optional[str]): Path to the exported file if export to file was successful.
1754 |         - dataframe (Optional[pd.DataFrame]): The raw Pandas DataFrame if export format was 'pandas'.
1755 |         - success (bool): Always True if no exception was raised.
1756 |     """
1757 |     tool_name = "execute_sql"
1758 |     action_name = "query"  # Default, may change
1759 |     original_query_input = query  # Keep track of original SQL input
1760 |     original_nl_input = natural_language  # Keep track of NL input
1761 |     generated_sql = None
1762 |     confidence = None
1763 |     final_query: str
1764 |     # Dict unpacking okay
1765 |     final_params = parameters or {}
1766 |     result: Dict[str, Any] = {}
1767 |     tables: List[str] = []
1768 |     # Dict unpacking okay
1769 |     audit_extras = {**options}
1770 | 
1771 |     try:
1772 |         # 1. Determine Query
1773 |         use_nl = natural_language and not query
1774 |         use_sql = query and not natural_language
1775 |         is_ambiguous = natural_language and query
1776 |         no_input = not natural_language and not query
1777 | 
1778 |         if is_ambiguous:
1779 |             msg = "Provide either 'query' or 'natural_language', not both."
1780 |             raise ToolInputError(msg, param_name="query/natural_language")
1781 |         if no_input:
1782 |             msg = "Either 'query' or 'natural_language' must be provided."
1783 |             raise ToolInputError(msg, param_name="query/natural_language")
1784 | 
1785 |         if use_nl:
1786 |             action_name = "nl_to_sql_exec"
1787 |             nl_preview = natural_language[:100]
1788 |             log_msg = (
1789 |                 f"Received natural language query for connection {connection_id}: '{nl_preview}...'"
1790 |             )
1791 |             logger.info(log_msg)
1792 |             try:
1793 |                 # Pass user_id/session_id to NL converter for lineage/audit trail consistency if needed
1794 |                 # Call okay
1795 |                 nl_result = await _sql_convert_nl_to_sql(
1796 |                     connection_id, natural_language, confidence_threshold, user_id, session_id
1797 |                 )
1798 |                 final_query = nl_result["sql"]
1799 |                 generated_sql = final_query
1800 |                 confidence = nl_result["confidence"]
1801 |                 # original_query remains None, original_nl_input has the NL
1802 |                 audit_extras["generated_sql"] = generated_sql
1803 |                 audit_extras["confidence"] = confidence
1804 |                 query_preview = final_query[:150]
1805 |                 log_msg = f"Successfully converted NL to SQL (Confidence: {confidence:.2f}): {query_preview}..."
1806 |                 logger.info(log_msg)
1807 |                 read_only = True  # Ensure read-only for generated SQL
1808 |             except ToolError as nl_err:
1809 |                 # Audit NL failure
1810 |                 await _sql_audit(
1811 |                     tool_name=tool_name,
1812 |                     action="nl_to_sql_fail",
1813 |                     connection_id=connection_id,
1814 |                     sql=natural_language,  # Log the NL query that failed
1815 |                     tables=None,
1816 |                     row_count=None,
1817 |                     success=False,
1818 |                     error=str(nl_err),
1819 |                     user_id=user_id,
1820 |                     session_id=session_id,
1821 |                     **audit_extras,
1822 |                 )
1823 |                 raise nl_err  # Re-raise the error
1824 |         elif use_sql:
1825 |             # Action name remains 'query'
1826 |             final_query = query
1827 |             query_preview = final_query[:150]
1828 |             logger.info(f"Executing direct SQL query on {connection_id}: {query_preview}...")
1829 |             # original_query_input has the SQL, original_nl_input is None
1830 |         # else case already handled by initial checks
1831 | 
1832 |         # 2. Check Safety
1833 |         _sql_check_safe(final_query, read_only)
1834 |         tables = _sql_extract_tables(final_query)
1835 |         logger.debug(f"Query targets tables: {tables}")
1836 | 
1837 |         # 3. Get Engine
1838 |         eng = await _sql_get_engine(connection_id)
1839 | 
1840 |         # 4. Handle Pagination or Standard Execution
1841 |         if pagination:
1842 |             action_name = "query_paginated"
1843 |             page = pagination.get("page", 1)
1844 |             page_size = pagination.get("page_size", 100)
1845 |             is_page_valid = isinstance(page, int) and page >= 1
1846 |             is_page_size_valid = isinstance(page_size, int) and page_size >= 1
1847 |             if not is_page_valid:
1848 |                 raise ToolInputError(
1849 |                     "Pagination 'page' must be an integer >= 1.", param_name="pagination.page"
1850 |                 )
1851 |             if not is_page_size_valid:
1852 |                 raise ToolInputError(
1853 |                     "Pagination 'page_size' must be an integer >= 1.",
1854 |                     param_name="pagination.page_size",
1855 |                 )
1856 | 
1857 |             offset = (page - 1) * page_size
1858 |             db_dialect = eng.dialect.name
1859 |             paginated_query: str
1860 |             if db_dialect == "sqlserver":
1861 |                 query_lower = final_query.lower()
1862 |                 has_order_by = "order by" in query_lower
1863 |                 if not has_order_by:
1864 |                     raise ToolInputError(
1865 |                         "SQL Server pagination requires an ORDER BY clause in the query.",
1866 |                         param_name="query",
1867 |                     )
1868 |                 paginated_query = (
1869 |                     f"{final_query} OFFSET :_page_offset ROWS FETCH NEXT :_page_size ROWS ONLY"
1870 |                 )
1871 |             elif db_dialect == "oracle":
1872 |                 paginated_query = (
1873 |                     f"{final_query} OFFSET :_page_offset ROWS FETCH NEXT :_page_size ROWS ONLY"
1874 |                 )
1875 |             else:  # Default LIMIT/OFFSET for others (MySQL, PostgreSQL, SQLite)
1876 |                 paginated_query = f"{final_query} LIMIT :_page_size OFFSET :_page_offset"
1877 | 
1878 |             # Fetch one extra row to check for next page
1879 |             fetch_size = page_size + 1
1880 |             # Dict unpacking okay
1881 |             paginated_params = {**final_params, "_page_size": fetch_size, "_page_offset": offset}
1882 |             log_msg = (
1883 |                 f"Executing paginated query (Page: {page}, Size: {page_size}): {paginated_query}"
1884 |             )
1885 |             logger.debug(log_msg)
1886 |             # Call okay
1887 |             cols, rows_with_extra, fetched_count_paged = await _sql_exec(
1888 |                 eng,
1889 |                 paginated_query,
1890 |                 paginated_params,
1891 |                 limit=None,  # Limit is applied in SQL for pagination
1892 |                 tool_name=tool_name,
1893 |                 action_name=action_name,
1894 |                 timeout=timeout,
1895 |             )
1896 | 
1897 |             # Check if more rows exist than requested page size
1898 |             has_next_page = len(rows_with_extra) > page_size
1899 |             returned_rows = rows_with_extra[:page_size]
1900 |             returned_row_count = len(returned_rows)
1901 | 
1902 |             # Build result dict piece by piece
1903 |             pagination_info = {}
1904 |             pagination_info["page"] = page
1905 |             pagination_info["page_size"] = page_size
1906 |             pagination_info["has_next_page"] = has_next_page
1907 |             pagination_info["has_previous_page"] = page > 1
1908 | 
1909 |             result = {}
1910 |             result["columns"] = cols
1911 |             result["rows"] = returned_rows
1912 |             result["row_count"] = returned_row_count
1913 |             result["pagination"] = pagination_info
1914 |             result["truncated"] = False  # Not truncated by max_rows in pagination mode
1915 |             result["success"] = True
1916 | 
1917 |         else:  # Standard execution (no pagination dict)
1918 |             action_name = "query_standard"
1919 |             # Ternary okay
1920 |             needs_limit = max_rows is not None and max_rows >= 0
1921 |             fetch_limit = (max_rows + 1) if needs_limit else None
1922 | 
1923 |             query_preview = final_query[:150]
1924 |             log_msg = f"Executing standard query (Max rows: {max_rows}): {query_preview}..."
1925 |             logger.debug(log_msg)
1926 |             # Call okay
1927 |             cols, rows_maybe_extra, fetched_count = await _sql_exec(
1928 |                 eng,
1929 |                 final_query,
1930 |                 final_params,
1931 |                 limit=fetch_limit,  # Use fetch_limit (max_rows + 1 or None)
1932 |                 tool_name=tool_name,
1933 |                 action_name=action_name,
1934 |                 timeout=timeout,
1935 |             )
1936 | 
1937 |             # Determine truncation based on fetch_limit
1938 |             truncated = fetch_limit is not None and fetched_count >= fetch_limit
1939 |             # Apply actual max_rows limit to returned data
1940 |             # Ternary okay
1941 |             returned_rows = rows_maybe_extra[:max_rows] if needs_limit else rows_maybe_extra
1942 |             returned_row_count = len(returned_rows)
1943 | 
1944 |             # Build result dict piece by piece
1945 |             result = {}
1946 |             result["columns"] = cols
1947 |             result["rows"] = returned_rows
1948 |             result["row_count"] = returned_row_count
1949 |             result["truncated"] = truncated
1950 |             result["success"] = True
1951 |             # No pagination key in standard mode
1952 | 
1953 |         # Add NL->SQL info if applicable
1954 |         if generated_sql:
1955 |             result["generated_sql"] = generated_sql
1956 |             result["confidence"] = confidence
1957 | 
1958 |         # 5. Handle Validation
1959 |         if validate_schema:
1960 |             temp_df = None
1961 |             validation_status = "skipped (unknown reason)"
1962 |             validation_errors = None
1963 |             if pd:
1964 |                 try:
1965 |                     # Ternary okay
1966 |                     df_data = result["rows"]
1967 |                     df_cols = result["columns"]
1968 |                     temp_df = (
1969 |                         pd.DataFrame(df_data, columns=df_cols)
1970 |                         if df_data
1971 |                         else pd.DataFrame(columns=df_cols)
1972 |                     )
1973 |                     try:
1974 |                         # Call okay
1975 |                         await _sql_validate_df(temp_df, validate_schema)
1976 |                         validation_status = "success"
1977 |                         logger.info("Pandera validation passed.")
1978 |                     except ToolError as val_err:
1979 |                         logger.warning(f"Pandera validation failed: {val_err}")
1980 |                         validation_status = "failed"
1981 |                         # Get validation errors okay
1982 |                         validation_errors = getattr(val_err, "validation_errors", str(val_err))
1983 |                 except Exception as df_err:
1984 |                     logger.error(f"Error creating DataFrame for validation: {df_err}")
1985 |                     validation_status = f"skipped (Failed to create DataFrame: {df_err})"
1986 |             else:
1987 |                 logger.warning("Pandas not installed, skipping Pandera validation.")
1988 |                 validation_status = "skipped (Pandas not installed)"
1989 | 
1990 |             result["validation_status"] = validation_status
1991 |             if validation_errors:
1992 |                 result["validation_errors"] = validation_errors
1993 | 
1994 |         # 6. Handle Export
1995 |         export_requested = export and export.get("format")
1996 |         if export_requested:
1997 |             export_format = export["format"]  # Keep original case for path key
1998 |             export_format_lower = export_format.lower()
1999 |             req_path = export.get("path")
2000 |             log_msg = f"Export requested: Format={export_format}, Path={req_path or 'Temporary'}"
2001 |             logger.info(log_msg)
2002 |             export_status = "failed (unknown reason)"
2003 |             try:
2004 |                 # Call okay
2005 |                 dataframe, export_path = _sql_export_rows(
2006 |                     result["columns"], result["rows"], export_format_lower, req_path
2007 |                 )
2008 |                 export_status = "success"
2009 |                 if dataframe is not None:  # Only if format was 'pandas'
2010 |                     result["dataframe"] = dataframe
2011 |                 if export_path:  # If file was created
2012 |                     path_key = f"{export_format_lower}_path"  # Use lowercase format for key
2013 |                     result[path_key] = export_path
2014 |                 log_msg = f"Export successful. Format: {export_format}, Path: {export_path or 'In-memory DataFrame'}"
2015 |                 logger.info(log_msg)
2016 |                 audit_extras["export_format"] = export_format
2017 |                 audit_extras["export_path"] = export_path
2018 |             except (ToolError, ToolInputError) as export_err:
2019 |                 logger.error(f"Export failed: {export_err}")
2020 |                 export_status = f"Failed: {export_err}"
2021 |             result["export_status"] = export_status
2022 | 
2023 |         # 7. Audit Success
2024 |         # Determine which query to log based on input
2025 |         audit_sql = original_nl_input if use_nl else original_query_input
2026 |         audit_row_count = result.get("row_count", 0)
2027 |         audit_val_status = result.get("validation_status")
2028 |         audit_exp_status = result.get("export_status", "not requested")
2029 |         # Call okay
2030 |         await _sql_audit(
2031 |             tool_name=tool_name,
2032 |             action=action_name,
2033 |             connection_id=connection_id,
2034 |             sql=audit_sql,
2035 |             tables=tables,
2036 |             row_count=audit_row_count,
2037 |             success=True,
2038 |             error=None,
2039 |             user_id=user_id,
2040 |             session_id=session_id,
2041 |             read_only=read_only,
2042 |             pagination_used=bool(pagination),
2043 |             validation_status=audit_val_status,
2044 |             export_status=audit_exp_status,
2045 |             **audit_extras,
2046 |         )
2047 |         return result
2048 | 
2049 |     except ToolInputError as tie:
2050 |         # Audit failure, use original inputs for logging context
2051 |         audit_sql = original_nl_input if original_nl_input else original_query_input
2052 |         # Call okay
2053 |         await _sql_audit(
2054 |             tool_name=tool_name,
2055 |             action=action_name + "_fail",
2056 |             connection_id=connection_id,
2057 |             sql=audit_sql,
2058 |             tables=tables,
2059 |             row_count=0,
2060 |             success=False,
2061 |             error=str(tie),
2062 |             user_id=user_id,
2063 |             session_id=session_id,
2064 |             **audit_extras,
2065 |         )
2066 |         raise tie
2067 |     except ToolError as te:
2068 |         # Audit failure
2069 |         audit_sql = original_nl_input if original_nl_input else original_query_input
2070 |         # Call okay
2071 |         await _sql_audit(
2072 |             tool_name=tool_name,
2073 |             action=action_name + "_fail",
2074 |             connection_id=connection_id,
2075 |             sql=audit_sql,
2076 |             tables=tables,
2077 |             row_count=0,
2078 |             success=False,
2079 |             error=str(te),
2080 |             user_id=user_id,
2081 |             session_id=session_id,
2082 |             **audit_extras,
2083 |         )
2084 |         raise te
2085 |     except Exception as e:
2086 |         log_msg = f"Unexpected error in execute_sql (action: {action_name}): {e}"
2087 |         logger.error(log_msg, exc_info=True)
2088 |         # Audit failure
2089 |         audit_sql = original_nl_input if original_nl_input else original_query_input
2090 |         error_str = f"Unexpected error: {e}"
2091 |         # Call okay
2092 |         await _sql_audit(
2093 |             tool_name=tool_name,
2094 |             action=action_name + "_fail",
2095 |             connection_id=connection_id,
2096 |             sql=audit_sql,
2097 |             tables=tables,
2098 |             row_count=0,
2099 |             success=False,
2100 |             error=error_str,
2101 |             user_id=user_id,
2102 |             session_id=session_id,
2103 |             **audit_extras,
2104 |         )
2105 |         raise ToolError(
2106 |             f"An unexpected error occurred during SQL execution: {e}", http_status_code=500
2107 |         ) from e
2108 | 
2109 | 
2110 | @with_tool_metrics
2111 | @with_error_handling
2112 | async def explore_database(
2113 |     connection_id: str,
2114 |     action: str,
2115 |     table_name: Optional[str] = None,
2116 |     column_name: Optional[str] = None,
2117 |     schema_name: Optional[str] = None,
2118 |     user_id: Optional[str] = None,
2119 |     session_id: Optional[str] = None,
2120 |     ctx: Optional[Dict] = None,  # Added ctx
2121 |     **options: Any,
2122 | ) -> Dict[str, Any]:
2123 |     """
2124 |     Unified database schema exploration and documentation tool.
2125 | 
2126 |     Performs actions like listing schemas, tables, views, columns,
2127 |     getting table/column details, finding relationships, and generating documentation.
2128 | 
2129 |     Args:
2130 |         connection_id: The ID of the database connection to use.
2131 |         action: The exploration action:
2132 |             - "schema": Get full schema details (tables, views, columns, relationships).
2133 |             - "table": Get details for a specific table (columns, PK, FKs, indexes, optionally sample data/stats). Requires `table_name`.
2134 |             - "column": Get statistics for a specific column (nulls, distinct, optionally histogram). Requires `table_name` and `column_name`.
2135 |             - "relationships": Find related tables via foreign keys up to a certain depth. Requires `table_name`.
2136 |             - "documentation": Generate schema documentation (markdown or JSON).
2137 |         table_name: Name of the table for 'table', 'column', 'relationships' actions.
2138 |         column_name: Name of the column for 'column' action.
2139 |         schema_name: Specific schema to inspect (if supported by dialect and needed). Defaults to connection's default schema.
2140 |         user_id: Optional user identifier for audit logging.
2141 |         session_id: Optional session identifier for audit logging.
2142 |         ctx: Optional context from MCP server.
2143 |         **options: Additional options depending on the action:
2144 |             - schema: include_indexes (bool), include_foreign_keys (bool), detailed (bool)
2145 |             - table: include_sample_data (bool), sample_size (int), include_statistics (bool)
2146 |             - column: histogram (bool), num_buckets (int)
2147 |             - relationships: depth (int)
2148 |             - documentation: output_format ('markdown'|'json'), include_indexes(bool), include_foreign_keys(bool)
2149 | 
2150 |     Returns:
2151 |         A dictionary containing the results of the exploration action and a 'success' flag.
2152 |         Structure varies significantly based on the action.
2153 |     """
2154 |     tool_name = "explore_database"
2155 |     # Break down audit_extras creation
2156 |     audit_extras = {}
2157 |     audit_extras.update(options)
2158 |     audit_extras["table_name"] = table_name
2159 |     audit_extras["column_name"] = column_name
2160 |     audit_extras["schema_name"] = schema_name
2161 | 
2162 |     try:
2163 |         log_msg = f"Exploring database for connection {connection_id}. Action: {action}, Table: {table_name}, Column: {column_name}, Schema: {schema_name}"
2164 |         logger.info(log_msg)
2165 |         eng = await _sql_get_engine(connection_id)
2166 |         db_dialect = eng.dialect.name
2167 |         audit_extras["database_type"] = db_dialect
2168 | 
2169 |         # Define sync inspection helper (runs within connect block)
2170 |         def _run_sync_inspection(
2171 |             inspector_target: Union[AsyncConnection, AsyncEngine], func_to_run: callable
2172 |         ):
2173 |             # Call okay
2174 |             sync_inspector = sa_inspect(inspector_target)
2175 |             return func_to_run(sync_inspector)
2176 | 
2177 |         async with eng.connect() as conn:
2178 |             # --- Action: schema ---
2179 |             if action == "schema":
2180 |                 include_indexes = options.get("include_indexes", True)
2181 |                 include_foreign_keys = options.get("include_foreign_keys", True)
2182 |                 detailed = options.get("detailed", False)
2183 |                 filter_schema = schema_name  # Use provided schema or None for default
2184 | 
2185 |                 def _get_full_schema(sync_conn) -> Dict[str, Any]:
2186 |                     # Separate inspector and target schema assignment
2187 |                     insp = sa_inspect(sync_conn)
2188 |                     target_schema = filter_schema or getattr(insp, "default_schema_name", None)
2189 | 
2190 |                     log_msg = f"Inspecting schema: {target_schema or 'Default'}. Detailed: {detailed}, Indexes: {include_indexes}, FKs: {include_foreign_keys}"
2191 |                     logger.info(log_msg)
2192 |                     tables_data: List[Dict[str, Any]] = []
2193 |                     views_data: List[Dict[str, Any]] = []
2194 |                     relationships: List[Dict[str, Any]] = []
2195 |                     try:
2196 |                         table_names = insp.get_table_names(schema=target_schema)
2197 |                         view_names = insp.get_view_names(schema=target_schema)
2198 |                     except Exception as inspect_err:
2199 |                         msg = f"Failed to list tables/views for schema '{target_schema}': {inspect_err}"
2200 |                         raise ToolError(msg, http_status_code=500) from inspect_err
2201 | 
2202 |                     for tbl_name in table_names:
2203 |                         try:
2204 |                             # Build t_info dict step-by-step
2205 |                             t_info: Dict[str, Any] = {}
2206 |                             t_info["name"] = tbl_name
2207 |                             t_info["columns"] = []
2208 |                             if target_schema:
2209 |                                 t_info["schema"] = target_schema
2210 | 
2211 |                             columns_raw = insp.get_columns(tbl_name, schema=target_schema)
2212 |                             for c in columns_raw:
2213 |                                 # Build col_info dict step-by-step
2214 |                                 col_info = {}
2215 |                                 col_info["name"] = c["name"]
2216 |                                 col_info["type"] = str(c["type"])
2217 |                                 col_info["nullable"] = c["nullable"]
2218 |                                 col_info["primary_key"] = bool(c.get("primary_key"))
2219 |                                 if detailed:
2220 |                                     col_info["default"] = c.get("default")
2221 |                                     col_info["comment"] = c.get("comment")
2222 |                                     col_info["autoincrement"] = c.get("autoincrement", "auto")
2223 |                                 t_info["columns"].append(col_info)
2224 | 
2225 |                             if include_indexes:
2226 |                                 try:
2227 |                                     idxs_raw = insp.get_indexes(tbl_name, schema=target_schema)
2228 |                                     # List comprehension okay
2229 |                                     t_info["indexes"] = [
2230 |                                         {
2231 |                                             "name": i["name"],
2232 |                                             "columns": i["column_names"],
2233 |                                             "unique": i.get("unique", False),
2234 |                                         }
2235 |                                         for i in idxs_raw
2236 |                                     ]
2237 |                                 except Exception as idx_err:
2238 |                                     logger.warning(
2239 |                                         f"Could not retrieve indexes for table {tbl_name}: {idx_err}"
2240 |                                     )
2241 |                                     t_info["indexes"] = []
2242 |                             if include_foreign_keys:
2243 |                                 try:
2244 |                                     fks_raw = insp.get_foreign_keys(tbl_name, schema=target_schema)
2245 |                                     if fks_raw:
2246 |                                         t_info["foreign_keys"] = []
2247 |                                         for fk in fks_raw:
2248 |                                             # Build fk_info dict step-by-step
2249 |                                             fk_info = {}
2250 |                                             fk_info["name"] = fk.get("name")
2251 |                                             fk_info["constrained_columns"] = fk[
2252 |                                                 "constrained_columns"
2253 |                                             ]
2254 |                                             fk_info["referred_schema"] = fk.get("referred_schema")
2255 |                                             fk_info["referred_table"] = fk["referred_table"]
2256 |                                             fk_info["referred_columns"] = fk["referred_columns"]
2257 |                                             t_info["foreign_keys"].append(fk_info)
2258 | 
2259 |                                             # Build relationship dict step-by-step
2260 |                                             rel_info = {}
2261 |                                             rel_info["source_schema"] = target_schema
2262 |                                             rel_info["source_table"] = tbl_name
2263 |                                             rel_info["source_columns"] = fk["constrained_columns"]
2264 |                                             rel_info["target_schema"] = fk.get("referred_schema")
2265 |                                             rel_info["target_table"] = fk["referred_table"]
2266 |                                             rel_info["target_columns"] = fk["referred_columns"]
2267 |                                             relationships.append(rel_info)
2268 |                                 except Exception as fk_err:
2269 |                                     logger.warning(
2270 |                                         f"Could not retrieve foreign keys for table {tbl_name}: {fk_err}"
2271 |                                     )
2272 |                             tables_data.append(t_info)
2273 |                         except Exception as tbl_err:
2274 |                             log_msg = f"Failed to inspect table '{tbl_name}' in schema '{target_schema}': {tbl_err}"
2275 |                             logger.error(log_msg, exc_info=True)
2276 |                             # Append error dict
2277 |                             error_entry = {
2278 |                                 "name": tbl_name,
2279 |                                 "schema": target_schema,
2280 |                                 "error": f"Failed to inspect: {tbl_err}",
2281 |                             }
2282 |                             tables_data.append(error_entry)
2283 | 
2284 |                     for view_name in view_names:
2285 |                         try:
2286 |                             # Build view_info dict step-by-step
2287 |                             view_info: Dict[str, Any] = {}
2288 |                             view_info["name"] = view_name
2289 |                             if target_schema:
2290 |                                 view_info["schema"] = target_schema
2291 |                             try:
2292 |                                 view_def_raw = insp.get_view_definition(
2293 |                                     view_name, schema=target_schema
2294 |                                 )
2295 |                                 # Ternary okay
2296 |                                 view_def = view_def_raw or ""
2297 |                                 view_info["definition"] = view_def
2298 |                             except Exception as view_def_err:
2299 |                                 log_msg = f"Could not retrieve definition for view {view_name}: {view_def_err}"
2300 |                                 logger.warning(log_msg)
2301 |                                 view_info["definition"] = "Error retrieving definition"
2302 |                             try:
2303 |                                 view_cols_raw = insp.get_columns(view_name, schema=target_schema)
2304 |                                 # List comprehension okay
2305 |                                 view_info["columns"] = [
2306 |                                     {"name": vc["name"], "type": str(vc["type"])}
2307 |                                     for vc in view_cols_raw
2308 |                                 ]
2309 |                             except Exception:
2310 |                                 pass  # Ignore column errors for views if definition failed etc.
2311 |                             views_data.append(view_info)
2312 |                         except Exception as view_err:
2313 |                             log_msg = f"Failed to inspect view '{view_name}' in schema '{target_schema}': {view_err}"
2314 |                             logger.error(log_msg, exc_info=True)
2315 |                             # Append error dict
2316 |                             error_entry = {
2317 |                                 "name": view_name,
2318 |                                 "schema": target_schema,
2319 |                                 "error": f"Failed to inspect: {view_err}",
2320 |                             }
2321 |                             views_data.append(error_entry)
2322 | 
2323 |                     # Build schema_result dict step-by-step
2324 |                     schema_result: Dict[str, Any] = {}
2325 |                     schema_result["action"] = "schema"
2326 |                     schema_result["database_type"] = db_dialect
2327 |                     schema_result["inspected_schema"] = target_schema or "Default"
2328 |                     schema_result["tables"] = tables_data
2329 |                     schema_result["views"] = views_data
2330 |                     schema_result["relationships"] = relationships
2331 |                     schema_result["success"] = True
2332 | 
2333 |                     # Schema Hashing and Lineage
2334 |                     try:
2335 |                         # Call okay
2336 |                         schema_json = json.dumps(schema_result, sort_keys=True, default=str)
2337 |                         schema_bytes = schema_json.encode()
2338 |                         # Call okay
2339 |                         schema_hash = hashlib.sha256(schema_bytes).hexdigest()
2340 | 
2341 |                         timestamp = _sql_now()
2342 |                         last_hash = _SCHEMA_VERSIONS.get(connection_id)
2343 |                         schema_changed = last_hash != schema_hash
2344 | 
2345 |                         if schema_changed:
2346 |                             _SCHEMA_VERSIONS[connection_id] = schema_hash
2347 |                             # Build lineage_entry dict step-by-step
2348 |                             lineage_entry = {}
2349 |                             lineage_entry["connection_id"] = connection_id
2350 |                             lineage_entry["timestamp"] = timestamp
2351 |                             lineage_entry["schema_hash"] = schema_hash
2352 |                             lineage_entry["previous_hash"] = last_hash
2353 |                             lineage_entry["user_id"] = user_id  # Include user from outer scope
2354 |                             lineage_entry["tables_count"] = len(tables_data)
2355 |                             lineage_entry["views_count"] = len(views_data)
2356 |                             lineage_entry["action_source"] = f"{tool_name}/{action}"
2357 |                             _LINEAGE.append(lineage_entry)
2358 | 
2359 |                             hash_preview = schema_hash[:8]
2360 |                             prev_hash_preview = last_hash[:8] if last_hash else "None"
2361 |                             log_msg = f"Schema change detected or initial capture for {connection_id}. New hash: {hash_preview}..., Previous: {prev_hash_preview}"
2362 |                             logger.info(log_msg)
2363 |                             schema_result["schema_hash"] = schema_hash
2364 |                             # Boolean conversion okay
2365 |                             schema_result["schema_change_detected"] = bool(last_hash)
2366 |                     except Exception as hash_err:
2367 |                         log_msg = f"Error generating schema hash or recording lineage: {hash_err}"
2368 |                         logger.error(log_msg, exc_info=True)
2369 | 
2370 |                     return schema_result
2371 | 
2372 |                 # Call okay
2373 |                 def sync_func(sync_conn_arg):
2374 |                     return _get_full_schema(sync_conn_arg)
2375 | 
2376 |                 result = await conn.run_sync(sync_func)  # Pass sync connection
2377 | 
2378 |             # --- Action: table ---
2379 |             elif action == "table":
2380 |                 if not table_name:
2381 |                     raise ToolInputError(
2382 |                         "`table_name` is required for 'table'", param_name="table_name"
2383 |                     )
2384 |                 include_sample = options.get("include_sample_data", False)
2385 |                 sample_size_raw = options.get("sample_size", 5)
2386 |                 sample_size = int(sample_size_raw)
2387 |                 include_stats = options.get("include_statistics", False)
2388 |                 if sample_size < 0:
2389 |                     sample_size = 0
2390 | 
2391 |                 def _get_basic_table_meta(sync_conn) -> Dict[str, Any]:
2392 |                     # Assign inspector and schema
2393 |                     insp = sa_inspect(sync_conn)
2394 |                     target_schema = schema_name or getattr(insp, "default_schema_name", None)
2395 |                     logger.info(f"Inspecting table details: {target_schema}.{table_name}")
2396 |                     try:
2397 |                         all_tables = insp.get_table_names(schema=target_schema)
2398 |                         if table_name not in all_tables:
2399 |                             msg = f"Table '{table_name}' not found in schema '{target_schema}'."
2400 |                             raise ToolInputError(msg, param_name="table_name")
2401 |                     except Exception as list_err:
2402 |                         msg = f"Could not verify if table '{table_name}' exists: {list_err}"
2403 |                         raise ToolError(msg, http_status_code=500) from list_err
2404 | 
2405 |                     # Initialize meta parts
2406 |                     cols = []
2407 |                     idx = []
2408 |                     fks = []
2409 |                     pk_constraint = {}
2410 |                     table_comment_text = None
2411 | 
2412 |                     cols = insp.get_columns(table_name, schema=target_schema)
2413 |                     try:
2414 |                         idx = insp.get_indexes(table_name, schema=target_schema)
2415 |                     except Exception as idx_err:
2416 |                         logger.warning(f"Could not get indexes for table {table_name}: {idx_err}")
2417 |                     try:
2418 |                         fks = insp.get_foreign_keys(table_name, schema=target_schema)
2419 |                     except Exception as fk_err:
2420 |                         logger.warning(
2421 |                             f"Could not get foreign keys for table {table_name}: {fk_err}"
2422 |                         )
2423 |                     try:
2424 |                         pk_info = insp.get_pk_constraint(table_name, schema=target_schema)
2425 |                         # Split pk_constraint assignment
2426 |                         if pk_info and pk_info.get("constrained_columns"):
2427 |                             pk_constraint = {
2428 |                                 "name": pk_info.get("name"),
2429 |                                 "columns": pk_info["constrained_columns"],
2430 |                             }
2431 |                         # else pk_constraint remains {}
2432 |                     except Exception as pk_err:
2433 |                         logger.warning(f"Could not get PK constraint for {table_name}: {pk_err}")
2434 |                     try:
2435 |                         table_comment_raw = insp.get_table_comment(table_name, schema=target_schema)
2436 |                         # Ternary okay
2437 |                         table_comment_text = (
2438 |                             table_comment_raw.get("text") if table_comment_raw else None
2439 |                         )
2440 |                     except Exception as cmt_err:
2441 |                         logger.warning(f"Could not get table comment for {table_name}: {cmt_err}")
2442 | 
2443 |                     # Build return dict step-by-step
2444 |                     meta_result = {}
2445 |                     meta_result["columns"] = cols
2446 |                     meta_result["indexes"] = idx
2447 |                     meta_result["foreign_keys"] = fks
2448 |                     meta_result["pk_constraint"] = pk_constraint
2449 |                     meta_result["table_comment"] = table_comment_text
2450 |                     meta_result["schema_name"] = target_schema  # Add schema name for reference
2451 |                     return meta_result
2452 | 
2453 |                 # Call okay
2454 |                 def sync_func_meta(sync_conn_arg):
2455 |                     return _get_basic_table_meta(sync_conn_arg)
2456 | 
2457 |                 meta = await conn.run_sync(sync_func_meta)  # Pass sync connection
2458 | 
2459 |                 # Build result dict step-by-step
2460 |                 result = {}
2461 |                 result["action"] = "table"
2462 |                 result["table_name"] = table_name
2463 |                 # Use schema name returned from meta function
2464 |                 result["schema_name"] = meta.get("schema_name")
2465 |                 result["comment"] = meta.get("table_comment")
2466 |                 # List comprehension okay
2467 |                 result["columns"] = [
2468 |                     {
2469 |                         "name": c["name"],
2470 |                         "type": str(c["type"]),
2471 |                         "nullable": c["nullable"],
2472 |                         "primary_key": bool(c.get("primary_key")),
2473 |                         "default": c.get("default"),
2474 |                         "comment": c.get("comment"),
2475 |                     }
2476 |                     for c in meta["columns"]
2477 |                 ]
2478 |                 result["primary_key"] = meta.get("pk_constraint")
2479 |                 result["indexes"] = meta.get("indexes", [])
2480 |                 result["foreign_keys"] = meta.get("foreign_keys", [])
2481 |                 result["success"] = True
2482 | 
2483 |                 # Quote identifiers
2484 |                 id_prep = eng.dialect.identifier_preparer
2485 |                 quoted_table_name = id_prep.quote(table_name)
2486 |                 quoted_schema_name = id_prep.quote(schema_name) if schema_name else None
2487 |                 # Ternary okay
2488 |                 full_table_name = (
2489 |                     f"{quoted_schema_name}.{quoted_table_name}"
2490 |                     if quoted_schema_name
2491 |                     else quoted_table_name
2492 |                 )
2493 | 
2494 |                 # Row count
2495 |                 try:
2496 |                     # Call okay
2497 |                     _, count_rows, _ = await _sql_exec(
2498 |                         eng,
2499 |                         f"SELECT COUNT(*) AS row_count FROM {full_table_name}",
2500 |                         None,
2501 |                         limit=1,
2502 |                         tool_name=tool_name,
2503 |                         action_name="table_count",
2504 |                         timeout=30,
2505 |                     )
2506 |                     # Ternary okay
2507 |                     result["row_count"] = count_rows[0]["row_count"] if count_rows else 0
2508 |                 except Exception as count_err:
2509 |                     logger.warning(
2510 |                         f"Could not get row count for table {full_table_name}: {count_err}"
2511 |                     )
2512 |                     result["row_count"] = "Error"
2513 | 
2514 |                 # Sample data
2515 |                 if include_sample and sample_size > 0:
2516 |                     try:
2517 |                         # Call okay
2518 |                         sample_cols, sample_rows, _ = await _sql_exec(
2519 |                             eng,
2520 |                             f"SELECT * FROM {full_table_name} LIMIT :n",
2521 |                             {"n": sample_size},
2522 |                             limit=sample_size,
2523 |                             tool_name=tool_name,
2524 |                             action_name="table_sample",
2525 |                             timeout=30,
2526 |                         )
2527 |                         # Assign sample data dict okay
2528 |                         result["sample_data"] = {"columns": sample_cols, "rows": sample_rows}
2529 |                     except Exception as sample_err:
2530 |                         logger.warning(
2531 |                             f"Could not get sample data for table {full_table_name}: {sample_err}"
2532 |                         )
2533 |                         # Assign error dict okay
2534 |                         result["sample_data"] = {
2535 |                             "error": f"Failed to retrieve sample data: {sample_err}"
2536 |                         }
2537 | 
2538 |                 # Statistics
2539 |                 if include_stats:
2540 |                     stats = {}
2541 |                     logger.debug(f"Calculating basic statistics for columns in {full_table_name}")
2542 |                     columns_to_stat = result.get("columns", [])
2543 |                     for c in columns_to_stat:
2544 |                         col_name = c["name"]
2545 |                         quoted_col = id_prep.quote(col_name)
2546 |                         col_stat_data = {}
2547 |                         try:
2548 |                             # Null count
2549 |                             # Call okay
2550 |                             _, null_rows, _ = await _sql_exec(
2551 |                                 eng,
2552 |                                 f"SELECT COUNT(*) AS null_count FROM {full_table_name} WHERE {quoted_col} IS NULL",
2553 |                                 None,
2554 |                                 limit=1,
2555 |                                 tool_name=tool_name,
2556 |                                 action_name="col_stat_null",
2557 |                                 timeout=20,
2558 |                             )
2559 |                             # Ternary okay
2560 |                             null_count = null_rows[0]["null_count"] if null_rows else "Error"
2561 | 
2562 |                             # Distinct count
2563 |                             # Call okay
2564 |                             _, distinct_rows, _ = await _sql_exec(
2565 |                                 eng,
2566 |                                 f"SELECT COUNT(DISTINCT {quoted_col}) AS distinct_count FROM {full_table_name}",
2567 |                                 None,
2568 |                                 limit=1,
2569 |                                 tool_name=tool_name,
2570 |                                 action_name="col_stat_distinct",
2571 |                                 timeout=45,
2572 |                             )
2573 |                             # Ternary okay
2574 |                             distinct_count = (
2575 |                                 distinct_rows[0]["distinct_count"] if distinct_rows else "Error"
2576 |                             )
2577 | 
2578 |                             # Assign stats dict okay
2579 |                             col_stat_data = {
2580 |                                 "null_count": null_count,
2581 |                                 "distinct_count": distinct_count,
2582 |                             }
2583 |                         except Exception as stat_err:
2584 |                             log_msg = f"Could not calculate statistics for column {col_name} in {full_table_name}: {stat_err}"
2585 |                             logger.warning(log_msg)
2586 |                             # Assign error dict okay
2587 |                             col_stat_data = {"error": f"Failed: {stat_err}"}
2588 |                         stats[col_name] = col_stat_data
2589 |                     result["statistics"] = stats
2590 | 
2591 |             # --- Action: column ---
2592 |             elif action == "column":
2593 |                 if not table_name:
2594 |                     raise ToolInputError(
2595 |                         "`table_name` required for 'column'", param_name="table_name"
2596 |                     )
2597 |                 if not column_name:
2598 |                     raise ToolInputError(
2599 |                         "`column_name` required for 'column'", param_name="column_name"
2600 |                     )
2601 | 
2602 |                 generate_histogram = options.get("histogram", False)
2603 |                 num_buckets_raw = options.get("num_buckets", 10)
2604 |                 num_buckets = int(num_buckets_raw)
2605 |                 num_buckets = max(1, num_buckets)  # Ensure at least one bucket
2606 | 
2607 |                 # Quote identifiers
2608 |                 id_prep = eng.dialect.identifier_preparer
2609 |                 quoted_table = id_prep.quote(table_name)
2610 |                 quoted_column = id_prep.quote(column_name)
2611 |                 quoted_schema = id_prep.quote(schema_name) if schema_name else None
2612 |                 # Ternary okay
2613 |                 full_table_name = (
2614 |                     f"{quoted_schema}.{quoted_table}" if quoted_schema else quoted_table
2615 |                 )
2616 |                 logger.info(f"Analyzing column {full_table_name}.{quoted_column}")
2617 | 
2618 |                 stats_data: Dict[str, Any] = {}
2619 |                 try:
2620 |                     # Total Rows
2621 |                     # Call okay
2622 |                     _, total_rows_res, _ = await _sql_exec(
2623 |                         eng,
2624 |                         f"SELECT COUNT(*) as cnt FROM {full_table_name}",
2625 |                         None,
2626 |                         limit=1,
2627 |                         tool_name=tool_name,
2628 |                         action_name="col_stat_total",
2629 |                         timeout=30,
2630 |                     )
2631 |                     # Ternary okay
2632 |                     total_rows_count = total_rows_res[0]["cnt"] if total_rows_res else 0
2633 |                     stats_data["total_rows"] = total_rows_count
2634 | 
2635 |                     # Null Count
2636 |                     # Call okay
2637 |                     _, null_rows_res, _ = await _sql_exec(
2638 |                         eng,
2639 |                         f"SELECT COUNT(*) as cnt FROM {full_table_name} WHERE {quoted_column} IS NULL",
2640 |                         None,
2641 |                         limit=1,
2642 |                         tool_name=tool_name,
2643 |                         action_name="col_stat_null",
2644 |                         timeout=30,
2645 |                     )
2646 |                     # Ternary okay
2647 |                     null_count = null_rows_res[0]["cnt"] if null_rows_res else 0
2648 |                     stats_data["null_count"] = null_count
2649 |                     # Ternary okay
2650 |                     null_perc = (
2651 |                         round((null_count / total_rows_count) * 100, 2) if total_rows_count else 0
2652 |                     )
2653 |                     stats_data["null_percentage"] = null_perc
2654 | 
2655 |                     # Distinct Count
2656 |                     # Call okay
2657 |                     _, distinct_rows_res, _ = await _sql_exec(
2658 |                         eng,
2659 |                         f"SELECT COUNT(DISTINCT {quoted_column}) as cnt FROM {full_table_name}",
2660 |                         None,
2661 |                         limit=1,
2662 |                         tool_name=tool_name,
2663 |                         action_name="col_stat_distinct",
2664 |                         timeout=60,
2665 |                     )
2666 |                     # Ternary okay
2667 |                     distinct_count = distinct_rows_res[0]["cnt"] if distinct_rows_res else 0
2668 |                     stats_data["distinct_count"] = distinct_count
2669 |                     # Ternary okay
2670 |                     distinct_perc = (
2671 |                         round((distinct_count / total_rows_count) * 100, 2)
2672 |                         if total_rows_count
2673 |                         else 0
2674 |                     )
2675 |                     stats_data["distinct_percentage"] = distinct_perc
2676 |                 except Exception as stat_err:
2677 |                     log_msg = f"Failed to get basic statistics for column {full_table_name}.{quoted_column}: {stat_err}"
2678 |                     logger.error(log_msg, exc_info=True)
2679 |                     stats_data["error"] = f"Failed to retrieve some statistics: {stat_err}"
2680 | 
2681 |                 # Build result dict step-by-step
2682 |                 result = {}
2683 |                 result["action"] = "column"
2684 |                 result["table_name"] = table_name
2685 |                 result["column_name"] = column_name
2686 |                 result["schema_name"] = schema_name
2687 |                 result["statistics"] = stats_data
2688 |                 result["success"] = True
2689 | 
2690 |                 if generate_histogram:
2691 |                     logger.debug(f"Generating histogram for {full_table_name}.{quoted_column}")
2692 |                     histogram_data: Optional[Dict[str, Any]] = None
2693 |                     try:
2694 |                         hist_query = f"SELECT {quoted_column} FROM {full_table_name} WHERE {quoted_column} IS NOT NULL"
2695 |                         # Call okay
2696 |                         _, value_rows, _ = await _sql_exec(
2697 |                             eng,
2698 |                             hist_query,
2699 |                             None,
2700 |                             limit=None,  # Fetch all non-null values
2701 |                             tool_name=tool_name,
2702 |                             action_name="col_hist_fetch",
2703 |                             timeout=90,
2704 |                         )
2705 |                         # List comprehension okay
2706 |                         values = [r[column_name] for r in value_rows]
2707 | 
2708 |                         if not values:
2709 |                             histogram_data = {"type": "empty", "buckets": []}
2710 |                         else:
2711 |                             first_val = values[0]
2712 |                             # Check type okay
2713 |                             is_numeric = isinstance(first_val, (int, float))
2714 | 
2715 |                             if is_numeric:
2716 |                                 try:
2717 |                                     min_val = min(values)
2718 |                                     max_val = max(values)
2719 |                                     buckets = []
2720 |                                     if min_val == max_val:
2721 |                                         # Single bucket dict okay
2722 |                                         bucket = {"range": f"{min_val}", "count": len(values)}
2723 |                                         buckets.append(bucket)
2724 |                                     else:
2725 |                                         # Calculate bin width okay
2726 |                                         val_range = max_val - min_val
2727 |                                         bin_width = val_range / num_buckets
2728 |                                         # List comprehension okay
2729 |                                         bucket_ranges_raw = [
2730 |                                             (min_val + i * bin_width, min_val + (i + 1) * bin_width)
2731 |                                             for i in range(num_buckets)
2732 |                                         ]
2733 |                                         # Adjust last bucket range okay
2734 |                                         last_bucket_idx = num_buckets - 1
2735 |                                         last_bucket_start = bucket_ranges_raw[last_bucket_idx][0]
2736 |                                         bucket_ranges_raw[last_bucket_idx] = (
2737 |                                             last_bucket_start,
2738 |                                             max_val,
2739 |                                         )
2740 |                                         bucket_ranges = bucket_ranges_raw
2741 | 
2742 |                                         # List comprehension for bucket init okay
2743 |                                         buckets = [
2744 |                                             {"range": f"{r[0]:.4g} - {r[1]:.4g}", "count": 0}
2745 |                                             for r in bucket_ranges
2746 |                                         ]
2747 |                                         for v in values:
2748 |                                             # Ternary okay
2749 |                                             idx_float = (
2750 |                                                 (v - min_val) / bin_width if bin_width > 0 else 0
2751 |                                             )
2752 |                                             idx_int = int(idx_float)
2753 |                                             # Ensure index is within bounds
2754 |                                             idx = min(idx_int, num_buckets - 1)
2755 |                                             # Handle max value potentially falling into last bucket due to precision
2756 |                                             if v == max_val:
2757 |                                                 idx = num_buckets - 1
2758 |                                             buckets[idx]["count"] += 1
2759 | 
2760 |                                     # Assign numeric histogram dict okay
2761 |                                     histogram_data = {
2762 |                                         "type": "numeric",
2763 |                                         "min": min_val,
2764 |                                         "max": max_val,
2765 |                                         "buckets": buckets,
2766 |                                     }
2767 |                                 except Exception as num_hist_err:
2768 |                                     log_msg = f"Error generating numeric histogram: {num_hist_err}"
2769 |                                     logger.error(log_msg, exc_info=True)
2770 |                                     # Assign error dict okay
2771 |                                     histogram_data = {
2772 |                                         "error": f"Failed to generate numeric histogram: {num_hist_err}"
2773 |                                     }
2774 |                             else:  # Categorical / Frequency
2775 |                                 try:
2776 |                                     # Import okay
2777 |                                     from collections import Counter
2778 | 
2779 |                                     # Call okay
2780 |                                     str_values = map(str, values)
2781 |                                     value_counts = Counter(str_values)
2782 |                                     # Call okay
2783 |                                     top_buckets_raw = value_counts.most_common(num_buckets)
2784 |                                     # List comprehension okay
2785 |                                     buckets_data = [
2786 |                                         {"value": str(k)[:100], "count": v}  # Limit value length
2787 |                                         for k, v in top_buckets_raw
2788 |                                     ]
2789 |                                     # Sum okay
2790 |                                     top_n_count = sum(b["count"] for b in buckets_data)
2791 |                                     other_count = len(values) - top_n_count
2792 | 
2793 |                                     # Assign frequency histogram dict okay
2794 |                                     histogram_data = {
2795 |                                         "type": "frequency",
2796 |                                         "top_n": num_buckets,
2797 |                                         "buckets": buckets_data,
2798 |                                     }
2799 |                                     if other_count > 0:
2800 |                                         histogram_data["other_values_count"] = other_count
2801 |                                 except Exception as freq_hist_err:
2802 |                                     log_msg = (
2803 |                                         f"Error generating frequency histogram: {freq_hist_err}"
2804 |                                     )
2805 |                                     logger.error(log_msg, exc_info=True)
2806 |                                     # Assign error dict okay
2807 |                                     histogram_data = {
2808 |                                         "error": f"Failed to generate frequency histogram: {freq_hist_err}"
2809 |                                     }
2810 |                     except Exception as hist_err:
2811 |                         log_msg = f"Failed to generate histogram for column {full_table_name}.{quoted_column}: {hist_err}"
2812 |                         logger.error(log_msg, exc_info=True)
2813 |                         # Assign error dict okay
2814 |                         histogram_data = {"error": f"Histogram generation failed: {hist_err}"}
2815 |                     result["histogram"] = histogram_data
2816 | 
2817 |             # --- Action: relationships ---
2818 |             elif action == "relationships":
2819 |                 if not table_name:
2820 |                     raise ToolInputError(
2821 |                         "`table_name` required for 'relationships'", param_name="table_name"
2822 |                     )
2823 |                 depth_raw = options.get("depth", 1)
2824 |                 depth_int = int(depth_raw)
2825 |                 # Clamp depth
2826 |                 depth = max(1, min(depth_int, 5))
2827 | 
2828 |                 log_msg = f"Finding relationships for table '{table_name}' (depth: {depth}, schema: {schema_name})"
2829 |                 logger.info(log_msg)
2830 |                 # Call explore_database for schema info - this recursive call is okay
2831 |                 schema_info = await explore_database(
2832 |                     connection_id=connection_id,
2833 |                     action="schema",
2834 |                     schema_name=schema_name,
2835 |                     include_indexes=False,  # Don't need indexes for relationships
2836 |                     include_foreign_keys=True,  # Need FKs
2837 |                     detailed=False,  # Don't need detailed column info
2838 |                 )
2839 |                 # Check success okay
2840 |                 schema_success = schema_info.get("success", False)
2841 |                 if not schema_success:
2842 |                     raise ToolError(
2843 |                         "Failed to retrieve schema information needed to find relationships."
2844 |                     )
2845 | 
2846 |                 # Dict comprehension okay
2847 |                 tables_list = schema_info.get("tables", [])
2848 |                 tables_by_name: Dict[str, Dict] = {t["name"]: t for t in tables_list}
2849 | 
2850 |                 if table_name not in tables_by_name:
2851 |                     msg = f"Starting table '{table_name}' not found in schema '{schema_name}'."
2852 |                     raise ToolInputError(msg, param_name="table_name")
2853 | 
2854 |                 visited_nodes = set()  # Track visited nodes to prevent cycles
2855 | 
2856 |                 # Define the recursive helper function *inside* this action block
2857 |                 # so it has access to tables_by_name and visited_nodes
2858 |                 def _build_relationship_graph_standalone(
2859 |                     current_table: str, current_depth: int
2860 |                 ) -> Dict[str, Any]:
2861 |                     # Build node_id string okay
2862 |                     current_schema = schema_name or "default"
2863 |                     node_id = f"{current_schema}.{current_table}"
2864 | 
2865 |                     is_max_depth = current_depth >= depth
2866 |                     is_visited = node_id in visited_nodes
2867 |                     if is_max_depth or is_visited:
2868 |                         # Return dict okay
2869 |                         return {
2870 |                             "table": current_table,
2871 |                             "schema": schema_name,  # Use original schema_name context
2872 |                             "max_depth_reached": is_max_depth,
2873 |                             "cyclic_reference": is_visited,
2874 |                         }
2875 | 
2876 |                     visited_nodes.add(node_id)
2877 |                     node_info = tables_by_name.get(current_table)
2878 | 
2879 |                     if not node_info:
2880 |                         visited_nodes.remove(node_id)  # Backtrack
2881 |                         # Return dict okay
2882 |                         return {
2883 |                             "table": current_table,
2884 |                             "schema": schema_name,
2885 |                             "error": "Table info not found",
2886 |                         }
2887 | 
2888 |                     # Build graph_node dict step-by-step
2889 |                     graph_node: Dict[str, Any] = {}
2890 |                     graph_node["table"] = current_table
2891 |                     graph_node["schema"] = schema_name
2892 |                     graph_node["children"] = []
2893 |                     graph_node["parents"] = []
2894 | 
2895 |                     # Find Parents (current table's FKs point to parents)
2896 |                     foreign_keys_list = node_info.get("foreign_keys", [])
2897 |                     for fk in foreign_keys_list:
2898 |                         ref_table = fk["referred_table"]
2899 |                         ref_schema = fk.get(
2900 |                             "referred_schema", schema_name
2901 |                         )  # Assume same schema if not specified
2902 | 
2903 |                         if ref_table in tables_by_name:
2904 |                             # Recursive call okay
2905 |                             parent_node = _build_relationship_graph_standalone(
2906 |                                 ref_table, current_depth + 1
2907 |                             )
2908 |                         else:
2909 |                             # Return dict okay for outside scope
2910 |                             parent_node = {
2911 |                                 "table": ref_table,
2912 |                                 "schema": ref_schema,
2913 |                                 "outside_scope": True,
2914 |                             }
2915 | 
2916 |                         # Build relationship string okay
2917 |                         constrained_cols_str = ",".join(fk["constrained_columns"])
2918 |                         referred_cols_str = ",".join(fk["referred_columns"])
2919 |                         rel_str = f"{current_table}.({constrained_cols_str}) -> {ref_table}.({referred_cols_str})"
2920 |                         # Append parent relationship dict okay
2921 |                         graph_node["parents"].append(
2922 |                             {"relationship": rel_str, "target": parent_node}
2923 |                         )
2924 | 
2925 |                     # Find Children (other tables' FKs point to current table)
2926 |                     for other_table_name, other_table_info in tables_by_name.items():
2927 |                         if other_table_name == current_table:
2928 |                             continue  # Skip self-reference check here
2929 | 
2930 |                         other_fks = other_table_info.get("foreign_keys", [])
2931 |                         for fk in other_fks:
2932 |                             points_to_current = fk["referred_table"] == current_table
2933 |                             # Check schema match (use original schema_name context)
2934 |                             referred_schema_matches = (
2935 |                                 fk.get("referred_schema", schema_name) == schema_name
2936 |                             )
2937 |                             if points_to_current and referred_schema_matches:
2938 |                                 # Recursive call okay
2939 |                                 child_node = _build_relationship_graph_standalone(
2940 |                                     other_table_name, current_depth + 1
2941 |                                 )
2942 |                                 # Build relationship string okay
2943 |                                 constrained_cols_str = ",".join(fk["constrained_columns"])
2944 |                                 referred_cols_str = ",".join(fk["referred_columns"])
2945 |                                 rel_str = f"{other_table_name}.({constrained_cols_str}) -> {current_table}.({referred_cols_str})"
2946 |                                 # Append child relationship dict okay
2947 |                                 graph_node["children"].append(
2948 |                                     {"relationship": rel_str, "source": child_node}
2949 |                                 )
2950 | 
2951 |                     visited_nodes.remove(node_id)  # Backtrack visited state
2952 |                     return graph_node
2953 | 
2954 |                 # Initial call to the recursive function
2955 |                 relationship_graph = _build_relationship_graph_standalone(table_name, 0)
2956 |                 # Build result dict step-by-step
2957 |                 result = {}
2958 |                 result["action"] = "relationships"
2959 |                 result["source_table"] = table_name
2960 |                 result["schema_name"] = schema_name
2961 |                 result["max_depth"] = depth
2962 |                 result["relationship_graph"] = relationship_graph
2963 |                 result["success"] = True
2964 | 
2965 |             # --- Action: documentation ---
2966 |             elif action == "documentation":
2967 |                 output_format_raw = options.get("output_format", "markdown")
2968 |                 output_format = output_format_raw.lower()
2969 |                 valid_formats = ["markdown", "json"]
2970 |                 if output_format not in valid_formats:
2971 |                     msg = "Invalid 'output_format'. Use 'markdown' or 'json'."
2972 |                     raise ToolInputError(msg, param_name="output_format")
2973 | 
2974 |                 doc_include_indexes = options.get("include_indexes", True)
2975 |                 doc_include_fks = options.get("include_foreign_keys", True)
2976 |                 log_msg = f"Generating database documentation (Format: {output_format}, Schema: {schema_name})"
2977 |                 logger.info(log_msg)
2978 | 
2979 |                 # Call explore_database for schema info (recursive call okay)
2980 |                 schema_data = await explore_database(
2981 |                     connection_id=connection_id,
2982 |                     action="schema",
2983 |                     schema_name=schema_name,
2984 |                     include_indexes=doc_include_indexes,
2985 |                     include_foreign_keys=doc_include_fks,
2986 |                     detailed=True,  # Need details for documentation
2987 |                 )
2988 |                 schema_success = schema_data.get("success", False)
2989 |                 if not schema_success:
2990 |                     raise ToolError(
2991 |                         "Failed to retrieve schema information needed for documentation."
2992 |                     )
2993 | 
2994 |                 if output_format == "json":
2995 |                     # Build result dict step-by-step
2996 |                     result = {}
2997 |                     result["action"] = "documentation"
2998 |                     result["format"] = "json"
2999 |                     result["documentation"] = schema_data  # Embed the schema result directly
3000 |                     result["success"] = True
3001 |                 else:  # Markdown
3002 |                     # --- Markdown Generation ---
3003 |                     lines = []
3004 |                     lines.append(f"# Database Documentation ({db_dialect})")
3005 |                     db_schema_name = schema_data.get("inspected_schema", "Default Schema")
3006 |                     lines.append(f"Schema: **{db_schema_name}**")
3007 |                     now_str = _sql_now()
3008 |                     lines.append(f"Generated: {now_str}")
3009 |                     schema_hash_val = schema_data.get("schema_hash")
3010 |                     if schema_hash_val:
3011 |                         hash_preview = schema_hash_val[:12]
3012 |                         lines.append(f"Schema Version (Hash): `{hash_preview}`")
3013 |                     lines.append("")  # Blank line
3014 | 
3015 |                     lines.append("## Tables")
3016 |                     lines.append("")
3017 |                     # Sort okay
3018 |                     tables_list_raw = schema_data.get("tables", [])
3019 |                     tables = sorted(tables_list_raw, key=lambda x: x["name"])
3020 | 
3021 |                     if not tables:
3022 |                         lines.append("*No tables found in this schema.*")
3023 | 
3024 |                     for t in tables:
3025 |                         table_name_doc = t["name"]
3026 |                         if t.get("error"):
3027 |                             lines.append(f"### {table_name_doc} (Error)")
3028 |                             lines.append(f"```\n{t['error']}\n```")
3029 |                             lines.append("")
3030 |                             continue  # Skip rest for this table
3031 | 
3032 |                         lines.append(f"### {table_name_doc}")
3033 |                         lines.append("")
3034 |                         table_comment = t.get("comment")
3035 |                         if table_comment:
3036 |                             lines.append(f"> {table_comment}")
3037 |                             lines.append("")
3038 | 
3039 |                         # Column Header
3040 |                         lines.append("| Column | Type | Nullable | PK | Default | Comment |")
3041 |                         lines.append("|--------|------|----------|----|---------|---------|")
3042 |                         columns_list = t.get("columns", [])
3043 |                         for c in columns_list:
3044 |                             # Ternary okay
3045 |                             pk_flag = "✅" if c["primary_key"] else ""
3046 |                             null_flag = "✅" if c["nullable"] else ""
3047 |                             default_raw = c.get("default")
3048 |                             # Ternary okay
3049 |                             default_val_str = f"`{default_raw}`" if default_raw is not None else ""
3050 |                             comment_val = c.get("comment") or ""
3051 |                             col_name_str = f"`{c['name']}`"
3052 |                             col_type_str = f"`{c['type']}`"
3053 |                             # Build line okay
3054 |                             line = f"| {col_name_str} | {col_type_str} | {null_flag} | {pk_flag} | {default_val_str} | {comment_val} |"
3055 |                             lines.append(line)
3056 |                         lines.append("")  # Blank line after table
3057 | 
3058 |                         # Primary Key section
3059 |                         pk_info = t.get("primary_key")
3060 |                         pk_cols = pk_info.get("columns") if pk_info else None
3061 |                         if pk_info and pk_cols:
3062 |                             pk_name = pk_info.get("name", "PK")
3063 |                             # List comprehension okay
3064 |                             pk_cols_formatted = [f"`{c}`" for c in pk_cols]
3065 |                             pk_cols_str = ", ".join(pk_cols_formatted)
3066 |                             lines.append(f"**Primary Key:** `{pk_name}` ({pk_cols_str})")
3067 |                             lines.append("")
3068 | 
3069 |                         # Indexes section
3070 |                         indexes_list = t.get("indexes")
3071 |                         if doc_include_indexes and indexes_list:
3072 |                             lines.append("**Indexes:**")
3073 |                             lines.append("")
3074 |                             lines.append("| Name | Columns | Unique |")
3075 |                             lines.append("|------|---------|--------|")
3076 |                             for idx in indexes_list:
3077 |                                 # Ternary okay
3078 |                                 unique_flag = "✅" if idx["unique"] else ""
3079 |                                 # List comprehension okay
3080 |                                 idx_cols_formatted = [f"`{c}`" for c in idx["columns"]]
3081 |                                 cols_str = ", ".join(idx_cols_formatted)
3082 |                                 idx_name_str = f"`{idx['name']}`"
3083 |                                 # Build line okay
3084 |                                 line = f"| {idx_name_str} | {cols_str} | {unique_flag} |"
3085 |                                 lines.append(line)
3086 |                             lines.append("")
3087 | 
3088 |                         # Foreign Keys section
3089 |                         fks_list = t.get("foreign_keys")
3090 |                         if doc_include_fks and fks_list:
3091 |                             lines.append("**Foreign Keys:**")
3092 |                             lines.append("")
3093 |                             lines.append("| Name | Column(s) | References |")
3094 |                             lines.append("|------|-----------|------------|")
3095 |                             for fk in fks_list:
3096 |                                 # List comprehension okay
3097 |                                 constrained_cols_fmt = [f"`{c}`" for c in fk["constrained_columns"]]
3098 |                                 constrained_cols_str = ", ".join(constrained_cols_fmt)
3099 | 
3100 |                                 ref_schema = fk.get("referred_schema", db_schema_name)
3101 |                                 ref_table_name = fk["referred_table"]
3102 |                                 ref_table_str = f"`{ref_schema}`.`{ref_table_name}`"
3103 | 
3104 |                                 # List comprehension okay
3105 |                                 ref_cols_fmt = [f"`{c}`" for c in fk["referred_columns"]]
3106 |                                 ref_cols_str = ", ".join(ref_cols_fmt)
3107 | 
3108 |                                 fk_name = fk.get("name", "FK")
3109 |                                 fk_name_str = f"`{fk_name}`"
3110 |                                 ref_full_str = f"{ref_table_str} ({ref_cols_str})"
3111 |                                 # Build line okay
3112 |                                 line = (
3113 |                                     f"| {fk_name_str} | {constrained_cols_str} | {ref_full_str} |"
3114 |                                 )
3115 |                                 lines.append(line)
3116 |                             lines.append("")
3117 | 
3118 |                     # Views Section
3119 |                     views_list_raw = schema_data.get("views", [])
3120 |                     views = sorted(views_list_raw, key=lambda x: x["name"])
3121 |                     if views:
3122 |                         lines.append("## Views")
3123 |                         lines.append("")
3124 |                         for v in views:
3125 |                             view_name_doc = v["name"]
3126 |                             if v.get("error"):
3127 |                                 lines.append(f"### {view_name_doc} (Error)")
3128 |                                 lines.append(f"```\n{v['error']}\n```")
3129 |                                 lines.append("")
3130 |                                 continue  # Skip rest for this view
3131 | 
3132 |                             lines.append(f"### {view_name_doc}")
3133 |                             lines.append("")
3134 |                             view_columns = v.get("columns")
3135 |                             if view_columns:
3136 |                                 # List comprehension okay
3137 |                                 view_cols_fmt = [
3138 |                                     f"`{vc['name']}` ({vc['type']})" for vc in view_columns
3139 |                                 ]
3140 |                                 view_cols_str = ", ".join(view_cols_fmt)
3141 |                                 lines.append(f"**Columns:** {view_cols_str}")
3142 |                                 lines.append("")
3143 | 
3144 |                             view_def = v.get("definition")
3145 |                             # Check for valid definition string
3146 |                             is_valid_def = (
3147 |                                 view_def and view_def != "N/A (Not Implemented by Dialect)"
3148 |                             )
3149 |                             if is_valid_def:
3150 |                                 lines.append("**Definition:**")
3151 |                                 lines.append("```sql")
3152 |                                 lines.append(view_def)
3153 |                                 lines.append("```")
3154 |                                 lines.append("")
3155 |                             else:
3156 |                                 lines.append(
3157 |                                     "**Definition:** *Not available or not implemented by dialect.*"
3158 |                                 )
3159 |                                 lines.append("")
3160 |                     # --- End Markdown Generation ---
3161 | 
3162 |                     # Join lines okay
3163 |                     markdown_output = "\n".join(lines)
3164 |                     # Build result dict step-by-step
3165 |                     result = {}
3166 |                     result["action"] = "documentation"
3167 |                     result["format"] = "markdown"
3168 |                     result["documentation"] = markdown_output
3169 |                     result["success"] = True
3170 | 
3171 |             else:
3172 |                 logger.error(f"Invalid action specified for explore_database: {action}")
3173 |                 details = {"action": action}
3174 |                 valid_actions = "schema, table, column, relationships, documentation"
3175 |                 msg = f"Unknown action: '{action}'. Valid actions: {valid_actions}"
3176 |                 raise ToolInputError(msg, param_name="action", details=details)
3177 | 
3178 |             # Audit success for all successful actions
3179 |             # Ternary okay
3180 |             audit_table = [table_name] if table_name else None
3181 |             # Call okay
3182 |             await _sql_audit(
3183 |                 tool_name=tool_name,
3184 |                 action=action,
3185 |                 connection_id=connection_id,
3186 |                 sql=None,
3187 |                 tables=audit_table,
3188 |                 row_count=None,
3189 |                 success=True,
3190 |                 error=None,
3191 |                 user_id=user_id,
3192 |                 session_id=session_id,
3193 |                 **audit_extras,
3194 |             )
3195 |             return result  # Return the constructed result dict
3196 | 
3197 |     except ToolInputError as tie:
3198 |         # Audit failure
3199 |         # Ternary okay
3200 |         audit_table = [table_name] if table_name else None
3201 |         action_fail = action + "_fail"
3202 |         # Call okay
3203 |         await _sql_audit(
3204 |             tool_name=tool_name,
3205 |             action=action_fail,
3206 |             connection_id=connection_id,
3207 |             sql=None,
3208 |             tables=audit_table,
3209 |             row_count=None,
3210 |             success=False,
3211 |             error=str(tie),
3212 |             user_id=user_id,
3213 |             session_id=session_id,
3214 |             **audit_extras,
3215 |         )
3216 |         raise tie
3217 |     except ToolError as te:
3218 |         # Audit failure
3219 |         # Ternary okay
3220 |         audit_table = [table_name] if table_name else None
3221 |         action_fail = action + "_fail"
3222 |         # Call okay
3223 |         await _sql_audit(
3224 |             tool_name=tool_name,
3225 |             action=action_fail,
3226 |             connection_id=connection_id,
3227 |             sql=None,
3228 |             tables=audit_table,
3229 |             row_count=None,
3230 |             success=False,
3231 |             error=str(te),
3232 |             user_id=user_id,
3233 |             session_id=session_id,
3234 |             **audit_extras,
3235 |         )
3236 |         raise te
3237 |     except Exception as e:
3238 |         log_msg = f"Unexpected error in explore_database (action: {action}): {e}"
3239 |         logger.error(log_msg, exc_info=True)
3240 |         # Audit failure
3241 |         # Ternary okay
3242 |         audit_table = [table_name] if table_name else None
3243 |         action_fail = action + "_fail"
3244 |         error_str = f"Unexpected error: {e}"
3245 |         # Call okay
3246 |         await _sql_audit(
3247 |             tool_name=tool_name,
3248 |             action=action_fail,
3249 |             connection_id=connection_id,
3250 |             sql=None,
3251 |             tables=audit_table,
3252 |             row_count=None,
3253 |             success=False,
3254 |             error=error_str,
3255 |             user_id=user_id,
3256 |             session_id=session_id,
3257 |             **audit_extras,
3258 |         )
3259 |         raise ToolError(
3260 |             f"An unexpected error occurred during database exploration: {e}", http_status_code=500
3261 |         ) from e
3262 | 
3263 | 
3264 | @with_tool_metrics
3265 | @with_error_handling
3266 | async def access_audit_log(
3267 |     action: str = "view",
3268 |     export_format: Optional[str] = None,
3269 |     limit: Optional[int] = 100,
3270 |     user_id: Optional[str] = None,
3271 |     connection_id: Optional[str] = None,
3272 |     ctx: Optional[Dict] = None,  # Added ctx
3273 | ) -> Dict[str, Any]:
3274 |     """
3275 |     Access and export the in-memory SQL audit log.
3276 | 
3277 |     Allows viewing recent log entries or exporting them to a file.
3278 |     Note: The audit log is currently stored only in memory and will be lost on server restart.
3279 | 
3280 |     Args:
3281 |         action: "view" (default) or "export".
3282 |         export_format: Required if action is "export". Supports "json", "excel", "csv".
3283 |         limit: For "view", the maximum number of most recent records to return (default: 100). Use None or -1 for all.
3284 |         user_id: Filter log entries by this user ID.
3285 |         connection_id: Filter log entries by this connection ID.
3286 |         ctx: Optional context from MCP server.
3287 | 
3288 |     Returns:
3289 |         Dict containing results:
3290 |         - For "view": {action: "view", records: List[Dict], filtered_record_count: int, total_records_in_log: int, filters_applied: Dict, success: True}
3291 |         - For "export": {action: "export", path: str, format: str, record_count: int, success: True} or {action: "export", message: str, record_count: 0, success: True} if no records.
3292 |     """
3293 |     tool_name = "access_audit_log"  # noqa: F841
3294 | 
3295 |     # Apply filters using global _AUDIT_LOG
3296 |     async with _audit_lock:  # Need lock to safely read/copy log
3297 |         # Call okay
3298 |         full_log_copy = list(_AUDIT_LOG)
3299 |     total_records_in_log = len(full_log_copy)
3300 | 
3301 |     # Start with the full copy
3302 |     filtered_log = full_log_copy
3303 | 
3304 |     # Apply filters sequentially
3305 |     if user_id:
3306 |         # List comprehension okay
3307 |         filtered_log = [r for r in filtered_log if r.get("user_id") == user_id]
3308 |     if connection_id:
3309 |         # List comprehension okay
3310 |         filtered_log = [r for r in filtered_log if r.get("connection_id") == connection_id]
3311 |     filtered_record_count = len(filtered_log)
3312 | 
3313 |     if action == "view":
3314 |         # Ternary okay
3315 |         needs_limit = limit is not None and limit >= 0
3316 |         records_to_return = filtered_log[-limit:] if needs_limit else filtered_log
3317 |         num_returned = len(records_to_return)
3318 |         log_msg = f"View audit log requested. Returning {num_returned}/{filtered_record_count} filtered records (Total in log: {total_records_in_log})."
3319 |         logger.info(log_msg)
3320 | 
3321 |         # Build filters applied dict okay
3322 |         filters_applied = {"user_id": user_id, "connection_id": connection_id}
3323 |         # Build result dict step-by-step
3324 |         result = {}
3325 |         result["action"] = "view"
3326 |         result["records"] = records_to_return
3327 |         result["filtered_record_count"] = filtered_record_count
3328 |         result["total_records_in_log"] = total_records_in_log
3329 |         result["filters_applied"] = filters_applied
3330 |         result["success"] = True
3331 |         return result
3332 | 
3333 |     elif action == "export":
3334 |         if not export_format:
3335 |             raise ToolInputError(
3336 |                 "`export_format` is required for 'export'", param_name="export_format"
3337 |             )
3338 |         export_format_lower = export_format.lower()
3339 |         log_msg = f"Export audit log requested. Format: {export_format_lower}. Records to export: {filtered_record_count}"
3340 |         logger.info(log_msg)
3341 | 
3342 |         if not filtered_log:
3343 |             logger.warning("Audit log is empty or filtered log is empty, nothing to export.")
3344 |             # Return dict okay
3345 |             return {
3346 |                 "action": "export",
3347 |                 "message": "No audit records found matching filters to export.",
3348 |                 "record_count": 0,
3349 |                 "success": True,
3350 |             }
3351 | 
3352 |         if export_format_lower == "json":
3353 |             path = ""  # Initialize path
3354 |             try:
3355 |                 # Call okay
3356 |                 fd, temp_path = tempfile.mkstemp(suffix=".json", prefix="mcp_audit_export_")
3357 |                 path = temp_path  # Assign path now we know mkstemp succeeded
3358 |                 os.close(fd)
3359 |                 # Use sync write for simplicity here
3360 |                 with open(path, "w", encoding="utf-8") as f:
3361 |                     # Call okay
3362 |                     json.dump(filtered_log, f, indent=2, default=str)
3363 |                 log_msg = (
3364 |                     f"Successfully exported {filtered_record_count} audit records to JSON: {path}"
3365 |                 )
3366 |                 logger.info(log_msg)
3367 |                 # Return dict okay
3368 |                 return {
3369 |                     "action": "export",
3370 |                     "path": path,
3371 |                     "format": "json",
3372 |                     "record_count": filtered_record_count,
3373 |                     "success": True,
3374 |                 }
3375 |             except Exception as e:
3376 |                 log_msg = f"Failed to export audit log to JSON: {e}"
3377 |                 logger.error(log_msg, exc_info=True)
3378 |                 # Clean up temp file if created
3379 |                 if path and Path(path).exists():
3380 |                     try:
3381 |                         Path(path).unlink()
3382 |                     except OSError:
3383 |                         logger.warning(f"Could not clean up failed JSON export file: {path}")
3384 |                 raise ToolError(
3385 |                     f"Failed to export audit log to JSON: {e}", http_status_code=500
3386 |                 ) from e
3387 | 
3388 |         elif export_format_lower in ["excel", "csv"]:
3389 |             if pd is None:
3390 |                 details = {"library": "pandas"}
3391 |                 msg = f"Pandas library not installed, cannot export audit log to '{export_format_lower}'."
3392 |                 raise ToolError(msg, http_status_code=501, details=details)
3393 |             path = ""  # Initialize path
3394 |             try:
3395 |                 # Call okay
3396 |                 df = pd.DataFrame(filtered_log)
3397 |                 # Ternary okay for suffix/writer/engine
3398 |                 is_excel = export_format_lower == "excel"
3399 |                 suffix = ".xlsx" if is_excel else ".csv"
3400 |                 writer_func = df.to_excel if is_excel else df.to_csv
3401 |                 engine = "xlsxwriter" if is_excel else None
3402 | 
3403 |                 # Call okay
3404 |                 fd, temp_path = tempfile.mkstemp(suffix=suffix, prefix="mcp_audit_export_")
3405 |                 path = temp_path  # Assign path
3406 |                 os.close(fd)
3407 | 
3408 |                 # Build export args dict okay
3409 |                 export_kwargs: Dict[str, Any] = {"index": False}
3410 |                 if engine:
3411 |                     export_kwargs["engine"] = engine
3412 | 
3413 |                 # Call writer function
3414 |                 writer_func(path, **export_kwargs)
3415 | 
3416 |                 log_msg = f"Successfully exported {filtered_record_count} audit records to {export_format_lower.upper()}: {path}"
3417 |                 logger.info(log_msg)
3418 |                 # Return dict okay
3419 |                 return {
3420 |                     "action": "export",
3421 |                     "path": path,
3422 |                     "format": export_format_lower,
3423 |                     "record_count": filtered_record_count,
3424 |                     "success": True,
3425 |                 }
3426 |             except Exception as e:
3427 |                 log_msg = f"Failed to export audit log to {export_format_lower}: {e}"
3428 |                 logger.error(log_msg, exc_info=True)
3429 |                 # Clean up temp file if created
3430 |                 if path and Path(path).exists():
3431 |                     try:
3432 |                         Path(path).unlink()
3433 |                     except OSError:
3434 |                         logger.warning(f"Could not clean up temporary export file: {path}")
3435 |                 msg = f"Failed to export audit log to {export_format_lower}: {e}"
3436 |                 raise ToolError(msg, http_status_code=500) from e
3437 |         else:
3438 |             details = {"format": export_format}
3439 |             valid_formats = "'excel', 'csv', or 'json'"
3440 |             msg = f"Unsupported export format: '{export_format}'. Use {valid_formats}."
3441 |             raise ToolInputError(msg, param_name="export_format", details=details)
3442 |     else:
3443 |         details = {"action": action}
3444 |         msg = f"Unknown action: '{action}'. Use 'view' or 'export'."
3445 |         raise ToolInputError(msg, param_name="action", details=details)
3446 | 
```
Page 41/45FirstPrevNextLast