This is page 2 of 3. Use http://codebase.md/ckanthony/openapi-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .github
│ └── workflows
│ ├── ci.yml
│ └── publish.yml
├── .gitignore
├── cmd
│ └── openapi-mcp
│ └── main.go
├── Dockerfile
├── example
│ ├── agent_demo.png
│ ├── docker-compose.yml
│ └── weather
│ ├── .env.example
│ └── weatherbitio-swagger.json
├── go.mod
├── go.sum
├── openapi-mcp.png
├── pkg
│ ├── config
│ │ ├── config_test.go
│ │ └── config.go
│ ├── mcp
│ │ └── types.go
│ ├── parser
│ │ ├── parser_test.go
│ │ └── parser.go
│ └── server
│ ├── manager_test.go
│ ├── manager.go
│ ├── server_test.go
│ └── server.go
└── README.md
```
# Files
--------------------------------------------------------------------------------
/pkg/server/server.go:
--------------------------------------------------------------------------------
```go
1 | package server
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "io"
9 | "log"
10 | "net/http"
11 | "net/url"
12 | "strings"
13 | "sync"
14 | "time"
15 |
16 | // "fmt" // No longer needed here
17 | // "sync" // No longer needed here
18 |
19 | "github.com/ckanthony/openapi-mcp/pkg/config"
20 | "github.com/ckanthony/openapi-mcp/pkg/mcp"
21 | "github.com/google/uuid" // Import UUID package
22 | )
23 |
24 | // --- JSON-RPC Structures (Re-introduced for Handshake/Messages) ---
25 |
26 | type jsonRPCRequest struct {
27 | Jsonrpc string `json:"jsonrpc"`
28 | Method string `json:"method"`
29 | Params interface{} `json:"params,omitempty"`
30 | ID interface{} `json:"id,omitempty"` // Can be string, number, or null
31 | }
32 |
33 | type jsonRPCResponse struct {
34 | Jsonrpc string `json:"jsonrpc"`
35 | Result interface{} `json:"result,omitempty"`
36 | Error *jsonError `json:"error,omitempty"`
37 | ID interface{} `json:"id"` // ID should match the request ID
38 | }
39 |
40 | type jsonError struct {
41 | Code int `json:"code"`
42 | Message string `json:"message"`
43 | Data interface{} `json:"data,omitempty"`
44 | }
45 |
46 | // --- MCP Message Structures (Kept for clarity on expected payloads) ---
47 |
48 | // MCPMessage represents a generic message exchanged over the transport.
49 | // Note: Adapt this structure based on the exact MCP spec requirements if needed.
50 | // This structure is now more for understanding the *payloads* within JSON-RPC.
51 | type MCPMessage struct {
52 | Type string `json:"type"` // e.g., "initialize", "tools/list", "tools/call", "tool_result", "error"
53 | ID string `json:"id,omitempty"` // Unique message ID (less relevant for JSON-RPC wrapper)
54 | Payload json.RawMessage `json:"payload,omitempty"` // Content specific to the message type
55 | ConnID string `json:"connectionId,omitempty"` // Included in responses related to a connection
56 | }
57 |
58 | // MCPError defines a structured error for MCP responses.
59 | // This will be used within the 'Error.Data' field of a jsonRPCResponse.
60 | type MCPError struct {
61 | Code int `json:"code,omitempty"` // Optional error code
62 | Message string `json:"message"`
63 | Data interface{} `json:"data,omitempty"` // Optional additional data
64 | }
65 |
66 | // ToolCallParams represents the expected payload for a tools/call request.
67 | // This will be the structure within the 'params' field of a jsonRPCRequest.
68 | type ToolCallParams struct {
69 | ToolName string `json:"name"` // Aligning with gin-mcp JSON-RPC 'name'
70 | Input map[string]interface{} `json:"arguments"` // Aligning with gin-mcp JSON-RPC 'arguments'
71 | }
72 |
73 | // ToolResultContent represents an item in the 'content' array of a tool_result.
74 | type ToolResultContent struct {
75 | Type string `json:"type"`
76 | Text string `json:"text"` // Assuming text/JSON string result
77 | // Add other content types if needed
78 | }
79 |
80 | // ToolResultPayload represents the structure for the 'result' of a 'tool_result' JSON-RPC response.
81 | type ToolResultPayload struct {
82 | Content []ToolResultContent `json:"content"` // Array of content items
83 | IsError bool `json:"isError"` // Aligning with gin-mcp
84 | Error *MCPError `json:"error,omitempty"` // Detailed error info if IsError is true
85 | ToolCallID string `json:"tool_call_id,omitempty"` // Optional: Can be helpful
86 | }
87 |
88 | // --- Server State ---
89 |
90 | // activeConnections stores channels for sending messages back to active SSE clients.
91 | var activeConnections = make(map[string]chan jsonRPCResponse) // Changed value type
92 | var connMutex sync.RWMutex
93 |
94 | // Channel buffer size
95 | const messageChannelBufferSize = 10
96 |
97 | // --- Server Implementation ---
98 |
99 | // ServeMCP starts an HTTP server handling MCP communication.
100 | func ServeMCP(addr string, toolSet *mcp.ToolSet, cfg *config.Config) error {
101 | log.Printf("Preparing ToolSet for MCP...")
102 |
103 | // --- Handler Functions ---
104 | mcpHandler := func(w http.ResponseWriter, r *http.Request) {
105 | // CORS Headers (Apply to all relevant requests)
106 | w.Header().Set("Access-Control-Allow-Origin", "*") // Be more specific in production
107 | w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
108 | w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-Connection-ID")
109 | w.Header().Set("Access-Control-Expose-Headers", "X-Connection-ID")
110 |
111 | if r.Method == http.MethodOptions {
112 | log.Println("Responding to OPTIONS request")
113 | w.WriteHeader(http.StatusNoContent) // Use 204 No Content for OPTIONS
114 | return
115 | }
116 |
117 | if r.Method == http.MethodGet {
118 | httpMethodGetHandler(w, r) // Handle SSE connection setup
119 | } else if r.Method == http.MethodPost {
120 | httpMethodPostHandler(w, r, toolSet, cfg) // Pass the cfg object here
121 | } else {
122 | log.Printf("Method Not Allowed: %s", r.Method)
123 | http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
124 | }
125 | }
126 |
127 | // Setup server mux
128 | mux := http.NewServeMux()
129 | mux.HandleFunc("/mcp", mcpHandler) // Single endpoint for GET/POST/OPTIONS
130 |
131 | log.Printf("MCP server listening on %s/mcp", addr)
132 | return http.ListenAndServe(addr, mux)
133 | }
134 |
135 | // httpMethodGetHandler handles the initial GET request to establish the SSE connection.
136 | func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) {
137 | connectionID := uuid.New().String()
138 | log.Printf("SSE client connecting: %s (Assigning ID: %s)", r.RemoteAddr, connectionID)
139 |
140 | flusher, ok := w.(http.Flusher)
141 | if !ok {
142 | http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
143 | log.Println("Error: Client connection does not support flushing")
144 | return
145 | }
146 |
147 | // --- Set headers FIRST ---
148 | w.Header().Set("Content-Type", "text/event-stream")
149 | w.Header().Set("Cache-Control", "no-cache")
150 | w.Header().Set("Connection", "keep-alive")
151 | // CORS headers are set in the main handler
152 | w.Header().Set("X-Connection-ID", connectionID)
153 | w.Header().Set("X-Accel-Buffering", "no") // Useful for proxies like Nginx
154 | w.WriteHeader(http.StatusOK) // Write headers and status code
155 | flusher.Flush() // Ensure headers are sent immediately
156 |
157 | // --- Send initial :ok --- (Must happen *after* headers)
158 | if _, err := fmt.Fprintf(w, ":ok\n\n"); err != nil {
159 | log.Printf("Error sending SSE preamble to %s (ID: %s): %v", r.RemoteAddr, connectionID, err)
160 | return // Cannot proceed if preamble fails
161 | }
162 | flusher.Flush()
163 | log.Printf("Sent :ok preamble to %s (ID: %s)", r.RemoteAddr, connectionID)
164 |
165 | // --- Send initial SSE events --- (endpoint, mcp-ready)
166 | endpointURL := fmt.Sprintf("/mcp?sessionId=%s", connectionID) // Assuming /mcp is the mount path
167 | if err := writeSSEEvent(w, "endpoint", endpointURL); err != nil {
168 | log.Printf("Error sending SSE endpoint event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err)
169 | return
170 | }
171 | flusher.Flush()
172 | log.Printf("Sent endpoint event to %s (ID: %s)", r.RemoteAddr, connectionID)
173 |
174 | readyMsg := jsonRPCRequest{ // Use request struct for notification format
175 | Jsonrpc: "2.0",
176 | Method: "mcp-ready",
177 | Params: map[string]interface{}{ // Put data in params
178 | "connectionId": connectionID,
179 | "status": "connected",
180 | "protocol": "2.0",
181 | },
182 | }
183 | if err := writeSSEEvent(w, "message", readyMsg); err != nil {
184 | log.Printf("Error sending SSE mcp-ready event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err)
185 | return
186 | }
187 | flusher.Flush()
188 | log.Printf("Sent mcp-ready event to %s (ID: %s)", r.RemoteAddr, connectionID)
189 |
190 | // --- Setup message channel and store connection ---
191 | msgChan := make(chan jsonRPCResponse, messageChannelBufferSize) // Channel for responses
192 | connMutex.Lock()
193 | activeConnections[connectionID] = msgChan
194 | connMutex.Unlock()
195 | log.Printf("Registered channel for connection %s. Active connections: %d", connectionID, len(activeConnections))
196 |
197 | // --- Cleanup function ---
198 | cleanup := func() {
199 | connMutex.Lock()
200 | delete(activeConnections, connectionID)
201 | connMutex.Unlock()
202 | close(msgChan) // Close channel when connection ends
203 | log.Printf("Removed connection %s. Active connections: %d", connectionID, len(activeConnections))
204 | }
205 | defer cleanup()
206 |
207 | // --- Goroutine to write messages from channel to SSE stream ---
208 | ctx, cancel := context.WithCancel(r.Context())
209 | defer cancel()
210 |
211 | go func() {
212 | log.Printf("[SSE Writer %s] Starting message writer goroutine", connectionID)
213 | defer log.Printf("[SSE Writer %s] Exiting message writer goroutine", connectionID)
214 | for {
215 | select {
216 | case <-ctx.Done():
217 | return // Exit if main context is cancelled
218 | case resp, ok := <-msgChan:
219 | if !ok {
220 | log.Printf("[SSE Writer %s] Message channel closed.", connectionID)
221 | return // Exit if channel is closed
222 | }
223 | log.Printf("[SSE Writer %s] Sending message (ID: %v) via SSE", connectionID, resp.ID)
224 | if err := writeSSEEvent(w, "message", resp); err != nil {
225 | log.Printf("[SSE Writer %s] Error writing message to SSE stream: %v. Cancelling context.", connectionID, err)
226 | cancel() // Signal main loop to exit on write error
227 | return
228 | }
229 | flusher.Flush() // Flush after writing message
230 | }
231 | }
232 | }()
233 |
234 | // --- Keep connection alive (main loop) ---
235 | keepAliveTicker := time.NewTicker(20 * time.Second)
236 | defer keepAliveTicker.Stop()
237 |
238 | log.Printf("[SSE %s] Entering keep-alive loop", connectionID)
239 | for {
240 | select {
241 | case <-ctx.Done():
242 | log.Printf("[SSE %s] Context done. Exiting keep-alive loop.", connectionID)
243 | return // Exit loop if context cancelled (client disconnect or write error)
244 | case <-keepAliveTicker.C:
245 | // Send JSON-RPC ping notification instead of SSE comment
246 | pingMsg := jsonRPCRequest{ // Use request struct for notification format
247 | Jsonrpc: "2.0",
248 | Method: "ping",
249 | Params: map[string]interface{}{ // Include timestamp like gin-mcp
250 | "timestamp": time.Now().Unix(),
251 | },
252 | }
253 | if err := writeSSEEvent(w, "message", pingMsg); err != nil {
254 | log.Printf("[SSE %s] Error sending ping notification: %v. Closing connection.", connectionID, err)
255 | cancel() // Signal writer goroutine and exit
256 | return
257 | }
258 | flusher.Flush()
259 | }
260 | }
261 | }
262 |
263 | // writeSSEEvent formats and writes data as a Server-Sent Event.
264 | func writeSSEEvent(w http.ResponseWriter, eventName string, data interface{}) error {
265 | buffer := bytes.Buffer{}
266 | if eventName != "" {
267 | buffer.WriteString(fmt.Sprintf("event: %s\n", eventName))
268 | }
269 |
270 | // Marshal data to JSON if it's not a simple string already
271 | var dataStr string
272 | if strData, ok := data.(string); ok && eventName == "endpoint" { // Special case for endpoint URL
273 | dataStr = strData
274 | } else {
275 | jsonData, err := json.Marshal(data)
276 | if err != nil {
277 | return fmt.Errorf("failed to marshal data for SSE event '%s': %w", eventName, err)
278 | }
279 | dataStr = string(jsonData)
280 | }
281 |
282 | // Write data line(s). Split multiline JSON for proper SSE formatting.
283 | lines := strings.Split(dataStr, "\n")
284 | for _, line := range lines {
285 | buffer.WriteString(fmt.Sprintf("data: %s\n", line))
286 | }
287 |
288 | // Add final newline
289 | buffer.WriteString("\n")
290 |
291 | // Write to the response writer
292 | _, err := w.Write(buffer.Bytes())
293 | if err != nil {
294 | return fmt.Errorf("failed to write SSE event '%s': %w", eventName, err)
295 | }
296 | return nil
297 | }
298 |
299 | // httpMethodPostHandler handles incoming POST requests containing MCP messages.
300 | func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp.ToolSet, cfg *config.Config) {
301 | // --- Original Logic (Restored) ---
302 | connID := r.Header.Get("X-Connection-ID") // Try header first
303 | if connID == "" {
304 | connID = r.URL.Query().Get("sessionId") // Fallback to query parameter
305 | log.Printf("X-Connection-ID header missing, checking sessionId query param: found='%s'", connID)
306 | }
307 |
308 | if connID == "" {
309 | log.Println("Error: POST request received without X-Connection-ID header or sessionId query parameter")
310 | http.Error(w, "Missing X-Connection-ID header or sessionId query parameter", http.StatusBadRequest)
311 | return
312 | }
313 |
314 | // Find the corresponding message channel for this connection
315 | connMutex.RLock()
316 | msgChan, isActive := activeConnections[connID]
317 | connMutex.RUnlock()
318 |
319 | if !isActive {
320 | log.Printf("Error: POST request received for inactive/unknown connection ID: %s", connID)
321 | // Still send sync error here, as we don't have a channel
322 | tryWriteHTTPError(w, http.StatusNotFound, "Invalid or expired connection ID")
323 | return
324 | }
325 |
326 | bodyBytes, err := io.ReadAll(r.Body)
327 | if err != nil {
328 | log.Printf("Error reading POST request body for %s: %v", connID, err)
329 | // Create error response in the ToolResultPayload format
330 | errPayload := ToolResultPayload{
331 | IsError: true,
332 | Error: &MCPError{
333 | Code: -32700, // JSON-RPC Parse Error Code
334 | Message: "Parse error reading request body",
335 | },
336 | // ToolCallID doesn't really apply here, maybe use connID or leave empty?
337 | // ToolCallID: connID,
338 | }
339 | errResp := jsonRPCResponse{
340 | Jsonrpc: "2.0",
341 | ID: nil, // ID is unknown if we can't read the body
342 | Result: errPayload,
343 | Error: nil, // Ensure top-level error is nil
344 | }
345 | // Attempt to send via SSE channel
346 | select {
347 | case msgChan <- errResp:
348 | log.Printf("Queued read error response (ID: %v) for %s onto SSE channel (as Result)", errResp.ID, connID)
349 | // Send HTTP 202 Accepted back to the POST request
350 | w.WriteHeader(http.StatusAccepted)
351 | fmt.Fprintln(w, "Request accepted (with parse error), response will be sent via SSE.")
352 | default:
353 | log.Printf("Error: Failed to queue read error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID)
354 | // Send an error back on the POST request if channel fails
355 | tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel")
356 | }
357 | return // Stop processing
358 | }
359 | // No defer r.Body.Close() needed here as io.ReadAll reads to EOF
360 |
361 | log.Printf("Received POST data for %s: %s", connID, string(bodyBytes))
362 |
363 | // Attempt to unmarshal into a temporary map first to extract ID if possible
364 | var rawReq map[string]interface{}
365 | var reqID interface{} // Keep track of ID even if full unmarshal fails
366 |
367 | // Try unmarshalling into raw map
368 | if err := json.Unmarshal(bodyBytes, &rawReq); err == nil {
369 | // Ensure reqID is treated as a string or number if possible, handle potential null
370 | if idVal, idExists := rawReq["id"]; idExists && idVal != nil {
371 | reqID = idVal
372 | } else {
373 | reqID = nil // Explicitly set to nil if missing or JSON null
374 | }
375 | } else {
376 | // Full unmarshal failed, log it but continue to try specific struct
377 | log.Printf("Warning: Initial unmarshal into map failed for %s: %v. Will attempt specific struct unmarshal.", connID, err)
378 | reqID = nil // ID is unknown
379 | }
380 |
381 | var req jsonRPCRequest // Expect JSON-RPC request
382 | if err := json.Unmarshal(bodyBytes, &req); err != nil {
383 | log.Printf("Error decoding JSON-RPC request for %s: %v", connID, err)
384 | // Use createJSONRPCError to correctly format the error response
385 | errResp := createJSONRPCError(reqID, -32700, "Parse error decoding JSON request", err.Error())
386 |
387 | // Attempt to send via SSE channel
388 | select {
389 | case msgChan <- errResp:
390 | log.Printf("Queued decode error response (ID: %v) for %s onto SSE channel", errResp.ID, connID)
391 | // Send HTTP 202 Accepted back to the POST request
392 | w.WriteHeader(http.StatusAccepted)
393 | // Use a specific message for decode errors
394 | fmt.Fprintln(w, "Request accepted (with decode error), response will be sent via SSE.")
395 | default:
396 | log.Printf("Error: Failed to queue decode error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID)
397 | // Send an error back on the POST request if channel fails
398 | tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel")
399 | }
400 | return // Stop processing
401 | }
402 |
403 | // If we successfully unmarshalled 'req', ensure reqID matches req.ID
404 | if req.ID != nil {
405 | reqID = req.ID
406 | } else {
407 | reqID = nil
408 | }
409 |
410 | // --- Variable to hold the final response to be sent via SSE ---
411 | var respToSend jsonRPCResponse
412 |
413 | // --- Validate JSON-RPC Request ---
414 | if req.Jsonrpc != "2.0" {
415 | log.Printf("Invalid JSON-RPC version ('%s') for %s, ID: %v", req.Jsonrpc, connID, reqID)
416 | respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: jsonrpc field must be \"2.0\"", nil)
417 | } else if req.Method == "" {
418 | log.Printf("Missing JSON-RPC method for %s, ID: %v", connID, reqID)
419 | respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: method field is missing or empty", nil)
420 | } else {
421 | // --- Process the valid request ---
422 | log.Printf("Processing JSON-RPC message for %s: Method=%s, ID=%v", connID, req.Method, reqID)
423 | switch req.Method {
424 | case "initialize":
425 | incomingInitializeJSON, _ := json.Marshal(req)
426 | log.Printf("DEBUG: Handling 'initialize' for %s. Incoming request: %s", connID, string(incomingInitializeJSON))
427 | respToSend = handleInitializeJSONRPC(connID, &req)
428 | outgoingInitializeJSON, _ := json.Marshal(respToSend)
429 | log.Printf("DEBUG: Prepared 'initialize' response for %s. Outgoing response: %s", connID, string(outgoingInitializeJSON))
430 | case "notifications/initialized":
431 | log.Printf("Received 'notifications/initialized' notification for %s. Ignoring.", connID)
432 | w.WriteHeader(http.StatusAccepted)
433 | fmt.Fprintln(w, "Notification received.")
434 | return // Return early, do not send anything on SSE channel
435 | case "tools/list":
436 | respToSend = handleToolsListJSONRPC(connID, &req, toolSet)
437 | case "tools/call":
438 | respToSend = handleToolCallJSONRPC(connID, &req, toolSet, cfg)
439 | default:
440 | log.Printf("Received unknown JSON-RPC method '%s' for %s", req.Method, connID)
441 | respToSend = createJSONRPCError(reqID, -32601, fmt.Sprintf("Method not found: %s", req.Method), nil)
442 | }
443 | }
444 |
445 | // --- Send response ASYNCHRONOUSLY via SSE channel (unless handled earlier) ---
446 | select {
447 | case msgChan <- respToSend:
448 | log.Printf("Queued response (ID: %v) for %s onto SSE channel", respToSend.ID, connID)
449 | // Send HTTP 202 Accepted back to the POST request
450 | w.WriteHeader(http.StatusAccepted)
451 | // Use the standard message for successfully queued responses
452 | fmt.Fprintln(w, "Request accepted, response will be sent via SSE.")
453 | default:
454 | log.Printf("Error: Failed to queue response (ID: %v) for %s - SSE channel likely full or closed.", respToSend.ID, connID)
455 | http.Error(w, "Failed to queue response for SSE channel", http.StatusInternalServerError)
456 | }
457 | }
458 |
459 | // --- JSON-RPC Message Handlers --- // Implementations returning jsonRPCResponse
460 |
461 | func handleInitializeJSONRPC(connID string, req *jsonRPCRequest) jsonRPCResponse {
462 | log.Printf("Handling 'initialize' (JSON-RPC) for %s", connID)
463 |
464 | // Construct the result payload based on gin-mcp's structure using map[string]interface{}
465 | resultPayload := map[string]interface{}{
466 | "protocolVersion": "2024-11-05", // Aligning with gin-mcp
467 | "capabilities": map[string]interface{}{
468 | "tools": map[string]interface{}{
469 | "enabled": true,
470 | "config": map[string]interface{}{
471 | "listChanged": false,
472 | },
473 | },
474 | "prompts": map[string]interface{}{
475 | "enabled": false,
476 | },
477 | "resources": map[string]interface{}{
478 | "enabled": true,
479 | },
480 | "logging": map[string]interface{}{
481 | "enabled": false,
482 | },
483 | "roots": map[string]interface{}{
484 | "listChanged": false,
485 | },
486 | },
487 | "serverInfo": map[string]interface{}{
488 | "name": "OpenAPI-MCP", // Or use config name if available
489 | "version": "openapi-mcp-0.1.0", // Your server version
490 | "apiVersion": "2024-11-05", // MCP API version
491 | },
492 | "connectionId": connID, // Include the connection ID
493 | }
494 |
495 | return jsonRPCResponse{
496 | Jsonrpc: "2.0",
497 | ID: req.ID, // Match request ID
498 | Result: resultPayload,
499 | }
500 | }
501 |
502 | func handleToolsListJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.ToolSet) jsonRPCResponse {
503 | log.Printf("Handling 'tools/list' (JSON-RPC) for %s", connID)
504 |
505 | // Construct the result payload based on gin-mcp's structure
506 | resultPayload := map[string]interface{}{
507 | "tools": toolSet.Tools,
508 | "metadata": map[string]interface{}{
509 | "version": "2024-11-05", // Align with gin-mcp if possible
510 | "count": len(toolSet.Tools),
511 | },
512 | }
513 |
514 | return jsonRPCResponse{
515 | Jsonrpc: "2.0",
516 | ID: req.ID, // Match request ID
517 | Result: resultPayload,
518 | }
519 | }
520 |
521 | // executeToolCall performs the actual HTTP request based on the resolved operation and parameters.
522 | // It now correctly handles API key injection based on the *cfg* parameter.
523 | func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.Config) (*http.Response, error) {
524 | toolName := params.ToolName
525 | toolInput := params.Input // This is the map[string]interface{} from the client
526 |
527 | log.Printf("[ExecuteToolCall] Looking up details for tool: %s", toolName)
528 | operation, ok := toolSet.Operations[toolName]
529 | if !ok {
530 | log.Printf("[ExecuteToolCall] Error: Operation details not found for tool '%s'", toolName)
531 | return nil, fmt.Errorf("operation details for tool '%s' not found", toolName)
532 | }
533 | log.Printf("[ExecuteToolCall] Found operation: Method=%s, Path=%s", operation.Method, operation.Path)
534 |
535 | // --- Resolve API Key (using cfg passed from main) ---
536 | resolvedKey := cfg.GetAPIKey()
537 | apiKeyName := cfg.APIKeyName
538 | apiKeyLocation := cfg.APIKeyLocation
539 | hasServerKey := resolvedKey != "" && apiKeyName != "" && apiKeyLocation != ""
540 |
541 | log.Printf("[ExecuteToolCall] API Key Details: Name='%s', In='%s', HasServerValue=%t", apiKeyName, apiKeyLocation, resolvedKey != "")
542 |
543 | // --- Prepare Request Components ---
544 | baseURL := operation.BaseURL // Use BaseURL from the specific operation
545 | if cfg.ServerBaseURL != "" {
546 | baseURL = cfg.ServerBaseURL // Override if global base URL is set
547 | log.Printf("[ExecuteToolCall] Overriding base URL with global config: %s", baseURL)
548 | }
549 | if baseURL == "" {
550 | log.Printf("[ExecuteToolCall] Warning: No base URL found for operation %s and no global override set.", toolName)
551 | // For now, assume relative if empty.
552 | }
553 |
554 | path := operation.Path
555 | queryParams := url.Values{}
556 | pathParams := make(map[string]string)
557 | headerParams := make(http.Header) // For headers to add
558 | cookieParams := []*http.Cookie{} // For cookies to add
559 | bodyData := make(map[string]interface{}) // For building the request body
560 | requestBodyRequired := operation.Method == "POST" || operation.Method == "PUT" || operation.Method == "PATCH"
561 |
562 | // Create a map of expected parameters from the operation details for easier lookup
563 | expectedParams := make(map[string]string) // Map param name to its location ('in')
564 | for _, p := range operation.Parameters {
565 | expectedParams[p.Name] = p.In
566 | }
567 |
568 | // --- Process Input Parameters (Separating and Handling API Key Override) ---
569 | log.Printf("[ExecuteToolCall] Processing %d input parameters...", len(toolInput))
570 | for key, value := range toolInput {
571 | // --- API Key Override Check ---
572 | // If this input param is the API key AND we have a valid server key config,
573 | // skip processing the client's value entirely.
574 | if hasServerKey && key == apiKeyName {
575 | log.Printf("[ExecuteToolCall] Skipping client-provided param '%s' due to server API key override.", key)
576 | continue
577 | }
578 | // --- End API Key Override ---
579 |
580 | paramLocation, knownParam := expectedParams[key]
581 | pathPlaceholder := "{" + key + "}" // OpenAPI uses {param}
582 |
583 | if strings.Contains(path, pathPlaceholder) {
584 | // Handle path parameter substitution
585 | pathParams[key] = fmt.Sprintf("%v", value)
586 | log.Printf("[ExecuteToolCall] Found path parameter %s=%v", key, value)
587 | } else if knownParam {
588 | // Handle parameters defined in the spec (query, header, cookie)
589 | switch paramLocation {
590 | case "query":
591 | queryParams.Add(key, fmt.Sprintf("%v", value))
592 | log.Printf("[ExecuteToolCall] Found query parameter %s=%v (from spec)", key, value)
593 | case "header":
594 | headerParams.Add(key, fmt.Sprintf("%v", value))
595 | log.Printf("[ExecuteToolCall] Found header parameter %s=%v (from spec)", key, value)
596 | case "cookie":
597 | cookieParams = append(cookieParams, &http.Cookie{Name: key, Value: fmt.Sprintf("%v", value)})
598 | log.Printf("[ExecuteToolCall] Found cookie parameter %s=%v (from spec)", key, value)
599 | // case "formData": // TODO: Handle form data if needed
600 | // bodyData[key] = value // Or handle differently based on content type
601 | // log.Printf("[ExecuteToolCall] Found formData parameter %s=%v (from spec)", key, value)
602 | default:
603 | // Known parameter but location handling is missing or mismatched.
604 | if paramLocation == "path" && (operation.Method == "GET" || operation.Method == "DELETE") {
605 | // If spec says 'path' but it wasn't in the actual path, and it's a GET/DELETE,
606 | // treat it as a query parameter as a fallback.
607 | log.Printf("[ExecuteToolCall] Warning: Parameter '%s' is 'path' in spec but not in URL path '%s'. Adding to query parameters as fallback for GET/DELETE.", key, operation.Path)
608 | queryParams.Add(key, fmt.Sprintf("%v", value))
609 | } else {
610 | // Otherwise, log the warning and ignore.
611 | log.Printf("[ExecuteToolCall] Warning: Parameter '%s' has unsupported or unhandled location '%s' in spec. Ignoring.", key, paramLocation)
612 | }
613 | }
614 | } else if requestBodyRequired {
615 | // If parameter is not in path or defined in spec params, and method expects a body,
616 | // assume it belongs in the request body.
617 | bodyData[key] = value
618 | log.Printf("[ExecuteToolCall] Added body parameter %s=%v (assumed)", key, value)
619 | } else {
620 | // Parameter not in path, not in spec, and not a body method.
621 | // This could be an extraneous parameter like 'explanation'. Log it.
622 | log.Printf("[ExecuteToolCall] Ignoring parameter '%s' as it doesn't match path or known parameter location for method %s.", key, operation.Method)
623 | }
624 | }
625 |
626 | // --- Substitute Path Parameters ---
627 | for key, value := range pathParams {
628 | path = strings.Replace(path, "{"+key+"}", value, -1)
629 | }
630 |
631 | // --- Inject Server API Key (if applicable) ---
632 | if hasServerKey {
633 | log.Printf("[ExecuteToolCall] Injecting server API key (Name: %s, Location: %s)", apiKeyName, string(apiKeyLocation))
634 | switch apiKeyLocation {
635 | case config.APIKeyLocationQuery:
636 | queryParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value
637 | log.Printf("[ExecuteToolCall] Injected API key '%s' into query parameters", apiKeyName)
638 | case config.APIKeyLocationHeader:
639 | headerParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value
640 | log.Printf("[ExecuteToolCall] Injected API key '%s' into headers", apiKeyName)
641 | case config.APIKeyLocationPath:
642 | pathPlaceholder := "{" + apiKeyName + "}"
643 | if strings.Contains(path, pathPlaceholder) {
644 | path = strings.Replace(path, pathPlaceholder, resolvedKey, -1)
645 | log.Printf("[ExecuteToolCall] Injected API key into path parameter '%s'", apiKeyName)
646 | } else {
647 | log.Printf("[ExecuteToolCall] Warning: API key location is 'path' but placeholder '%s' not found in final path '%s' for injection.", pathPlaceholder, path)
648 | }
649 | case config.APIKeyLocationCookie:
650 | // Check if cookie already exists from input, replace if so
651 | foundCookie := false
652 | for i, c := range cookieParams {
653 | if c.Name == apiKeyName {
654 | log.Printf("[ExecuteToolCall] Replacing existing cookie '%s' with injected API key.", apiKeyName)
655 | cookieParams[i] = &http.Cookie{Name: apiKeyName, Value: resolvedKey} // Replace existing
656 | foundCookie = true
657 | break
658 | }
659 | }
660 | if !foundCookie {
661 | log.Printf("[ExecuteToolCall] Adding new cookie '%s' with injected API key.", apiKeyName)
662 | cookieParams = append(cookieParams, &http.Cookie{Name: apiKeyName, Value: resolvedKey}) // Append new
663 | }
664 | default:
665 | // Use log.Printf for consistency
666 | log.Printf("Warning: Unsupported API key location specified in config: '%s'", apiKeyLocation)
667 | }
668 | } else {
669 | log.Printf("[ExecuteToolCall] Skipping server API key injection (config incomplete or key unresolved).")
670 | }
671 |
672 | // --- Final URL Construction ---
673 | // Reconstruct query string *after* potential API key injection
674 | targetURL := baseURL + path
675 | if len(queryParams) > 0 {
676 | targetURL += "?" + queryParams.Encode()
677 | }
678 | log.Printf("[ExecuteToolCall] Final Target URL: %s %s", operation.Method, targetURL)
679 |
680 | // --- Prepare Request Body ---
681 | var reqBody io.Reader
682 | var bodyBytes []byte // Keep for logging
683 | if requestBodyRequired && len(bodyData) > 0 {
684 | var err error
685 | bodyBytes, err = json.Marshal(bodyData)
686 | if err != nil {
687 | log.Printf("[ExecuteToolCall] Error marshalling request body: %v", err)
688 | return nil, fmt.Errorf("error marshalling request body: %w", err)
689 | }
690 | reqBody = bytes.NewBuffer(bodyBytes)
691 | log.Printf("[ExecuteToolCall] Request body: %s", string(bodyBytes))
692 | }
693 |
694 | // --- Create HTTP Request ---
695 | req, err := http.NewRequest(operation.Method, targetURL, reqBody)
696 | if err != nil {
697 | log.Printf("[ExecuteToolCall] Error creating HTTP request: %v", err)
698 | return nil, fmt.Errorf("error creating request: %w", err)
699 | }
700 |
701 | // --- Set Headers ---
702 | // Default headers
703 | req.Header.Set("Accept", "application/json") // Assume JSON response typical for APIs
704 | if reqBody != nil {
705 | req.Header.Set("Content-Type", "application/json") // Assume JSON body if body exists
706 | }
707 |
708 | // Add headers collected from input/spec AND potentially injected API key
709 | for key, values := range headerParams {
710 | // Note: We use Set, assuming single value per header from input typically.
711 | // If multi-value headers are needed from spec/input, use Add.
712 | if len(values) > 0 {
713 | req.Header.Set(key, values[0])
714 | }
715 | }
716 |
717 | // Add custom headers from config (comma-separated)
718 | if cfg.CustomHeaders != "" {
719 | headers := strings.Split(cfg.CustomHeaders, ",")
720 | for _, h := range headers {
721 | parts := strings.SplitN(h, ":", 2)
722 | if len(parts) == 2 {
723 | headerName := strings.TrimSpace(parts[0])
724 | headerValue := strings.TrimSpace(parts[1])
725 | if headerName != "" {
726 | req.Header.Set(headerName, headerValue) // Set overrides potential input
727 | log.Printf("[ExecuteToolCall] Added custom header from config: %s", headerName)
728 | }
729 | }
730 | }
731 | }
732 |
733 | // --- Add Cookies ---
734 | for _, cookie := range cookieParams {
735 | req.AddCookie(cookie)
736 | }
737 |
738 | log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header)
739 | if len(req.Cookies()) > 0 {
740 | log.Printf("[ExecuteToolCall] Sending request with cookies: %+v", req.Cookies())
741 | }
742 |
743 | // --- Execute HTTP Request ---
744 | log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header)
745 | client := &http.Client{Timeout: 30 * time.Second}
746 | resp, err := client.Do(req)
747 | if err != nil {
748 | log.Printf("[ExecuteToolCall] Error executing HTTP request: %v", err)
749 | return nil, fmt.Errorf("error executing request: %w", err)
750 | }
751 |
752 | log.Printf("[ExecuteToolCall] Request executed. Status Code: %d", resp.StatusCode)
753 | // Note: Don't close resp.Body here, the caller (handleToolCallJSONRPC) needs it.
754 | return resp, nil
755 | }
756 |
757 | func handleToolCallJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.ToolSet, cfg *config.Config) jsonRPCResponse {
758 | // req.Params is interface{}, but should contain json.RawMessage for tools/call
759 | rawParams, ok := req.Params.(json.RawMessage)
760 | if !ok {
761 | // If it's not RawMessage, maybe it was already decoded to a map? Handle that case too.
762 | if paramsMap, mapOk := req.Params.(map[string]interface{}); mapOk {
763 | // Attempt to marshal the map back to JSON bytes
764 | var marshalErr error
765 | rawParams, marshalErr = json.Marshal(paramsMap)
766 | if marshalErr != nil {
767 | log.Printf("Error marshalling params map for %s: %v", connID, marshalErr)
768 | return createJSONRPCError(req.ID, -32602, "Invalid parameters format (map marshal failed)", marshalErr.Error())
769 | }
770 | log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from map)", connID, string(rawParams))
771 | } else {
772 | log.Printf("Invalid parameters format for tools/call (not json.RawMessage or map[string]interface{}): %T", req.Params)
773 | return createJSONRPCError(req.ID, -32602, "Invalid parameters format (expected JSON object)", nil)
774 | }
775 | } else {
776 | log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from RawMessage)", connID, string(rawParams))
777 | }
778 |
779 | // Now, unmarshal the rawParams ([]byte) into ToolCallParams
780 | var params ToolCallParams
781 | if err := json.Unmarshal(rawParams, ¶ms); err != nil {
782 | log.Printf("Error unmarshalling tools/call params for %s: %v", connID, err)
783 | return createJSONRPCError(req.ID, -32602, "Invalid parameters structure (unmarshal)", err.Error())
784 | }
785 |
786 | log.Printf("Executing tool '%s' for %s with input: %+v", params.ToolName, connID, params.Input)
787 |
788 | // --- Execute the actual tool call ---
789 | httpResp, execErr := executeToolCall(¶ms, toolSet, cfg)
790 |
791 | // --- Process Response ---
792 | var resultPayload ToolResultPayload
793 | if execErr != nil {
794 | log.Printf("Error executing tool call '%s': %v", params.ToolName, execErr)
795 | resultPayload = ToolResultPayload{
796 | IsError: true,
797 | Error: &MCPError{
798 | Message: fmt.Sprintf("Failed to execute tool '%s': %v", params.ToolName, execErr),
799 | },
800 | ToolCallID: fmt.Sprintf("%v", req.ID),
801 | }
802 | } else {
803 | defer httpResp.Body.Close() // Ensure body is closed
804 | bodyBytes, readErr := io.ReadAll(httpResp.Body)
805 | if readErr != nil {
806 | log.Printf("Error reading response body for tool '%s': %v", params.ToolName, readErr)
807 | resultPayload = ToolResultPayload{
808 | IsError: true,
809 | Error: &MCPError{
810 | Message: fmt.Sprintf("Failed to read response from tool '%s': %v", params.ToolName, readErr),
811 | },
812 | ToolCallID: fmt.Sprintf("%v", req.ID),
813 | }
814 | } else {
815 | log.Printf("Received response body for tool '%s': %s", params.ToolName, string(bodyBytes))
816 | // Check status code for API-level errors
817 | if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
818 | resultPayload = ToolResultPayload{
819 | IsError: true,
820 | Error: &MCPError{
821 | Code: httpResp.StatusCode,
822 | Message: fmt.Sprintf("Tool '%s' API call failed with status %s", params.ToolName, httpResp.Status),
823 | Data: string(bodyBytes), // Include response body in error data
824 | },
825 | ToolCallID: fmt.Sprintf("%v", req.ID),
826 | }
827 | } else {
828 | // Successful execution
829 | resultContent := []ToolResultContent{
830 | {
831 | Type: "text", // TODO: Handle JSON responses properly if Content-Type indicates it
832 | Text: string(bodyBytes),
833 | },
834 | }
835 | resultPayload = ToolResultPayload{
836 | Content: resultContent,
837 | IsError: false,
838 | ToolCallID: fmt.Sprintf("%v", req.ID),
839 | }
840 | }
841 | }
842 | }
843 |
844 | // --- Send Response ---
845 | return jsonRPCResponse{
846 | Jsonrpc: "2.0",
847 | ID: req.ID, // Match request ID
848 | Result: resultPayload, // Use the actual result payload
849 | }
850 | }
851 |
852 | // --- Helper Functions (Updated for JSON-RPC) ---
853 |
854 | // sendJSONRPCResponse sends a JSON-RPC response *synchronously*.
855 | // Keep this for now for sending synchronous errors on POST decode/read failures.
856 | func sendJSONRPCResponse(w http.ResponseWriter, resp jsonRPCResponse) {
857 | w.Header().Set("Content-Type", "application/json")
858 | if err := json.NewEncoder(w).Encode(resp); err != nil {
859 | log.Printf("Error encoding JSON-RPC response (ID: %v) for ConnID %v: %v", resp.ID, resp.Error, err)
860 | // Attempt to send a plain text error if JSON encoding fails
861 | tryWriteHTTPError(w, http.StatusInternalServerError, "Internal Server Error encoding JSON-RPC response")
862 | }
863 | log.Printf("Sent JSON-RPC response: Method=%s, ID=%v", getMethodFromResponse(resp), resp.ID)
864 | }
865 |
866 | // createJSONRPCError creates a JSON-RPC error response.
867 | func createJSONRPCError(id interface{}, code int, message string, data interface{}) jsonRPCResponse {
868 | jsonErr := &jsonError{Code: code, Message: message, Data: data}
869 | return jsonRPCResponse{
870 | Jsonrpc: "2.0",
871 | ID: id, // Error response should echo the request ID
872 | Error: jsonErr,
873 | }
874 | }
875 |
876 | // sendJSONRPCError sends a JSON-RPC error response.
877 | func sendJSONRPCError(w http.ResponseWriter, connID string, id interface{}, code int, message string, data interface{}) {
878 | resp := createJSONRPCError(id, code, message, data)
879 | log.Printf("Sending JSON-RPC Error for ConnID %s, ID %v: Code=%d, Message='%s'", connID, id, code, message)
880 | sendJSONRPCResponse(w, resp)
881 | }
882 |
883 | // Helper to get the method name for logging purposes (from the result/error structure if possible)
884 | func getMethodFromResponse(resp jsonRPCResponse) string {
885 | if resp.Result != nil {
886 | // Attempt to infer method from result structure if it has a type field
887 | if resMap, ok := resp.Result.(map[string]interface{}); ok {
888 | if methodType, typeOk := resMap["type"].(string); typeOk {
889 | return methodType + "_result"
890 | }
891 | }
892 | // Infer based on known result types if possible
893 | if _, ok := resp.Result.(map[string]interface{}); ok && resp.Result.(map[string]interface{})["tools"] != nil {
894 | return "tool_set"
895 | }
896 | // If not easily identifiable, just indicate success
897 | return "success"
898 | } else if resp.Error != nil {
899 | return "error"
900 | }
901 | return "unknown"
902 | }
903 |
904 | // tryWriteHTTPError attempts to write an HTTP error, ignoring failures.
905 | func tryWriteHTTPError(w http.ResponseWriter, code int, message string) {
906 | if _, err := w.Write([]byte(message)); err != nil {
907 | log.Printf("Error writing plain HTTP error response: %v", err)
908 | }
909 | log.Printf("Sent plain HTTP error: %s (Code: %d)", message, code)
910 | }
911 |
```
--------------------------------------------------------------------------------
/pkg/server/server_test.go:
--------------------------------------------------------------------------------
```go
1 | package server
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "io"
9 | "log"
10 | "net/http"
11 | "net/http/httptest"
12 | "strings"
13 | "sync"
14 | "testing"
15 | "time"
16 |
17 | "github.com/ckanthony/openapi-mcp/pkg/config"
18 | "github.com/ckanthony/openapi-mcp/pkg/mcp"
19 | "github.com/google/uuid"
20 | "github.com/stretchr/testify/assert"
21 | "github.com/stretchr/testify/require"
22 | )
23 |
24 | // --- Re-added Helper Functions ---
25 |
26 | // Helper function to create a simple ToolSet for testing tool calls
27 | func createTestToolSetForCall() *mcp.ToolSet {
28 | return &mcp.ToolSet{
29 | Name: "Call Test API",
30 | Tools: []mcp.Tool{
31 | {
32 | Name: "get_user",
33 | Description: "Get user details",
34 | InputSchema: mcp.Schema{
35 | Type: "object",
36 | Properties: map[string]mcp.Schema{
37 | "user_id": {Type: "string"},
38 | },
39 | Required: []string{"user_id"},
40 | },
41 | },
42 | {
43 | Name: "post_data",
44 | Description: "Post some data",
45 | InputSchema: mcp.Schema{
46 | Type: "object",
47 | Properties: map[string]mcp.Schema{
48 | "data": {Type: "string"},
49 | },
50 | Required: []string{"data"},
51 | },
52 | },
53 | },
54 | Operations: map[string]mcp.OperationDetail{
55 | "get_user": {
56 | Method: "GET",
57 | Path: "/users/{user_id}",
58 | Parameters: []mcp.ParameterDetail{
59 | {Name: "user_id", In: "path"},
60 | },
61 | },
62 | "post_data": {
63 | Method: "POST",
64 | Path: "/data",
65 | Parameters: []mcp.ParameterDetail{}, // Body params assumed
66 | },
67 | },
68 | }
69 | }
70 |
71 | // Helper to safely manage activeConnections for tests
72 | func setupTestConnection(connID string) chan jsonRPCResponse {
73 | msgChan := make(chan jsonRPCResponse, 1) // Buffer of 1 sufficient for most tests
74 | connMutex.Lock()
75 | activeConnections[connID] = msgChan
76 | connMutex.Unlock()
77 | return msgChan
78 | }
79 |
80 | func cleanupTestConnection(connID string) {
81 | connMutex.Lock()
82 | msgChan, exists := activeConnections[connID]
83 | if exists {
84 | delete(activeConnections, connID)
85 | close(msgChan)
86 | }
87 | connMutex.Unlock()
88 | }
89 |
90 | // --- End Re-added Helper Functions ---
91 |
92 | func TestHttpMethodPostHandler(t *testing.T) {
93 | // --- Setup common test items ---
94 | toolSet := createTestToolSetForCall() // Use the helper
95 | cfg := &config.Config{} // Basic config
96 | // NOTE: connID is now generated within each subtest to ensure isolation
97 |
98 | // --- Define Test Cases ---
99 | tests := []struct {
100 | name string
101 | requestBodyFn func(connID string) string // Function to generate body with dynamic connID
102 | expectedSyncStatus int // Expected status code for the immediate POST response
103 | expectedSyncBody string // Expected body for the immediate POST response
104 | checkAsyncResponse func(t *testing.T, resp jsonRPCResponse) // Function to check async response
105 | mockBackend http.HandlerFunc // Optional mock backend for tool calls
106 | setupChannelDirectly func(connID string) chan jsonRPCResponse // Optional: For specific channel setups
107 | }{
108 | {
109 | name: "Valid Initialize Request",
110 | requestBodyFn: func(connID string) string {
111 | return fmt.Sprintf(`{
112 | "jsonrpc": "2.0",
113 | "method": "initialize",
114 | "id": "init-post-1",
115 | "params": {"connectionId": "%s"}
116 | }`, connID)
117 | },
118 | expectedSyncStatus: http.StatusAccepted,
119 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
120 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
121 | assert.Equal(t, "init-post-1", resp.ID)
122 | assert.Nil(t, resp.Error)
123 | resultMap, ok := resp.Result.(map[string]interface{})
124 | require.True(t, ok)
125 | assert.Contains(t, resultMap, "connectionId") // Check existence, actual ID checked separately
126 | assert.Equal(t, "2024-11-05", resultMap["protocolVersion"])
127 | },
128 | },
129 | {
130 | name: "Valid Tools List Request",
131 | requestBodyFn: func(connID string) string {
132 | return `{
133 | "jsonrpc": "2.0",
134 | "method": "tools/list",
135 | "id": "list-post-1"
136 | }`
137 | },
138 | expectedSyncStatus: http.StatusAccepted,
139 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
140 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
141 | assert.Equal(t, "list-post-1", resp.ID)
142 | assert.Nil(t, resp.Error)
143 | resultMap, ok := resp.Result.(map[string]interface{})
144 | require.True(t, ok)
145 | assert.Contains(t, resultMap, "metadata")
146 | assert.Contains(t, resultMap, "tools")
147 | metadata, _ := resultMap["metadata"].(map[string]interface{})
148 | assert.Equal(t, 2, metadata["count"]) // Corrected: Expect int(2)
149 | },
150 | },
151 | {
152 | name: "Valid Tool Call Request (Success)",
153 | requestBodyFn: func(connID string) string {
154 | return `{
155 | "jsonrpc": "2.0",
156 | "method": "tools/call",
157 | "id": "call-post-1",
158 | "params": {"name": "get_user", "arguments": {"user_id": "postUser"}}
159 | }`
160 | },
161 | expectedSyncStatus: http.StatusAccepted,
162 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
163 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
164 | assert.Equal(t, "call-post-1", resp.ID)
165 | assert.Nil(t, resp.Error)
166 | resultPayload, ok := resp.Result.(ToolResultPayload)
167 | require.True(t, ok)
168 | assert.False(t, resultPayload.IsError)
169 | require.Len(t, resultPayload.Content, 1)
170 | assert.JSONEq(t, `{"id":"postUser"}`, resultPayload.Content[0].Text)
171 | },
172 | mockBackend: func(w http.ResponseWriter, r *http.Request) {
173 | w.WriteHeader(http.StatusOK)
174 | fmt.Fprintln(w, `{"id":"postUser"}`)
175 | },
176 | },
177 | {
178 | name: "Valid Tool Call Request (Tool Not Found)",
179 | requestBodyFn: func(connID string) string {
180 | return `{
181 | "jsonrpc": "2.0",
182 | "method": "tools/call",
183 | "id": "call-post-err-1",
184 | "params": {"name": "nonexistent_tool", "arguments": {}}
185 | }`
186 | },
187 | expectedSyncStatus: http.StatusAccepted,
188 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
189 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
190 | assert.Equal(t, "call-post-err-1", resp.ID)
191 | assert.Nil(t, resp.Error)
192 | resultPayload, ok := resp.Result.(ToolResultPayload)
193 | require.True(t, ok)
194 | assert.True(t, resultPayload.IsError)
195 | require.NotNil(t, resultPayload.Error)
196 | assert.Contains(t, resultPayload.Error.Message, "operation details for tool 'nonexistent_tool' not found")
197 | },
198 | },
199 | {
200 | name: "Malformed JSON Request",
201 | requestBodyFn: func(connID string) string {
202 | return `{"jsonrpc": "2.0", "method": "initialize"`
203 | },
204 | expectedSyncStatus: http.StatusAccepted, // Even decode errors return 202, error is sent async
205 | expectedSyncBody: "Request accepted (with decode error), response will be sent via SSE.\n",
206 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
207 | assert.Nil(t, resp.ID) // ID might be nil if request parsing failed early
208 | require.NotNil(t, resp.Error)
209 | assert.Equal(t, -32700, resp.Error.Code) // Parse Error
210 | assert.Equal(t, "Parse error decoding JSON request", resp.Error.Message) // Corrected assertion
211 | },
212 | },
213 | {
214 | name: "Missing JSON-RPC Version",
215 | requestBodyFn: func(connID string) string {
216 | return `{
217 | "method": "initialize",
218 | "id": "rpc-err-1"
219 | }`
220 | },
221 | expectedSyncStatus: http.StatusAccepted,
222 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
223 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
224 | assert.Equal(t, "rpc-err-1", resp.ID)
225 | require.NotNil(t, resp.Error)
226 | assert.Equal(t, -32600, resp.Error.Code) // Invalid Request
227 | assert.Contains(t, resp.Error.Message, "jsonrpc field must be \"2.0\"")
228 | },
229 | },
230 | {
231 | name: "Unknown Method",
232 | requestBodyFn: func(connID string) string {
233 | return `{
234 | "jsonrpc": "2.0",
235 | "method": "unknown/method",
236 | "id": "rpc-err-2"
237 | }`
238 | },
239 | expectedSyncStatus: http.StatusAccepted,
240 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
241 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
242 | assert.Equal(t, "rpc-err-2", resp.ID)
243 | require.NotNil(t, resp.Error)
244 | assert.Equal(t, -32601, resp.Error.Code) // Method not found
245 | assert.Contains(t, resp.Error.Message, "Method not found")
246 | },
247 | },
248 | {
249 | name: "Missing Method",
250 | requestBodyFn: func(connID string) string {
251 | return `{
252 | "jsonrpc": "2.0",
253 | "id": "rpc-err-3"
254 | }`
255 | },
256 | expectedSyncStatus: http.StatusAccepted,
257 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
258 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
259 | assert.Equal(t, "rpc-err-3", resp.ID)
260 | require.NotNil(t, resp.Error)
261 | assert.Equal(t, -32600, resp.Error.Code) // Invalid Request
262 | assert.Equal(t, "Invalid Request: method field is missing or empty", resp.Error.Message) // Corrected assertion
263 | },
264 | },
265 | {
266 | name: "Error Queuing Response To SSE",
267 | requestBodyFn: func(connID string) string { // Use a simple valid request like tools/list
268 | return `{
269 | "jsonrpc": "2.0",
270 | "method": "tools/list",
271 | "id": "list-post-err-queue"
272 | }`
273 | },
274 | expectedSyncStatus: http.StatusInternalServerError, // Expect 500 when channel is blocked
275 | expectedSyncBody: "Failed to queue response for SSE channel\n", // Specific error message expected
276 | setupChannelDirectly: func(connID string) chan jsonRPCResponse {
277 | // Create a NON-BUFFERED channel to simulate blocking/full channel
278 | msgChan := make(chan jsonRPCResponse) // No buffer size!
279 | connMutex.Lock()
280 | activeConnections[connID] = msgChan
281 | connMutex.Unlock()
282 | // Important: Do NOT start a reader for this channel
283 | return msgChan
284 | },
285 | checkAsyncResponse: nil, // No async response should be successfully sent
286 | },
287 | }
288 |
289 | // --- Run Test Cases ---
290 | for _, tc := range tests {
291 | t.Run(tc.name, func(t *testing.T) {
292 | connID := uuid.NewString() // Generate unique connID for each subtest
293 |
294 | // Setup mock backend if needed for this test case
295 | var backendServer *httptest.Server
296 | // --- Add Connection ID before test ---
297 | var msgChan chan jsonRPCResponse
298 | if tc.setupChannelDirectly != nil {
299 | // Use custom setup if provided (e.g., for blocking channel test)
300 | msgChan = tc.setupChannelDirectly(connID)
301 | } else {
302 | // Default setup using the helper with buffered channel
303 | msgChan = setupTestConnection(connID)
304 | }
305 | defer cleanupTestConnection(connID) // Ensure cleanup after test
306 |
307 | if tc.mockBackend != nil {
308 | backendServer = httptest.NewServer(tc.mockBackend)
309 | defer backendServer.Close()
310 | // IMPORTANT: Update the toolset's BaseURL for the relevant operation
311 | if strings.Contains(tc.requestBodyFn(connID), "get_user") { // Simple check based on request
312 | op := toolSet.Operations["get_user"]
313 | op.BaseURL = backendServer.URL
314 | toolSet.Operations["get_user"] = op
315 | }
316 | // Update post_data BaseURL if needed
317 | if strings.Contains(tc.requestBodyFn(connID), "post_data") {
318 | op := toolSet.Operations["post_data"]
319 | op.BaseURL = backendServer.URL
320 | toolSet.Operations["post_data"] = op
321 | }
322 | }
323 |
324 | reqBody := tc.requestBodyFn(connID) // Generate request body
325 | req := httptest.NewRequest(http.MethodPost, "/mcp", strings.NewReader(reqBody))
326 | req.Header.Set("Content-Type", "application/json")
327 | req.Header.Set("X-Connection-ID", connID) // Use the generated connID
328 | rr := httptest.NewRecorder()
329 |
330 | httpMethodPostHandler(rr, req, toolSet, cfg)
331 |
332 | // 1. Check synchronous response
333 | assert.Equal(t, tc.expectedSyncStatus, rr.Code, "Unexpected status code for sync response")
334 | // Trim space for comparison as http.Error might add a newline our literal doesn't have
335 | assert.Equal(t, strings.TrimSpace(tc.expectedSyncBody), strings.TrimSpace(rr.Body.String()), "Unexpected body for sync response")
336 |
337 | // 2. Check asynchronous response (sent via SSE channel)
338 | if tc.checkAsyncResponse != nil {
339 | select {
340 | case asyncResp := <-msgChan:
341 | tc.checkAsyncResponse(t, asyncResp)
342 | case <-time.After(100 * time.Millisecond): // Add a timeout
343 | t.Fatal("Timeout waiting for async response on SSE channel")
344 | }
345 | } else {
346 | // If no async check is defined, ensure nothing was sent (e.g., for queue error test)
347 | select {
348 | case unexpectedResp, ok := <-msgChan:
349 | if ok { // Only fail if the channel wasn't closed AND we got a message
350 | t.Errorf("Received unexpected async response when none was expected: %+v", unexpectedResp)
351 | }
352 | // If !ok, channel was closed, which is fine/expected after cleanup
353 | case <-time.After(50 * time.Millisecond):
354 | // Success - no message received quickly, channel likely blocked as expected
355 | }
356 | }
357 | })
358 | }
359 | }
360 |
361 | func TestHttpMethodGetHandler(t *testing.T) {
362 | // --- Setup ---
363 | // Reset global state for this test
364 | connMutex.Lock()
365 | originalConnections := activeConnections
366 | activeConnections = make(map[string]chan jsonRPCResponse)
367 | connMutex.Unlock()
368 |
369 | req, err := http.NewRequest("GET", "/mcp", nil)
370 | require.NoError(t, err, "Failed to create request")
371 |
372 | rr := httptest.NewRecorder()
373 |
374 | // Ensure cleanup happens regardless of test outcome
375 | defer func() {
376 | connMutex.Lock()
377 | // Clean up any connections potentially left by the test
378 | for id, ch := range activeConnections {
379 | close(ch)
380 | delete(activeConnections, id)
381 | log.Printf("[DEFER Cleanup] Closed channel and removed connection %s", id)
382 | }
383 | activeConnections = originalConnections // Restore the original map
384 | connMutex.Unlock()
385 | }()
386 |
387 | // --- Execute Handler (in a goroutine as it blocks waiting for context) ---
388 | ctx, cancel := context.WithCancel(context.Background())
389 | req = req.WithContext(ctx)
390 |
391 | hwg := sync.WaitGroup{}
392 | hwg.Add(1)
393 | go func() {
394 | defer hwg.Done()
395 | // Simulate some work before handler returns
396 | // In a real scenario, this would block on ctx.Done() or keepAliveTicker
397 | // For the test, we just call cancel() after a short delay
398 | // to simulate the connection ending gracefully.
399 | time.AfterFunc(100*time.Millisecond, cancel) // Allow handler to start and write initial data
400 | httpMethodGetHandler(rr, req)
401 | }()
402 |
403 | // Wait for the handler goroutine to finish.
404 | // This ensures all writes to rr are complete before we read.
405 | if !waitTimeout(&hwg, 2*time.Second) { // Use a reasonable timeout
406 | t.Fatal("Handler goroutine did not exit cleanly after context cancellation")
407 | }
408 |
409 | // --- Assertions (Performed *after* handler completion) ---
410 | assert.Equal(t, http.StatusOK, rr.Code, "Status code should be OK")
411 |
412 | // Check headers are set correctly
413 | assert.Equal(t, "text/event-stream", rr.Header().Get("Content-Type"))
414 | assert.Equal(t, "no-cache", rr.Header().Get("Cache-Control"))
415 | assert.Equal(t, "keep-alive", rr.Header().Get("Connection"))
416 | connID := rr.Header().Get("X-Connection-ID")
417 | assert.NotEmpty(t, connID, "X-Connection-ID header should be set")
418 |
419 | // Check connection was registered and then cleaned up
420 | connMutex.RLock()
421 | _, exists := originalConnections[connID] // Check original map after cleanup
422 | connMutex.RUnlock()
423 | assert.False(t, exists, "Connection ID should be removed from map after handler exits")
424 |
425 | // Check initial body content is present
426 | bodyContent := rr.Body.String()
427 | assert.Contains(t, bodyContent, ":ok\n\n", "Body should contain :ok preamble")
428 | // Construct the expected endpoint data string accurately
429 | expectedEndpointData := "data: /mcp?sessionId=" + connID + "\n\n"
430 | assert.Contains(t, bodyContent, "event: endpoint\n"+expectedEndpointData, "Body should contain endpoint event")
431 | assert.Contains(t, bodyContent, "event: message\ndata: {", "Body should contain start of a message event (e.g., mcp-ready)")
432 | // Check if connectionId is present in the ready message (adjust based on actual JSON structure)
433 | assert.Contains(t, bodyContent, `"connectionId":"`+connID+`"`, "Body should contain mcp-ready event with correct connection ID")
434 |
435 | // The explicit cleanupTestConnection call is not needed because the handler's defer and the test's defer handle it.
436 | }
437 |
438 | func TestExecuteToolCall(t *testing.T) {
439 | tests := []struct {
440 | name string
441 | params ToolCallParams
442 | opDetail mcp.OperationDetail
443 | cfg *config.Config
444 | expectError bool
445 | containsError string
446 | requestAsserter func(t *testing.T, r *http.Request) // Function to assert details of the received HTTP request
447 | backendResponse string // Response body from mock backend
448 | backendStatusCode int // Status code from mock backend
449 | }{
450 | // --- Basic GET with Path Param ---
451 | {
452 | name: "GET with path parameter",
453 | params: ToolCallParams{
454 | ToolName: "get_item",
455 | Input: map[string]interface{}{"item_id": "item123"},
456 | },
457 | opDetail: mcp.OperationDetail{
458 | Method: "GET",
459 | Path: "/items/{item_id}",
460 | Parameters: []mcp.ParameterDetail{{Name: "item_id", In: "path"}},
461 | },
462 | cfg: &config.Config{},
463 | expectError: false,
464 | backendStatusCode: http.StatusOK,
465 | backendResponse: `{"status":"ok"}`,
466 | requestAsserter: func(t *testing.T, r *http.Request) {
467 | assert.Equal(t, http.MethodGet, r.Method)
468 | assert.Equal(t, "/items/item123", r.URL.Path)
469 | assert.Empty(t, r.URL.RawQuery)
470 | },
471 | },
472 | // --- POST with Query, Header, Cookie, and Body Params ---
473 | {
474 | name: "POST with various params",
475 | params: ToolCallParams{
476 | ToolName: "create_resource",
477 | Input: map[string]interface{}{
478 | "queryArg": "value1",
479 | "X-Custom-Hdr": "headerValue",
480 | "sessionToken": "cookieValue",
481 | "bodyFieldA": "A",
482 | "bodyFieldB": 123,
483 | },
484 | },
485 | opDetail: mcp.OperationDetail{
486 | Method: "POST",
487 | Path: "/resources",
488 | Parameters: []mcp.ParameterDetail{
489 | {Name: "queryArg", In: "query"},
490 | {Name: "X-Custom-Hdr", In: "header"},
491 | {Name: "sessionToken", In: "cookie"},
492 | // Body fields are implicitly handled
493 | },
494 | },
495 | cfg: &config.Config{},
496 | expectError: false,
497 | backendStatusCode: http.StatusCreated,
498 | backendResponse: `{"id":"res456"}`,
499 | requestAsserter: func(t *testing.T, r *http.Request) {
500 | assert.Equal(t, http.MethodPost, r.Method)
501 | assert.Equal(t, "/resources", r.URL.Path)
502 | assert.Equal(t, "value1", r.URL.Query().Get("queryArg"))
503 | assert.Equal(t, "headerValue", r.Header.Get("X-Custom-Hdr"))
504 | cookie, err := r.Cookie("sessionToken")
505 | require.NoError(t, err)
506 | assert.Equal(t, "cookieValue", cookie.Value)
507 | bodyBytes, _ := io.ReadAll(r.Body)
508 | assert.JSONEq(t, `{"bodyFieldA":"A", "bodyFieldB":123}`, string(bodyBytes))
509 | },
510 | },
511 | // --- API Key Injection (Header) ---
512 | {
513 | name: "API Key Injection (Header)",
514 | params: ToolCallParams{
515 | ToolName: "get_secure",
516 | Input: map[string]interface{}{}, // No client key provided
517 | },
518 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure"},
519 | cfg: &config.Config{
520 | APIKey: "secret-server-key",
521 | APIKeyName: "Authorization",
522 | APIKeyLocation: config.APIKeyLocationHeader,
523 | },
524 | expectError: false,
525 | backendStatusCode: http.StatusOK,
526 | requestAsserter: func(t *testing.T, r *http.Request) {
527 | assert.Equal(t, "secret-server-key", r.Header.Get("Authorization"))
528 | },
529 | },
530 | // --- API Key Injection (Query) ---
531 | {
532 | name: "API Key Injection (Query)",
533 | params: ToolCallParams{
534 | ToolName: "get_secure",
535 | Input: map[string]interface{}{"otherParam": "abc"},
536 | },
537 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure", Parameters: []mcp.ParameterDetail{{Name: "otherParam", In: "query"}}},
538 | cfg: &config.Config{
539 | APIKey: "secret-server-key-q",
540 | APIKeyName: "api_key",
541 | APIKeyLocation: config.APIKeyLocationQuery,
542 | },
543 | expectError: false,
544 | backendStatusCode: http.StatusOK,
545 | requestAsserter: func(t *testing.T, r *http.Request) {
546 | assert.Equal(t, "secret-server-key-q", r.URL.Query().Get("api_key"))
547 | assert.Equal(t, "abc", r.URL.Query().Get("otherParam")) // Ensure other params are preserved
548 | },
549 | },
550 | // --- API Key Injection (Path) ---
551 | {
552 | name: "API Key Injection (Path)",
553 | params: ToolCallParams{
554 | ToolName: "get_secure_path",
555 | Input: map[string]interface{}{}, // Key comes from config
556 | },
557 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure/{apiKey}/data"},
558 | cfg: &config.Config{
559 | APIKey: "path-key-123",
560 | APIKeyName: "apiKey", // Matches the placeholder name
561 | APIKeyLocation: config.APIKeyLocationPath,
562 | },
563 | expectError: false,
564 | backendStatusCode: http.StatusOK,
565 | requestAsserter: func(t *testing.T, r *http.Request) {
566 | assert.Equal(t, "/secure/path-key-123/data", r.URL.Path)
567 | },
568 | },
569 | // --- API Key Injection (Cookie) ---
570 | {
571 | name: "API Key Injection (Cookie)",
572 | params: ToolCallParams{
573 | ToolName: "get_secure_cookie",
574 | Input: map[string]interface{}{}, // Key comes from config
575 | },
576 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure_cookie"},
577 | cfg: &config.Config{
578 | APIKey: "cookie-key-abc",
579 | APIKeyName: "AuthToken",
580 | APIKeyLocation: config.APIKeyLocationCookie,
581 | },
582 | expectError: false,
583 | backendStatusCode: http.StatusOK,
584 | requestAsserter: func(t *testing.T, r *http.Request) {
585 | cookie, err := r.Cookie("AuthToken")
586 | require.NoError(t, err)
587 | assert.Equal(t, "cookie-key-abc", cookie.Value)
588 | },
589 | },
590 | // --- Base URL Handling Tests ---
591 | {
592 | name: "Base URL from Default (Mock Server)",
593 | params: ToolCallParams{ToolName: "get_default_url", Input: map[string]interface{}{}},
594 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/path1"}, // No BaseURL here
595 | cfg: &config.Config{}, // No global override
596 | expectError: false,
597 | backendStatusCode: http.StatusOK,
598 | requestAsserter: func(t *testing.T, r *http.Request) {
599 | // Should hit the mock server at the correct path
600 | assert.Equal(t, "/path1", r.URL.Path)
601 | },
602 | },
603 | {
604 | name: "Base URL from Global Config Override",
605 | params: ToolCallParams{ToolName: "get_global_url", Input: map[string]interface{}{}},
606 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/path2", BaseURL: "http://should-be-ignored.com"},
607 | // cfg will be updated in test loop to point ServerBaseURL to mock server
608 | cfg: &config.Config{},
609 | expectError: false,
610 | backendStatusCode: http.StatusOK,
611 | requestAsserter: func(t *testing.T, r *http.Request) {
612 | // Should hit the mock server (set via cfg override) at the correct path
613 | assert.Equal(t, "/path2", r.URL.Path)
614 | },
615 | },
616 | // --- Error Case (Tool Not Found in ToolSet) ---
617 | {
618 | name: "Error - Tool Not Found",
619 | params: ToolCallParams{
620 | ToolName: "nonexistent",
621 | Input: map[string]interface{}{},
622 | },
623 | opDetail: mcp.OperationDetail{}, // Not used, error occurs before this
624 | cfg: &config.Config{},
625 | expectError: true,
626 | containsError: "operation details for tool 'nonexistent' not found",
627 | requestAsserter: nil, // No request should be made
628 | backendStatusCode: 0, // Not applicable
629 | },
630 | }
631 |
632 | for _, tc := range tests {
633 | t.Run(tc.name, func(t *testing.T) {
634 | // --- Mock Backend Setup ---
635 | var backendServer *httptest.Server
636 | backendServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
637 | if tc.requestAsserter != nil {
638 | tc.requestAsserter(t, r)
639 | }
640 | w.WriteHeader(tc.backendStatusCode)
641 | fmt.Fprint(w, tc.backendResponse)
642 | }))
643 | defer backendServer.Close()
644 |
645 | // --- Prepare ToolSet (using mock server URL if needed) ---
646 | toolSet := &mcp.ToolSet{
647 | Operations: make(map[string]mcp.OperationDetail),
648 | }
649 |
650 | // Clone config to avoid modifying the template test case config
651 | testCfg := *tc.cfg
652 |
653 | // Special handling for the global override test case
654 | if tc.name == "Base URL from Global Config Override" {
655 | testCfg.ServerBaseURL = backendServer.URL // Point global override to mock server
656 | }
657 |
658 | // If the opDetail needs a BaseURL, set it to the mock server ONLY if it wasn't
659 | // already set in the test case definition AND the global override isn't being used.
660 | if tc.opDetail.Method != "" { // Only add if it's a valid detail for the test
661 | if tc.opDetail.BaseURL == "" && testCfg.ServerBaseURL == "" {
662 | tc.opDetail.BaseURL = backendServer.URL
663 | }
664 | toolSet.Operations[tc.params.ToolName] = tc.opDetail
665 | }
666 |
667 | // --- Execute Function ---
668 | httpResp, err := executeToolCall(&tc.params, toolSet, &testCfg) // Use the potentially modified testCfg
669 |
670 | // --- Assertions ---
671 | if tc.expectError {
672 | assert.Error(t, err)
673 | if tc.containsError != "" {
674 | assert.Contains(t, err.Error(), tc.containsError)
675 | }
676 | assert.Nil(t, httpResp)
677 | } else {
678 | assert.NoError(t, err)
679 | require.NotNil(t, httpResp)
680 | defer httpResp.Body.Close()
681 | assert.Equal(t, tc.backendStatusCode, httpResp.StatusCode)
682 | bodyBytes, _ := io.ReadAll(httpResp.Body)
683 | assert.Equal(t, tc.backendResponse, string(bodyBytes))
684 | }
685 | })
686 | }
687 | }
688 |
689 | func TestWriteSSEEvent(t *testing.T) {
690 | tests := []struct {
691 | name string
692 | eventName string
693 | data interface{}
694 | expectedOut string
695 | expectError bool
696 | }{
697 | {
698 | name: "Simple String Data",
699 | eventName: "endpoint",
700 | data: "/mcp?sessionId=123",
701 | expectedOut: "event: endpoint\ndata: /mcp?sessionId=123\n\n",
702 | expectError: false,
703 | },
704 | {
705 | name: "Struct Data (JSON-RPC Request)",
706 | eventName: "message",
707 | data: jsonRPCRequest{
708 | Jsonrpc: "2.0",
709 | Method: "mcp-ready",
710 | Params: map[string]interface{}{"connectionId": "abc"},
711 | },
712 | // Note: JSON marshaling order isn't guaranteed, so use JSONEq or check fields
713 | expectedOut: "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"mcp-ready\",\"params\":{\"connectionId\":\"abc\"}}\n\n",
714 | expectError: false,
715 | },
716 | {
717 | name: "Struct Data (JSON-RPC Response)",
718 | eventName: "message",
719 | data: jsonRPCResponse{
720 | Jsonrpc: "2.0",
721 | Result: map[string]interface{}{"status": "ok"},
722 | ID: "req-1",
723 | },
724 | expectedOut: "event: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{\"status\":\"ok\"},\"id\":\"req-1\"}\n\n",
725 | expectError: false,
726 | },
727 | {
728 | name: "Error - Unmarshalable Data",
729 | eventName: "error",
730 | data: make(chan int), // Channels cannot be marshaled to JSON
731 | expectError: true,
732 | },
733 | }
734 |
735 | for _, tc := range tests {
736 | t.Run(tc.name, func(t *testing.T) {
737 | rr := httptest.NewRecorder()
738 | err := writeSSEEvent(rr, tc.eventName, tc.data)
739 |
740 | if tc.expectError {
741 | assert.Error(t, err)
742 | } else {
743 | assert.NoError(t, err)
744 | // For struct data, use JSONEq for robust comparison
745 | if _, isStruct := tc.data.(jsonRPCRequest); isStruct {
746 | prefix := fmt.Sprintf("event: %s\ndata: ", tc.eventName)
747 | suffix := "\n\n"
748 | require.True(t, strings.HasPrefix(rr.Body.String(), prefix))
749 | require.True(t, strings.HasSuffix(rr.Body.String(), suffix))
750 | actualJSON := strings.TrimSuffix(strings.TrimPrefix(rr.Body.String(), prefix), suffix)
751 | expectedJSONBytes, _ := json.Marshal(tc.data)
752 | assert.JSONEq(t, string(expectedJSONBytes), actualJSON)
753 | } else if _, isStruct := tc.data.(jsonRPCResponse); isStruct {
754 | prefix := fmt.Sprintf("event: %s\ndata: ", tc.eventName)
755 | suffix := "\n\n"
756 | require.True(t, strings.HasPrefix(rr.Body.String(), prefix))
757 | require.True(t, strings.HasSuffix(rr.Body.String(), suffix))
758 | actualJSON := strings.TrimSuffix(strings.TrimPrefix(rr.Body.String(), prefix), suffix)
759 | expectedJSONBytes, _ := json.Marshal(tc.data)
760 | assert.JSONEq(t, string(expectedJSONBytes), actualJSON)
761 | } else {
762 | // For simple types, direct string comparison is fine
763 | assert.Equal(t, tc.expectedOut, rr.Body.String())
764 | }
765 | }
766 | })
767 | }
768 | }
769 |
770 | func TestTryWriteHTTPError(t *testing.T) {
771 | rr := httptest.NewRecorder()
772 | message := "Test Error Message"
773 | code := http.StatusInternalServerError
774 |
775 | tryWriteHTTPError(rr, code, message)
776 |
777 | // Note: tryWriteHTTPError doesn't set the status code, it only writes the body.
778 | // The calling function is expected to have set the code earlier.
779 | // So, we only check the body content here.
780 | assert.Equal(t, message, rr.Body.String())
781 | }
782 |
783 | func TestGetMethodFromResponse(t *testing.T) {
784 | tests := []struct {
785 | name string
786 | response jsonRPCResponse
787 | expected string
788 | }{
789 | {
790 | name: "Error Response",
791 | response: jsonRPCResponse{
792 | Error: &jsonError{Code: -32600, Message: "..."},
793 | },
794 | expected: "error",
795 | },
796 | {
797 | name: "Tool List Response",
798 | response: jsonRPCResponse{
799 | Result: map[string]interface{}{"tools": []interface{}{}, "metadata": map[string]interface{}{}},
800 | },
801 | expected: "tool_set",
802 | },
803 | {
804 | name: "Initialize Response (Result is Map)",
805 | response: jsonRPCResponse{
806 | Result: map[string]interface{}{"protocolVersion": "...", "capabilities": map[string]interface{}{}},
807 | },
808 | expected: "success", // Falls back to 'success' as type isn't explicitly set
809 | },
810 | {
811 | name: "Tool Call Response (Result is ToolResultPayload)",
812 | response: jsonRPCResponse{
813 | Result: ToolResultPayload{Content: []ToolResultContent{{Type: "text", Text: "..."}}},
814 | },
815 | expected: "success", // Falls back to 'success'
816 | },
817 | {
818 | name: "Empty Response",
819 | response: jsonRPCResponse{},
820 | expected: "unknown",
821 | },
822 | }
823 |
824 | for _, tc := range tests {
825 | t.Run(tc.name, func(t *testing.T) {
826 | actual := getMethodFromResponse(tc.response)
827 | assert.Equal(t, tc.expected, actual)
828 | })
829 | }
830 | }
831 |
832 | // --- Mock ResponseWriter for error simulation ---
833 |
834 | // mockResponseWriter implements http.ResponseWriter and http.Flusher for testing SSE.
835 | type sseMockResponseWriter struct {
836 | hdr http.Header // Internal map for headers
837 | statusCode int
838 | body *bytes.Buffer
839 | flushed bool
840 | forceError error // If set, Write and Flush will return this error
841 | failAfterNWrites int // Start failing after this many writes (-1 = disable)
842 | writesMade int // Counter for writes made
843 | }
844 |
845 | // Renamed constructor
846 | func newSseMockResponseWriter() *sseMockResponseWriter {
847 | return &sseMockResponseWriter{
848 | hdr: make(http.Header), // Initialize internal map
849 | body: &bytes.Buffer{},
850 | failAfterNWrites: -1, // Default to disabled
851 | }
852 | }
853 |
854 | // Implement http.ResponseWriter interface
855 | func (m *sseMockResponseWriter) Header() http.Header {
856 | return m.hdr // Return the internal map
857 | }
858 |
859 | func (m *sseMockResponseWriter) WriteHeader(statusCode int) {
860 | m.statusCode = statusCode
861 | }
862 |
863 | func (m *sseMockResponseWriter) Write(p []byte) (int, error) {
864 | // Check if already forced error
865 | if m.forceError != nil {
866 | return 0, m.forceError
867 | }
868 |
869 | // Increment write count
870 | m.writesMade++
871 |
872 | // Check if write count triggers failure
873 | if m.failAfterNWrites >= 0 && m.writesMade >= m.failAfterNWrites {
874 | m.forceError = fmt.Errorf("forced write error after %d writes", m.failAfterNWrites)
875 | log.Printf("DEBUG: sseMockResponseWriter triggering error: %v", m.forceError) // Debug log
876 | return 0, m.forceError
877 | }
878 |
879 | // Proceed with normal write
880 | return m.body.Write(p)
881 | }
882 |
883 | // Implement http.Flusher interface
884 | func (m *sseMockResponseWriter) Flush() {
885 | // Check if already forced error
886 | if m.forceError != nil {
887 | // Optional: log or handle repeated flush attempts after error
888 | return
889 | }
890 |
891 | // Check if flush count triggers failure (less common to fail on flush, but possible)
892 | // We are primarily testing Write failures, so we might skip count check here for simplicity
893 | // or use a separate failAfterNFlushes counter if needed.
894 |
895 | m.flushed = true
896 | }
897 |
898 | // Helper to get body content
899 | func (m *sseMockResponseWriter) String() string {
900 | return m.body.String()
901 | }
902 |
903 | // --- End Mock ResponseWriter ---
904 |
905 | func TestHttpMethodGetHandler_WriteErrors(t *testing.T) {
906 | tests := []struct {
907 | name string
908 | errorOnStage string // "preamble", "endpoint", "ready", "ping", "message"
909 | forceError error // Error to set on the mock writer *before* handler runs
910 | expectConnRemoved bool
911 | }{
912 | {"Error on Preamble (:ok)", "preamble", fmt.Errorf("forced write error during preamble"), true},
913 | // Removed: {"Error on Endpoint Event", "endpoint", nil, true}, // Hard to simulate reliably without patching
914 | // Removed: {"Error on MCP-Ready Event", "ready", nil, true}, // Hard to simulate reliably without patching
915 | // TODO: Add test for error during keep-alive ping
916 | // TODO: Add test for error during message write from channel
917 | }
918 |
919 | for _, tc := range tests {
920 | t.Run(tc.name, func(t *testing.T) {
921 | // Use renamed mock writer
922 | mockWriter := newSseMockResponseWriter()
923 | req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
924 | var connID string // Variable to capture assigned ID
925 |
926 | // Set the error on the writer *before* calling the handler
927 | if tc.forceError != nil {
928 | mockWriter.forceError = tc.forceError
929 | }
930 |
931 | // Need to capture connID *if* headers get written before error
932 | // We can check mockWriter.Header() after the handler potentially runs
933 |
934 | // Inject error based on the test stage - REMOVED FUNCTION PATCHING
935 | /*
936 | originalWriteSSE := writeSSEEvent
937 | defer func() { writeSSEEvent = originalWriteSSE }() // Restore original
938 |
939 | writeSSEEvent = func(w http.ResponseWriter, eventName string, data interface{}) error {
940 | // ... removed patching logic ...
941 | }
942 | */
943 |
944 | // Execute handler in goroutine as it might block briefly before erroring
945 | done := make(chan struct{})
946 | go func() {
947 | defer close(done)
948 | httpMethodGetHandler(mockWriter, req)
949 | }()
950 |
951 | // Wait for the handler goroutine to finish or timeout
952 | select {
953 | case <-done:
954 | // Handler finished (presumably due to error)
955 | case <-time.After(200 * time.Millisecond): // Generous timeout
956 | t.Fatal("Timeout waiting for httpMethodGetHandler goroutine to exit after injected error")
957 | }
958 |
959 | // Capture ConnID *after* handler exit, in case headers were set before error
960 | connID = mockWriter.Header().Get("X-Connection-ID")
961 |
962 | // Assert connection removal
963 | if tc.expectConnRemoved && connID != "" {
964 | connMutex.RLock()
965 | _, exists := activeConnections[connID]
966 | connMutex.RUnlock()
967 | assert.False(t, exists, "Connection %s should have been removed from activeConnections after write error", connID)
968 | } else if tc.expectConnRemoved && connID == "" {
969 | t.Log("Cannot assert connection removal as ConnID was not captured before error")
970 | }
971 | })
972 | }
973 | }
974 |
975 | func TestHttpMethodGetHandler_GoroutineErrors(t *testing.T) {
976 | t.Run("Error_on_Message_Write", func(t *testing.T) {
977 | // Estimate writes before first message: :ok(1), endpoint(1), ready(1) = 3 writes
978 | // Target failure on the 4th write (first write of the actual message event line)
979 | mockWriter := newSseMockResponseWriter()
980 | mockWriter.failAfterNWrites = 4 // Fail on the 4th write overall
981 |
982 | req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
983 | var connID string
984 | var msgChan chan jsonRPCResponse
985 |
986 | // Clean connections before test
987 | connMutex.Lock()
988 | activeConnections = make(map[string]chan jsonRPCResponse)
989 | connMutex.Unlock()
990 | defer func() {
991 | // Clean up after test, ensure channel is closed if exists
992 | connMutex.Lock()
993 | if msgChan != nil {
994 | // Only delete from map, handler is responsible for closing channel
995 | delete(activeConnections, connID)
996 | }
997 | activeConnections = make(map[string]chan jsonRPCResponse) // Reset for other tests
998 | connMutex.Unlock()
999 | }()
1000 |
1001 | done := make(chan struct{})
1002 | go func() {
1003 | defer close(done)
1004 | httpMethodGetHandler(mockWriter, req)
1005 | log.Println("DEBUG: httpMethodGetHandler goroutine exited")
1006 | }()
1007 |
1008 | // Wait for the connection to be established
1009 | assert.Eventually(t, func() bool {
1010 | connMutex.RLock()
1011 | defer connMutex.RUnlock()
1012 | for id, ch := range activeConnections {
1013 | connID = id
1014 | msgChan = ch
1015 | log.Printf("DEBUG: Connection established: %s", connID)
1016 | return true
1017 | }
1018 | return false
1019 | }, 200*time.Millisecond, 20*time.Millisecond, "Connection not established in time")
1020 |
1021 | require.NotEmpty(t, connID, "connID should have been captured")
1022 | require.NotNil(t, msgChan, "msgChan should have been captured")
1023 |
1024 | // Send a message that should trigger the write error
1025 | testResp := jsonRPCResponse{Jsonrpc: "2.0", ID: "test-msg-1", Result: "test data"}
1026 | log.Printf("DEBUG: Sending test message to channel for %s", connID)
1027 | select {
1028 | case msgChan <- testResp:
1029 | log.Printf("DEBUG: Test message sent to channel for %s", connID)
1030 | case <-time.After(100 * time.Millisecond):
1031 | t.Fatal("Timeout sending message to channel")
1032 | }
1033 |
1034 | // Wait for the handler goroutine to finish due to the write error
1035 | select {
1036 | case <-done:
1037 | log.Printf("DEBUG: Handler goroutine finished as expected after message write error")
1038 | // Handler finished (presumably due to write error)
1039 | case <-time.After(1000 * time.Millisecond): // Increased timeout to 1 second
1040 | t.Fatal("Timeout waiting for httpMethodGetHandler goroutine to exit after message write error")
1041 | }
1042 |
1043 | // Assert connection removal
1044 | connMutex.RLock()
1045 | _, exists := activeConnections[connID]
1046 | connMutex.RUnlock()
1047 | assert.False(t, exists, "Connection %s should have been removed after message write error", connID)
1048 | })
1049 |
1050 | // TODO: Add sub-test for Error_on_Ping_Write
1051 | }
1052 |
1053 | // Helper function to wait for a WaitGroup with a timeout
1054 | func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
1055 | c := make(chan struct{})
1056 | go func() {
1057 | defer close(c)
1058 | wg.Wait()
1059 | }()
1060 | select {
1061 | case <-c:
1062 | return true // Completed normally
1063 | case <-time.After(timeout):
1064 | return false // Timed out
1065 | }
1066 | }
1067 |
```