#
tokens: 9855/50000 3/48 files (page 2/2)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 2 of 2. Use http://codebase.md/stinkgen/trino_mcp?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .gitignore
├── CHANGELOG.md
├── docker-compose.yml
├── Dockerfile
├── etc
│   ├── catalog
│   │   ├── bullshit.properties
│   │   └── memory.properties
│   ├── config.properties
│   ├── jvm.config
│   └── node.properties
├── examples
│   └── simple_mcp_query.py
├── LICENSE
├── llm_query_trino.py
├── llm_trino_api.py
├── load_bullshit_data.py
├── openapi.json
├── pyproject.toml
├── pytest.ini
├── README.md
├── requirements-dev.txt
├── run_tests.sh
├── scripts
│   ├── docker_stdio_test.py
│   ├── fix_trino_session.py
│   ├── test_direct_query.py
│   ├── test_fixed_client.py
│   ├── test_messages.py
│   ├── test_quick_query.py
│   └── test_stdio_trino.py
├── src
│   └── trino_mcp
│       ├── __init__.py
│       ├── config.py
│       ├── resources
│       │   └── __init__.py
│       ├── server.py
│       ├── tools
│       │   └── __init__.py
│       └── trino_client.py
├── test_bullshit_query.py
├── test_llm_api.py
├── test_mcp_stdio.py
├── tests
│   ├── __init__.py
│   ├── conftest.py
│   ├── integration
│   │   └── __init__.py
│   └── test_client.py
├── tools
│   ├── create_bullshit_data.py
│   ├── run_queries.sh
│   ├── setup
│   │   ├── setup_data.sh
│   │   └── setup_tables.sql
│   └── setup_bullshit_table.py
└── trino-conf
    ├── catalog
    │   └── memory.properties
    ├── config.properties
    ├── jvm.config
    └── node.properties
```

# Files

--------------------------------------------------------------------------------
/test_mcp_stdio.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | STDIO transport test script for Trino MCP.
  4 | This script demonstrates the end-to-end flow of initializing MCP, listing tools, 
  5 | querying data, and shutting down using the STDIO transport.
  6 | """
  7 | import json
  8 | import subprocess
  9 | import sys
 10 | import time
 11 | 
 12 | def test_mcp_stdio():
 13 |     """Run an end-to-end test of Trino MCP using STDIO transport."""
 14 |     print("🚀 Starting Trino MCP STDIO test")
 15 |     
 16 |     # Start the MCP server with STDIO transport
 17 |     server_cmd = [
 18 |         "docker", "exec", "-i", "trino_mcp_trino-mcp_1", 
 19 |         "python", "-m", "trino_mcp.server", 
 20 |         "--transport", "stdio", 
 21 |         "--debug",
 22 |         "--trino-host", "trino",
 23 |         "--trino-port", "8080",
 24 |         "--trino-user", "trino", 
 25 |         "--trino-catalog", "memory"
 26 |     ]
 27 |     
 28 |     try:
 29 |         print(f"Starting MCP server process: {' '.join(server_cmd)}")
 30 |         process = subprocess.Popen(
 31 |             server_cmd,
 32 |             stdin=subprocess.PIPE,
 33 |             stdout=subprocess.PIPE,
 34 |             stderr=sys.stderr,  # Pass stderr through to see logs directly
 35 |             text=True,
 36 |             bufsize=1  # Line buffered
 37 |         )
 38 |         
 39 |         # Sleep a bit to let the server initialize
 40 |         time.sleep(2)
 41 |         
 42 |         # Helper function to send a request and get a response
 43 |         def send_request(request, expect_response=True):
 44 |             """
 45 |             Send a request to the MCP server and get the response.
 46 |             
 47 |             Args:
 48 |                 request: The JSON-RPC request to send
 49 |                 expect_response: Whether to wait for a response
 50 |                 
 51 |             Returns:
 52 |                 The JSON-RPC response, or None if no response is expected
 53 |             """
 54 |             request_str = json.dumps(request) + "\n"
 55 |             print(f"\n📤 Sending: {request_str.strip()}")
 56 |             
 57 |             try:
 58 |                 process.stdin.write(request_str)
 59 |                 process.stdin.flush()
 60 |             except BrokenPipeError:
 61 |                 print("❌ Broken pipe - server has closed the connection")
 62 |                 return None
 63 |                 
 64 |             if not expect_response:
 65 |                 print("✅ Sent notification (no response expected)")
 66 |                 return None
 67 |                 
 68 |             # Read the response
 69 |             print("Waiting for response...")
 70 |             try:
 71 |                 response_str = process.stdout.readline()
 72 |                 if response_str:
 73 |                     print(f"📩 Received: {response_str.strip()}")
 74 |                     return json.loads(response_str)
 75 |                 else:
 76 |                     print("❌ No response received")
 77 |                     return None
 78 |             except Exception as e:
 79 |                 print(f"❌ Error reading response: {e}")
 80 |                 return None
 81 |         
 82 |         # ===== STEP 1: Initialize MCP =====
 83 |         print("\n===== STEP 1: Initialize MCP =====")
 84 |         initialize_request = {
 85 |             "jsonrpc": "2.0",
 86 |             "id": 1,
 87 |             "method": "initialize",
 88 |             "params": {
 89 |                 "protocolVersion": "2024-11-05",
 90 |                 "clientInfo": {
 91 |                     "name": "trino-mcp-stdio-test",
 92 |                     "version": "1.0.0"
 93 |                 },
 94 |                 "capabilities": {
 95 |                     "tools": True,
 96 |                     "resources": {
 97 |                         "supportedSources": ["trino://catalog"]
 98 |                     }
 99 |                 }
100 |             }
101 |         }
102 |         
103 |         init_response = send_request(initialize_request)
104 |         if not init_response:
105 |             print("❌ Failed to initialize MCP - exiting test")
106 |             return
107 |             
108 |         # Print server info
109 |         if "result" in init_response and "serverInfo" in init_response["result"]:
110 |             server_info = init_response["result"]["serverInfo"]
111 |             print(f"✅ Connected to server: {server_info.get('name')} {server_info.get('version')}")
112 |         
113 |         # ===== STEP 2: Send initialized notification =====
114 |         print("\n===== STEP 2: Send initialized notification =====")
115 |         initialized_notification = {
116 |             "jsonrpc": "2.0",
117 |             "method": "notifications/initialized",
118 |             "params": {}
119 |         }
120 |         
121 |         send_request(initialized_notification, expect_response=False)
122 |         
123 |         # ===== STEP 3: List available tools =====
124 |         print("\n===== STEP 3: List available tools =====")
125 |         tools_request = {
126 |             "jsonrpc": "2.0",
127 |             "id": 2,
128 |             "method": "tools/list"
129 |         }
130 |         
131 |         tools_response = send_request(tools_request)
132 |         if not tools_response or "result" not in tools_response:
133 |             print("❌ Failed to get tools list")
134 |         else:
135 |             tools = tools_response.get("result", {}).get("tools", [])
136 |             print(f"✅ Available tools: {len(tools)}")
137 |             for tool in tools:
138 |                 print(f"  - {tool.get('name')}: {tool.get('description', 'No description')[:80]}...")
139 |         
140 |         # ===== STEP 4: Execute a simple query =====
141 |         print("\n===== STEP 4: Execute a simple query =====")
142 |         query_request = {
143 |             "jsonrpc": "2.0",
144 |             "id": 3,
145 |             "method": "tools/call",
146 |             "params": {
147 |                 "name": "execute_query",
148 |                 "arguments": {
149 |                     "sql": "SELECT 'Hello from Trino MCP' AS message",
150 |                     "catalog": "memory"
151 |                 }
152 |             }
153 |         }
154 |         
155 |         query_response = send_request(query_request)
156 |         if not query_response:
157 |             print("❌ Failed to execute query")
158 |         elif "error" in query_response:
159 |             print(f"❌ Query error: {json.dumps(query_response.get('error', {}), indent=2)}")
160 |         else:
161 |             print(f"✅ Query executed successfully:")
162 |             if "result" in query_response:
163 |                 result = query_response["result"]
164 |                 if isinstance(result, dict) and "content" in result:
165 |                     # Parse the content text which contains the actual results as a JSON string
166 |                     try:
167 |                         content = result["content"][0]["text"]
168 |                         result_data = json.loads(content)
169 |                         print(f"  Query ID: {result_data.get('query_id', 'unknown')}")
170 |                         print(f"  Columns: {', '.join(result_data.get('columns', []))}")
171 |                         print(f"  Row count: {result_data.get('row_count', 0)}")
172 |                         print(f"  Results: {json.dumps(result_data.get('preview_rows', []), indent=2)}")
173 |                     except (json.JSONDecodeError, IndexError) as e:
174 |                         print(f"  Raw result: {json.dumps(result, indent=2)}")
175 |                 else:
176 |                     print(f"  Raw result: {json.dumps(result, indent=2)}")
177 |             else:
178 |                 print(f"  Raw response: {json.dumps(query_response, indent=2)}")
179 |                 
180 |         # Try the bullshit table query - this is what the original script wanted
181 |         print("\n===== STEP 5: Query the Bullshit Table =====")
182 |         bs_query_request = {
183 |             "jsonrpc": "2.0",
184 |             "id": 4,
185 |             "method": "tools/call",
186 |             "params": {
187 |                 "name": "execute_query",
188 |                 "arguments": {
189 |                     "sql": "SELECT * FROM memory.bullshit.bullshit_data LIMIT 3",
190 |                     "catalog": "memory"
191 |                 }
192 |             }
193 |         }
194 |         
195 |         bs_query_response = send_request(bs_query_request)
196 |         if not bs_query_response:
197 |             print("❌ Failed to execute bullshit table query")
198 |         elif "error" in bs_query_response:
199 |             err = bs_query_response.get("error", {}) 
200 |             if isinstance(err, dict):
201 |                 print(f"❌ Query error: {json.dumps(err, indent=2)}")
202 |             else:
203 |                 print(f"❌ Query error: {err}")
204 |                 
205 |             # Try with information_schema as fallback
206 |             print("\n----- Fallback Query: Checking Available Schemas -----")
207 |             fallback_query = {
208 |                 "jsonrpc": "2.0",
209 |                 "id": 5,
210 |                 "method": "tools/call",
211 |                 "params": {
212 |                     "name": "execute_query",
213 |                     "arguments": {
214 |                         "sql": "SHOW SCHEMAS FROM memory",
215 |                         "catalog": "memory"
216 |                     }
217 |                 }
218 |             }
219 |             schemas_response = send_request(fallback_query)
220 |             if schemas_response and "result" in schemas_response:
221 |                 result = schemas_response["result"]
222 |                 if isinstance(result, dict) and "content" in result:
223 |                     try:
224 |                         content = result["content"][0]["text"]
225 |                         result_data = json.loads(content)
226 |                         print(f"  Available schemas: {json.dumps(result_data.get('preview_rows', []), indent=2)}")
227 |                     except (json.JSONDecodeError, IndexError) as e:
228 |                         print(f"  Raw schemas result: {json.dumps(result, indent=2)}")
229 |         else:
230 |             print(f"✅ Bullshit query executed successfully:")
231 |             if "result" in bs_query_response:
232 |                 result = bs_query_response["result"]
233 |                 if isinstance(result, dict) and "content" in result:
234 |                     try:
235 |                         content = result["content"][0]["text"]
236 |                         result_data = json.loads(content)
237 |                         print(f"  Query ID: {result_data.get('query_id', 'unknown')}")
238 |                         print(f"  Columns: {', '.join(result_data.get('columns', []))}")
239 |                         print(f"  Row count: {result_data.get('row_count', 0)}")
240 |                         print(f"  Results: {json.dumps(result_data.get('preview_rows', []), indent=2)}")
241 |                     except (json.JSONDecodeError, IndexError) as e:
242 |                         print(f"  Raw result: {json.dumps(result, indent=2)}")
243 |                 else:
244 |                     print(f"  Raw result: {json.dumps(result, indent=2)}")
245 |             else:
246 |                 print(f"  Raw response: {json.dumps(bs_query_response, indent=2)}")
247 |         
248 |         # Skip the shutdown steps since those cause MCP errors
249 |         print("\n🎉 Test successful - skipping shutdown to avoid MCP errors")
250 |         
251 |     except Exception as e:
252 |         print(f"❌ Error: {e}")
253 |     finally:
254 |         # Make sure to terminate the process
255 |         if 'process' in locals() and process.poll() is None:
256 |             print("Terminating server process...")
257 |             process.terminate()
258 |             try:
259 |                 process.wait(timeout=5)
260 |             except subprocess.TimeoutExpired:
261 |                 print("Process didn't terminate, killing it...")
262 |                 process.kill()
263 |         
264 |         print("\n�� Test completed!")
265 | 
266 | if __name__ == "__main__":
267 |     test_mcp_stdio() 
```

--------------------------------------------------------------------------------
/scripts/test_messages.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | Simple test script to try connecting to the MCP messages endpoint.
  4 | This follows the MCP 2024-11-05 specification precisely.
  5 | """
  6 | import json
  7 | import requests
  8 | import sys
  9 | import time
 10 | import sseclient
 11 | import signal
 12 | 
 13 | def handle_exit(signum, frame):
 14 |     """Handle exit gracefully when user presses Ctrl+C."""
 15 |     print("\nInterrupted. Exiting...")
 16 |     sys.exit(0)
 17 | 
 18 | # Register signal handler for clean exit
 19 | signal.signal(signal.SIGINT, handle_exit)
 20 | 
 21 | def test_mcp():
 22 |     """
 23 |     Test the MCP server with standard protocol communication.
 24 |     Follows the MCP specification for 2024-11-05 carefully.
 25 |     """
 26 |     print("🚀 Testing MCP server following 2024-11-05 specification")
 27 |     
 28 |     # Connect to SSE endpoint
 29 |     print("Connecting to SSE endpoint...")
 30 |     headers = {"Accept": "text/event-stream"}
 31 |     sse_response = requests.get("http://localhost:9096/sse", headers=headers, stream=True)
 32 |     
 33 |     if sse_response.status_code != 200:
 34 |         print(f"❌ Failed to connect to SSE endpoint: {sse_response.status_code}")
 35 |         return
 36 |     
 37 |     print(f"✅ SSE connection established: {sse_response.status_code}")
 38 |     
 39 |     try:
 40 |         client = sseclient.SSEClient(sse_response)
 41 |         
 42 |         # Get the messages URL from the first event
 43 |         messages_url = None
 44 |         session_id = None
 45 |         
 46 |         for event in client.events():
 47 |             print(f"📩 SSE event: {event.event} - {event.data}")
 48 |             
 49 |             if event.event == "endpoint":
 50 |                 messages_url = f"http://localhost:9096{event.data}"
 51 |                 # Extract session ID from URL
 52 |                 if "session_id=" in event.data:
 53 |                     session_id = event.data.split("session_id=")[1]
 54 |                 print(f"✅ Got messages URL: {messages_url}")
 55 |                 print(f"✅ Session ID: {session_id}")
 56 |                 break
 57 |         
 58 |         if not messages_url:
 59 |             print("❌ Failed to get messages URL from SSE")
 60 |             sse_response.close()
 61 |             return
 62 |         
 63 |         # Now we have the messages URL, send initialize request
 64 |         print(f"\n📤 Sending initialize request to {messages_url}")
 65 |         initialize_request = {
 66 |             "jsonrpc": "2.0",
 67 |             "id": 1,
 68 |             "method": "initialize",
 69 |             "params": {
 70 |                 "protocolVersion": "2024-11-05",
 71 |                 "clientInfo": {
 72 |                     "name": "mcp-trino-test-client",
 73 |                     "version": "1.0.0"
 74 |                 },
 75 |                 "capabilities": {
 76 |                     "tools": True,
 77 |                     "resources": {
 78 |                         "supportedSources": ["trino://catalog"]
 79 |                     }
 80 |                 }
 81 |             }
 82 |         }
 83 |         
 84 |         response = requests.post(messages_url, json=initialize_request)
 85 |         print(f"Status code: {response.status_code}")
 86 |         
 87 |         if response.status_code != 202:
 88 |             print(f"❌ Initialize request failed: {response.text}")
 89 |             sse_response.close()
 90 |             return
 91 |             
 92 |         print(f"✅ Initialize request accepted")
 93 |         
 94 |         # Listen for events and handle protocol properly
 95 |         print("\n🔄 Listening for response events...")
 96 |         
 97 |         # Set up a timeout
 98 |         timeout = time.time() + 60  # 60 seconds timeout
 99 |         
100 |         # Protocol state tracking
101 |         status = {
102 |             "initialized": False,
103 |             "tools_requested": False,
104 |             "query_requested": False,
105 |             "summary_requested": False,
106 |             "done": False
107 |         }
108 |         
109 |         # Event loop
110 |         while time.time() < timeout and not status["done"]:
111 |             events_received = False
112 |             
113 |             for event in client.events():
114 |                 events_received = True
115 |                 
116 |                 # Skip ping events
117 |                 if event.event == "ping":
118 |                     print("📍 Ping event received")
119 |                     continue
120 |                     
121 |                 print(f"\n📩 Received event: {event.event}")
122 |                 
123 |                 # If we get a message event, parse it
124 |                 if event.event == "message" and event.data:
125 |                     try:
126 |                         data = json.loads(event.data)
127 |                         print(f"📦 Parsed message: {json.dumps(data, indent=2)}")
128 |                         
129 |                         # Handle initialize response 
130 |                         if "id" in data and data["id"] == 1 and not status["initialized"]:
131 |                             # Send initialized notification (following spec)
132 |                             print("\n📤 Sending initialized notification...")
133 |                             initialized_notification = {
134 |                                 "jsonrpc": "2.0",
135 |                                 "method": "initialized"
136 |                             }
137 |                             init_response = requests.post(messages_url, json=initialized_notification)
138 |                             
139 |                             if init_response.status_code != 202:
140 |                                 print(f"❌ Initialized notification failed: {init_response.status_code}")
141 |                             else:
142 |                                 print(f"✅ Initialized notification accepted")
143 |                                 status["initialized"] = True
144 |                             
145 |                             # Now request the tools list
146 |                             print("\n📤 Sending tools/list request...")
147 |                             tools_request = {
148 |                                 "jsonrpc": "2.0",
149 |                                 "id": 2,
150 |                                 "method": "tools/list"
151 |                             }
152 |                             tools_response = requests.post(messages_url, json=tools_request)
153 |                             
154 |                             if tools_response.status_code != 202:
155 |                                 print(f"❌ Tools list request failed: {tools_response.status_code}")
156 |                             else:
157 |                                 print(f"✅ Tools list request accepted")
158 |                                 status["tools_requested"] = True
159 |                         
160 |                         # Handle tools list response
161 |                         elif "id" in data and data["id"] == 2 and not status["query_requested"]:
162 |                             # Extract available tools
163 |                             tools = []
164 |                             if "result" in data and "tools" in data["result"]:
165 |                                 tools = [tool["name"] for tool in data["result"]["tools"]]
166 |                                 print(f"🔧 Available tools: {', '.join(tools)}")
167 |                             
168 |                             # Execute a memory query if the execute_query tool is available
169 |                             if "execute_query" in tools:
170 |                                 print("\n📤 Sending query for memory.bullshit.bullshit_data...")
171 |                                 query_request = {
172 |                                     "jsonrpc": "2.0",
173 |                                     "id": 3,
174 |                                     "method": "tools/call",
175 |                                     "params": {
176 |                                         "name": "execute_query",
177 |                                         "arguments": {
178 |                                             "sql": "SELECT * FROM memory.bullshit.bullshit_data",
179 |                                             "catalog": "memory"
180 |                                         }
181 |                                     }
182 |                                 }
183 |                                 query_response = requests.post(messages_url, json=query_request)
184 |                                 
185 |                                 if query_response.status_code != 202:
186 |                                     print(f"❌ Query request failed: {query_response.status_code}")
187 |                                 else:
188 |                                     print(f"✅ Query request accepted")
189 |                                     status["query_requested"] = True
190 |                             else:
191 |                                 print("❌ execute_query tool not available")
192 |                                 status["done"] = True
193 |                         
194 |                         # Handle query response
195 |                         elif "id" in data and data["id"] == 3 and not status["summary_requested"]:
196 |                             # Check if query was successful
197 |                             if "result" in data:
198 |                                 print(f"✅ Query succeeded with {data['result'].get('row_count', 0)} rows")
199 |                                 
200 |                                 # Now query the summary view
201 |                                 print("\n📤 Sending query for memory.bullshit.bullshit_summary...")
202 |                                 summary_request = {
203 |                                     "jsonrpc": "2.0",
204 |                                     "id": 4,
205 |                                     "method": "tools/call",
206 |                                     "params": {
207 |                                         "name": "execute_query",
208 |                                         "arguments": {
209 |                                             "sql": "SELECT * FROM memory.bullshit.bullshit_summary ORDER BY count DESC",
210 |                                             "catalog": "memory"
211 |                                         }
212 |                                     }
213 |                                 }
214 |                                 summary_response = requests.post(messages_url, json=summary_request)
215 |                                 
216 |                                 if summary_response.status_code != 202:
217 |                                     print(f"❌ Summary query request failed: {summary_response.status_code}")
218 |                                 else:
219 |                                     print(f"✅ Summary query request accepted")
220 |                                     status["summary_requested"] = True
221 |                             else:
222 |                                 print(f"❌ Query failed: {data.get('error', 'Unknown error')}")
223 |                                 status["done"] = True
224 |                         
225 |                         # Handle summary query response
226 |                         elif "id" in data and data["id"] == 4:
227 |                             if "result" in data:
228 |                                 print(f"✅ Summary query succeeded with {data['result'].get('row_count', 0)} rows")
229 |                                 # Print the summary data nicely formatted
230 |                                 if "preview_rows" in data["result"]:
231 |                                     for row in data["result"]["preview_rows"]:
232 |                                         print(f"  {row}")
233 |                             else:
234 |                                 print(f"❌ Summary query failed: {data.get('error', 'Unknown error')}")
235 |                             
236 |                             print("\n🏁 All tests completed successfully!")
237 |                             status["done"] = True
238 |                             break
239 |                     
240 |                     except json.JSONDecodeError as e:
241 |                         print(f"❌ Error parsing message: {e}")
242 |                     except Exception as e:
243 |                         print(f"❌ Unexpected error: {e}")
244 |                 
245 |                 # Break out of the event loop if we're done
246 |                 if status["done"]:
247 |                     break
248 |             
249 |             # If we didn't receive any events, wait a bit before trying again
250 |             if not events_received:
251 |                 time.sleep(0.5)
252 |         
253 |         # Check if we timed out
254 |         if time.time() >= timeout and not status["done"]:
255 |             print("⏱️ Timeout waiting for responses")
256 |             
257 |     except KeyboardInterrupt:
258 |         print("\n🛑 Interrupted by user. Exiting...")
259 |     except Exception as e:
260 |         print(f"❌ Error: {e}")
261 |     finally:
262 |         # Close the SSE connection
263 |         print("\n👋 Closing SSE connection...")
264 |         sse_response.close()
265 |         print("✅ Connection closed")
266 | 
267 | if __name__ == "__main__":
268 |     test_mcp() 
```

--------------------------------------------------------------------------------
/src/trino_mcp/server.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Main module for the Trino MCP server.
  3 | """
  4 | from __future__ import annotations
  5 | 
  6 | import argparse
  7 | import json
  8 | import sys
  9 | import os
 10 | import asyncio
 11 | import time
 12 | from contextlib import asynccontextmanager
 13 | from dataclasses import dataclass
 14 | from typing import AsyncIterator, Dict, Any, List, Optional
 15 | 
 16 | import uvicorn
 17 | from fastapi import FastAPI, Response, Body
 18 | from fastapi.responses import JSONResponse
 19 | from pydantic import BaseModel
 20 | from loguru import logger
 21 | from mcp.server.fastmcp import Context, FastMCP
 22 | 
 23 | from trino_mcp.config import ServerConfig, TrinoConfig
 24 | from trino_mcp.resources import register_trino_resources
 25 | from trino_mcp.tools import register_trino_tools
 26 | from trino_mcp.trino_client import TrinoClient
 27 | 
 28 | # Global app context for health check access
 29 | app_context_global = None
 30 | 
 31 | @dataclass
 32 | class AppContext:
 33 |     """Application context passed to all MCP handlers."""
 34 |     trino_client: TrinoClient
 35 |     config: ServerConfig
 36 |     is_healthy: bool = True
 37 | 
 38 | # Models for the LLM API
 39 | class QueryRequest(BaseModel):
 40 |     """Model for query requests."""
 41 |     query: str
 42 |     catalog: str = "memory"
 43 |     schema: Optional[str] = None
 44 |     explain: bool = False
 45 | 
 46 | class QueryResponse(BaseModel):
 47 |     """Model for query responses."""
 48 |     success: bool
 49 |     message: str
 50 |     results: Optional[Dict[str, Any]] = None
 51 | 
 52 | @asynccontextmanager
 53 | async def app_lifespan(mcp: FastMCP) -> AsyncIterator[AppContext]:
 54 |     """
 55 |     Manage the application lifecycle.
 56 |     
 57 |     Args:
 58 |         mcp: The MCP server instance.
 59 |         
 60 |     Yields:
 61 |         AppContext: The application context with initialized services.
 62 |     """
 63 |     global app_context_global
 64 |     
 65 |     logger.info("Initializing Trino MCP server")
 66 |     
 67 |     # Get server configuration from environment or command line
 68 |     config = parse_args()
 69 |     
 70 |     # Initialize Trino client
 71 |     trino_client = TrinoClient(config.trino)
 72 |     
 73 |     # Create and set global app context
 74 |     app_context = AppContext(trino_client=trino_client, config=config)
 75 |     app_context_global = app_context
 76 |     
 77 |     try:
 78 |         # Connect to Trino
 79 |         logger.info(f"Connecting to Trino at {config.trino.host}:{config.trino.port}")
 80 |         trino_client.connect()
 81 |         
 82 |         # Register resources and tools
 83 |         logger.info("Registering resources and tools")
 84 |         register_trino_resources(mcp, trino_client)
 85 |         register_trino_tools(mcp, trino_client)
 86 |         
 87 |         # Yield the application context
 88 |         logger.info("Trino MCP server initialized and ready")
 89 |         yield app_context
 90 |     except Exception as e:
 91 |         logger.error(f"Failed to initialize: {e}")
 92 |         app_context.is_healthy = False
 93 |         yield app_context
 94 |     finally:
 95 |         # Cleanup on shutdown
 96 |         logger.info("Shutting down Trino MCP server")
 97 |         if trino_client.conn:
 98 |             trino_client.disconnect()
 99 |         app_context.is_healthy = False
100 | 
101 | 
102 | def parse_args() -> ServerConfig:
103 |     """
104 |     Parse command line arguments and return server configuration.
105 |     
106 |     Returns:
107 |         ServerConfig: The server configuration.
108 |     """
109 |     parser = argparse.ArgumentParser(description="Trino MCP server")
110 |     
111 |     # Server configuration
112 |     parser.add_argument("--name", default="Trino MCP", help="Server name")
113 |     parser.add_argument("--version", default="0.1.0", help="Server version")
114 |     parser.add_argument("--transport", default="stdio", choices=["stdio", "sse"], help="Transport type")
115 |     parser.add_argument("--host", default="127.0.0.1", help="Host for HTTP server (SSE transport only)")
116 |     parser.add_argument("--port", type=int, default=3000, help="Port for HTTP server (SSE transport only)")
117 |     parser.add_argument("--debug", action="store_true", help="Enable debug mode")
118 |     
119 |     # Trino connection
120 |     parser.add_argument("--trino-host", default="localhost", help="Trino host")
121 |     parser.add_argument("--trino-port", type=int, default=8080, help="Trino port")
122 |     parser.add_argument("--trino-user", default="trino", help="Trino user")
123 |     parser.add_argument("--trino-password", help="Trino password")
124 |     parser.add_argument("--trino-catalog", help="Default Trino catalog")
125 |     parser.add_argument("--trino-schema", help="Default Trino schema")
126 |     parser.add_argument("--trino-http-scheme", default="http", help="Trino HTTP scheme")
127 |     
128 |     args = parser.parse_args()
129 |     
130 |     # Create Trino configuration
131 |     trino_config = TrinoConfig(
132 |         host=args.trino_host,
133 |         port=args.trino_port,
134 |         user=args.trino_user,
135 |         password=args.trino_password,
136 |         catalog=args.trino_catalog,
137 |         schema=args.trino_schema,
138 |         http_scheme=args.trino_http_scheme
139 |     )
140 |     
141 |     # Create server configuration
142 |     server_config = ServerConfig(
143 |         name=args.name,
144 |         version=args.version,
145 |         transport_type=args.transport,
146 |         host=args.host,
147 |         port=args.port,
148 |         debug=args.debug,
149 |         trino=trino_config
150 |     )
151 |     
152 |     return server_config
153 | 
154 | 
155 | def create_app() -> FastMCP:
156 |     """
157 |     Create and configure the MCP server application.
158 |     
159 |     Returns:
160 |         FastMCP: The configured MCP server.
161 |     """
162 |     # Create the MCP server with lifespan management
163 |     mcp = FastMCP(
164 |         "Trino MCP",
165 |         dependencies=["trino>=0.329.0"],
166 |         lifespan=app_lifespan
167 |     )
168 |     
169 |     return mcp
170 | 
171 | 
172 | def create_health_app() -> FastAPI:
173 |     """
174 |     Create a FastAPI app that provides a health check endpoint and LLM API.
175 |     
176 |     This function creates a FastAPI app with a health check endpoint and
177 |     a query endpoint for LLMs to use.
178 |     
179 |     Returns:
180 |         FastAPI: The FastAPI app with health check and LLM API endpoints.
181 |     """
182 |     app = FastAPI(
183 |         title="Trino MCP API",
184 |         description="API for health checks and LLM query access to Trino MCP",
185 |         version="0.1.0"
186 |     )
187 |     
188 |     @app.get("/health")
189 |     async def health():
190 |         global app_context_global
191 |         
192 |         # For Docker health check, always return 200 during startup
193 |         # This gives the app time to initialize
194 |         return JSONResponse(
195 |             status_code=200,
196 |             content={"status": "ok", "message": "Health check endpoint is responding"}
197 |         )
198 |     
199 |     @app.post("/api/query", response_model=QueryResponse)
200 |     async def query(request: QueryRequest):
201 |         """
202 |         Execute a SQL query against Trino and return results.
203 |         
204 |         This endpoint is designed to be used by LLMs to query Trino through MCP.
205 |         """
206 |         global app_context_global
207 |         
208 |         if not app_context_global or not app_context_global.is_healthy:
209 |             return JSONResponse(
210 |                 status_code=503,
211 |                 content={
212 |                     "success": False,
213 |                     "message": "Trino MCP server is not healthy or not initialized"
214 |                 }
215 |             )
216 |             
217 |         logger.info(f"LLM API Query: {request.query}")
218 |         
219 |         try:
220 |             # Use the Trino client from the app context
221 |             client = app_context_global.trino_client
222 |             
223 |             # Optionally add EXPLAIN
224 |             query = request.query
225 |             if request.explain:
226 |                 query = f"EXPLAIN {query}"
227 |                 
228 |             # Execute the query
229 |             result = client.execute_query(query, request.catalog, request.schema)
230 |             
231 |             # Format the results for the response
232 |             formatted_rows = []
233 |             for row in result.rows:
234 |                 # Convert row to dict using column names
235 |                 row_dict = {}
236 |                 for i, col in enumerate(result.columns):
237 |                     row_dict[col] = row[i]
238 |                 formatted_rows.append(row_dict)
239 |                 
240 |             return {
241 |                 "success": True,
242 |                 "message": "Query executed successfully",
243 |                 "results": {
244 |                     "query_id": result.query_id,
245 |                     "columns": result.columns,
246 |                     "rows": formatted_rows,
247 |                     "row_count": result.row_count,
248 |                     "execution_time_ms": result.query_time_ms
249 |                 }
250 |             }
251 |             
252 |         except Exception as e:
253 |             logger.error(f"Error executing query: {e}")
254 |             return JSONResponse(
255 |                 status_code=400,
256 |                 content={
257 |                     "success": False,
258 |                     "message": f"Error executing query: {str(e)}"
259 |                 }
260 |             )
261 |     
262 |     @app.get("/api")
263 |     async def api_root():
264 |         """Root API endpoint with usage instructions."""
265 |         return {
266 |             "message": "Trino MCP API for LLMs",
267 |             "version": app_context_global.config.version if app_context_global else "unknown",
268 |             "endpoints": {
269 |                 "health": "GET /health - Check server health",
270 |                 "query": "POST /api/query - Execute SQL queries"
271 |             },
272 |             "query_example": {
273 |                 "query": "SELECT * FROM memory.bullshit.real_bullshit_data LIMIT 3",
274 |                 "catalog": "memory",
275 |                 "schema": "bullshit"
276 |             }
277 |         }
278 |     
279 |     return app
280 | 
281 | 
282 | def main() -> None:
283 |     """
284 |     Main entry point for the server.
285 |     """
286 |     config = parse_args()
287 |     mcp = create_app()
288 |     
289 |     # ADDING EXPLICIT CONTEXT INITIALIZATION HERE
290 |     global app_context_global
291 |     try:
292 |         # Initialize the Trino client
293 |         logger.info(f"Connecting to Trino at {config.trino.host}:{config.trino.port}")
294 |         trino_client = TrinoClient(config.trino)
295 |         trino_client.connect()
296 |         
297 |         # Create application context
298 |         app_context = AppContext(
299 |             trino_client=trino_client,
300 |             config=config,
301 |             is_healthy=True
302 |         )
303 |         
304 |         # Set global context
305 |         app_context_global = app_context
306 |         
307 |         # Register resources and tools
308 |         register_trino_resources(mcp, trino_client)
309 |         register_trino_tools(mcp, trino_client)
310 |         
311 |         logger.info("Trino MCP server initialized and ready")
312 |     except Exception as e:
313 |         logger.error(f"Error initializing Trino MCP: {e}")
314 |         if app_context_global:
315 |             app_context_global.is_healthy = False
316 |     
317 |     if config.transport_type == "stdio":
318 |         # For STDIO transport, run directly
319 |         logger.info("Starting Trino MCP server with STDIO transport")
320 |         mcp.run()
321 |     else:
322 |         # For SSE transport, use run_sse_async method from MCP library
323 |         logger.info(f"Starting Trino MCP server with SSE transport on {config.host}:{config.port}")
324 |         
325 |         # In MCP 1.3.0, run_sse_async takes no arguments
326 |         # We set the environment variables to configure the host and port
327 |         os.environ["MCP_HOST"] = config.host
328 |         os.environ["MCP_PORT"] = str(config.port)
329 |         
330 |         # Configure more robust error handling for the server
331 |         import traceback
332 |         try:
333 |             # Try to import and configure SSE settings if available in this version
334 |             from mcp.server.sse import configure_sse
335 |             configure_sse(ignore_client_disconnect=True)
336 |             logger.info("Configured SSE with ignore_client_disconnect=True")
337 |         except (ImportError, AttributeError):
338 |             logger.warning("Could not configure SSE settings - this may be expected in some MCP versions")
339 |         
340 |         # Start a separate thread for the health check endpoint
341 |         import threading
342 |         
343 |         def run_health_check():
344 |             """Run the health check FastAPI app."""
345 |             health_app = create_health_app()
346 |             # Use a different port for the health check endpoint
347 |             health_port = config.port + 1
348 |             logger.info(f"Starting API server on port {health_port}")
349 |             uvicorn.run(health_app, host=config.host, port=health_port)
350 |         
351 |         # Start the health check in a separate thread
352 |         health_thread = threading.Thread(target=run_health_check)
353 |         health_thread.daemon = True
354 |         health_thread.start()
355 |         
356 |         # Now run the SSE server with robust error handling
357 |         try:
358 |             asyncio.run(mcp.run_sse_async())
359 |         except RuntimeError as e:
360 |             if "generator didn't stop after athrow()" in str(e):
361 |                 logger.error(f"Generator error in SSE server. This is a known issue with MCP 1.3.0: {e}")
362 |                 
363 |                 # Set unhealthy status for health checks
364 |                 if app_context_global:
365 |                     app_context_global.is_healthy = False
366 |                 
367 |                 logger.info("Server will continue running but may not function correctly.")
368 |                 
369 |                 # Keep the server alive despite the error
370 |                 import time
371 |                 while True:
372 |                     time.sleep(60)  # Sleep to keep the container running
373 |                     
374 |             else:
375 |                 logger.error(f"Fatal error running SSE server: {e}")
376 |                 logger.error(traceback.format_exc())
377 |                 raise
378 |         except Exception as e:
379 |             logger.error(f"Fatal error running SSE server: {e}")
380 |             logger.error(traceback.format_exc())
381 |             raise
382 | 
383 | 
384 | if __name__ == "__main__":
385 |     logger.remove()
386 |     logger.add(sys.stderr, level="DEBUG")
387 |     main()
388 | 
```
Page 2/2FirstPrevNextLast