This is page 18 of 35. Use http://codebase.md/dicklesworthstone/llm_gateway_mcp_server?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .cursorignore
├── .env.example
├── .envrc
├── .gitignore
├── additional_features.md
├── check_api_keys.py
├── completion_support.py
├── comprehensive_test.py
├── docker-compose.yml
├── Dockerfile
├── empirically_measured_model_speeds.json
├── error_handling.py
├── example_structured_tool.py
├── examples
│ ├── __init__.py
│ ├── advanced_agent_flows_using_unified_memory_system_demo.py
│ ├── advanced_extraction_demo.py
│ ├── advanced_unified_memory_system_demo.py
│ ├── advanced_vector_search_demo.py
│ ├── analytics_reporting_demo.py
│ ├── audio_transcription_demo.py
│ ├── basic_completion_demo.py
│ ├── cache_demo.py
│ ├── claude_integration_demo.py
│ ├── compare_synthesize_demo.py
│ ├── cost_optimization.py
│ ├── data
│ │ ├── sample_event.txt
│ │ ├── Steve_Jobs_Introducing_The_iPhone_compressed.md
│ │ └── Steve_Jobs_Introducing_The_iPhone_compressed.mp3
│ ├── docstring_refiner_demo.py
│ ├── document_conversion_and_processing_demo.py
│ ├── entity_relation_graph_demo.py
│ ├── filesystem_operations_demo.py
│ ├── grok_integration_demo.py
│ ├── local_text_tools_demo.py
│ ├── marqo_fused_search_demo.py
│ ├── measure_model_speeds.py
│ ├── meta_api_demo.py
│ ├── multi_provider_demo.py
│ ├── ollama_integration_demo.py
│ ├── prompt_templates_demo.py
│ ├── python_sandbox_demo.py
│ ├── rag_example.py
│ ├── research_workflow_demo.py
│ ├── sample
│ │ ├── article.txt
│ │ ├── backprop_paper.pdf
│ │ ├── buffett.pdf
│ │ ├── contract_link.txt
│ │ ├── legal_contract.txt
│ │ ├── medical_case.txt
│ │ ├── northwind.db
│ │ ├── research_paper.txt
│ │ ├── sample_data.json
│ │ └── text_classification_samples
│ │ ├── email_classification.txt
│ │ ├── news_samples.txt
│ │ ├── product_reviews.txt
│ │ └── support_tickets.txt
│ ├── sample_docs
│ │ └── downloaded
│ │ └── attention_is_all_you_need.pdf
│ ├── sentiment_analysis_demo.py
│ ├── simple_completion_demo.py
│ ├── single_shot_synthesis_demo.py
│ ├── smart_browser_demo.py
│ ├── sql_database_demo.py
│ ├── sse_client_demo.py
│ ├── test_code_extraction.py
│ ├── test_content_detection.py
│ ├── test_ollama.py
│ ├── text_classification_demo.py
│ ├── text_redline_demo.py
│ ├── tool_composition_examples.py
│ ├── tournament_code_demo.py
│ ├── tournament_text_demo.py
│ ├── unified_memory_system_demo.py
│ ├── vector_search_demo.py
│ ├── web_automation_instruction_packs.py
│ └── workflow_delegation_demo.py
├── LICENSE
├── list_models.py
├── marqo_index_config.json.example
├── mcp_protocol_schema_2025-03-25_version.json
├── mcp_python_lib_docs.md
├── mcp_tool_context_estimator.py
├── model_preferences.py
├── pyproject.toml
├── quick_test.py
├── README.md
├── resource_annotations.py
├── run_all_demo_scripts_and_check_for_errors.py
├── storage
│ └── smart_browser_internal
│ ├── locator_cache.db
│ ├── readability.js
│ └── storage_state.enc
├── test_client.py
├── test_connection.py
├── TEST_README.md
├── test_sse_client.py
├── test_stdio_client.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── integration
│ │ ├── __init__.py
│ │ └── test_server.py
│ ├── manual
│ │ ├── test_extraction_advanced.py
│ │ └── test_extraction.py
│ └── unit
│ ├── __init__.py
│ ├── test_cache.py
│ ├── test_providers.py
│ └── test_tools.py
├── TODO.md
├── tool_annotations.py
├── tools_list.json
├── ultimate_mcp_banner.webp
├── ultimate_mcp_logo.webp
├── ultimate_mcp_server
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── commands.py
│ │ ├── helpers.py
│ │ └── typer_cli.py
│ ├── clients
│ │ ├── __init__.py
│ │ ├── completion_client.py
│ │ └── rag_client.py
│ ├── config
│ │ └── examples
│ │ └── filesystem_config.yaml
│ ├── config.py
│ ├── constants.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── evaluation
│ │ │ ├── base.py
│ │ │ └── evaluators.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── anthropic.py
│ │ │ ├── base.py
│ │ │ ├── deepseek.py
│ │ │ ├── gemini.py
│ │ │ ├── grok.py
│ │ │ ├── ollama.py
│ │ │ ├── openai.py
│ │ │ └── openrouter.py
│ │ ├── server.py
│ │ ├── state_store.py
│ │ ├── tournaments
│ │ │ ├── manager.py
│ │ │ ├── tasks.py
│ │ │ └── utils.py
│ │ └── ums_api
│ │ ├── __init__.py
│ │ ├── ums_database.py
│ │ ├── ums_endpoints.py
│ │ ├── ums_models.py
│ │ └── ums_services.py
│ ├── exceptions.py
│ ├── graceful_shutdown.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── analytics
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── reporting.py
│ │ ├── cache
│ │ │ ├── __init__.py
│ │ │ ├── cache_service.py
│ │ │ ├── persistence.py
│ │ │ ├── strategies.py
│ │ │ └── utils.py
│ │ ├── cache.py
│ │ ├── document.py
│ │ ├── knowledge_base
│ │ │ ├── __init__.py
│ │ │ ├── feedback.py
│ │ │ ├── manager.py
│ │ │ ├── rag_engine.py
│ │ │ ├── retriever.py
│ │ │ └── utils.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── repository.py
│ │ │ └── templates.py
│ │ ├── prompts.py
│ │ └── vector
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ └── vector_service.py
│ ├── tool_token_counter.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── audio_transcription.py
│ │ ├── base.py
│ │ ├── completion.py
│ │ ├── docstring_refiner.py
│ │ ├── document_conversion_and_processing.py
│ │ ├── enhanced-ums-lookbook.html
│ │ ├── entity_relation_graph.py
│ │ ├── excel_spreadsheet_automation.py
│ │ ├── extraction.py
│ │ ├── filesystem.py
│ │ ├── html_to_markdown.py
│ │ ├── local_text_tools.py
│ │ ├── marqo_fused_search.py
│ │ ├── meta_api_tool.py
│ │ ├── ocr_tools.py
│ │ ├── optimization.py
│ │ ├── provider.py
│ │ ├── pyodide_boot_template.html
│ │ ├── python_sandbox.py
│ │ ├── rag.py
│ │ ├── redline-compiled.css
│ │ ├── sentiment_analysis.py
│ │ ├── single_shot_synthesis.py
│ │ ├── smart_browser.py
│ │ ├── sql_databases.py
│ │ ├── text_classification.py
│ │ ├── text_redline_tools.py
│ │ ├── tournament.py
│ │ ├── ums_explorer.html
│ │ └── unified_memory_system.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── async_utils.py
│ │ ├── display.py
│ │ ├── logging
│ │ │ ├── __init__.py
│ │ │ ├── console.py
│ │ │ ├── emojis.py
│ │ │ ├── formatter.py
│ │ │ ├── logger.py
│ │ │ ├── panels.py
│ │ │ ├── progress.py
│ │ │ └── themes.py
│ │ ├── parse_yaml.py
│ │ ├── parsing.py
│ │ ├── security.py
│ │ └── text.py
│ └── working_memory_api.py
├── unified_memory_system_technical_analysis.md
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/ultimate_mcp_server/core/providers/ollama.py:
--------------------------------------------------------------------------------
```python
"""Ollama provider implementation for the Ultimate MCP Server.
This module implements the Ollama provider, enabling interaction with locally running
Ollama models through a standard interface. Ollama is an open-source framework for
running LLMs locally with minimal setup.
The implementation supports:
- Text completion (generate) and chat completations
- Streaming responses
- Model listing and information retrieval
- Embeddings generation
- Cost tracking (estimated since Ollama is free to use locally)
Ollama must be installed and running locally (by default on localhost:11434)
for this provider to work properly.
"""
import asyncio
import json
import re
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import aiohttp
import httpx
from pydantic import BaseModel
from ultimate_mcp_server.config import get_config
from ultimate_mcp_server.constants import COST_PER_MILLION_TOKENS, Provider
from ultimate_mcp_server.core.providers.base import (
BaseProvider,
ModelResponse,
)
from ultimate_mcp_server.exceptions import ProviderError
from ultimate_mcp_server.utils import get_logger
logger = get_logger("ultimate_mcp_server.providers.ollama")
# Define the Model class locally since it's not available in base.py
class Model(dict):
"""Model information returned by providers."""
def __init__(self, id: str, name: str, description: str, provider: str, **kwargs):
"""Initialize a model info dictionary.
Args:
id: Model identifier (e.g., "llama3.2")
name: Human-readable model name
description: Longer description of the model
provider: Provider name
**kwargs: Additional model metadata
"""
super().__init__(id=id, name=name, description=description, provider=provider, **kwargs)
# Define ProviderFeatures locally since it's not available in base.py
class ProviderFeatures:
"""Features supported by a provider."""
def __init__(
self,
supports_chat_completions: bool = False,
supports_streaming: bool = False,
supports_function_calling: bool = False,
supports_multiple_functions: bool = False,
supports_embeddings: bool = False,
supports_json_mode: bool = False,
max_retries: int = 3,
):
"""Initialize provider features.
Args:
supports_chat_completions: Whether the provider supports chat completions
supports_streaming: Whether the provider supports streaming responses
supports_function_calling: Whether the provider supports function calling
supports_multiple_functions: Whether the provider supports multiple functions
supports_embeddings: Whether the provider supports embeddings
supports_json_mode: Whether the provider supports JSON mode
max_retries: Maximum number of retries for failed requests
"""
self.supports_chat_completions = supports_chat_completions
self.supports_streaming = supports_streaming
self.supports_function_calling = supports_function_calling
self.supports_multiple_functions = supports_multiple_functions
self.supports_embeddings = supports_embeddings
self.supports_json_mode = supports_json_mode
self.max_retries = max_retries
# Define ProviderStatus locally since it's not available in base.py
class ProviderStatus:
"""Status information for a provider."""
def __init__(
self,
name: str,
enabled: bool = False,
available: bool = False,
api_key_configured: bool = False,
features: Optional[ProviderFeatures] = None,
default_model: Optional[str] = None,
):
"""Initialize provider status.
Args:
name: Provider name
enabled: Whether the provider is enabled
available: Whether the provider is available
api_key_configured: Whether an API key is configured
features: Provider features
default_model: Default model for the provider
"""
self.name = name
self.enabled = enabled
self.available = available
self.api_key_configured = api_key_configured
self.features = features
self.default_model = default_model
class OllamaConfig(BaseModel):
"""Configuration for the Ollama provider."""
# API endpoint (default is localhost:11434)
api_url: str = "http://127.0.0.1:11434"
# Default model to use if none specified
default_model: str = "llama3.2"
# Timeout settings
request_timeout: int = 300
# Whether this provider is enabled
enabled: bool = True
class OllamaProvider(BaseProvider):
"""
Provider implementation for Ollama.
Ollama allows running open-source language models locally with minimal setup.
This provider implementation connects to a locally running Ollama instance and
provides a standard interface for generating completions and embeddings.
Unlike cloud providers, Ollama runs models locally, so:
- No API key is required
- Costs are estimated (since running locally is free)
- Model availability depends on what models have been downloaded locally
The Ollama provider supports both chat completions and text completions,
as well as streaming responses. It requires that the Ollama service is
running and accessible at the configured endpoint.
"""
provider_name = Provider.OLLAMA
def __init__(self, api_key: Optional[str] = None, **kwargs):
"""Initialize the Ollama provider.
Args:
api_key: Not used by Ollama, included for API compatibility with other providers
**kwargs: Additional provider-specific options
"""
# Skip API key, it's not used by Ollama but we accept it for compatibility
super().__init__()
self.logger = get_logger(f"provider.{Provider.OLLAMA}")
self.logger.info("Initializing Ollama provider...")
self.config = self._load_config()
self.logger.info(
f"Loaded config: API URL={self.config.api_url}, default_model={self.config.default_model}, enabled={self.config.enabled}"
)
# Initialize session to None, we'll create it when needed
self._session = None
self.client_session_params = {
"timeout": aiohttp.ClientTimeout(total=self.config.request_timeout)
}
# Unlike other providers, Ollama doesn't require an API key
# But we'll still set this flag to True for consistency
self._api_key_configured = True
self._initialized = False
# Set feature flags
self.features = ProviderFeatures(
supports_chat_completions=True,
supports_streaming=True,
supports_function_calling=False, # Ollama doesn't support function calling natively
supports_multiple_functions=False,
supports_embeddings=True,
supports_json_mode=True, # Now supported via prompt engineering and format parameter
max_retries=3,
)
# Set default costs for Ollama models (very low estimated costs)
# Since Ollama runs locally, the actual cost is hardware usage/electricity
# We'll use very low values for tracking purposes
self._default_token_cost = {
"input": 0.0001, # $0.0001 per 1M tokens (effectively free)
"output": 0.0001, # $0.0001 per 1M tokens (effectively free)
}
self.logger.info("Ollama provider initialization completed")
@property
async def session(self) -> aiohttp.ClientSession:
"""Get the current session or create a new one if needed."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession(**self.client_session_params)
return self._session
async def __aenter__(self):
"""Enter async context, initializing the provider."""
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit async context, ensuring proper shutdown."""
await self.shutdown()
async def initialize(self) -> bool:
"""Initialize the provider, creating a new HTTP session.
This method handles the initialization of the connection to Ollama.
If Ollama isn't available (not installed or not running),
it will gracefully report the issue without spamming errors.
Returns:
bool: True if initialization was successful, False otherwise
"""
try:
# Create a temporary session with a short timeout for the initial check
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=5.0)
) as check_session:
# Try to connect to Ollama and check if it's running
self.logger.info(
f"Attempting to connect to Ollama at {self.config.api_url}/api/tags",
emoji_key="provider",
)
# First try the configured URL
try:
async with check_session.get(
f"{self.config.api_url}/api/tags", timeout=5.0
) as response:
if response.status == 200:
# Ollama is running, we'll create the main session when needed later
self.logger.info(
"Ollama service is available and running", emoji_key="provider"
)
self._initialized = True
return True
else:
self.logger.warning(
f"Ollama service responded with status {response.status}. "
"The service might be misconfigured.",
emoji_key="warning",
)
except aiohttp.ClientConnectionError:
# Try alternate localhost format (127.0.0.1 instead of localhost or vice versa)
alternate_url = (
self.config.api_url.replace("localhost", "127.0.0.1")
if "localhost" in self.config.api_url
else self.config.api_url.replace("127.0.0.1", "localhost")
)
self.logger.info(
f"Connection failed, trying alternate URL: {alternate_url}",
emoji_key="provider",
)
try:
async with check_session.get(
f"{alternate_url}/api/tags", timeout=5.0
) as response:
if response.status == 200:
# Update the config to use the working URL
self.logger.info(
f"Connected successfully using alternate URL: {alternate_url}",
emoji_key="provider",
)
self.config.api_url = alternate_url
self._initialized = True
return True
else:
self.logger.warning(
f"Ollama service at alternate URL responded with status {response.status}. "
"The service might be misconfigured.",
emoji_key="warning",
)
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
self.logger.warning(
f"Could not connect to alternate URL: {str(e)}. "
"Make sure Ollama is installed and running: https://ollama.com/download",
emoji_key="warning",
)
except aiohttp.ClientError as e:
# Other client errors
self.logger.warning(
f"Could not connect to Ollama service: {str(e)}. "
"Make sure Ollama is installed and running: https://ollama.com/download",
emoji_key="warning",
)
except asyncio.TimeoutError:
# Timeout indicates Ollama is likely not responding
self.logger.warning(
"Connection to Ollama service timed out. "
"Make sure Ollama is installed and running: https://ollama.com/download",
emoji_key="warning",
)
# If we got here, Ollama is not available
self._initialized = False
return False
except Exception as e:
# Catch any other exceptions to avoid spamming errors
self.logger.error(
f"Unexpected error initializing Ollama provider: {str(e)}", emoji_key="error"
)
self._initialized = False
return False
async def shutdown(self) -> None:
"""Shutdown the provider, closing the HTTP session."""
try:
if self._session and not self._session.closed:
await self._session.close()
self._session = None
except Exception as e:
self.logger.warning(
f"Error closing Ollama session during shutdown: {str(e)}", emoji_key="warning"
)
finally:
self._initialized = False
def _load_config(self) -> OllamaConfig:
"""Load Ollama configuration from app configuration."""
try:
self.logger.info("Loading Ollama config from app configuration")
config = get_config()
# Print entire config for debugging
self.logger.debug(f"Full config: {config}")
if not hasattr(config, "providers"):
self.logger.warning("Config doesn't have 'providers' attribute")
return OllamaConfig()
if not hasattr(config.providers, Provider.OLLAMA):
self.logger.warning(f"Config doesn't have '{Provider.OLLAMA}' provider configured")
return OllamaConfig()
provider_config = getattr(config.providers, Provider.OLLAMA, {})
self.logger.info(f"Found provider config: {provider_config}")
if hasattr(provider_config, "dict"):
self.logger.info("Provider config has 'dict' method, using it")
return OllamaConfig(**provider_config.dict())
else:
self.logger.warning(
"Provider config doesn't have 'dict' method, attempting direct conversion"
)
# Try to convert to dict directly
config_dict = {}
# Define mapping from ProviderConfig field names to OllamaConfig field names
field_mapping = {
"base_url": "api_url", # ProviderConfig -> OllamaConfig
"default_model": "default_model",
"timeout": "request_timeout",
"enabled": "enabled",
}
# Map fields from provider_config to OllamaConfig's expected field names
for provider_key, ollama_key in field_mapping.items():
if hasattr(provider_config, provider_key):
config_dict[ollama_key] = getattr(provider_config, provider_key)
self.logger.info(
f"Mapped {provider_key} to {ollama_key}: {getattr(provider_config, provider_key)}"
)
self.logger.info(f"Created config dict: {config_dict}")
return OllamaConfig(**config_dict)
except Exception as e:
self.logger.error(f"Error loading Ollama config: {e}", exc_info=True)
return OllamaConfig()
def get_default_model(self) -> str:
"""Get the default model for this provider."""
return self.config.default_model
def get_status(self) -> ProviderStatus:
"""Get the current status of this provider."""
return ProviderStatus(
name=self.provider_name,
enabled=self.config.enabled,
available=self._initialized,
api_key_configured=self._api_key_configured,
features=self.features,
default_model=self.get_default_model(),
)
async def check_api_key(self) -> bool:
"""
Check if the Ollama service is accessible.
Since Ollama doesn't use API keys, this just checks if the service is running.
This check is designed to fail gracefully if Ollama is not installed or running,
without causing cascading errors in the system.
Returns:
bool: True if Ollama service is running and accessible, False otherwise
"""
if not self._initialized:
try:
# Attempt to initialize with a short timeout
return await self.initialize()
except Exception as e:
self.logger.warning(
f"Failed to initialize Ollama during service check: {str(e)}",
emoji_key="warning",
)
return False
try:
# Use a dedicated session with short timeout for health check
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3.0)) as session:
try:
async with session.get(f"{self.config.api_url}/api/tags") as response:
return response.status == 200
except (aiohttp.ClientConnectionError, asyncio.TimeoutError, Exception) as e:
self.logger.warning(
f"Ollama service check failed: {str(e)}", emoji_key="warning"
)
return False
except Exception as e:
self.logger.warning(
f"Failed to create session for Ollama check: {str(e)}", emoji_key="warning"
)
return False
def _build_api_url(self, endpoint: str) -> str:
"""Build the full API URL for a given endpoint."""
return f"{self.config.api_url}/api/{endpoint}"
def _estimate_token_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
"""
Estimate the cost of a completion based on token counts.
Since Ollama runs locally, the costs are just estimates and very low.
"""
# Try to get model-specific costs if available
model_costs = COST_PER_MILLION_TOKENS.get(model, self._default_token_cost)
# Calculate costs
input_cost = (input_tokens / 1_000_000) * model_costs.get(
"input", self._default_token_cost["input"]
)
output_cost = (output_tokens / 1_000_000) * model_costs.get(
"output", self._default_token_cost["output"]
)
return input_cost + output_cost
async def list_models(self) -> List[Model]:
"""
List all available models from Ollama.
This method attempts to list all locally available Ollama models.
If Ollama is not available or cannot be reached, it will return
an empty list instead of raising an exception.
Returns:
List of available Ollama models, or empty list if Ollama is not available
"""
if not self._initialized:
try:
initialized = await self.initialize()
if not initialized:
self.logger.warning(
"Cannot list Ollama models because the service is not available",
emoji_key="warning",
)
return []
except Exception:
self.logger.warning(
"Cannot list Ollama models because initialization failed", emoji_key="warning"
)
return []
try:
# Create a dedicated session for this operation to avoid shared session issues
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10.0)) as session:
return await self._fetch_models(session)
except Exception as e:
self.logger.warning(
f"Error listing Ollama models: {str(e)}. The service may not be running.",
emoji_key="warning",
)
return []
async def _fetch_models(self, session: aiohttp.ClientSession) -> List[Model]:
"""Fetch models using the provided session."""
try:
async with session.get(self._build_api_url("tags")) as response:
if response.status != 200:
self.logger.warning(f"Failed to list Ollama models: {response.status}")
return []
data = await response.json()
models = []
# Process the response
for model_info in data.get("models", []):
model_id = model_info.get("name", "")
# Extract additional info if available
description = f"Ollama model: {model_id}"
model_size = model_info.get("size", 0)
size_gb = None
if model_size:
# Convert to GB for readability if size is provided in bytes
size_gb = model_size / (1024 * 1024 * 1024)
description += f" ({size_gb:.2f} GB)"
models.append(
Model(
id=model_id,
name=model_id,
description=description,
provider=self.provider_name,
size=f"{size_gb:.2f} GB" if size_gb else "Unknown",
)
)
return models
except aiohttp.ClientConnectionError:
self.logger.warning(
"Connection refused while listing Ollama models", emoji_key="warning"
)
return []
except asyncio.TimeoutError:
self.logger.warning("Timeout while listing Ollama models", emoji_key="warning")
return []
except Exception as e:
self.logger.warning(f"Error fetching Ollama models: {str(e)}", emoji_key="warning")
return []
async def generate_completion(
self,
prompt: Optional[str] = None,
messages: Optional[List[Dict[str, Any]]] = None,
model: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: float = 0.7,
stop: Optional[List[str]] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
mirostat: Optional[int] = None,
mirostat_tau: Optional[float] = None,
mirostat_eta: Optional[float] = None,
json_mode: bool = False,
**kwargs
) -> ModelResponse:
"""Generate a completion from Ollama.
Args:
prompt: Text prompt to send to Ollama (optional if messages provided)
messages: List of message dictionaries (optional if prompt provided)
model: Ollama model name (e.g., "llama2:13b")
max_tokens: Maximum tokens to generate
temperature: Controls randomness (0.0-1.0)
stop: List of strings that stop generation when encountered
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
frequency_penalty: Frequency penalty parameter
presence_penalty: Presence penalty parameter
mirostat: Mirostat sampling algorithm (0, 1, or 2)
mirostat_tau: Target entropy for mirostat
mirostat_eta: Learning rate for mirostat
json_mode: Request JSON-formatted response
**kwargs: Additional parameters
Returns:
ModelResponse object with completion result
"""
if not self.config.api_url:
raise ValueError("Ollama API URL not configured")
# Verify we have either prompt or messages
if prompt is None and not messages:
raise ValueError("Either prompt or messages must be provided to generate a completion")
# If model is None, use configured default
model = model or self.get_default_model()
# Only strip provider prefix if it's our provider name, keep organization prefixes
if "/" in model and model.startswith(f"{self.provider_name}/"):
model = model.split("/", 1)[1]
# If JSON mode is enabled, use the streaming implementation internally
# since Ollama's non-streaming JSON mode is inconsistent
if json_mode:
self.logger.debug("JSON mode requested, using streaming implementation internally for reliability")
return await self._generate_completion_via_streaming(
prompt=prompt,
messages=messages,
model=model,
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
mirostat=mirostat,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
json_mode=True, # Ensure json_mode is passed through
**kwargs
)
# Log request start
self.logger.info(
f"Generating Ollama completion (generate) with model {model}",
emoji_key=self.provider_name
)
# Convert messages to prompt if messages provided
using_messages = False
if messages and not prompt:
using_messages = True
# Convert messages to Ollama's chat format
chat_params = {"messages": []}
# Process messages into Ollama format
for msg in messages:
role = msg.get("role", "").lower()
content = msg.get("content", "")
# Map roles to Ollama's expected format
if role == "system":
ollama_role = "system"
elif role == "user":
ollama_role = "user"
elif role == "assistant":
ollama_role = "assistant"
else:
# Default unknown roles to user
self.logger.warning(f"Unknown message role '{role}', treating as 'user'")
ollama_role = "user"
chat_params["messages"].append({
"role": ollama_role,
"content": content
})
# Add model and parameters to chat_params
chat_params["model"] = model
# Add optional parameters if provided
if temperature is not None and temperature != 0.7:
chat_params["options"] = chat_params.get("options", {})
chat_params["options"]["temperature"] = temperature
if max_tokens is not None:
chat_params["options"] = chat_params.get("options", {})
chat_params["options"]["num_predict"] = max_tokens
if stop:
chat_params["options"] = chat_params.get("options", {})
chat_params["options"]["stop"] = stop
# Add other parameters if provided
for param_name, param_value in [
("top_p", top_p),
("top_k", top_k),
("frequency_penalty", frequency_penalty),
("presence_penalty", presence_penalty),
("mirostat", mirostat),
("mirostat_tau", mirostat_tau),
("mirostat_eta", mirostat_eta)
]:
if param_value is not None:
chat_params["options"] = chat_params.get("options", {})
chat_params["options"][param_name] = param_value
# Add json_mode if requested (as format option)
if json_mode:
chat_params["options"] = chat_params.get("options", {})
chat_params["options"]["format"] = "json"
# For Ollama non-streaming completions, we need to force the system message
# because the format param alone isn't reliable
kwargs["add_json_instructions"] = True
# Only add system message instruction as a fallback if explicitly requested
add_json_instructions = kwargs.pop("add_json_instructions", False)
# Add system message for json_mode only if requested
if add_json_instructions:
has_system = any(msg.get("role", "").lower() == "system" for msg in messages)
if not has_system:
# Add JSON instruction as a system message
chat_params["messages"].insert(0, {
"role": "system",
"content": "You must respond with valid JSON. Format your entire response as a JSON object with properly quoted keys and values."
})
self.logger.debug("Added JSON system instructions for chat_params")
# Add any additional kwargs as options
if kwargs:
chat_params["options"] = chat_params.get("options", {})
chat_params["options"].update(kwargs)
# Use chat endpoint
api_endpoint = self._build_api_url("chat")
response_type = "chat"
else:
# Using generate endpoint with prompt
# Prepare generate parameters
generate_params = {
"model": model,
"prompt": prompt
}
# Add optional parameters if provided
if temperature is not None and temperature != 0.7:
generate_params["options"] = generate_params.get("options", {})
generate_params["options"]["temperature"] = temperature
if max_tokens is not None:
generate_params["options"] = generate_params.get("options", {})
generate_params["options"]["num_predict"] = max_tokens
if stop:
generate_params["options"] = generate_params.get("options", {})
generate_params["options"]["stop"] = stop
# Add other parameters if provided
for param_name, param_value in [
("top_p", top_p),
("top_k", top_k),
("frequency_penalty", frequency_penalty),
("presence_penalty", presence_penalty),
("mirostat", mirostat),
("mirostat_tau", mirostat_tau),
("mirostat_eta", mirostat_eta)
]:
if param_value is not None:
generate_params["options"] = generate_params.get("options", {})
generate_params["options"][param_name] = param_value
# Add json_mode if requested (as format option)
if json_mode:
generate_params["options"] = generate_params.get("options", {})
generate_params["options"]["format"] = "json"
# For Ollama non-streaming completions, we need to force the JSON instructions
# because the format param alone isn't reliable
kwargs["add_json_instructions"] = True
# Only enhance prompt with JSON instructions if explicitly requested
add_json_instructions = kwargs.pop("add_json_instructions", False)
if add_json_instructions:
# Enhance prompt with JSON instructions for better compliance
generate_params["prompt"] = f"Please respond with valid JSON only. {prompt}\nEnsure your entire response is a valid, parseable JSON object with properly quoted keys and values."
self.logger.debug("Enhanced prompt with JSON instructions for generate_params")
# Add any additional kwargs as options
if kwargs:
generate_params["options"] = generate_params.get("options", {})
generate_params["options"].update(kwargs)
# Use generate endpoint
api_endpoint = self._build_api_url("generate")
response_type = "generate" # noqa: F841
# Start timer for tracking
start_time = time.time()
try:
# Make HTTP request to Ollama
async with httpx.AsyncClient(timeout=self.config.request_timeout) as client:
if using_messages:
# Using chat endpoint
response = await client.post(api_endpoint, json=chat_params)
else:
# Using generate endpoint
response = await client.post(api_endpoint, json=generate_params)
# Check for HTTP errors
response.raise_for_status()
# Parse response - handle multi-line JSON data which can happen with json_mode
try:
# First try regular JSON parsing
result = response.json()
except json.JSONDecodeError as e:
# If that fails, try parsing line by line and concatenate responses
self.logger.debug("Response contains multiple JSON objects, parsing line by line")
content = response.text
lines = content.strip().split('\n')
# If we have multiple JSON objects
if len(lines) > 1:
# For multiple objects, take the last one which should have the final response
# This happens in some Ollama versions when using format=json
try:
result = json.loads(lines[-1]) # Use the last line, which typically has the complete response
# Verify result has response/message field, if not try the first line
if using_messages and "message" not in result:
result = json.loads(lines[0])
elif not using_messages and "response" not in result:
result = json.loads(lines[0])
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to parse Ollama JSON response: {str(e)}. Response: {content[:200]}...") from e
else:
# If we only have one line but still got a JSON error
raise RuntimeError(f"Invalid JSON in Ollama response: {content[:200]}...") from e
# Calculate processing time
processing_time = time.time() - start_time
# Extract response text based on endpoint
if using_messages:
# Extract from chat endpoint
completion_text = result.get("message", {}).get("content", "")
else:
# Extract from generate endpoint
completion_text = result.get("response", "")
# Log the raw response for debugging
self.logger.debug(f"Raw Ollama response: {result}")
self.logger.debug(f"Extracted completion text: {completion_text[:500]}...")
# For JSON mode, ensure the completion text is properly formatted JSON
if json_mode and completion_text:
# Always use add_json_instructions for this model since it seems to need it
if "gemma" in model.lower():
# Force adding instructions for gemma models specifically
kwargs["add_json_instructions"] = True
try:
# First try to extract JSON using our comprehensive method
extracted_json = self._extract_json_from_text(completion_text)
self.logger.debug(f"Extracted JSON: {extracted_json[:500]}...")
# If we found valid JSON, parse and format it
json_data = json.loads(extracted_json)
# If successful, format it nicely with indentation
if isinstance(json_data, (dict, list)):
completion_text = json.dumps(json_data, indent=2)
self.logger.debug("Successfully parsed and formatted JSON response")
else:
self.logger.warning(f"JSON response is not a dict or list: {type(json_data)}")
except (json.JSONDecodeError, TypeError) as e:
self.logger.warning(f"Failed to extract valid JSON from response: {str(e)[:100]}...")
# Calculate token usage
prompt_tokens = result.get("prompt_eval_count", 0)
completion_tokens = result.get("eval_count", 0)
# Format the standardized response
model_response = ModelResponse(
text=completion_text,
model=f"{self.provider_name}/{model}",
provider=self.provider_name,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
processing_time=processing_time,
raw_response=result
)
# Add message field for chat_completion compatibility
model_response.message = {"role": "assistant", "content": completion_text}
# Ensure there's always a value returned for JSON mode to prevent empty displays
if json_mode and (not completion_text or not completion_text.strip()):
# If we got an empty response, create a default one
default_json = {
"response": "No content was returned by the model",
"error": "Empty response with json_mode enabled"
}
completion_text = json.dumps(default_json, indent=2)
model_response.text = completion_text
model_response.message["content"] = completion_text
self.logger.warning("Empty response with JSON mode, returning default JSON structure")
# Log success
self.logger.success(
f"Ollama completion successful with model {model}",
emoji_key="completion_success",
tokens={"input": prompt_tokens, "output": completion_tokens},
time=processing_time,
model=model
)
return model_response
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors
processing_time = time.time() - start_time
try:
error_json = http_err.response.json()
error_msg = error_json.get("error", str(http_err))
except (json.JSONDecodeError, KeyError):
error_msg = f"HTTP error: {http_err.response.status_code} - {http_err.response.text}"
self.logger.error(
f"Ollama API error: {error_msg}",
emoji_key="error",
status_code=http_err.response.status_code,
model=model
)
raise ConnectionError(f"Ollama API error: {error_msg}") from http_err
except httpx.RequestError as req_err:
# Handle request errors (e.g., connection issues)
processing_time = time.time() - start_time
error_msg = f"Request error: {str(req_err)}"
self.logger.error(
f"Ollama request error: {error_msg}",
emoji_key="error",
model=model
)
raise ConnectionError(f"Ollama request error: {error_msg}") from req_err
except Exception as e:
# Handle other unexpected errors
processing_time = time.time() - start_time
self.logger.error(
f"Unexpected error calling Ollama: {str(e)}",
emoji_key="error",
model=model,
exc_info=True
)
raise RuntimeError(f"Unexpected error calling Ollama: {str(e)}") from e
async def generate_completion_stream(
self,
prompt: Optional[str] = None,
messages: Optional[List[Dict[str, Any]]] = None,
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
stop: Optional[List[str]] = None,
system: Optional[str] = None,
**kwargs: Any,
) -> AsyncGenerator[Tuple[str, Dict[str, Any]], None]:
# This is the main try block for the whole function - needs exception handling
try:
# Verify we have either prompt or messages
if prompt is None and not messages:
raise ValueError("Either prompt or messages must be provided to generate a streaming completion")
# Check if provider is initialized before attempting to generate
if not self._initialized:
try:
initialized = await self.initialize()
if not initialized:
# Yield an error message and immediately terminate
error_metadata = {
"model": f"{self.provider_name}/{model or self.get_default_model()}",
"provider": self.provider_name,
"error": "Ollama service is not available. Make sure Ollama is installed and running: https://ollama.com/download",
"finish_reason": "error",
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"processing_time": 0.0,
}
yield "", error_metadata
return
except Exception as e:
# Yield an error message and immediately terminate
error_metadata = {
"model": f"{self.provider_name}/{model or self.get_default_model()}",
"provider": self.provider_name,
"error": f"Failed to initialize Ollama provider: {str(e)}. Make sure Ollama is installed and running: https://ollama.com/download",
"finish_reason": "error",
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"processing_time": 0.0,
}
yield "", error_metadata
return
# Use default model if none specified
model_id = model or self.get_default_model()
# Only remove our provider prefix if present, keep organization prefixes
if "/" in model_id and model_id.startswith(f"{self.provider_name}/"):
model_id = model_id.split("/", 1)[1]
# Check for json_mode flag and remove it from kwargs
json_mode = kwargs.pop("json_mode", False)
format_param = None
if json_mode:
# Ollama supports structured output via 'format' parameter at the ROOT level
# This can be either "json" for basic JSON mode or a JSON schema for structured output
format_param = "json" # Use simple "json" string for basic JSON mode
self.logger.debug("Setting format='json' for Ollama streaming")
# Note: Format parameter may be less reliable with streaming
# due to how content is chunked, but Ollama should handle this.
# Flag to track if we're using messages format
using_messages = False
# Prepare the payload based on input type (messages or prompt)
if messages:
using_messages = True # noqa: F841
# Convert messages to Ollama's expected format
ollama_messages = []
# Process messages
for msg in messages:
role = msg.get("role", "").lower()
content = msg.get("content", "")
# Map roles to Ollama's expected format
if role == "system":
ollama_role = "system"
elif role == "user":
ollama_role = "user"
elif role == "assistant":
ollama_role = "assistant"
else:
# Default unknown roles to user
self.logger.warning(f"Unknown message role '{role}', treating as 'user'")
ollama_role = "user"
ollama_messages.append({
"role": ollama_role,
"content": content
})
# Build chat payload
payload = {
"model": model_id,
"messages": ollama_messages,
"stream": True,
"options": { # Ollama options go inside an 'options' dict
"temperature": temperature,
},
}
# Use chat endpoint
api_endpoint = "chat"
elif system is not None or model_id.startswith(
("llama", "gpt", "claude", "phi", "mistral")
):
# Use chat endpoint with system message (if provided) and prompt
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
payload = {
"model": model_id,
"messages": messages,
"stream": True,
"options": { # Ollama options go inside an 'options' dict
"temperature": temperature,
},
}
# Use chat endpoint
api_endpoint = "chat"
else:
# Use generate endpoint with prompt
payload = {
"model": model_id,
"prompt": prompt,
"stream": True,
"options": { # Ollama options go inside an 'options' dict
"temperature": temperature,
},
}
# Use generate endpoint
api_endpoint = "generate"
# Add common optional parameters
if max_tokens:
payload["options"]["num_predict"] = max_tokens
if stop:
payload["options"]["stop"] = stop
# Add format parameter at the root level if JSON mode is enabled
if format_param:
payload["format"] = format_param
# Add any additional supported parameters from kwargs into options
for key, value in kwargs.items():
if key in ["seed", "top_k", "top_p", "num_ctx"]:
payload["options"][key] = value
# Log request including JSON mode status
content_length = 0
if messages:
content_length = sum(len(m.get("content", "")) for m in messages)
elif prompt:
content_length = len(prompt)
self.logger.info(
f"Generating Ollama streaming completion ({api_endpoint}) with model {model_id}",
emoji_key=self.provider_name,
prompt_length=content_length,
json_mode_requested=json_mode,
)
start_time = time.time()
input_tokens = 0
output_tokens = 0
finish_reason = None
final_error = None
async with aiohttp.ClientSession(**self.client_session_params) as streaming_session:
async with streaming_session.post(
self._build_api_url(api_endpoint), json=payload
) as response:
if response.status != 200:
error_text = await response.text()
final_error = (
f"Ollama streaming API error: {response.status} - {error_text}"
)
# Yield error and stop
yield (
"",
{
"error": final_error,
"finished": True,
"provider": self.provider_name,
"model": model_id,
},
)
return
buffer = ""
chunk_index = 0
async for line in response.content:
if not line.strip():
continue
buffer += line.decode("utf-8")
# Process complete JSON objects in the buffer
while "\n" in buffer:
json_str, buffer = buffer.split("\n", 1)
if not json_str.strip():
continue
try:
data = json.loads(json_str)
chunk_index += 1
# Extract content based on endpoint
if api_endpoint == "chat":
text_chunk = data.get("message", {}).get("content", "")
else: # generate endpoint
text_chunk = data.get("response", "")
# Check if this is the final summary chunk
if data.get("done", False):
input_tokens = data.get("prompt_eval_count", input_tokens)
output_tokens = data.get("eval_count", output_tokens)
finish_reason = data.get(
"done_reason", "stop"
) # Get finish reason if available
# Yield the final text chunk if any, then break to yield summary
if text_chunk:
metadata = {
"provider": self.provider_name,
"model": model_id,
"chunk_index": chunk_index,
"finished": False,
}
yield text_chunk, metadata
break # Exit inner loop after processing final chunk
# Yield regular chunk
if text_chunk:
metadata = {
"provider": self.provider_name,
"model": model_id,
"chunk_index": chunk_index,
"finished": False,
}
yield text_chunk, metadata
except json.JSONDecodeError:
self.logger.warning(
f"Could not decode JSON line: {json_str[:100]}..."
)
# Continue, maybe it's part of a larger object split across lines
except Exception as parse_error:
self.logger.warning(f"Error processing stream chunk: {parse_error}")
final_error = f"Error processing stream: {parse_error}"
break # Stop processing on unexpected error
if final_error:
break # Exit outer loop if error occurred
# --- Final Chunk ---
processing_time = time.time() - start_time
total_tokens = input_tokens + output_tokens
cost = self._estimate_token_cost(model_id, input_tokens, output_tokens)
final_metadata = {
"model": f"{self.provider_name}/{model_id}",
"provider": self.provider_name,
"finished": True,
"finish_reason": finish_reason,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
"cost": cost,
"processing_time": processing_time,
"error": final_error,
}
yield "", final_metadata # Yield empty chunk with final stats
except aiohttp.ClientConnectionError as e:
# Yield connection error
yield (
"",
{
"error": f"Connection to Ollama failed: {str(e)}",
"finished": True,
"provider": self.provider_name,
"model": model_id,
},
)
except asyncio.TimeoutError:
# Yield timeout error
yield (
"",
{
"error": "Connection to Ollama timed out",
"finished": True,
"provider": self.provider_name,
"model": model_id,
},
)
except Exception as e:
# Yield generic error
if isinstance(e, ProviderError):
raise
yield (
"",
{
"error": f"Error generating streaming completion: {str(e)}",
"finished": True,
"provider": self.provider_name,
"model": model_id,
},
)
async def create_embeddings(
self,
texts: List[str],
model: Optional[str] = None,
**kwargs: Any,
) -> ModelResponse:
"""
Generate embeddings for a list of texts using the Ollama API.
Args:
texts: List of texts to generate embeddings for.
model: The model ID to use (defaults to provider's default).
**kwargs: Additional parameters to pass to the API.
Returns:
An ModelResponse object with the embeddings and metadata.
If Ollama is not available, returns an error in the metadata.
"""
# Check if provider is initialized before attempting to generate
if not self._initialized:
try:
initialized = await self.initialize()
if not initialized:
# Return a clear error without raising an exception
return ModelResponse(
text="",
model=f"{self.provider_name}/{model or self.get_default_model()}",
provider=self.provider_name,
input_tokens=0,
output_tokens=0,
total_tokens=0,
processing_time=0.0,
metadata={
"error": "Ollama service is not available. Make sure Ollama is installed and running: https://ollama.com/download",
"embeddings": [],
},
)
except Exception as e:
# Return a clear error without raising an exception
return ModelResponse(
text="",
model=f"{self.provider_name}/{model or self.get_default_model()}",
provider=self.provider_name,
input_tokens=0,
output_tokens=0,
total_tokens=0,
processing_time=0.0,
metadata={
"error": f"Failed to initialize Ollama provider: {str(e)}. Make sure Ollama is installed and running: https://ollama.com/download",
"embeddings": [],
},
)
# Use default model if none specified
model_id = model or self.get_default_model()
# Only remove our provider prefix if present, keep organization prefixes
if "/" in model_id and model_id.startswith(f"{self.provider_name}/"):
model_id = model_id.split("/", 1)[1]
# Get total number of tokens in all texts
# This is an estimation since Ollama doesn't provide token counts for embeddings
total_tokens = sum(len(text.split()) for text in texts)
# Prepare the result
result_embeddings = []
errors = []
all_dimensions = None
try:
start_time = time.time()
# Create a dedicated session for this embeddings request
async with aiohttp.ClientSession(**self.client_session_params) as session:
# Process each text individually (Ollama supports batching but we'll use same pattern as other providers)
for text in texts:
payload = {
"model": model_id,
"prompt": text,
}
# Add any additional parameters
for key, value in kwargs.items():
if key not in payload and value is not None:
payload[key] = value
try:
async with session.post(
self._build_api_url("embeddings"), json=payload, timeout=30.0
) as response:
if response.status != 200:
error_text = await response.text()
errors.append(f"Ollama API error: {response.status} - {error_text}")
# Continue with the next text
continue
data = await response.json()
# Extract embeddings
embedding = data.get("embedding", [])
if not embedding:
errors.append(f"No embedding returned for text: {text[:50]}...")
continue
# Store the embedding
result_embeddings.append(embedding)
# Check dimensions for consistency
dimensions = len(embedding)
if all_dimensions is None:
all_dimensions = dimensions
elif dimensions != all_dimensions:
errors.append(
f"Inconsistent embedding dimensions: got {dimensions}, expected {all_dimensions}"
)
except aiohttp.ClientConnectionError as e:
errors.append(
f"Connection to Ollama failed: {str(e)}. Make sure Ollama is running and accessible."
)
break
except asyncio.TimeoutError:
errors.append(
"Connection to Ollama timed out. Check if the service is overloaded."
)
break
except Exception as e:
errors.append(f"Error generating embedding: {str(e)}")
continue
# Calculate processing time
processing_time = time.time() - start_time
# Calculate cost (estimated)
estimated_cost = (total_tokens / 1_000_000) * 0.0001 # Very low cost estimation
# Create response model with embeddings in metadata
return ModelResponse(
text="", # Embeddings don't have text content
model=f"{self.provider_name}/{model_id}",
provider=self.provider_name,
input_tokens=total_tokens, # Use total tokens as input tokens for embeddings
output_tokens=0, # No output tokens for embeddings
total_tokens=total_tokens,
processing_time=processing_time,
metadata={
"embeddings": result_embeddings,
"dimensions": all_dimensions or 0,
"errors": errors if errors else None,
"cost": estimated_cost,
},
)
except aiohttp.ClientConnectionError as e:
# Return a clear error without raising an exception
return ModelResponse(
text="",
model=f"{self.provider_name}/{model_id}",
provider=self.provider_name,
input_tokens=0,
output_tokens=0,
total_tokens=0,
processing_time=0.0,
metadata={
"error": f"Connection to Ollama failed: {str(e)}. Make sure Ollama is running and accessible.",
"embeddings": [],
"cost": 0.0,
},
)
except Exception as e:
# Return a clear error without raising an exception
if isinstance(e, ProviderError):
raise
return ModelResponse(
text="",
model=f"{self.provider_name}/{model_id}",
provider=self.provider_name,
input_tokens=0,
output_tokens=0,
total_tokens=0,
processing_time=0.0,
metadata={
"error": f"Error generating embeddings: {str(e)}",
"embeddings": result_embeddings,
"cost": 0.0,
},
)
def _extract_json_from_text(self, text: str) -> str:
"""Extract JSON content from text that might include markdown code blocks or explanatory text.
Args:
text: The raw text response that might contain JSON
Returns:
Cleaned JSON content
"""
# First check if the text is already valid JSON
try:
json.loads(text)
return text # Already valid JSON
except json.JSONDecodeError:
pass # Continue with extraction
# Extract JSON from code blocks - common pattern
code_block_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', text)
if code_block_match:
code_content = code_block_match.group(1).strip()
try:
json.loads(code_content)
return code_content
except json.JSONDecodeError:
# Try to fix common JSON syntax issues like trailing commas
fixed_content = re.sub(r',\s*([}\]])', r'\1', code_content)
try:
json.loads(fixed_content)
return fixed_content
except json.JSONDecodeError:
pass # Continue with other extraction methods
# Look for JSON array or object patterns in the content
# Find the first [ or { and the matching closing ] or }
stripped = text.strip()
# Try to extract array
if '[' in stripped and ']' in stripped:
start = stripped.find('[')
# Find the matching closing bracket
end = -1
depth = 0
for i in range(start, len(stripped)):
if stripped[i] == '[':
depth += 1
elif stripped[i] == ']':
depth -= 1
if depth == 0:
end = i + 1
break
if end > start:
array_content = stripped[start:end]
try:
json.loads(array_content)
return array_content
except json.JSONDecodeError:
pass # Try other methods
# Try to extract object
if '{' in stripped and '}' in stripped:
start = stripped.find('{')
# Find the matching closing bracket
end = -1
depth = 0
for i in range(start, len(stripped)):
if stripped[i] == '{':
depth += 1
elif stripped[i] == '}':
depth -= 1
if depth == 0:
end = i + 1
break
if end > start:
object_content = stripped[start:end]
try:
json.loads(object_content)
return object_content
except json.JSONDecodeError:
pass # Try other methods
# If all else fails, return the original text
return text
async def _generate_completion_via_streaming(
self,
prompt: Optional[str] = None,
messages: Optional[List[Dict[str, Any]]] = None,
model: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: float = 0.7,
stop: Optional[List[str]] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
mirostat: Optional[int] = None,
mirostat_tau: Optional[float] = None,
mirostat_eta: Optional[float] = None,
system: Optional[str] = None,
json_mode: bool = False, # Add json_mode parameter to pass it through to streaming method
**kwargs: Any,
) -> ModelResponse:
"""Generate a completion via streaming and collect the results.
This is a workaround for Ollama's inconsistent behavior with JSON mode
in non-streaming completions. It uses the streaming API which works reliably
with JSON mode, and collects all chunks into a single result.
Args:
Same as generate_completion and generate_completion_stream
Returns:
ModelResponse: The complete response
"""
self.logger.debug("Using streaming method internally to handle JSON mode reliably")
# Start the streaming generator
stream_gen = self.generate_completion_stream(
prompt=prompt,
messages=messages,
model=model,
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
mirostat=mirostat,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
system=system,
json_mode=json_mode,
**kwargs
)
# Collect all text chunks
combined_text = ""
metadata = {}
input_tokens = 0
output_tokens = 0
processing_time = 0
try:
async for chunk, chunk_metadata in stream_gen:
if chunk_metadata.get("error"):
# If there's an error, raise it
raise RuntimeError(chunk_metadata["error"])
# Add current chunk to result
combined_text += chunk
# If this is the final chunk with stats, save the metadata
if chunk_metadata.get("finished", False):
metadata = chunk_metadata
input_tokens = chunk_metadata.get("input_tokens", 0)
output_tokens = chunk_metadata.get("output_tokens", 0)
processing_time = chunk_metadata.get("processing_time", 0)
except Exception as e:
# If streaming fails, re-raise the exception
raise RuntimeError(f"Error in streaming completion: {str(e)}") from e
# Create a ModelResponse with the combined text
result = ModelResponse(
text=combined_text,
model=metadata.get("model", f"{self.provider_name}/{model or self.get_default_model()}"),
provider=self.provider_name,
input_tokens=input_tokens,
output_tokens=output_tokens,
processing_time=processing_time,
raw_response={"streaming_source": True, "metadata": metadata}
)
# Add message field for chat_completion compatibility
result.message = {"role": "assistant", "content": combined_text}
return result
```
--------------------------------------------------------------------------------
/examples/sql_database_demo.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python
"""Demonstration script for SQLTool in Ultimate MCP Server."""
import asyncio
import datetime as dt
import os
import sqlite3
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional
# Add project root to path for imports when running as script
sys.path.insert(0, str(Path(__file__).parent.parent))
# Rich imports for nice UI
import pandas as pd
import pandera as pa
from rich import box
from rich.console import Console
from rich.markup import escape
from rich.panel import Panel
from rich.progress import BarColumn, Progress, TextColumn
from rich.rule import Rule
from rich.syntax import Syntax
from rich.table import Table
from rich.traceback import install as install_rich_traceback
from rich.tree import Tree
from ultimate_mcp_server.core.server import Gateway # Import the actual Gateway
from ultimate_mcp_server.exceptions import ToolError, ToolInputError
# Import the SQLTool class from our module
from ultimate_mcp_server.tools.sql_databases import SQLTool
from ultimate_mcp_server.utils import get_logger
# Initialize Rich console and logger
console = Console()
logger = get_logger("demo.sql_tool")
# Install rich tracebacks for better error display
install_rich_traceback(show_locals=False, width=console.width)
# --- Configuration ---
DEFAULT_CONNECTION_STRING = "sqlite:///:memory:" # In-memory SQLite for demo
# You can replace with a connection string like:
# "postgresql://username:password@localhost:5432/demo_db"
# "mysql+pymysql://username:password@localhost:3306/demo_db"
# "mssql+pyodbc://username:password@localhost:1433/demo_db?driver=ODBC+Driver+17+for+SQL+Server"
# --- Demo Helper Functions ---
def display_result(title: str, result: Dict[str, Any], query_str: Optional[str] = None) -> None:
"""Display query result with enhanced formatting."""
console.print(Rule(f"[bold cyan]{escape(title)}[/bold cyan]"))
if query_str:
console.print(Panel(
Syntax(query_str.strip(), "sql", theme="default", line_numbers=False, word_wrap=True),
title="Executed Query",
border_style="blue",
padding=(1, 2)
))
if not result.get("success", False):
error_msg = result.get("error", "Unknown error")
console.print(Panel(
f"[bold red]:x: Operation Failed:[/]\n{escape(error_msg)}",
title="Error",
border_style="red",
padding=(1, 2),
expand=False
))
return
# Handle different result types based on content
if "rows" in result:
# Query result with rows
rows = result.get("rows", [])
columns = result.get("columns", [])
row_count = result.get("row_count", len(rows))
if not rows:
console.print(Panel("[yellow]No results returned for this operation.", padding=(0, 1), border_style="yellow"))
return
table_title = f"Results ({row_count} row{'s' if row_count != 1 else ''} returned)"
if "pagination" in result:
pagination = result["pagination"]
table_title += f" - Page {pagination.get('page', '?')}"
table = Table(title=table_title, box=box.ROUNDED, show_header=True, padding=(0, 1), border_style="bright_blue")
# Add columns
for name in columns:
justify = "right" if any(k in name.lower() for k in ['id', 'count', 'price', 'amount', 'quantity', 'total']) else "left"
style = "cyan" if justify == "left" else "magenta"
table.add_column(name, style=style, justify=justify, header_style=f"bold {style}")
# Add data rows
for row in rows:
table.add_row(*[escape(str(row.get(col_name, ''))) for col_name in columns])
console.print(table)
# Display pagination info if available
if "pagination" in result:
pagination = result["pagination"]
pagination_info = Table(title="Pagination Info", show_header=False, box=box.SIMPLE, padding=(0, 1))
pagination_info.add_column("Metric", style="cyan", justify="right")
pagination_info.add_column("Value", style="white")
pagination_info.add_row("Page", str(pagination.get("page")))
pagination_info.add_row("Page Size", str(pagination.get("page_size")))
pagination_info.add_row("Has Next", "[green]:heavy_check_mark:[/]" if pagination.get("has_next_page") else "[dim]:x:[/]")
pagination_info.add_row("Has Previous", "[green]:heavy_check_mark:[/]" if pagination.get("has_previous_page") else "[dim]:x:[/]")
console.print(pagination_info)
# Show if truncated
if result.get("truncated"):
console.print("[yellow]⚠ Results truncated (reached max_rows limit)[/yellow]")
elif "documentation" in result:
# Documentation result
doc_content = result.get("documentation", "")
format_type = result.get("format", "markdown")
console.print(Panel(
Syntax(doc_content, format_type, theme="default", line_numbers=False, word_wrap=True),
title=f"Documentation ({format_type.upper()})",
border_style="magenta",
padding=(1, 2)
))
else:
# Generic success result, display as is
console.print(Panel(
"\n".join([f"[cyan]{k}:[/] {escape(str(v))}" for k, v in result.items() if k != "success"]),
title="Operation Result",
border_style="green",
padding=(1, 2)
))
console.print() # Add spacing
# Add setup functionality directly to avoid import issues
def init_demo_database(db_path):
"""Set up a demo database with sample tables and data."""
logger.info(f"Setting up demo database at: {db_path}")
# Connect to SQLite database
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Create tables
setup_queries = [
"""
CREATE TABLE IF NOT EXISTS customers (
customer_id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT UNIQUE,
signup_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
status TEXT CHECK(status IN ('active', 'inactive', 'pending')) DEFAULT 'pending',
ssn TEXT,
credit_card TEXT
)
""",
"""
CREATE TABLE IF NOT EXISTS products (
product_id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
description TEXT,
price DECIMAL(10,2) NOT NULL,
category TEXT,
in_stock BOOLEAN DEFAULT 1
)
""",
"""
CREATE TABLE IF NOT EXISTS orders (
order_id INTEGER PRIMARY KEY,
customer_id INTEGER NOT NULL,
order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
total_amount DECIMAL(10,2) NOT NULL,
status TEXT DEFAULT 'pending',
FOREIGN KEY (customer_id) REFERENCES customers(customer_id)
)
""",
"""
CREATE TABLE IF NOT EXISTS order_items (
item_id INTEGER PRIMARY KEY,
order_id INTEGER NOT NULL,
product_id INTEGER NOT NULL,
quantity INTEGER NOT NULL,
price_per_unit DECIMAL(10,2) NOT NULL,
FOREIGN KEY (order_id) REFERENCES orders(order_id),
FOREIGN KEY (product_id) REFERENCES products(product_id)
)
"""
]
# Insert sample data
sample_data_queries = [
# Insert customers with PII data already included
"""
INSERT INTO customers (customer_id, name, email, status, ssn, credit_card) VALUES
(1, 'Alice Johnson', '[email protected]', 'active', '123-45-6789', '4111-1111-1111-1111'),
(2, 'Bob Smith', '[email protected]', 'active', '234-56-7890', '4222-2222-2222-2222'),
(3, 'Charlie Davis', '[email protected]', 'inactive', '345-67-8901', '4333-3333-3333-3333'),
(4, 'Diana Miller', '[email protected]', 'active', '456-78-9012', '4444-4444-4444-4444'),
(5, 'Ethan Garcia', '[email protected]', 'pending', '567-89-0123', '4555-5555-5555-5555')
""",
# Insert products
"""
INSERT INTO products (product_id, name, description, price, category, in_stock) VALUES
(1, 'Laptop Pro X', 'High-performance laptop with 16GB RAM', 1499.99, 'Electronics', 1),
(2, 'Smartphone Z', 'Latest flagship smartphone', 999.99, 'Electronics', 1),
(3, 'Wireless Earbuds', 'Noise-cancelling earbuds', 179.99, 'Audio', 1),
(4, 'Smart Coffee Maker', 'WiFi-enabled coffee machine', 119.99, 'Kitchen', 0),
(5, 'Fitness Tracker', 'Waterproof fitness band with GPS', 79.99, 'Wearables', 1)
""",
# Insert orders
"""
INSERT INTO orders (order_id, customer_id, total_amount, status) VALUES
(1, 1, 1499.98, 'completed'),
(2, 2, 89.99, 'processing'),
(3, 1, 249.99, 'completed'),
(4, 3, 1099.98, 'completed'),
(5, 4, 49.99, 'processing')
""",
# Insert order items
"""
INSERT INTO order_items (item_id, order_id, product_id, quantity, price_per_unit) VALUES
(1, 1, 1, 1, 1499.99),
(2, 2, 5, 1, 79.99),
(3, 3, 3, 1, 179.99),
(4, 3, 4, 1, 119.99),
(5, 4, 2, 1, 999.99),
(6, 4, 5, 1, 79.99),
(7, 5, 4, 1, 119.99)
"""
]
try:
# Execute each query to set up schema
for query in setup_queries:
cursor.execute(query)
logger.info(f"Created table: {query.strip().split()[2]}")
# Execute each query to insert data
for query in sample_data_queries:
cursor.execute(query)
table_name = query.strip().split()[2]
row_count = cursor.rowcount
logger.info(f"Inserted {row_count} rows into {table_name}")
# Commit the changes
conn.commit()
logger.info("Database setup complete")
except sqlite3.Error as e:
logger.error(f"SQLite error: {e}")
conn.rollback()
raise
finally:
conn.close()
return db_path
# --- Demo Functions ---
async def connection_demo(sql_tool: SQLTool, conn_string: Optional[str] = None) -> Optional[str]:
"""Demonstrate database connection and status checking."""
console.print(Rule("[bold green]1. Database Connection Demo[/bold green]", style="green"))
logger.info("Starting database connection demo")
connection_id = None
connection_string = conn_string or DEFAULT_CONNECTION_STRING
with console.status("[bold cyan]Connecting to database...", spinner="earth"):
try:
# Connect to database
connection_result = await sql_tool.manage_database(
action="connect",
connection_string=connection_string,
echo=False # Disable SQLAlchemy logging for cleaner output
)
if connection_result.get("success"):
connection_id = connection_result.get("connection_id")
db_type = connection_result.get("database_type", "Unknown")
logger.success(f"Connected to database with ID: {connection_id}")
console.print(Panel(
f"Connection ID: [bold cyan]{escape(connection_id)}[/]\n"
f"Database Type: [blue]{escape(db_type)}[/]",
title="[bold green]:link: Connected[/]",
border_style="green",
padding=(1, 2),
expand=False
))
# Test the connection
console.print("[cyan]Testing connection health...[/]")
test_result = await sql_tool.manage_database(
action="test",
connection_id=connection_id
)
if test_result.get("success"):
resp_time = test_result.get("response_time_seconds", 0)
version = test_result.get("version", "N/A")
console.print(Panel(
f"[green]:heavy_check_mark: Connection test OK\n"
f"Response time: {resp_time:.4f}s\n"
f"DB Version: {version}",
border_style="green",
padding=(1, 2)
))
else:
console.print(Panel(
f"[bold red]:x: Connection test failed:[/]\n{escape(test_result.get('error', 'Unknown error'))}",
border_style="red",
padding=(1, 2)
))
# Get connection status
console.print("[cyan]Fetching database status...[/]")
status_result = await sql_tool.manage_database(
action="status",
connection_id=connection_id
)
if status_result.get("success"):
status_table = Table(title="Active Connections", box=box.HEAVY, padding=(0, 1), border_style="blue")
status_table.add_column("Connection ID", style="cyan")
status_table.add_column("Database", style="blue")
status_table.add_column("Last Accessed", style="dim")
status_table.add_column("Idle Time", style="yellow")
connections = status_result.get("connections", {})
for conn_id, conn_info in connections.items():
status_table.add_row(
conn_id,
conn_info.get("dialect", "unknown"),
conn_info.get("last_accessed", "N/A"),
f"{conn_info.get('idle_time_seconds', 0):.1f}s"
)
console.print(status_table)
else:
console.print(Panel(
f"[bold red]:x: Failed to get database status:[/]\n{escape(status_result.get('error', 'Unknown error'))}",
border_style="red",
padding=(1, 2)
))
else:
error_msg = connection_result.get('error', 'Unknown error')
logger.error(f"Failed to connect to database: {error_msg}")
console.print(Panel(
f"[bold red]:x: Connection failed:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Unexpected error in connection demo: {e}")
console.print(f"[bold red]:x: Unexpected Error:[/]\n{escape(str(e))}")
console.print() # Spacing
return connection_id
async def schema_discovery_demo(sql_tool: SQLTool, connection_id: str) -> None:
"""Demonstrate database schema discovery."""
console.print(Rule("[bold green]2. Schema Discovery Demo[/bold green]", style="green"))
logger.info("Starting schema discovery demo")
with console.status("[bold cyan]Discovering database schema...", spinner="dots"):
try:
schema_result = await sql_tool.explore_database(
connection_id=connection_id,
action="schema",
include_indexes=True,
include_foreign_keys=True,
detailed=True
)
if schema_result.get("success"):
tables = schema_result.get("tables", [])
views = schema_result.get("views", [])
relationships = schema_result.get("relationships", [])
logger.success(f"Schema discovered: {len(tables)} tables, {len(views)} views, {len(relationships)} relationships")
# Create a tree visualization
tree = Tree(
f"[bold bright_blue]:database: Database Schema ({len(tables)} Tables, {len(views)} Views)[/]",
guide_style="bright_blue"
)
# Add Tables branch
if tables:
tables_branch = tree.add("[bold cyan]:page_facing_up: Tables[/]")
for table in tables:
table_name = table.get("name", "Unknown")
table_node = tables_branch.add(f"[cyan]{escape(table_name)}[/]")
# Add columns
cols = table.get("columns", [])
if cols:
cols_branch = table_node.add("[bold yellow]:heavy_minus_sign: Columns[/]")
for col in cols:
col_name = col.get("name", "?")
col_type = col.get("type", "?")
is_pk = col.get("primary_key", False)
is_nullable = col.get("nullable", True)
pk_str = " [bold magenta](PK)[/]" if is_pk else ""
null_str = "" if is_nullable else " [dim]NOT NULL[/]"
cols_branch.add(f"[yellow]{escape(col_name)}[/]: {escape(col_type)}{pk_str}{null_str}")
# Add foreign keys
fks = table.get("foreign_keys", [])
if fks:
fks_branch = table_node.add("[bold blue]:link: Foreign Keys[/]")
for fk in fks:
ref_table = fk.get("referred_table", "?")
con_cols = ', '.join(fk.get("constrained_columns", []))
ref_cols = ', '.join(fk.get("referred_columns", []))
fks_branch.add(f"[blue]({escape(con_cols)})[/] -> [cyan]{escape(ref_table)}[/]({escape(ref_cols)})")
# Add Views branch
if views:
views_branch = tree.add("[bold magenta]:scroll: Views[/]")
for view in views:
view_name = view.get("name", "Unknown")
views_branch.add(f"[magenta]{escape(view_name)}[/]")
console.print(Panel(tree, title="Schema Overview", border_style="bright_blue", padding=(1, 2)))
# Show schema hash if available
if schema_hash := schema_result.get("schema_hash"):
console.print(f"[dim]Schema Hash: {schema_hash}[/dim]")
else:
error_msg = schema_result.get('error', 'Unknown error')
logger.error(f"Failed to discover schema: {error_msg}")
console.print(Panel(
f"[bold red]:x: Schema discovery failed:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Unexpected error in schema discovery demo: {e}")
console.print(f"[bold red]:x: Unexpected Error:[/]\n{escape(str(e))}")
console.print() # Spacing
async def table_details_demo(sql_tool: SQLTool, connection_id: str, table_name: str) -> None:
"""Demonstrate getting detailed information about a specific table."""
console.print(Rule(f"[bold green]3. Table Details: [cyan]{escape(table_name)}[/cyan][/bold green]", style="green"))
logger.info(f"Getting details for table: {table_name}")
try:
table_result = await sql_tool.explore_database(
connection_id=connection_id,
action="table",
table_name=table_name,
include_sample_data=True,
sample_size=3,
include_statistics=True
)
if table_result.get("success"):
logger.success(f"Successfully retrieved details for table: {table_name}")
console.print(Panel(f"[green]:heavy_check_mark: Details retrieved for [cyan]{escape(table_name)}[/]", border_style="green", padding=(0, 1)))
# Display columns
columns = table_result.get("columns", [])
if columns:
cols_table = Table(title="Columns", box=box.ROUNDED, show_header=True, padding=(0, 1), border_style="yellow")
cols_table.add_column("Name", style="yellow", header_style="bold yellow")
cols_table.add_column("Type", style="white")
cols_table.add_column("Nullable", style="dim")
cols_table.add_column("PK", style="magenta")
cols_table.add_column("Default", style="dim")
for column in columns:
cols_table.add_row(
escape(column.get("name", "?")),
escape(column.get("type", "?")),
":heavy_check_mark:" if column.get("nullable", False) else ":x:",
"[bold magenta]:key:[/]" if column.get("primary_key", False) else "",
escape(str(column.get("default", "")))
)
console.print(cols_table)
# Display sample data
sample_data = table_result.get("sample_data", {})
sample_rows = sample_data.get("rows", [])
sample_cols = sample_data.get("columns", [])
if sample_rows:
sample_table = Table(title="Sample Data (first 3 rows)", box=box.ROUNDED, show_header=True, padding=(0, 1), border_style="green")
for col_name in sample_cols:
sample_table.add_column(col_name, style="dim cyan", header_style="bold cyan")
for row in sample_rows:
sample_table.add_row(*[escape(str(row.get(col, ""))) for col in sample_cols])
console.print(sample_table)
# Display row count
row_count = table_result.get("row_count", "N/A")
console.print(f"[cyan]Total Rows:[/] [yellow]{row_count}[/yellow]")
# Display statistics if available
statistics = table_result.get("statistics", {})
if statistics:
stats_table = Table(title="Column Statistics", box=box.SIMPLE, show_header=True, padding=(0, 1), border_style="magenta")
stats_table.add_column("Column", style="cyan")
stats_table.add_column("Null Count", style="yellow", justify="right")
stats_table.add_column("Distinct Count", style="blue", justify="right")
for col_name, stats in statistics.items():
if isinstance(stats, dict) and "error" not in stats:
null_count = stats.get("null_count", "N/A")
distinct_count = stats.get("distinct_count", "N/A")
stats_table.add_row(escape(col_name), str(null_count), str(distinct_count))
console.print(stats_table)
else:
error_msg = table_result.get('error', 'Unknown error')
logger.error(f"Failed to get table details: {error_msg}")
console.print(Panel(
f"[bold red]:x: Failed to get table details:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Unexpected error in table details demo: {e}")
console.print(f"[bold red]:x: Unexpected Error:[/]\n{escape(str(e))}")
console.print() # Spacing
async def find_related_tables_demo(sql_tool: SQLTool, connection_id: str, table_name: str) -> None:
"""Demonstrate finding tables related to a specific table."""
console.print(Rule(f"[bold green]4. Related Tables: [cyan]{escape(table_name)}[/cyan][/bold green]", style="green"))
logger.info(f"Finding tables related to {table_name}")
try:
relations_result = await sql_tool.explore_database(
connection_id=connection_id,
action="relationships",
table_name=table_name,
depth=2 # Explore relationships to depth 2
)
if relations_result.get("success"):
rel_graph = relations_result.get("relationship_graph", {})
parents = rel_graph.get("parents", [])
children = rel_graph.get("children", [])
if parents or children:
logger.success(f"Found relationships for table: {table_name}")
# Create tree visualization
rel_tree = Tree(f"[bold blue]:link: Relationships for [cyan]{escape(table_name)}[/][/]", guide_style="blue")
# Add parent relationships (tables referenced by this table)
if parents:
parent_branch = rel_tree.add("[bold green]:arrow_up: References (Parents)[/]")
for parent in parents:
relationship = parent.get("relationship", "")
target = parent.get("target", {})
target_table = target.get("table", "?")
parent_branch.add(f"[blue]{escape(relationship)}[/] -> [green]{escape(target_table)}[/]")
# Add child relationships (tables that reference this table)
if children:
child_branch = rel_tree.add("[bold magenta]:arrow_down: Referenced By (Children)[/]")
for child in children:
relationship = child.get("relationship", "")
source = child.get("source", {})
source_table = source.get("table", "?")
child_branch.add(f"[magenta]{escape(source_table)}[/] -> [blue]{escape(relationship)}[/]")
console.print(Panel(rel_tree, title="Table Relationships", border_style="blue", padding=(1, 2)))
else:
logger.info(f"No direct relationships found for {table_name}")
console.print(Panel(f"[yellow]No direct relationships found for '{escape(table_name)}'", border_style="yellow", padding=(0, 1)))
else:
error_msg = relations_result.get('error', 'Unknown error')
logger.error(f"Failed to find relationships: {error_msg}")
console.print(Panel(
f"[bold red]:x: Failed to find relationships:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Unexpected error in relationship discovery demo: {e}")
console.print(f"[bold red]:x: Unexpected Error:[/]\n{escape(str(e))}")
console.print() # Spacing
async def column_statistics_demo(sql_tool: SQLTool, connection_id: str, table_name: str, column_name: str) -> None:
"""Demonstrate detailed column statistics."""
console.print(Rule(f"[bold green]5. Column Statistics: [cyan]{escape(table_name)}.[yellow]{escape(column_name)}[/yellow][/cyan][/bold green]", style="green"))
logger.info(f"Analyzing statistics for column {table_name}.{column_name}")
try:
stats_result = await sql_tool.explore_database(
connection_id=connection_id,
action="column",
table_name=table_name,
column_name=column_name,
histogram=True,
num_buckets=8
)
if stats_result.get("success"):
logger.success(f"Successfully analyzed statistics for {table_name}.{column_name}")
# Display basic statistics
statistics = stats_result.get("statistics", {})
if statistics:
stats_table = Table(title=f"Statistics for {column_name}", box=box.ROUNDED, show_header=False, padding=(1, 1), border_style="cyan")
stats_table.add_column("Metric", style="cyan", justify="right")
stats_table.add_column("Value", style="white")
for key, value in statistics.items():
stats_table.add_row(key.replace("_", " ").title(), str(value))
console.print(stats_table)
# Display histogram if available
histogram = stats_result.get("histogram", {})
buckets = histogram.get("buckets", [])
if buckets:
console.print("[bold cyan]Value Distribution:[/]")
# Find the max count for scaling
max_count = max(bucket.get("count", 0) for bucket in buckets)
# Create a progress bar visualization for the histogram
progress = Progress(
TextColumn("[cyan]{task.description}", justify="right"),
BarColumn(bar_width=40),
TextColumn("[magenta]{task.fields[count]} ({task.percentage:>3.1f}%)")
)
with progress:
for bucket in buckets:
label = bucket.get("range", "?")
count = bucket.get("count", 0)
percentage = (count / max_count) * 100 if max_count > 0 else 0
# Add a task for this bucket
progress.add_task(
description=escape(str(label)),
total=100,
completed=percentage,
count=count
)
else:
error_msg = stats_result.get('error', 'Unknown error')
logger.error(f"Failed to analyze column statistics: {error_msg}")
console.print(Panel(
f"[bold red]:x: Failed to analyze column statistics:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Unexpected error in column statistics demo: {e}")
console.print(f"[bold red]:x: Unexpected Error:[/]\n{escape(str(e))}")
console.print() # Spacing
async def query_execution_demo(sql_tool: SQLTool, connection_id: str) -> None:
"""Demonstrate query execution capabilities."""
console.print(Rule("[bold green]6. Query Execution Demo[/bold green]", style="green"))
logger.info("Demonstrating query execution capabilities")
try:
# Simple SELECT query
simple_query = "SELECT customer_id, name, email, status FROM customers WHERE status = 'active'"
logger.info("Executing simple query...")
with console.status("[cyan]Running simple query...[/]"):
query_result = await sql_tool.execute_sql(
connection_id=connection_id,
query=simple_query,
read_only=True,
max_rows=10
)
display_result("Simple Query: Active Customers", query_result, query_str=simple_query)
# Parameterized query
param_query = "SELECT product_id, name, price FROM products WHERE category = :category AND price < :max_price ORDER BY price DESC"
params = {"category": "Electronics", "max_price": 1000.00}
logger.info(f"Executing parameterized query with params: {params}")
with console.status("[cyan]Running parameterized query...[/]"):
param_result = await sql_tool.execute_sql(
connection_id=connection_id,
query=param_query,
parameters=params,
read_only=True
)
display_result("Parameterized Query: Electronics under $1000", param_result, query_str=param_query)
# Pagination query
pagination_query = "SELECT product_id, name, category, price FROM products ORDER BY price DESC"
logger.info("Executing query with pagination (Page 1)")
with console.status("[cyan]Running paginated query (Page 1)...[/]"):
pagination_result_p1 = await sql_tool.execute_sql(
connection_id=connection_id,
query=pagination_query,
pagination={"page": 1, "page_size": 2},
read_only=True
)
display_result("Paginated Query: Products by Price (Page 1)", pagination_result_p1, query_str=pagination_query)
# Pagination page 2
logger.info("Executing query with pagination (Page 2)")
with console.status("[cyan]Running paginated query (Page 2)...[/]"):
pagination_result_p2 = await sql_tool.execute_sql(
connection_id=connection_id,
query=pagination_query,
pagination={"page": 2, "page_size": 2},
read_only=True
)
display_result("Paginated Query: Products by Price (Page 2)", pagination_result_p2)
# Join query with multiple tables
join_query = """
SELECT c.name AS customer_name, o.order_id, o.order_date, o.total_amount, o.status
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
WHERE c.status = 'active'
ORDER BY o.order_date DESC
"""
logger.info("Executing join query")
with console.status("[cyan]Running join query...[/]"):
join_result = await sql_tool.execute_sql(
connection_id=connection_id,
query=join_query,
read_only=True
)
display_result("Join Query: Orders by Active Customers", join_result, query_str=join_query)
except Exception as e:
logger.error(f"Unexpected error in query execution demo: {e}")
console.print(f"[bold red]:x: Unexpected Error:[/]\n{escape(str(e))}")
console.print() # Spacing
async def nl_to_sql_demo(sql_tool: SQLTool, connection_id: str) -> None:
"""Demonstrate natural language to SQL conversion."""
console.print(Rule("[bold green]7. Natural Language to SQL Demo[/bold green]", style="green"))
logger.info("Demonstrating natural language to SQL conversion")
try:
# Example NL query
natural_language = "Show me all active customers and their total order value"
logger.info(f"Converting natural language to SQL: '{natural_language}'")
with console.status("[cyan]Converting natural language to SQL...[/]"):
nl_result = await sql_tool.execute_sql(
connection_id=connection_id,
natural_language=natural_language,
read_only=True
)
if nl_result.get("success"):
generated_sql = nl_result.get("generated_sql", "")
confidence = nl_result.get("confidence", 0.0)
# Display the generated SQL and confidence
console.print(Panel(
Syntax(generated_sql, "sql", theme="default", line_numbers=False, word_wrap=True),
title=f"Generated SQL (Confidence: {confidence:.2f})",
border_style="green",
padding=(1, 2)
))
# Display the query results
display_result("Natural Language Query Results", nl_result)
else:
error_msg = nl_result.get('error', 'Unknown error')
logger.error(f"Failed to convert natural language to SQL: {error_msg}")
console.print(Panel(
f"[bold red]:x: Natural language conversion failed:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
# Try another more complex example
complex_nl = "What's the average price of products by category?"
logger.info(f"Converting complex natural language to SQL: '{complex_nl}'")
with console.status("[cyan]Converting complex natural language to SQL...[/]"):
complex_result = await sql_tool.execute_sql(
connection_id=connection_id,
natural_language=complex_nl,
read_only=True
)
if complex_result.get("success"):
generated_sql = complex_result.get("generated_sql", "")
confidence = complex_result.get("confidence", 0.0)
console.print(Panel(
Syntax(generated_sql, "sql", theme="default", line_numbers=False, word_wrap=True),
title=f"Generated SQL for complex query (Confidence: {confidence:.2f})",
border_style="green",
padding=(1, 2)
))
display_result("Complex Natural Language Query Results", complex_result)
else:
error_msg = complex_result.get('error', 'Unknown error')
logger.error(f"Failed to convert complex natural language to SQL: {error_msg}")
console.print(Panel(
f"[bold red]:x: Complex natural language conversion failed:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Unexpected error in NL to SQL demo: {e}")
console.print(f"[bold red]:x: Unexpected Error:[/]\n{escape(str(e))}")
console.print() # Spacing
async def documentation_demo(sql_tool: SQLTool, connection_id: str) -> None:
"""Demonstrate database documentation generation."""
console.print(Rule("[bold green]8. Database Documentation Demo[/bold green]", style="green"))
logger.info("Demonstrating database documentation generation")
try:
# Generate database documentation
logger.info("Generating database documentation")
with console.status("[cyan]Generating database documentation...[/]"):
doc_result = await sql_tool.explore_database(
connection_id=connection_id,
action="documentation",
output_format="markdown"
)
if doc_result.get("success"):
logger.success("Successfully generated database documentation")
# Display the documentation
display_result("Database Documentation", doc_result)
# Optionally save to file
documentation = doc_result.get("documentation", "")
if documentation:
# Create a temporary file to save the documentation
fd, doc_path = tempfile.mkstemp(suffix=".md", prefix="db_doc_")
os.close(fd)
with open(doc_path, "w") as f:
f.write(documentation)
console.print(f"[green]Documentation saved to: [cyan]{doc_path}[/cyan][/green]")
else:
error_msg = doc_result.get('error', 'Unknown error')
logger.error(f"Failed to generate documentation: {error_msg}")
console.print(Panel(
f"[bold red]:x: Documentation generation failed:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Unexpected error in documentation demo: {e}")
console.print(f"[bold red]:x: Unexpected Error:[/]\n{escape(str(e))}")
console.print() # Spacing
async def security_features_demo(sql_tool: SQLTool, connection_id: str) -> None:
"""Demonstrate security features of the SQLTool."""
console.print(Rule("[bold green]10. Security Features Demo[/bold green]", style="green"))
logger.info("Demonstrating security features")
# --- PII MASKING DEMO ---
console.print(Rule("[bold blue]10.1 PII Data Masking[/bold blue]", style="blue"))
logger.info("Demonstrating PII data masking")
try:
console.print("[green]PII test data added successfully.[/]")
# Now run a query to show masked PII data
pii_select_query = """
SELECT customer_id, name, email, ssn, credit_card
FROM customers
ORDER BY customer_id
"""
with console.status("[cyan]Executing query with PII data...[/]"):
pii_result = await sql_tool.execute_sql(
connection_id=connection_id,
query=pii_select_query,
read_only=True
)
display_result("PII Masking Demo: Automatically Masked Sensitive Data", pii_result, pii_select_query)
console.print(Panel(
"Notice how the [bold]SSN[/bold], [bold]credit card numbers[/bold], and [bold]email addresses[/bold] are "
"automatically masked according to SQLTool's masking rules, protecting sensitive information.",
title="PII Masking Explanation",
border_style="cyan",
padding=(1, 2)
))
# --- PROHIBITED STATEMENT DETECTION DEMO ---
console.print(Rule("[bold blue]10.2 Prohibited Statement Detection[/bold blue]", style="blue"))
logger.info("Demonstrating prohibited statement detection")
# List of prohibited statements to test
prohibited_queries = [
"DROP TABLE customers",
"DELETE FROM products",
"TRUNCATE TABLE orders",
"ALTER TABLE customers DROP COLUMN name",
"GRANT ALL PRIVILEGES ON products TO user",
"CREATE USER hacker WITH PASSWORD 'password'"
]
prohibited_table = Table(title="Prohibited Statement Detection", box=box.ROUNDED, show_header=True, padding=(0, 1), border_style="red")
prohibited_table.add_column("Prohibited SQL", style="yellow")
prohibited_table.add_column("Result", style="green")
for query in prohibited_queries:
try:
with console.status(f"[cyan]Testing: {query}[/]"):
await sql_tool.execute_sql(
connection_id=connection_id,
query=query,
read_only=True
)
# If we get here, protection failed (no exception was raised)
prohibited_table.add_row(query, "[red]FAILED - Statement was allowed![/]")
except ToolError as e:
# This is expected behavior - statement should be blocked
prohibited_table.add_row(query, f"[green]SUCCESS - Blocked: {str(e).split(':')[0]}[/]")
except Exception as e:
prohibited_table.add_row(query, f"[yellow]ERROR: {str(e)[:50]}...[/]")
console.print(prohibited_table)
# --- ACL CONTROLS DEMO ---
console.print(Rule("[bold blue]10.3 Access Control Lists (ACL)[/bold blue]", style="blue"))
logger.info("Demonstrating ACL controls")
# Set up ACL restrictions
console.print("[cyan]Setting up ACL restrictions...[/]")
# We'll restrict access to the 'customers' table and the 'credit_card' column
sql_tool.update_acl(tables=["customers"], columns=["credit_card", "ssn"])
console.print(Panel(
"Access control lists configured:\n"
"- Restricted tables: [red]customers[/]\n"
"- Restricted columns: [red]credit_card, ssn[/]",
title="ACL Configuration",
border_style="yellow",
padding=(1, 2)
))
# Try to access restricted table
restricted_table_query = "SELECT * FROM customers"
console.print("\n[cyan]Attempting to query restricted table:[/]")
console.print(Syntax(restricted_table_query, "sql", theme="default"))
try:
with console.status("[cyan]Executing query on restricted table...[/]"):
await sql_tool.execute_sql(
connection_id=connection_id,
query=restricted_table_query,
read_only=True
)
console.print("[red]ACL FAILURE: Query was allowed on restricted table![/]")
except ToolError as e:
console.print(Panel(
f"[green]✅ ACL WORKING: Access denied as expected:[/]\n{escape(str(e))}",
border_style="green",
padding=(1, 2)
))
# Try to access restricted column
restricted_column_query = "SELECT customer_id, name, credit_card FROM products JOIN customers USING(customer_id)"
console.print("\n[cyan]Attempting to query restricted column:[/]")
console.print(Syntax(restricted_column_query, "sql", theme="default"))
try:
with console.status("[cyan]Executing query with restricted column...[/]"):
await sql_tool.execute_sql(
connection_id=connection_id,
query=restricted_column_query,
read_only=True
)
console.print("[red]ACL FAILURE: Query was allowed with restricted column![/]")
except ToolError as e:
console.print(Panel(
f"[green]✅ ACL WORKING: Access denied as expected:[/]\n{escape(str(e))}",
border_style="green",
padding=(1, 2)
))
# Clear ACL restrictions for further demos
sql_tool.update_acl(tables=[], columns=[])
console.print("[cyan]ACL restrictions cleared for following demos.[/]")
# --- SCHEMA DRIFT DETECTION ---
console.print(Rule("[bold blue]10.4 Schema Drift Detection[/bold blue]", style="blue"))
logger.info("Demonstrating schema drift detection")
# First run schema discovery to capture initial state
console.print("[cyan]Capturing initial schema state...[/]")
with console.status("[cyan]Performing initial schema discovery...[/]"):
initial_schema = await sql_tool.explore_database(
connection_id=connection_id,
action="schema",
include_indexes=True,
include_foreign_keys=True
)
initial_hash = initial_schema.get("schema_hash", "unknown")
console.print(f"[green]Initial schema captured with hash: [bold]{initial_hash[:16]}...[/][/]")
# Now make a schema change
schema_change_query = "ALTER TABLE products ADD COLUMN last_updated TIMESTAMP"
console.print("[cyan]Making a schema change...[/]")
console.print(Syntax(schema_change_query, "sql", theme="default"))
with console.status("[cyan]Executing schema change...[/]"):
# Execute the schema change
await sql_tool.execute_sql(
connection_id=connection_id,
query=schema_change_query,
read_only=False # Need to disable read-only for ALTER TABLE
)
# Now run schema discovery again to detect the change
with console.status("[cyan]Performing follow-up schema discovery to detect changes...[/]"):
new_schema = await sql_tool.explore_database(
connection_id=connection_id,
action="schema",
include_indexes=True,
include_foreign_keys=True
)
new_hash = new_schema.get("schema_hash", "unknown")
schema_changed = new_schema.get("schema_change_detected", False)
if initial_hash != new_hash:
console.print(Panel(
f"[green]✅ SCHEMA DRIFT DETECTED:[/]\n"
f"- Initial hash: [dim]{initial_hash[:16]}...[/]\n"
f"- New hash: [bold]{new_hash[:16]}...[/]\n"
f"- Change detected by system: {'[green]Yes[/]' if schema_changed else '[red]No[/]'}",
title="Schema Drift Detection Result",
border_style="green",
padding=(1, 2)
))
else:
console.print(Panel(
"[red]Schema drift detection did not identify a change in hash even though schema was modified.[/]",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Error in security features demo: {e}", exc_info=True)
console.print(Panel(
f"[bold red]Error in security features demo:[/]\n{escape(str(e))}",
border_style="red",
padding=(1, 2)
))
console.print() # Spacing
async def advanced_export_demo(sql_tool: SQLTool, connection_id: str) -> None:
"""Demonstrate advanced export options."""
console.print(Rule("[bold green]11. Advanced Export Options Demo[/bold green]", style="green"))
logger.info("Demonstrating advanced export options")
# Query to export
export_query = """
SELECT p.product_id, p.name AS product_name, p.category, p.price,
SUM(oi.quantity) AS units_sold,
SUM(oi.quantity * oi.price_per_unit) AS total_revenue
FROM products p
LEFT JOIN order_items oi ON p.product_id = oi.product_id
GROUP BY p.product_id, p.name, p.category, p.price
ORDER BY total_revenue DESC
"""
try:
# --- PANDAS DATAFRAME EXPORT ---
console.print(Rule("[bold blue]11.1 Pandas DataFrame Export[/bold blue]", style="blue"))
logger.info("Demonstrating Pandas DataFrame export")
console.print(Syntax(export_query, "sql", theme="default", line_numbers=False))
with console.status("[cyan]Executing query and exporting to Pandas DataFrame...[/]"):
df_result = await sql_tool.execute_sql(
connection_id=connection_id,
query=export_query,
read_only=True,
export={"format": "pandas"}
)
if df_result.get("success") and "dataframe" in df_result:
df = df_result["dataframe"]
# Display DataFrame info
df_info = [
f"Shape: {df.shape[0]} rows × {df.shape[1]} columns",
f"Memory usage: {df.memory_usage(deep=True).sum() / 1024:.2f} KB",
f"Column dtypes: {', '.join([f'{col}: {dtype}' for col, dtype in df.dtypes.items()])}"
]
console.print(Panel(
"\n".join(df_info),
title="Pandas DataFrame Export Result",
border_style="green",
padding=(1, 2)
))
# Show DataFrame operations
console.print("[cyan]Demonstrating DataFrame operations:[/]")
# Create a summary statistics table
stats_table = Table(title="DataFrame Statistics", box=box.ROUNDED, padding=(0, 1), border_style="blue")
stats_table.add_column("Statistic", style="cyan")
stats_table.add_column("Value", style="yellow")
# Add some sample statistics
stats_table.add_row("Average Price", f"${df['price'].mean():.2f}")
stats_table.add_row("Max Price", f"${df['price'].max():.2f}")
stats_table.add_row("Min Price", f"${df['price'].min():.2f}")
stats_table.add_row("Total Revenue", f"${df['total_revenue'].sum():.2f}")
stats_table.add_row("Highest Revenue Product", df.loc[df['total_revenue'].idxmax()]['product_name'])
console.print(stats_table)
# Create a simple DataFrame transformation
console.print("\n[cyan]Demonstrating DataFrame transformation - Adding discount column:[/]")
df['discount_price'] = df['price'] * 0.9
# Display the first few rows of the transformed DataFrame
table = Table(title="Transformed DataFrame (First 3 Rows)", box=box.ROUNDED, show_header=True)
# Add columns based on the DataFrame
for col in df.columns:
justify = "right" if df[col].dtype.kind in 'ifc' else "left"
table.add_column(col, style="cyan", justify=justify)
# Add the first 3 rows
for _, row in df.head(3).iterrows():
# Format numeric values nicely
formatted_row = []
for col in df.columns:
val = row[col]
if pd.api.types.is_numeric_dtype(df[col].dtype): # Check column dtype, not row value
if 'price' in col or 'revenue' in col:
formatted_row.append(f"${val:.2f}")
else:
formatted_row.append(f"{val:,.2f}" if isinstance(val, float) else f"{val:,}")
else:
formatted_row.append(str(val))
table.add_row(*formatted_row)
console.print(table)
else:
console.print(Panel(
f"[red]Failed to export to DataFrame: {df_result.get('error', 'Unknown error')}[/]",
border_style="red",
padding=(1, 2)
))
# --- EXCEL EXPORT WITH FORMATTING ---
console.print(Rule("[bold blue]11.2 Excel Export with Formatting[/bold blue]", style="blue"))
logger.info("Demonstrating Excel export with formatting")
excel_fd, excel_path = tempfile.mkstemp(suffix=".xlsx", prefix="sql_demo_export_")
os.close(excel_fd) # Close file descriptor, as we only need the path
with console.status("[cyan]Executing query and exporting to formatted Excel...[/]"):
excel_result = await sql_tool.execute_sql(
connection_id=connection_id,
query=export_query,
read_only=True,
export={
"format": "excel",
"path": excel_path,
# Note: Additional formatting options might be available in your implementation
}
)
if excel_result.get("success") and "excel_path" in excel_result:
export_path = excel_result["excel_path"]
file_size = os.path.getsize(export_path) / 1024 # Size in KB
console.print(Panel(
f"[green]✅ Successfully exported to Excel:[/]\n"
f"Path: [cyan]{export_path}[/]\n"
f"Size: [yellow]{file_size:.2f} KB[/]",
title="Excel Export Result",
border_style="green",
padding=(1, 2)
))
else:
console.print(Panel(
f"[red]Failed to export to Excel: {excel_result.get('error', 'Unknown error')}[/]",
border_style="red",
padding=(1, 2)
))
# --- CUSTOM EXPORT PATH (CSV) ---
console.print(Rule("[bold blue]11.3 Custom Export Path (CSV)[/bold blue]", style="blue"))
logger.info("Demonstrating custom export path")
# Create a custom path in the user's home directory
user_home = os.path.expanduser("~")
custom_dir = os.path.join(user_home, "sql_demo_exports")
os.makedirs(custom_dir, exist_ok=True)
timestamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
custom_path = os.path.join(custom_dir, f"product_sales_{timestamp}.csv")
console.print(f"[cyan]Exporting to custom path: [/][yellow]{custom_path}[/]")
with console.status("[cyan]Executing query and exporting to custom CSV path...[/]"):
csv_result = await sql_tool.execute_sql(
connection_id=connection_id,
query=export_query,
read_only=True,
export={
"format": "csv",
"path": custom_path
}
)
if csv_result.get("success") and "csv_path" in csv_result:
export_path = csv_result["csv_path"]
file_size = os.path.getsize(export_path) / 1024 # Size in KB
# Read first few lines to show content
with open(export_path, 'r') as f:
first_lines = [next(f) for _ in range(3)]
console.print(Panel(
f"[green]✅ Successfully exported to custom CSV path:[/]\n"
f"Path: [cyan]{export_path}[/]\n"
f"Size: [yellow]{file_size:.2f} KB[/]\n\n"
f"[dim]Preview (first 3 lines):[/]\n"
f"[white]{escape(''.join(first_lines))}[/]",
title="Custom CSV Export Result",
border_style="green",
padding=(1, 2)
))
else:
console.print(Panel(
f"[red]Failed to export to custom CSV path: {csv_result.get('error', 'Unknown error')}[/]",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Error in advanced export demo: {e}", exc_info=True)
console.print(Panel(
f"[bold red]Error in advanced export demo:[/]\n{escape(str(e))}",
border_style="red",
padding=(1, 2)
))
console.print() # Spacing
async def schema_validation_demo(sql_tool: SQLTool, connection_id: str) -> None:
"""Demonstrate Pandera schema validation for query results."""
console.print(Rule("[bold green]12. Schema Validation Demo[/bold green]", style="green"))
logger.info("Demonstrating Pandera schema validation")
try:
# Query to validate
validation_query = """
SELECT
product_id,
name AS product_name,
price,
category,
in_stock
FROM products
"""
console.print("[cyan]We'll validate that query results conform to a specified schema:[/]")
console.print(Syntax(validation_query, "sql", theme="default"))
# Define a Pandera schema
schema_code = """
# Define a Pandera schema for validation using DataFrameSchema
product_schema = pa.DataFrameSchema({
"product_id": pa.Column(int, checks=pa.Check.greater_than(0)),
"product_name": pa.Column(str, nullable=False),
"price": pa.Column(float, checks=[
pa.Check.greater_than(0, error="price must be positive"),
pa.Check.less_than(2000.0, error="price must be under $2000")
]),
"category": pa.Column(
str,
checks=pa.Check.isin(["Electronics", "Audio", "Kitchen", "Wearables"]),
nullable=False
),
"in_stock": pa.Column(bool)
})
"""
console.print(Panel(
Syntax(schema_code, "python", theme="default"),
title="Pandera Validation Schema",
border_style="cyan",
padding=(1, 2)
))
# Check pandera version
version = getattr(pa, '__version__', 'unknown')
console.print(f"[dim]Using pandera version: {version}")
# Define the actual schema
product_schema = pa.DataFrameSchema({
"product_id": pa.Column(int, checks=pa.Check.greater_than(0)),
"product_name": pa.Column(str, nullable=False),
"price": pa.Column(float, checks=[
pa.Check.greater_than(0, error="price must be positive"),
pa.Check.less_than(2000.0, error="price must be under $2000")
]),
"category": pa.Column(
str,
checks=pa.Check.isin(["Electronics", "Audio", "Kitchen", "Wearables"]),
nullable=False
),
"in_stock": pa.Column(bool)
})
# WORKAROUND: Instead of using built-in validation (which has an error),
# we'll fetch the data first, then validate it manually
console.print("[cyan]Executing query to fetch data...[/]")
with console.status("[cyan]Running query...[/]"):
query_result = await sql_tool.execute_sql(
connection_id=connection_id,
query=validation_query,
read_only=True
)
if query_result.get("success"):
# Show the data
display_result("Data Retrieved for Validation", query_result)
# Now manually validate with Pandera
console.print("[cyan]Now validating results with Pandera...[/]")
if pd is not None:
try:
# Create DataFrame from results
df = pd.DataFrame(query_result.get("rows", []), columns=query_result.get("columns", []))
# Fix type issues - convert in_stock to boolean if needed
if "in_stock" in df.columns and df["in_stock"].dtype != bool:
df["in_stock"] = df["in_stock"].astype(bool)
console.print(f"[dim]Created DataFrame with shape {df.shape} for validation")
# Validate the data
with console.status("[cyan]Validating against schema...[/]"):
try:
product_schema.validate(df)
console.print(Panel(
"[green]✅ Schema validation passed![/]\n"
"All data meets the requirements defined in the schema.",
title="Validation Result",
border_style="green",
padding=(1, 2)
))
except Exception as val_err:
console.print(Panel(
f"[yellow]⚠ Schema validation failed![/]\n"
f"Error: {str(val_err)}",
title="Validation Result",
border_style="yellow",
padding=(1, 2)
))
except Exception as df_err:
console.print(f"[red]Error creating DataFrame: {df_err}[/]")
else:
console.print("[yellow]Pandas is not available, cannot perform validation.[/]")
else:
console.print(Panel(
f"[red]Failed to execute query: {query_result.get('error', 'Unknown error')}[/]",
border_style="red",
padding=(1, 2)
))
# Simulate a failing validation case
console.print("\n[cyan]Simulating validation failure with invalid data...[/]")
if pd is not None:
# Create a DataFrame with valid and invalid data
test_data = [
# Valid data
{"product_id": 1, "product_name": "Laptop Pro X", "price": 1499.99, "category": "Electronics", "in_stock": True},
{"product_id": 2, "product_name": "Smartphone Z", "price": 999.99, "category": "Electronics", "in_stock": True},
# Invalid data (negative price)
{"product_id": 6, "product_name": "Invalid Product", "price": -10.0, "category": "Electronics", "in_stock": True},
# Invalid data (unknown category)
{"product_id": 7, "product_name": "Test Product", "price": 50.0, "category": "Invalid Category", "in_stock": True}
]
test_df = pd.DataFrame(test_data)
# Display the test data
test_table = Table(title="Test Data for Validation", box=box.ROUNDED, show_header=True, padding=(0, 1), border_style="yellow")
for col in test_df.columns:
test_table.add_column(str(col), style="cyan")
for _, row in test_df.iterrows():
test_table.add_row(*[str(val) for val in row])
console.print(test_table)
# Test validation
console.print("[cyan]Attempting to validate this data...[/]")
try:
# Try to validate the DataFrame directly
product_schema.validate(test_df, lazy=True)
console.print(Panel(
"[red]Unexpected result: Validation passed when it should have failed![/]",
border_style="red",
padding=(1, 2)
))
except Exception as val_err:
console.print(Panel(
f"[green]✅ Validation correctly failed as expected![/]\n"
f"Error: {str(val_err)}",
title="Expected Validation Failure (Simulated)",
border_style="green",
padding=(1, 2)
))
else:
console.print("[yellow]Pandas not available, cannot demonstrate validation failure.[/]")
except Exception as e:
logger.error(f"Error in schema validation demo: {e}", exc_info=True)
console.print(Panel(
f"[bold red]Error in schema validation demo:[/]\n{escape(str(e))}",
border_style="red",
padding=(1, 2)
))
console.print() # Spacing
async def audit_log_demo(sql_tool: SQLTool, connection_id: str) -> None:
"""Demonstrate audit log functionality."""
console.print(Rule("[bold green]9. Audit Log Demo[/bold green]", style="green"))
logger.info("Demonstrating audit log functionality")
try:
# View the audit log
logger.info("Viewing audit log")
with console.status("[cyan]Retrieving audit log...[/]"):
audit_result = await sql_tool.access_audit_log(
action="view",
limit=10
)
if audit_result.get("success"):
logger.success("Successfully retrieved audit log")
records = audit_result.get("records", [])
if records:
audit_table = Table(title="Audit Log", box=box.ROUNDED, show_header=True, padding=(0, 1), border_style="blue")
audit_table.add_column("ID", style="dim")
audit_table.add_column("Timestamp", style="cyan")
audit_table.add_column("Tool", style="green")
audit_table.add_column("Action", style="yellow")
audit_table.add_column("Connection ID", style="magenta")
audit_table.add_column("Success", style="cyan")
for record in records:
audit_table.add_row(
record.get("audit_id", "?"),
record.get("timestamp", "?"),
record.get("tool_name", "?"),
record.get("action", "?"),
record.get("connection_id", "?"),
"[green]:heavy_check_mark:[/]" if record.get("success") else "[red]:x:[/]"
)
console.print(audit_table)
# Show details of one specific audit record
if records:
sample_record = records[0]
console.print(Panel(
"\n".join([f"[cyan]{k}:[/] {escape(str(v))}" for k, v in sample_record.items() if k not in ["audit_id", "timestamp", "tool_name", "action", "connection_id", "success"]]),
title=f"Audit Record Details: {sample_record.get('audit_id', '?')}",
border_style="dim",
padding=(1, 2)
))
else:
console.print(Panel("[yellow]No audit records found.", border_style="yellow", padding=(0, 1)))
else:
error_msg = audit_result.get('error', 'Unknown error')
logger.error(f"Failed to retrieve audit log: {error_msg}")
console.print(Panel(
f"[bold red]:x: Audit log retrieval failed:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
# Export the audit log
logger.info("Exporting audit log")
with console.status("[cyan]Exporting audit log to CSV...[/]"):
export_result = await sql_tool.access_audit_log(
action="export",
export_format="csv"
)
if export_result.get("success"):
export_path = export_result.get("path", "")
record_count = export_result.get("record_count", 0)
logger.success(f"Successfully exported {record_count} audit records to CSV")
console.print(Panel(
f"[green]:heavy_check_mark: Exported {record_count} audit records to:[/]\n[cyan]{export_path}[/]",
border_style="green",
padding=(1, 2)
))
else:
error_msg = export_result.get('error', 'Unknown error')
logger.error(f"Failed to export audit log: {error_msg}")
console.print(Panel(
f"[bold red]:x: Audit log export failed:[/]\n{escape(error_msg)}",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Unexpected error in audit log demo: {e}")
console.print(f"[bold red]:x: Unexpected Error:[/]\n{escape(str(e))}")
console.print() # Spacing
async def cleanup_demo(sql_tool: SQLTool, connection_id: str) -> None:
"""Demonstrate disconnecting from the database."""
console.print(Rule("[bold green]Database Cleanup and Disconnection[/bold green]", style="green"))
logger.info("Disconnecting from database")
try:
# Disconnect from the database
disconnect_result = await sql_tool.manage_database(
action="disconnect",
connection_id=connection_id
)
if disconnect_result.get("success"):
logger.success(f"Successfully disconnected from database (ID: {connection_id})")
console.print(Panel(
f"[green]:heavy_check_mark: Successfully disconnected from database. Connection ID: [dim]{connection_id}[/dim][/]",
border_style="green",
padding=(0, 1)
))
else:
logger.error(f"Failed to disconnect: {disconnect_result.get('error')}")
console.print(Panel(
f"[bold red]:x: Failed to disconnect:[/]\n{escape(disconnect_result.get('error', 'Unknown error'))}",
border_style="red",
padding=(1, 2)
))
except Exception as e:
logger.error(f"Error in cleanup demo: {e}")
console.print(f"[bold red]:x: Error in cleanup:[/]\n{escape(str(e))}")
console.print()
async def verify_demo_database(sql_tool, connection_id: str) -> None:
"""Verify the demo database has been set up correctly."""
logger.info("Verifying database setup...")
# For consistency, we'll still display the setup status
console.print(Panel("[green]:heavy_check_mark: Using prepared sample database.", padding=(0, 1), border_style="green"))
# Check the tables to ensure the database was set up correctly
try:
# Execute a simple query to check if the tables have data
result = await sql_tool.execute_sql(
connection_id=connection_id,
query="SELECT COUNT(*) as count FROM customers",
read_only=True
)
count = result.get("rows", [{}])[0].get("count", 0)
if count > 0:
logger.info(f"Verified database setup: {count} customers found")
console.print(Panel(f"[green]:heavy_check_mark: Sample database verified with {count} customer records.", padding=(0, 1), border_style="green"))
else:
logger.warning("Database tables found but they appear to be empty")
console.print(Panel("[yellow]⚠ Database tables found but they appear to be empty.", padding=(0, 1), border_style="yellow"))
except (ToolError, ToolInputError) as e:
logger.error(f"Error checking database setup: {e}")
console.print(Panel(f"[bold red]:x: Database Setup Error:[/]\n{escape(str(e))}", padding=(1, 2), border_style="red"))
# --- Main Function ---
async def main() -> int:
"""Run the SQL database tools demo."""
console.print(Rule("[bold magenta]SQL Database Tools Demo[/bold magenta]"))
exit_code = 0
connection_id = None
# Get path to the pre-initialized database
db_file = os.path.join(os.path.dirname(__file__), "demo.db")
# Force recreate the demo database
if os.path.exists(db_file):
try:
os.remove(db_file)
console.print("[yellow]Removed existing database file to ensure correct schema.[/]")
except OSError as e:
console.print(f"[yellow]Warning: Could not remove existing database: {e}[/]")
# Check if the demo database exists, and create it if not
if not os.path.exists(db_file):
console.print("[yellow]Demo database not found. Creating it now...[/]")
try:
# Initialize the database directly
init_demo_database(db_file)
console.print("[green]Demo database created successfully.[/]")
except Exception as e:
console.print(f"[red]Failed to create demo database: {e}[/]")
return 1
gateway = Gateway("sql-database-demo", register_tools=False)
# Connection string for file-based SQLite database instead of memory
file_connection_string = f"sqlite:///{db_file}"
# Create an instance of the SQLTool
try:
sql_tool = SQLTool(gateway)
# Run the demonstrations
connection_id = await connection_demo(sql_tool, file_connection_string)
if connection_id:
await verify_demo_database(sql_tool, connection_id)
await schema_discovery_demo(sql_tool, connection_id)
await table_details_demo(sql_tool, connection_id, "customers")
await find_related_tables_demo(sql_tool, connection_id, "orders")
await column_statistics_demo(sql_tool, connection_id, "products", "price")
await query_execution_demo(sql_tool, connection_id)
await nl_to_sql_demo(sql_tool, connection_id)
await documentation_demo(sql_tool, connection_id)
await audit_log_demo(sql_tool, connection_id)
# Add the new demos
await security_features_demo(sql_tool, connection_id)
await advanced_export_demo(sql_tool, connection_id)
await schema_validation_demo(sql_tool, connection_id)
await cleanup_demo(sql_tool, connection_id)
else:
logger.error("Skipping demonstrations due to connection failure")
exit_code = 1
except Exception as e:
logger.critical(f"Demo failed with unexpected error: {e}")
console.print(f"[bold red]CRITICAL ERROR: {escape(str(e))}[/]")
exit_code = 1
finally:
# Ensure we shutdown the SQLTool if it was created
if 'sql_tool' in locals():
try:
await sql_tool.shutdown()
logger.info("SQLTool shut down successfully")
except Exception as shutdown_err:
logger.error(f"Error during SQLTool shutdown: {shutdown_err}")
# Clean up the demo database file
try:
if os.path.exists(db_file) and 'sql_demo_export' in db_file:
os.remove(db_file)
logger.info(f"Cleaned up demo database file: {db_file}")
except Exception as clean_err:
logger.warning(f"Could not clean up demo database: {clean_err}")
return exit_code
if __name__ == "__main__":
# Setup logging
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Run the demo
exit_code = asyncio.run(main())
sys.exit(exit_code)
```