This is page 2 of 2. Use http://codebase.md/stinkgen/trino_mcp?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .gitignore ├── CHANGELOG.md ├── docker-compose.yml ├── Dockerfile ├── etc │ ├── catalog │ │ ├── bullshit.properties │ │ └── memory.properties │ ├── config.properties │ ├── jvm.config │ └── node.properties ├── examples │ └── simple_mcp_query.py ├── LICENSE ├── llm_query_trino.py ├── llm_trino_api.py ├── load_bullshit_data.py ├── openapi.json ├── pyproject.toml ├── pytest.ini ├── README.md ├── requirements-dev.txt ├── run_tests.sh ├── scripts │ ├── docker_stdio_test.py │ ├── fix_trino_session.py │ ├── test_direct_query.py │ ├── test_fixed_client.py │ ├── test_messages.py │ ├── test_quick_query.py │ └── test_stdio_trino.py ├── src │ └── trino_mcp │ ├── __init__.py │ ├── config.py │ ├── resources │ │ └── __init__.py │ ├── server.py │ ├── tools │ │ └── __init__.py │ └── trino_client.py ├── test_bullshit_query.py ├── test_llm_api.py ├── test_mcp_stdio.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── integration │ │ └── __init__.py │ └── test_client.py ├── tools │ ├── create_bullshit_data.py │ ├── run_queries.sh │ ├── setup │ │ ├── setup_data.sh │ │ └── setup_tables.sql │ └── setup_bullshit_table.py └── trino-conf ├── catalog │ └── memory.properties ├── config.properties ├── jvm.config └── node.properties ``` # Files -------------------------------------------------------------------------------- /test_mcp_stdio.py: -------------------------------------------------------------------------------- ```python 1 | #!/usr/bin/env python3 2 | """ 3 | STDIO transport test script for Trino MCP. 4 | This script demonstrates the end-to-end flow of initializing MCP, listing tools, 5 | querying data, and shutting down using the STDIO transport. 6 | """ 7 | import json 8 | import subprocess 9 | import sys 10 | import time 11 | 12 | def test_mcp_stdio(): 13 | """Run an end-to-end test of Trino MCP using STDIO transport.""" 14 | print("🚀 Starting Trino MCP STDIO test") 15 | 16 | # Start the MCP server with STDIO transport 17 | server_cmd = [ 18 | "docker", "exec", "-i", "trino_mcp_trino-mcp_1", 19 | "python", "-m", "trino_mcp.server", 20 | "--transport", "stdio", 21 | "--debug", 22 | "--trino-host", "trino", 23 | "--trino-port", "8080", 24 | "--trino-user", "trino", 25 | "--trino-catalog", "memory" 26 | ] 27 | 28 | try: 29 | print(f"Starting MCP server process: {' '.join(server_cmd)}") 30 | process = subprocess.Popen( 31 | server_cmd, 32 | stdin=subprocess.PIPE, 33 | stdout=subprocess.PIPE, 34 | stderr=sys.stderr, # Pass stderr through to see logs directly 35 | text=True, 36 | bufsize=1 # Line buffered 37 | ) 38 | 39 | # Sleep a bit to let the server initialize 40 | time.sleep(2) 41 | 42 | # Helper function to send a request and get a response 43 | def send_request(request, expect_response=True): 44 | """ 45 | Send a request to the MCP server and get the response. 46 | 47 | Args: 48 | request: The JSON-RPC request to send 49 | expect_response: Whether to wait for a response 50 | 51 | Returns: 52 | The JSON-RPC response, or None if no response is expected 53 | """ 54 | request_str = json.dumps(request) + "\n" 55 | print(f"\n📤 Sending: {request_str.strip()}") 56 | 57 | try: 58 | process.stdin.write(request_str) 59 | process.stdin.flush() 60 | except BrokenPipeError: 61 | print("❌ Broken pipe - server has closed the connection") 62 | return None 63 | 64 | if not expect_response: 65 | print("✅ Sent notification (no response expected)") 66 | return None 67 | 68 | # Read the response 69 | print("Waiting for response...") 70 | try: 71 | response_str = process.stdout.readline() 72 | if response_str: 73 | print(f"📩 Received: {response_str.strip()}") 74 | return json.loads(response_str) 75 | else: 76 | print("❌ No response received") 77 | return None 78 | except Exception as e: 79 | print(f"❌ Error reading response: {e}") 80 | return None 81 | 82 | # ===== STEP 1: Initialize MCP ===== 83 | print("\n===== STEP 1: Initialize MCP =====") 84 | initialize_request = { 85 | "jsonrpc": "2.0", 86 | "id": 1, 87 | "method": "initialize", 88 | "params": { 89 | "protocolVersion": "2024-11-05", 90 | "clientInfo": { 91 | "name": "trino-mcp-stdio-test", 92 | "version": "1.0.0" 93 | }, 94 | "capabilities": { 95 | "tools": True, 96 | "resources": { 97 | "supportedSources": ["trino://catalog"] 98 | } 99 | } 100 | } 101 | } 102 | 103 | init_response = send_request(initialize_request) 104 | if not init_response: 105 | print("❌ Failed to initialize MCP - exiting test") 106 | return 107 | 108 | # Print server info 109 | if "result" in init_response and "serverInfo" in init_response["result"]: 110 | server_info = init_response["result"]["serverInfo"] 111 | print(f"✅ Connected to server: {server_info.get('name')} {server_info.get('version')}") 112 | 113 | # ===== STEP 2: Send initialized notification ===== 114 | print("\n===== STEP 2: Send initialized notification =====") 115 | initialized_notification = { 116 | "jsonrpc": "2.0", 117 | "method": "notifications/initialized", 118 | "params": {} 119 | } 120 | 121 | send_request(initialized_notification, expect_response=False) 122 | 123 | # ===== STEP 3: List available tools ===== 124 | print("\n===== STEP 3: List available tools =====") 125 | tools_request = { 126 | "jsonrpc": "2.0", 127 | "id": 2, 128 | "method": "tools/list" 129 | } 130 | 131 | tools_response = send_request(tools_request) 132 | if not tools_response or "result" not in tools_response: 133 | print("❌ Failed to get tools list") 134 | else: 135 | tools = tools_response.get("result", {}).get("tools", []) 136 | print(f"✅ Available tools: {len(tools)}") 137 | for tool in tools: 138 | print(f" - {tool.get('name')}: {tool.get('description', 'No description')[:80]}...") 139 | 140 | # ===== STEP 4: Execute a simple query ===== 141 | print("\n===== STEP 4: Execute a simple query =====") 142 | query_request = { 143 | "jsonrpc": "2.0", 144 | "id": 3, 145 | "method": "tools/call", 146 | "params": { 147 | "name": "execute_query", 148 | "arguments": { 149 | "sql": "SELECT 'Hello from Trino MCP' AS message", 150 | "catalog": "memory" 151 | } 152 | } 153 | } 154 | 155 | query_response = send_request(query_request) 156 | if not query_response: 157 | print("❌ Failed to execute query") 158 | elif "error" in query_response: 159 | print(f"❌ Query error: {json.dumps(query_response.get('error', {}), indent=2)}") 160 | else: 161 | print(f"✅ Query executed successfully:") 162 | if "result" in query_response: 163 | result = query_response["result"] 164 | if isinstance(result, dict) and "content" in result: 165 | # Parse the content text which contains the actual results as a JSON string 166 | try: 167 | content = result["content"][0]["text"] 168 | result_data = json.loads(content) 169 | print(f" Query ID: {result_data.get('query_id', 'unknown')}") 170 | print(f" Columns: {', '.join(result_data.get('columns', []))}") 171 | print(f" Row count: {result_data.get('row_count', 0)}") 172 | print(f" Results: {json.dumps(result_data.get('preview_rows', []), indent=2)}") 173 | except (json.JSONDecodeError, IndexError) as e: 174 | print(f" Raw result: {json.dumps(result, indent=2)}") 175 | else: 176 | print(f" Raw result: {json.dumps(result, indent=2)}") 177 | else: 178 | print(f" Raw response: {json.dumps(query_response, indent=2)}") 179 | 180 | # Try the bullshit table query - this is what the original script wanted 181 | print("\n===== STEP 5: Query the Bullshit Table =====") 182 | bs_query_request = { 183 | "jsonrpc": "2.0", 184 | "id": 4, 185 | "method": "tools/call", 186 | "params": { 187 | "name": "execute_query", 188 | "arguments": { 189 | "sql": "SELECT * FROM memory.bullshit.bullshit_data LIMIT 3", 190 | "catalog": "memory" 191 | } 192 | } 193 | } 194 | 195 | bs_query_response = send_request(bs_query_request) 196 | if not bs_query_response: 197 | print("❌ Failed to execute bullshit table query") 198 | elif "error" in bs_query_response: 199 | err = bs_query_response.get("error", {}) 200 | if isinstance(err, dict): 201 | print(f"❌ Query error: {json.dumps(err, indent=2)}") 202 | else: 203 | print(f"❌ Query error: {err}") 204 | 205 | # Try with information_schema as fallback 206 | print("\n----- Fallback Query: Checking Available Schemas -----") 207 | fallback_query = { 208 | "jsonrpc": "2.0", 209 | "id": 5, 210 | "method": "tools/call", 211 | "params": { 212 | "name": "execute_query", 213 | "arguments": { 214 | "sql": "SHOW SCHEMAS FROM memory", 215 | "catalog": "memory" 216 | } 217 | } 218 | } 219 | schemas_response = send_request(fallback_query) 220 | if schemas_response and "result" in schemas_response: 221 | result = schemas_response["result"] 222 | if isinstance(result, dict) and "content" in result: 223 | try: 224 | content = result["content"][0]["text"] 225 | result_data = json.loads(content) 226 | print(f" Available schemas: {json.dumps(result_data.get('preview_rows', []), indent=2)}") 227 | except (json.JSONDecodeError, IndexError) as e: 228 | print(f" Raw schemas result: {json.dumps(result, indent=2)}") 229 | else: 230 | print(f"✅ Bullshit query executed successfully:") 231 | if "result" in bs_query_response: 232 | result = bs_query_response["result"] 233 | if isinstance(result, dict) and "content" in result: 234 | try: 235 | content = result["content"][0]["text"] 236 | result_data = json.loads(content) 237 | print(f" Query ID: {result_data.get('query_id', 'unknown')}") 238 | print(f" Columns: {', '.join(result_data.get('columns', []))}") 239 | print(f" Row count: {result_data.get('row_count', 0)}") 240 | print(f" Results: {json.dumps(result_data.get('preview_rows', []), indent=2)}") 241 | except (json.JSONDecodeError, IndexError) as e: 242 | print(f" Raw result: {json.dumps(result, indent=2)}") 243 | else: 244 | print(f" Raw result: {json.dumps(result, indent=2)}") 245 | else: 246 | print(f" Raw response: {json.dumps(bs_query_response, indent=2)}") 247 | 248 | # Skip the shutdown steps since those cause MCP errors 249 | print("\n🎉 Test successful - skipping shutdown to avoid MCP errors") 250 | 251 | except Exception as e: 252 | print(f"❌ Error: {e}") 253 | finally: 254 | # Make sure to terminate the process 255 | if 'process' in locals() and process.poll() is None: 256 | print("Terminating server process...") 257 | process.terminate() 258 | try: 259 | process.wait(timeout=5) 260 | except subprocess.TimeoutExpired: 261 | print("Process didn't terminate, killing it...") 262 | process.kill() 263 | 264 | print("\n�� Test completed!") 265 | 266 | if __name__ == "__main__": 267 | test_mcp_stdio() ``` -------------------------------------------------------------------------------- /scripts/test_messages.py: -------------------------------------------------------------------------------- ```python 1 | #!/usr/bin/env python3 2 | """ 3 | Simple test script to try connecting to the MCP messages endpoint. 4 | This follows the MCP 2024-11-05 specification precisely. 5 | """ 6 | import json 7 | import requests 8 | import sys 9 | import time 10 | import sseclient 11 | import signal 12 | 13 | def handle_exit(signum, frame): 14 | """Handle exit gracefully when user presses Ctrl+C.""" 15 | print("\nInterrupted. Exiting...") 16 | sys.exit(0) 17 | 18 | # Register signal handler for clean exit 19 | signal.signal(signal.SIGINT, handle_exit) 20 | 21 | def test_mcp(): 22 | """ 23 | Test the MCP server with standard protocol communication. 24 | Follows the MCP specification for 2024-11-05 carefully. 25 | """ 26 | print("🚀 Testing MCP server following 2024-11-05 specification") 27 | 28 | # Connect to SSE endpoint 29 | print("Connecting to SSE endpoint...") 30 | headers = {"Accept": "text/event-stream"} 31 | sse_response = requests.get("http://localhost:9096/sse", headers=headers, stream=True) 32 | 33 | if sse_response.status_code != 200: 34 | print(f"❌ Failed to connect to SSE endpoint: {sse_response.status_code}") 35 | return 36 | 37 | print(f"✅ SSE connection established: {sse_response.status_code}") 38 | 39 | try: 40 | client = sseclient.SSEClient(sse_response) 41 | 42 | # Get the messages URL from the first event 43 | messages_url = None 44 | session_id = None 45 | 46 | for event in client.events(): 47 | print(f"📩 SSE event: {event.event} - {event.data}") 48 | 49 | if event.event == "endpoint": 50 | messages_url = f"http://localhost:9096{event.data}" 51 | # Extract session ID from URL 52 | if "session_id=" in event.data: 53 | session_id = event.data.split("session_id=")[1] 54 | print(f"✅ Got messages URL: {messages_url}") 55 | print(f"✅ Session ID: {session_id}") 56 | break 57 | 58 | if not messages_url: 59 | print("❌ Failed to get messages URL from SSE") 60 | sse_response.close() 61 | return 62 | 63 | # Now we have the messages URL, send initialize request 64 | print(f"\n📤 Sending initialize request to {messages_url}") 65 | initialize_request = { 66 | "jsonrpc": "2.0", 67 | "id": 1, 68 | "method": "initialize", 69 | "params": { 70 | "protocolVersion": "2024-11-05", 71 | "clientInfo": { 72 | "name": "mcp-trino-test-client", 73 | "version": "1.0.0" 74 | }, 75 | "capabilities": { 76 | "tools": True, 77 | "resources": { 78 | "supportedSources": ["trino://catalog"] 79 | } 80 | } 81 | } 82 | } 83 | 84 | response = requests.post(messages_url, json=initialize_request) 85 | print(f"Status code: {response.status_code}") 86 | 87 | if response.status_code != 202: 88 | print(f"❌ Initialize request failed: {response.text}") 89 | sse_response.close() 90 | return 91 | 92 | print(f"✅ Initialize request accepted") 93 | 94 | # Listen for events and handle protocol properly 95 | print("\n🔄 Listening for response events...") 96 | 97 | # Set up a timeout 98 | timeout = time.time() + 60 # 60 seconds timeout 99 | 100 | # Protocol state tracking 101 | status = { 102 | "initialized": False, 103 | "tools_requested": False, 104 | "query_requested": False, 105 | "summary_requested": False, 106 | "done": False 107 | } 108 | 109 | # Event loop 110 | while time.time() < timeout and not status["done"]: 111 | events_received = False 112 | 113 | for event in client.events(): 114 | events_received = True 115 | 116 | # Skip ping events 117 | if event.event == "ping": 118 | print("📍 Ping event received") 119 | continue 120 | 121 | print(f"\n📩 Received event: {event.event}") 122 | 123 | # If we get a message event, parse it 124 | if event.event == "message" and event.data: 125 | try: 126 | data = json.loads(event.data) 127 | print(f"📦 Parsed message: {json.dumps(data, indent=2)}") 128 | 129 | # Handle initialize response 130 | if "id" in data and data["id"] == 1 and not status["initialized"]: 131 | # Send initialized notification (following spec) 132 | print("\n📤 Sending initialized notification...") 133 | initialized_notification = { 134 | "jsonrpc": "2.0", 135 | "method": "initialized" 136 | } 137 | init_response = requests.post(messages_url, json=initialized_notification) 138 | 139 | if init_response.status_code != 202: 140 | print(f"❌ Initialized notification failed: {init_response.status_code}") 141 | else: 142 | print(f"✅ Initialized notification accepted") 143 | status["initialized"] = True 144 | 145 | # Now request the tools list 146 | print("\n📤 Sending tools/list request...") 147 | tools_request = { 148 | "jsonrpc": "2.0", 149 | "id": 2, 150 | "method": "tools/list" 151 | } 152 | tools_response = requests.post(messages_url, json=tools_request) 153 | 154 | if tools_response.status_code != 202: 155 | print(f"❌ Tools list request failed: {tools_response.status_code}") 156 | else: 157 | print(f"✅ Tools list request accepted") 158 | status["tools_requested"] = True 159 | 160 | # Handle tools list response 161 | elif "id" in data and data["id"] == 2 and not status["query_requested"]: 162 | # Extract available tools 163 | tools = [] 164 | if "result" in data and "tools" in data["result"]: 165 | tools = [tool["name"] for tool in data["result"]["tools"]] 166 | print(f"🔧 Available tools: {', '.join(tools)}") 167 | 168 | # Execute a memory query if the execute_query tool is available 169 | if "execute_query" in tools: 170 | print("\n📤 Sending query for memory.bullshit.bullshit_data...") 171 | query_request = { 172 | "jsonrpc": "2.0", 173 | "id": 3, 174 | "method": "tools/call", 175 | "params": { 176 | "name": "execute_query", 177 | "arguments": { 178 | "sql": "SELECT * FROM memory.bullshit.bullshit_data", 179 | "catalog": "memory" 180 | } 181 | } 182 | } 183 | query_response = requests.post(messages_url, json=query_request) 184 | 185 | if query_response.status_code != 202: 186 | print(f"❌ Query request failed: {query_response.status_code}") 187 | else: 188 | print(f"✅ Query request accepted") 189 | status["query_requested"] = True 190 | else: 191 | print("❌ execute_query tool not available") 192 | status["done"] = True 193 | 194 | # Handle query response 195 | elif "id" in data and data["id"] == 3 and not status["summary_requested"]: 196 | # Check if query was successful 197 | if "result" in data: 198 | print(f"✅ Query succeeded with {data['result'].get('row_count', 0)} rows") 199 | 200 | # Now query the summary view 201 | print("\n📤 Sending query for memory.bullshit.bullshit_summary...") 202 | summary_request = { 203 | "jsonrpc": "2.0", 204 | "id": 4, 205 | "method": "tools/call", 206 | "params": { 207 | "name": "execute_query", 208 | "arguments": { 209 | "sql": "SELECT * FROM memory.bullshit.bullshit_summary ORDER BY count DESC", 210 | "catalog": "memory" 211 | } 212 | } 213 | } 214 | summary_response = requests.post(messages_url, json=summary_request) 215 | 216 | if summary_response.status_code != 202: 217 | print(f"❌ Summary query request failed: {summary_response.status_code}") 218 | else: 219 | print(f"✅ Summary query request accepted") 220 | status["summary_requested"] = True 221 | else: 222 | print(f"❌ Query failed: {data.get('error', 'Unknown error')}") 223 | status["done"] = True 224 | 225 | # Handle summary query response 226 | elif "id" in data and data["id"] == 4: 227 | if "result" in data: 228 | print(f"✅ Summary query succeeded with {data['result'].get('row_count', 0)} rows") 229 | # Print the summary data nicely formatted 230 | if "preview_rows" in data["result"]: 231 | for row in data["result"]["preview_rows"]: 232 | print(f" {row}") 233 | else: 234 | print(f"❌ Summary query failed: {data.get('error', 'Unknown error')}") 235 | 236 | print("\n🏁 All tests completed successfully!") 237 | status["done"] = True 238 | break 239 | 240 | except json.JSONDecodeError as e: 241 | print(f"❌ Error parsing message: {e}") 242 | except Exception as e: 243 | print(f"❌ Unexpected error: {e}") 244 | 245 | # Break out of the event loop if we're done 246 | if status["done"]: 247 | break 248 | 249 | # If we didn't receive any events, wait a bit before trying again 250 | if not events_received: 251 | time.sleep(0.5) 252 | 253 | # Check if we timed out 254 | if time.time() >= timeout and not status["done"]: 255 | print("⏱️ Timeout waiting for responses") 256 | 257 | except KeyboardInterrupt: 258 | print("\n🛑 Interrupted by user. Exiting...") 259 | except Exception as e: 260 | print(f"❌ Error: {e}") 261 | finally: 262 | # Close the SSE connection 263 | print("\n👋 Closing SSE connection...") 264 | sse_response.close() 265 | print("✅ Connection closed") 266 | 267 | if __name__ == "__main__": 268 | test_mcp() ``` -------------------------------------------------------------------------------- /src/trino_mcp/server.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Main module for the Trino MCP server. 3 | """ 4 | from __future__ import annotations 5 | 6 | import argparse 7 | import json 8 | import sys 9 | import os 10 | import asyncio 11 | import time 12 | from contextlib import asynccontextmanager 13 | from dataclasses import dataclass 14 | from typing import AsyncIterator, Dict, Any, List, Optional 15 | 16 | import uvicorn 17 | from fastapi import FastAPI, Response, Body 18 | from fastapi.responses import JSONResponse 19 | from pydantic import BaseModel 20 | from loguru import logger 21 | from mcp.server.fastmcp import Context, FastMCP 22 | 23 | from trino_mcp.config import ServerConfig, TrinoConfig 24 | from trino_mcp.resources import register_trino_resources 25 | from trino_mcp.tools import register_trino_tools 26 | from trino_mcp.trino_client import TrinoClient 27 | 28 | # Global app context for health check access 29 | app_context_global = None 30 | 31 | @dataclass 32 | class AppContext: 33 | """Application context passed to all MCP handlers.""" 34 | trino_client: TrinoClient 35 | config: ServerConfig 36 | is_healthy: bool = True 37 | 38 | # Models for the LLM API 39 | class QueryRequest(BaseModel): 40 | """Model for query requests.""" 41 | query: str 42 | catalog: str = "memory" 43 | schema: Optional[str] = None 44 | explain: bool = False 45 | 46 | class QueryResponse(BaseModel): 47 | """Model for query responses.""" 48 | success: bool 49 | message: str 50 | results: Optional[Dict[str, Any]] = None 51 | 52 | @asynccontextmanager 53 | async def app_lifespan(mcp: FastMCP) -> AsyncIterator[AppContext]: 54 | """ 55 | Manage the application lifecycle. 56 | 57 | Args: 58 | mcp: The MCP server instance. 59 | 60 | Yields: 61 | AppContext: The application context with initialized services. 62 | """ 63 | global app_context_global 64 | 65 | logger.info("Initializing Trino MCP server") 66 | 67 | # Get server configuration from environment or command line 68 | config = parse_args() 69 | 70 | # Initialize Trino client 71 | trino_client = TrinoClient(config.trino) 72 | 73 | # Create and set global app context 74 | app_context = AppContext(trino_client=trino_client, config=config) 75 | app_context_global = app_context 76 | 77 | try: 78 | # Connect to Trino 79 | logger.info(f"Connecting to Trino at {config.trino.host}:{config.trino.port}") 80 | trino_client.connect() 81 | 82 | # Register resources and tools 83 | logger.info("Registering resources and tools") 84 | register_trino_resources(mcp, trino_client) 85 | register_trino_tools(mcp, trino_client) 86 | 87 | # Yield the application context 88 | logger.info("Trino MCP server initialized and ready") 89 | yield app_context 90 | except Exception as e: 91 | logger.error(f"Failed to initialize: {e}") 92 | app_context.is_healthy = False 93 | yield app_context 94 | finally: 95 | # Cleanup on shutdown 96 | logger.info("Shutting down Trino MCP server") 97 | if trino_client.conn: 98 | trino_client.disconnect() 99 | app_context.is_healthy = False 100 | 101 | 102 | def parse_args() -> ServerConfig: 103 | """ 104 | Parse command line arguments and return server configuration. 105 | 106 | Returns: 107 | ServerConfig: The server configuration. 108 | """ 109 | parser = argparse.ArgumentParser(description="Trino MCP server") 110 | 111 | # Server configuration 112 | parser.add_argument("--name", default="Trino MCP", help="Server name") 113 | parser.add_argument("--version", default="0.1.0", help="Server version") 114 | parser.add_argument("--transport", default="stdio", choices=["stdio", "sse"], help="Transport type") 115 | parser.add_argument("--host", default="127.0.0.1", help="Host for HTTP server (SSE transport only)") 116 | parser.add_argument("--port", type=int, default=3000, help="Port for HTTP server (SSE transport only)") 117 | parser.add_argument("--debug", action="store_true", help="Enable debug mode") 118 | 119 | # Trino connection 120 | parser.add_argument("--trino-host", default="localhost", help="Trino host") 121 | parser.add_argument("--trino-port", type=int, default=8080, help="Trino port") 122 | parser.add_argument("--trino-user", default="trino", help="Trino user") 123 | parser.add_argument("--trino-password", help="Trino password") 124 | parser.add_argument("--trino-catalog", help="Default Trino catalog") 125 | parser.add_argument("--trino-schema", help="Default Trino schema") 126 | parser.add_argument("--trino-http-scheme", default="http", help="Trino HTTP scheme") 127 | 128 | args = parser.parse_args() 129 | 130 | # Create Trino configuration 131 | trino_config = TrinoConfig( 132 | host=args.trino_host, 133 | port=args.trino_port, 134 | user=args.trino_user, 135 | password=args.trino_password, 136 | catalog=args.trino_catalog, 137 | schema=args.trino_schema, 138 | http_scheme=args.trino_http_scheme 139 | ) 140 | 141 | # Create server configuration 142 | server_config = ServerConfig( 143 | name=args.name, 144 | version=args.version, 145 | transport_type=args.transport, 146 | host=args.host, 147 | port=args.port, 148 | debug=args.debug, 149 | trino=trino_config 150 | ) 151 | 152 | return server_config 153 | 154 | 155 | def create_app() -> FastMCP: 156 | """ 157 | Create and configure the MCP server application. 158 | 159 | Returns: 160 | FastMCP: The configured MCP server. 161 | """ 162 | # Create the MCP server with lifespan management 163 | mcp = FastMCP( 164 | "Trino MCP", 165 | dependencies=["trino>=0.329.0"], 166 | lifespan=app_lifespan 167 | ) 168 | 169 | return mcp 170 | 171 | 172 | def create_health_app() -> FastAPI: 173 | """ 174 | Create a FastAPI app that provides a health check endpoint and LLM API. 175 | 176 | This function creates a FastAPI app with a health check endpoint and 177 | a query endpoint for LLMs to use. 178 | 179 | Returns: 180 | FastAPI: The FastAPI app with health check and LLM API endpoints. 181 | """ 182 | app = FastAPI( 183 | title="Trino MCP API", 184 | description="API for health checks and LLM query access to Trino MCP", 185 | version="0.1.0" 186 | ) 187 | 188 | @app.get("/health") 189 | async def health(): 190 | global app_context_global 191 | 192 | # For Docker health check, always return 200 during startup 193 | # This gives the app time to initialize 194 | return JSONResponse( 195 | status_code=200, 196 | content={"status": "ok", "message": "Health check endpoint is responding"} 197 | ) 198 | 199 | @app.post("/api/query", response_model=QueryResponse) 200 | async def query(request: QueryRequest): 201 | """ 202 | Execute a SQL query against Trino and return results. 203 | 204 | This endpoint is designed to be used by LLMs to query Trino through MCP. 205 | """ 206 | global app_context_global 207 | 208 | if not app_context_global or not app_context_global.is_healthy: 209 | return JSONResponse( 210 | status_code=503, 211 | content={ 212 | "success": False, 213 | "message": "Trino MCP server is not healthy or not initialized" 214 | } 215 | ) 216 | 217 | logger.info(f"LLM API Query: {request.query}") 218 | 219 | try: 220 | # Use the Trino client from the app context 221 | client = app_context_global.trino_client 222 | 223 | # Optionally add EXPLAIN 224 | query = request.query 225 | if request.explain: 226 | query = f"EXPLAIN {query}" 227 | 228 | # Execute the query 229 | result = client.execute_query(query, request.catalog, request.schema) 230 | 231 | # Format the results for the response 232 | formatted_rows = [] 233 | for row in result.rows: 234 | # Convert row to dict using column names 235 | row_dict = {} 236 | for i, col in enumerate(result.columns): 237 | row_dict[col] = row[i] 238 | formatted_rows.append(row_dict) 239 | 240 | return { 241 | "success": True, 242 | "message": "Query executed successfully", 243 | "results": { 244 | "query_id": result.query_id, 245 | "columns": result.columns, 246 | "rows": formatted_rows, 247 | "row_count": result.row_count, 248 | "execution_time_ms": result.query_time_ms 249 | } 250 | } 251 | 252 | except Exception as e: 253 | logger.error(f"Error executing query: {e}") 254 | return JSONResponse( 255 | status_code=400, 256 | content={ 257 | "success": False, 258 | "message": f"Error executing query: {str(e)}" 259 | } 260 | ) 261 | 262 | @app.get("/api") 263 | async def api_root(): 264 | """Root API endpoint with usage instructions.""" 265 | return { 266 | "message": "Trino MCP API for LLMs", 267 | "version": app_context_global.config.version if app_context_global else "unknown", 268 | "endpoints": { 269 | "health": "GET /health - Check server health", 270 | "query": "POST /api/query - Execute SQL queries" 271 | }, 272 | "query_example": { 273 | "query": "SELECT * FROM memory.bullshit.real_bullshit_data LIMIT 3", 274 | "catalog": "memory", 275 | "schema": "bullshit" 276 | } 277 | } 278 | 279 | return app 280 | 281 | 282 | def main() -> None: 283 | """ 284 | Main entry point for the server. 285 | """ 286 | config = parse_args() 287 | mcp = create_app() 288 | 289 | # ADDING EXPLICIT CONTEXT INITIALIZATION HERE 290 | global app_context_global 291 | try: 292 | # Initialize the Trino client 293 | logger.info(f"Connecting to Trino at {config.trino.host}:{config.trino.port}") 294 | trino_client = TrinoClient(config.trino) 295 | trino_client.connect() 296 | 297 | # Create application context 298 | app_context = AppContext( 299 | trino_client=trino_client, 300 | config=config, 301 | is_healthy=True 302 | ) 303 | 304 | # Set global context 305 | app_context_global = app_context 306 | 307 | # Register resources and tools 308 | register_trino_resources(mcp, trino_client) 309 | register_trino_tools(mcp, trino_client) 310 | 311 | logger.info("Trino MCP server initialized and ready") 312 | except Exception as e: 313 | logger.error(f"Error initializing Trino MCP: {e}") 314 | if app_context_global: 315 | app_context_global.is_healthy = False 316 | 317 | if config.transport_type == "stdio": 318 | # For STDIO transport, run directly 319 | logger.info("Starting Trino MCP server with STDIO transport") 320 | mcp.run() 321 | else: 322 | # For SSE transport, use run_sse_async method from MCP library 323 | logger.info(f"Starting Trino MCP server with SSE transport on {config.host}:{config.port}") 324 | 325 | # In MCP 1.3.0, run_sse_async takes no arguments 326 | # We set the environment variables to configure the host and port 327 | os.environ["MCP_HOST"] = config.host 328 | os.environ["MCP_PORT"] = str(config.port) 329 | 330 | # Configure more robust error handling for the server 331 | import traceback 332 | try: 333 | # Try to import and configure SSE settings if available in this version 334 | from mcp.server.sse import configure_sse 335 | configure_sse(ignore_client_disconnect=True) 336 | logger.info("Configured SSE with ignore_client_disconnect=True") 337 | except (ImportError, AttributeError): 338 | logger.warning("Could not configure SSE settings - this may be expected in some MCP versions") 339 | 340 | # Start a separate thread for the health check endpoint 341 | import threading 342 | 343 | def run_health_check(): 344 | """Run the health check FastAPI app.""" 345 | health_app = create_health_app() 346 | # Use a different port for the health check endpoint 347 | health_port = config.port + 1 348 | logger.info(f"Starting API server on port {health_port}") 349 | uvicorn.run(health_app, host=config.host, port=health_port) 350 | 351 | # Start the health check in a separate thread 352 | health_thread = threading.Thread(target=run_health_check) 353 | health_thread.daemon = True 354 | health_thread.start() 355 | 356 | # Now run the SSE server with robust error handling 357 | try: 358 | asyncio.run(mcp.run_sse_async()) 359 | except RuntimeError as e: 360 | if "generator didn't stop after athrow()" in str(e): 361 | logger.error(f"Generator error in SSE server. This is a known issue with MCP 1.3.0: {e}") 362 | 363 | # Set unhealthy status for health checks 364 | if app_context_global: 365 | app_context_global.is_healthy = False 366 | 367 | logger.info("Server will continue running but may not function correctly.") 368 | 369 | # Keep the server alive despite the error 370 | import time 371 | while True: 372 | time.sleep(60) # Sleep to keep the container running 373 | 374 | else: 375 | logger.error(f"Fatal error running SSE server: {e}") 376 | logger.error(traceback.format_exc()) 377 | raise 378 | except Exception as e: 379 | logger.error(f"Fatal error running SSE server: {e}") 380 | logger.error(traceback.format_exc()) 381 | raise 382 | 383 | 384 | if __name__ == "__main__": 385 | logger.remove() 386 | logger.add(sys.stderr, level="DEBUG") 387 | main() 388 | ```