#
tokens: 25263/50000 2/19 files (page 2/3)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 2 of 3. Use http://codebase.md/ckanthony/openapi-mcp?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .dockerignore
├── .github
│   └── workflows
│       ├── ci.yml
│       └── publish.yml
├── .gitignore
├── cmd
│   └── openapi-mcp
│       └── main.go
├── Dockerfile
├── example
│   ├── agent_demo.png
│   ├── docker-compose.yml
│   └── weather
│       ├── .env.example
│       └── weatherbitio-swagger.json
├── go.mod
├── go.sum
├── openapi-mcp.png
├── pkg
│   ├── config
│   │   ├── config_test.go
│   │   └── config.go
│   ├── mcp
│   │   └── types.go
│   ├── parser
│   │   ├── parser_test.go
│   │   └── parser.go
│   └── server
│       ├── manager_test.go
│       ├── manager.go
│       ├── server_test.go
│       └── server.go
└── README.md
```

# Files

--------------------------------------------------------------------------------
/pkg/server/server.go:
--------------------------------------------------------------------------------

```go
  1 | package server
  2 | 
  3 | import (
  4 | 	"bytes"
  5 | 	"context"
  6 | 	"encoding/json"
  7 | 	"fmt"
  8 | 	"io"
  9 | 	"log"
 10 | 	"net/http"
 11 | 	"net/url"
 12 | 	"strings"
 13 | 	"sync"
 14 | 	"time"
 15 | 
 16 | 	// "fmt" // No longer needed here
 17 | 	// "sync" // No longer needed here
 18 | 
 19 | 	"github.com/ckanthony/openapi-mcp/pkg/config"
 20 | 	"github.com/ckanthony/openapi-mcp/pkg/mcp"
 21 | 	"github.com/google/uuid" // Import UUID package
 22 | )
 23 | 
 24 | // --- JSON-RPC Structures (Re-introduced for Handshake/Messages) ---
 25 | 
 26 | type jsonRPCRequest struct {
 27 | 	Jsonrpc string      `json:"jsonrpc"`
 28 | 	Method  string      `json:"method"`
 29 | 	Params  interface{} `json:"params,omitempty"`
 30 | 	ID      interface{} `json:"id,omitempty"` // Can be string, number, or null
 31 | }
 32 | 
 33 | type jsonRPCResponse struct {
 34 | 	Jsonrpc string      `json:"jsonrpc"`
 35 | 	Result  interface{} `json:"result,omitempty"`
 36 | 	Error   *jsonError  `json:"error,omitempty"`
 37 | 	ID      interface{} `json:"id"` // ID should match the request ID
 38 | }
 39 | 
 40 | type jsonError struct {
 41 | 	Code    int         `json:"code"`
 42 | 	Message string      `json:"message"`
 43 | 	Data    interface{} `json:"data,omitempty"`
 44 | }
 45 | 
 46 | // --- MCP Message Structures (Kept for clarity on expected payloads) ---
 47 | 
 48 | // MCPMessage represents a generic message exchanged over the transport.
 49 | // Note: Adapt this structure based on the exact MCP spec requirements if needed.
 50 | // This structure is now more for understanding the *payloads* within JSON-RPC.
 51 | type MCPMessage struct {
 52 | 	Type    string          `json:"type"`                   // e.g., "initialize", "tools/list", "tools/call", "tool_result", "error"
 53 | 	ID      string          `json:"id,omitempty"`           // Unique message ID (less relevant for JSON-RPC wrapper)
 54 | 	Payload json.RawMessage `json:"payload,omitempty"`      // Content specific to the message type
 55 | 	ConnID  string          `json:"connectionId,omitempty"` // Included in responses related to a connection
 56 | }
 57 | 
 58 | // MCPError defines a structured error for MCP responses.
 59 | // This will be used within the 'Error.Data' field of a jsonRPCResponse.
 60 | type MCPError struct {
 61 | 	Code    int         `json:"code,omitempty"` // Optional error code
 62 | 	Message string      `json:"message"`
 63 | 	Data    interface{} `json:"data,omitempty"` // Optional additional data
 64 | }
 65 | 
 66 | // ToolCallParams represents the expected payload for a tools/call request.
 67 | // This will be the structure within the 'params' field of a jsonRPCRequest.
 68 | type ToolCallParams struct {
 69 | 	ToolName string                 `json:"name"`      // Aligning with gin-mcp JSON-RPC 'name'
 70 | 	Input    map[string]interface{} `json:"arguments"` // Aligning with gin-mcp JSON-RPC 'arguments'
 71 | }
 72 | 
 73 | // ToolResultContent represents an item in the 'content' array of a tool_result.
 74 | type ToolResultContent struct {
 75 | 	Type string `json:"type"`
 76 | 	Text string `json:"text"` // Assuming text/JSON string result
 77 | 	// Add other content types if needed
 78 | }
 79 | 
 80 | // ToolResultPayload represents the structure for the 'result' of a 'tool_result' JSON-RPC response.
 81 | type ToolResultPayload struct {
 82 | 	Content    []ToolResultContent `json:"content"`                // Array of content items
 83 | 	IsError    bool                `json:"isError"`                // Aligning with gin-mcp
 84 | 	Error      *MCPError           `json:"error,omitempty"`        // Detailed error info if IsError is true
 85 | 	ToolCallID string              `json:"tool_call_id,omitempty"` // Optional: Can be helpful
 86 | }
 87 | 
 88 | // --- Server State ---
 89 | 
 90 | // activeConnections stores channels for sending messages back to active SSE clients.
 91 | var activeConnections = make(map[string]chan jsonRPCResponse) // Changed value type
 92 | var connMutex sync.RWMutex
 93 | 
 94 | // Channel buffer size
 95 | const messageChannelBufferSize = 10
 96 | 
 97 | // --- Server Implementation ---
 98 | 
 99 | // ServeMCP starts an HTTP server handling MCP communication.
100 | func ServeMCP(addr string, toolSet *mcp.ToolSet, cfg *config.Config) error {
101 | 	log.Printf("Preparing ToolSet for MCP...")
102 | 
103 | 	// --- Handler Functions ---
104 | 	mcpHandler := func(w http.ResponseWriter, r *http.Request) {
105 | 		// CORS Headers (Apply to all relevant requests)
106 | 		w.Header().Set("Access-Control-Allow-Origin", "*") // Be more specific in production
107 | 		w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
108 | 		w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-Connection-ID")
109 | 		w.Header().Set("Access-Control-Expose-Headers", "X-Connection-ID")
110 | 
111 | 		if r.Method == http.MethodOptions {
112 | 			log.Println("Responding to OPTIONS request")
113 | 			w.WriteHeader(http.StatusNoContent) // Use 204 No Content for OPTIONS
114 | 			return
115 | 		}
116 | 
117 | 		if r.Method == http.MethodGet {
118 | 			httpMethodGetHandler(w, r) // Handle SSE connection setup
119 | 		} else if r.Method == http.MethodPost {
120 | 			httpMethodPostHandler(w, r, toolSet, cfg) // Pass the cfg object here
121 | 		} else {
122 | 			log.Printf("Method Not Allowed: %s", r.Method)
123 | 			http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
124 | 		}
125 | 	}
126 | 
127 | 	// Setup server mux
128 | 	mux := http.NewServeMux()
129 | 	mux.HandleFunc("/mcp", mcpHandler) // Single endpoint for GET/POST/OPTIONS
130 | 
131 | 	log.Printf("MCP server listening on %s/mcp", addr)
132 | 	return http.ListenAndServe(addr, mux)
133 | }
134 | 
135 | // httpMethodGetHandler handles the initial GET request to establish the SSE connection.
136 | func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) {
137 | 	connectionID := uuid.New().String()
138 | 	log.Printf("SSE client connecting: %s (Assigning ID: %s)", r.RemoteAddr, connectionID)
139 | 
140 | 	flusher, ok := w.(http.Flusher)
141 | 	if !ok {
142 | 		http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
143 | 		log.Println("Error: Client connection does not support flushing")
144 | 		return
145 | 	}
146 | 
147 | 	// --- Set headers FIRST ---
148 | 	w.Header().Set("Content-Type", "text/event-stream")
149 | 	w.Header().Set("Cache-Control", "no-cache")
150 | 	w.Header().Set("Connection", "keep-alive")
151 | 	// CORS headers are set in the main handler
152 | 	w.Header().Set("X-Connection-ID", connectionID)
153 | 	w.Header().Set("X-Accel-Buffering", "no") // Useful for proxies like Nginx
154 | 	w.WriteHeader(http.StatusOK)              // Write headers and status code
155 | 	flusher.Flush()                           // Ensure headers are sent immediately
156 | 
157 | 	// --- Send initial :ok --- (Must happen *after* headers)
158 | 	if _, err := fmt.Fprintf(w, ":ok\n\n"); err != nil {
159 | 		log.Printf("Error sending SSE preamble to %s (ID: %s): %v", r.RemoteAddr, connectionID, err)
160 | 		return // Cannot proceed if preamble fails
161 | 	}
162 | 	flusher.Flush()
163 | 	log.Printf("Sent :ok preamble to %s (ID: %s)", r.RemoteAddr, connectionID)
164 | 
165 | 	// --- Send initial SSE events --- (endpoint, mcp-ready)
166 | 	endpointURL := fmt.Sprintf("/mcp?sessionId=%s", connectionID) // Assuming /mcp is the mount path
167 | 	if err := writeSSEEvent(w, "endpoint", endpointURL); err != nil {
168 | 		log.Printf("Error sending SSE endpoint event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err)
169 | 		return
170 | 	}
171 | 	flusher.Flush()
172 | 	log.Printf("Sent endpoint event to %s (ID: %s)", r.RemoteAddr, connectionID)
173 | 
174 | 	readyMsg := jsonRPCRequest{ // Use request struct for notification format
175 | 		Jsonrpc: "2.0",
176 | 		Method:  "mcp-ready",
177 | 		Params: map[string]interface{}{ // Put data in params
178 | 			"connectionId": connectionID,
179 | 			"status":       "connected",
180 | 			"protocol":     "2.0",
181 | 		},
182 | 	}
183 | 	if err := writeSSEEvent(w, "message", readyMsg); err != nil {
184 | 		log.Printf("Error sending SSE mcp-ready event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err)
185 | 		return
186 | 	}
187 | 	flusher.Flush()
188 | 	log.Printf("Sent mcp-ready event to %s (ID: %s)", r.RemoteAddr, connectionID)
189 | 
190 | 	// --- Setup message channel and store connection ---
191 | 	msgChan := make(chan jsonRPCResponse, messageChannelBufferSize) // Channel for responses
192 | 	connMutex.Lock()
193 | 	activeConnections[connectionID] = msgChan
194 | 	connMutex.Unlock()
195 | 	log.Printf("Registered channel for connection %s. Active connections: %d", connectionID, len(activeConnections))
196 | 
197 | 	// --- Cleanup function ---
198 | 	cleanup := func() {
199 | 		connMutex.Lock()
200 | 		delete(activeConnections, connectionID)
201 | 		connMutex.Unlock()
202 | 		close(msgChan) // Close channel when connection ends
203 | 		log.Printf("Removed connection %s. Active connections: %d", connectionID, len(activeConnections))
204 | 	}
205 | 	defer cleanup()
206 | 
207 | 	// --- Goroutine to write messages from channel to SSE stream ---
208 | 	ctx, cancel := context.WithCancel(r.Context())
209 | 	defer cancel()
210 | 
211 | 	go func() {
212 | 		log.Printf("[SSE Writer %s] Starting message writer goroutine", connectionID)
213 | 		defer log.Printf("[SSE Writer %s] Exiting message writer goroutine", connectionID)
214 | 		for {
215 | 			select {
216 | 			case <-ctx.Done():
217 | 				return // Exit if main context is cancelled
218 | 			case resp, ok := <-msgChan:
219 | 				if !ok {
220 | 					log.Printf("[SSE Writer %s] Message channel closed.", connectionID)
221 | 					return // Exit if channel is closed
222 | 				}
223 | 				log.Printf("[SSE Writer %s] Sending message (ID: %v) via SSE", connectionID, resp.ID)
224 | 				if err := writeSSEEvent(w, "message", resp); err != nil {
225 | 					log.Printf("[SSE Writer %s] Error writing message to SSE stream: %v. Cancelling context.", connectionID, err)
226 | 					cancel() // Signal main loop to exit on write error
227 | 					return
228 | 				}
229 | 				flusher.Flush() // Flush after writing message
230 | 			}
231 | 		}
232 | 	}()
233 | 
234 | 	// --- Keep connection alive (main loop) ---
235 | 	keepAliveTicker := time.NewTicker(20 * time.Second)
236 | 	defer keepAliveTicker.Stop()
237 | 
238 | 	log.Printf("[SSE %s] Entering keep-alive loop", connectionID)
239 | 	for {
240 | 		select {
241 | 		case <-ctx.Done():
242 | 			log.Printf("[SSE %s] Context done. Exiting keep-alive loop.", connectionID)
243 | 			return // Exit loop if context cancelled (client disconnect or write error)
244 | 		case <-keepAliveTicker.C:
245 | 			// Send JSON-RPC ping notification instead of SSE comment
246 | 			pingMsg := jsonRPCRequest{ // Use request struct for notification format
247 | 				Jsonrpc: "2.0",
248 | 				Method:  "ping",
249 | 				Params: map[string]interface{}{ // Include timestamp like gin-mcp
250 | 					"timestamp": time.Now().Unix(),
251 | 				},
252 | 			}
253 | 			if err := writeSSEEvent(w, "message", pingMsg); err != nil {
254 | 				log.Printf("[SSE %s] Error sending ping notification: %v. Closing connection.", connectionID, err)
255 | 				cancel() // Signal writer goroutine and exit
256 | 				return
257 | 			}
258 | 			flusher.Flush()
259 | 		}
260 | 	}
261 | }
262 | 
263 | // writeSSEEvent formats and writes data as a Server-Sent Event.
264 | func writeSSEEvent(w http.ResponseWriter, eventName string, data interface{}) error {
265 | 	buffer := bytes.Buffer{}
266 | 	if eventName != "" {
267 | 		buffer.WriteString(fmt.Sprintf("event: %s\n", eventName))
268 | 	}
269 | 
270 | 	// Marshal data to JSON if it's not a simple string already
271 | 	var dataStr string
272 | 	if strData, ok := data.(string); ok && eventName == "endpoint" { // Special case for endpoint URL
273 | 		dataStr = strData
274 | 	} else {
275 | 		jsonData, err := json.Marshal(data)
276 | 		if err != nil {
277 | 			return fmt.Errorf("failed to marshal data for SSE event '%s': %w", eventName, err)
278 | 		}
279 | 		dataStr = string(jsonData)
280 | 	}
281 | 
282 | 	// Write data line(s). Split multiline JSON for proper SSE formatting.
283 | 	lines := strings.Split(dataStr, "\n")
284 | 	for _, line := range lines {
285 | 		buffer.WriteString(fmt.Sprintf("data: %s\n", line))
286 | 	}
287 | 
288 | 	// Add final newline
289 | 	buffer.WriteString("\n")
290 | 
291 | 	// Write to the response writer
292 | 	_, err := w.Write(buffer.Bytes())
293 | 	if err != nil {
294 | 		return fmt.Errorf("failed to write SSE event '%s': %w", eventName, err)
295 | 	}
296 | 	return nil
297 | }
298 | 
299 | // httpMethodPostHandler handles incoming POST requests containing MCP messages.
300 | func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp.ToolSet, cfg *config.Config) {
301 | 	// --- Original Logic (Restored) ---
302 | 	connID := r.Header.Get("X-Connection-ID") // Try header first
303 | 	if connID == "" {
304 | 		connID = r.URL.Query().Get("sessionId") // Fallback to query parameter
305 | 		log.Printf("X-Connection-ID header missing, checking sessionId query param: found='%s'", connID)
306 | 	}
307 | 
308 | 	if connID == "" {
309 | 		log.Println("Error: POST request received without X-Connection-ID header or sessionId query parameter")
310 | 		http.Error(w, "Missing X-Connection-ID header or sessionId query parameter", http.StatusBadRequest)
311 | 		return
312 | 	}
313 | 
314 | 	// Find the corresponding message channel for this connection
315 | 	connMutex.RLock()
316 | 	msgChan, isActive := activeConnections[connID]
317 | 	connMutex.RUnlock()
318 | 
319 | 	if !isActive {
320 | 		log.Printf("Error: POST request received for inactive/unknown connection ID: %s", connID)
321 | 		// Still send sync error here, as we don't have a channel
322 | 		tryWriteHTTPError(w, http.StatusNotFound, "Invalid or expired connection ID")
323 | 		return
324 | 	}
325 | 
326 | 	bodyBytes, err := io.ReadAll(r.Body)
327 | 	if err != nil {
328 | 		log.Printf("Error reading POST request body for %s: %v", connID, err)
329 | 		// Create error response in the ToolResultPayload format
330 | 		errPayload := ToolResultPayload{
331 | 			IsError: true,
332 | 			Error: &MCPError{
333 | 				Code:    -32700, // JSON-RPC Parse Error Code
334 | 				Message: "Parse error reading request body",
335 | 			},
336 | 			// ToolCallID doesn't really apply here, maybe use connID or leave empty?
337 | 			// ToolCallID: connID,
338 | 		}
339 | 		errResp := jsonRPCResponse{
340 | 			Jsonrpc: "2.0",
341 | 			ID:      nil, // ID is unknown if we can't read the body
342 | 			Result:  errPayload,
343 | 			Error:   nil, // Ensure top-level error is nil
344 | 		}
345 | 		// Attempt to send via SSE channel
346 | 		select {
347 | 		case msgChan <- errResp:
348 | 			log.Printf("Queued read error response (ID: %v) for %s onto SSE channel (as Result)", errResp.ID, connID)
349 | 			// Send HTTP 202 Accepted back to the POST request
350 | 			w.WriteHeader(http.StatusAccepted)
351 | 			fmt.Fprintln(w, "Request accepted (with parse error), response will be sent via SSE.")
352 | 		default:
353 | 			log.Printf("Error: Failed to queue read error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID)
354 | 			// Send an error back on the POST request if channel fails
355 | 			tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel")
356 | 		}
357 | 		return // Stop processing
358 | 	}
359 | 	// No defer r.Body.Close() needed here as io.ReadAll reads to EOF
360 | 
361 | 	log.Printf("Received POST data for %s: %s", connID, string(bodyBytes))
362 | 
363 | 	// Attempt to unmarshal into a temporary map first to extract ID if possible
364 | 	var rawReq map[string]interface{}
365 | 	var reqID interface{} // Keep track of ID even if full unmarshal fails
366 | 
367 | 	// Try unmarshalling into raw map
368 | 	if err := json.Unmarshal(bodyBytes, &rawReq); err == nil {
369 | 		// Ensure reqID is treated as a string or number if possible, handle potential null
370 | 		if idVal, idExists := rawReq["id"]; idExists && idVal != nil {
371 | 			reqID = idVal
372 | 		} else {
373 | 			reqID = nil // Explicitly set to nil if missing or JSON null
374 | 		}
375 | 	} else {
376 | 		// Full unmarshal failed, log it but continue to try specific struct
377 | 		log.Printf("Warning: Initial unmarshal into map failed for %s: %v. Will attempt specific struct unmarshal.", connID, err)
378 | 		reqID = nil // ID is unknown
379 | 	}
380 | 
381 | 	var req jsonRPCRequest // Expect JSON-RPC request
382 | 	if err := json.Unmarshal(bodyBytes, &req); err != nil {
383 | 		log.Printf("Error decoding JSON-RPC request for %s: %v", connID, err)
384 | 		// Use createJSONRPCError to correctly format the error response
385 | 		errResp := createJSONRPCError(reqID, -32700, "Parse error decoding JSON request", err.Error())
386 | 
387 | 		// Attempt to send via SSE channel
388 | 		select {
389 | 		case msgChan <- errResp:
390 | 			log.Printf("Queued decode error response (ID: %v) for %s onto SSE channel", errResp.ID, connID)
391 | 			// Send HTTP 202 Accepted back to the POST request
392 | 			w.WriteHeader(http.StatusAccepted)
393 | 			// Use a specific message for decode errors
394 | 			fmt.Fprintln(w, "Request accepted (with decode error), response will be sent via SSE.")
395 | 		default:
396 | 			log.Printf("Error: Failed to queue decode error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID)
397 | 			// Send an error back on the POST request if channel fails
398 | 			tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel")
399 | 		}
400 | 		return // Stop processing
401 | 	}
402 | 
403 | 	// If we successfully unmarshalled 'req', ensure reqID matches req.ID
404 | 	if req.ID != nil {
405 | 		reqID = req.ID
406 | 	} else {
407 | 		reqID = nil
408 | 	}
409 | 
410 | 	// --- Variable to hold the final response to be sent via SSE ---
411 | 	var respToSend jsonRPCResponse
412 | 
413 | 	// --- Validate JSON-RPC Request ---
414 | 	if req.Jsonrpc != "2.0" {
415 | 		log.Printf("Invalid JSON-RPC version ('%s') for %s, ID: %v", req.Jsonrpc, connID, reqID)
416 | 		respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: jsonrpc field must be \"2.0\"", nil)
417 | 	} else if req.Method == "" {
418 | 		log.Printf("Missing JSON-RPC method for %s, ID: %v", connID, reqID)
419 | 		respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: method field is missing or empty", nil)
420 | 	} else {
421 | 		// --- Process the valid request ---
422 | 		log.Printf("Processing JSON-RPC message for %s: Method=%s, ID=%v", connID, req.Method, reqID)
423 | 		switch req.Method {
424 | 		case "initialize":
425 | 			incomingInitializeJSON, _ := json.Marshal(req)
426 | 			log.Printf("DEBUG: Handling 'initialize' for %s. Incoming request: %s", connID, string(incomingInitializeJSON))
427 | 			respToSend = handleInitializeJSONRPC(connID, &req)
428 | 			outgoingInitializeJSON, _ := json.Marshal(respToSend)
429 | 			log.Printf("DEBUG: Prepared 'initialize' response for %s. Outgoing response: %s", connID, string(outgoingInitializeJSON))
430 | 		case "notifications/initialized":
431 | 			log.Printf("Received 'notifications/initialized' notification for %s. Ignoring.", connID)
432 | 			w.WriteHeader(http.StatusAccepted)
433 | 			fmt.Fprintln(w, "Notification received.")
434 | 			return // Return early, do not send anything on SSE channel
435 | 		case "tools/list":
436 | 			respToSend = handleToolsListJSONRPC(connID, &req, toolSet)
437 | 		case "tools/call":
438 | 			respToSend = handleToolCallJSONRPC(connID, &req, toolSet, cfg)
439 | 		default:
440 | 			log.Printf("Received unknown JSON-RPC method '%s' for %s", req.Method, connID)
441 | 			respToSend = createJSONRPCError(reqID, -32601, fmt.Sprintf("Method not found: %s", req.Method), nil)
442 | 		}
443 | 	}
444 | 
445 | 	// --- Send response ASYNCHRONOUSLY via SSE channel (unless handled earlier) ---
446 | 	select {
447 | 	case msgChan <- respToSend:
448 | 		log.Printf("Queued response (ID: %v) for %s onto SSE channel", respToSend.ID, connID)
449 | 		// Send HTTP 202 Accepted back to the POST request
450 | 		w.WriteHeader(http.StatusAccepted)
451 | 		// Use the standard message for successfully queued responses
452 | 		fmt.Fprintln(w, "Request accepted, response will be sent via SSE.")
453 | 	default:
454 | 		log.Printf("Error: Failed to queue response (ID: %v) for %s - SSE channel likely full or closed.", respToSend.ID, connID)
455 | 		http.Error(w, "Failed to queue response for SSE channel", http.StatusInternalServerError)
456 | 	}
457 | }
458 | 
459 | // --- JSON-RPC Message Handlers --- // Implementations returning jsonRPCResponse
460 | 
461 | func handleInitializeJSONRPC(connID string, req *jsonRPCRequest) jsonRPCResponse {
462 | 	log.Printf("Handling 'initialize' (JSON-RPC) for %s", connID)
463 | 
464 | 	// Construct the result payload based on gin-mcp's structure using map[string]interface{}
465 | 	resultPayload := map[string]interface{}{
466 | 		"protocolVersion": "2024-11-05", // Aligning with gin-mcp
467 | 		"capabilities": map[string]interface{}{
468 | 			"tools": map[string]interface{}{
469 | 				"enabled": true,
470 | 				"config": map[string]interface{}{
471 | 					"listChanged": false,
472 | 				},
473 | 			},
474 | 			"prompts": map[string]interface{}{
475 | 				"enabled": false,
476 | 			},
477 | 			"resources": map[string]interface{}{
478 | 				"enabled": true,
479 | 			},
480 | 			"logging": map[string]interface{}{
481 | 				"enabled": false,
482 | 			},
483 | 			"roots": map[string]interface{}{
484 | 				"listChanged": false,
485 | 			},
486 | 		},
487 | 		"serverInfo": map[string]interface{}{
488 | 			"name":       "OpenAPI-MCP",       // Or use config name if available
489 | 			"version":    "openapi-mcp-0.1.0", // Your server version
490 | 			"apiVersion": "2024-11-05",        // MCP API version
491 | 		},
492 | 		"connectionId": connID, // Include the connection ID
493 | 	}
494 | 
495 | 	return jsonRPCResponse{
496 | 		Jsonrpc: "2.0",
497 | 		ID:      req.ID, // Match request ID
498 | 		Result:  resultPayload,
499 | 	}
500 | }
501 | 
502 | func handleToolsListJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.ToolSet) jsonRPCResponse {
503 | 	log.Printf("Handling 'tools/list' (JSON-RPC) for %s", connID)
504 | 
505 | 	// Construct the result payload based on gin-mcp's structure
506 | 	resultPayload := map[string]interface{}{
507 | 		"tools": toolSet.Tools,
508 | 		"metadata": map[string]interface{}{
509 | 			"version": "2024-11-05", // Align with gin-mcp if possible
510 | 			"count":   len(toolSet.Tools),
511 | 		},
512 | 	}
513 | 
514 | 	return jsonRPCResponse{
515 | 		Jsonrpc: "2.0",
516 | 		ID:      req.ID, // Match request ID
517 | 		Result:  resultPayload,
518 | 	}
519 | }
520 | 
521 | // executeToolCall performs the actual HTTP request based on the resolved operation and parameters.
522 | // It now correctly handles API key injection based on the *cfg* parameter.
523 | func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.Config) (*http.Response, error) {
524 | 	toolName := params.ToolName
525 | 	toolInput := params.Input // This is the map[string]interface{} from the client
526 | 
527 | 	log.Printf("[ExecuteToolCall] Looking up details for tool: %s", toolName)
528 | 	operation, ok := toolSet.Operations[toolName]
529 | 	if !ok {
530 | 		log.Printf("[ExecuteToolCall] Error: Operation details not found for tool '%s'", toolName)
531 | 		return nil, fmt.Errorf("operation details for tool '%s' not found", toolName)
532 | 	}
533 | 	log.Printf("[ExecuteToolCall] Found operation: Method=%s, Path=%s", operation.Method, operation.Path)
534 | 
535 | 	// --- Resolve API Key (using cfg passed from main) ---
536 | 	resolvedKey := cfg.GetAPIKey()
537 | 	apiKeyName := cfg.APIKeyName
538 | 	apiKeyLocation := cfg.APIKeyLocation
539 | 	hasServerKey := resolvedKey != "" && apiKeyName != "" && apiKeyLocation != ""
540 | 
541 | 	log.Printf("[ExecuteToolCall] API Key Details: Name='%s', In='%s', HasServerValue=%t", apiKeyName, apiKeyLocation, resolvedKey != "")
542 | 
543 | 	// --- Prepare Request Components ---
544 | 	baseURL := operation.BaseURL // Use BaseURL from the specific operation
545 | 	if cfg.ServerBaseURL != "" {
546 | 		baseURL = cfg.ServerBaseURL // Override if global base URL is set
547 | 		log.Printf("[ExecuteToolCall] Overriding base URL with global config: %s", baseURL)
548 | 	}
549 | 	if baseURL == "" {
550 | 		log.Printf("[ExecuteToolCall] Warning: No base URL found for operation %s and no global override set.", toolName)
551 | 		// For now, assume relative if empty.
552 | 	}
553 | 
554 | 	path := operation.Path
555 | 	queryParams := url.Values{}
556 | 	pathParams := make(map[string]string)
557 | 	headerParams := make(http.Header)        // For headers to add
558 | 	cookieParams := []*http.Cookie{}         // For cookies to add
559 | 	bodyData := make(map[string]interface{}) // For building the request body
560 | 	requestBodyRequired := operation.Method == "POST" || operation.Method == "PUT" || operation.Method == "PATCH"
561 | 
562 | 	// Create a map of expected parameters from the operation details for easier lookup
563 | 	expectedParams := make(map[string]string) // Map param name to its location ('in')
564 | 	for _, p := range operation.Parameters {
565 | 		expectedParams[p.Name] = p.In
566 | 	}
567 | 
568 | 	// --- Process Input Parameters (Separating and Handling API Key Override) ---
569 | 	log.Printf("[ExecuteToolCall] Processing %d input parameters...", len(toolInput))
570 | 	for key, value := range toolInput {
571 | 		// --- API Key Override Check ---
572 | 		// If this input param is the API key AND we have a valid server key config,
573 | 		// skip processing the client's value entirely.
574 | 		if hasServerKey && key == apiKeyName {
575 | 			log.Printf("[ExecuteToolCall] Skipping client-provided param '%s' due to server API key override.", key)
576 | 			continue
577 | 		}
578 | 		// --- End API Key Override ---
579 | 
580 | 		paramLocation, knownParam := expectedParams[key]
581 | 		pathPlaceholder := "{" + key + "}" // OpenAPI uses {param}
582 | 
583 | 		if strings.Contains(path, pathPlaceholder) {
584 | 			// Handle path parameter substitution
585 | 			pathParams[key] = fmt.Sprintf("%v", value)
586 | 			log.Printf("[ExecuteToolCall] Found path parameter %s=%v", key, value)
587 | 		} else if knownParam {
588 | 			// Handle parameters defined in the spec (query, header, cookie)
589 | 			switch paramLocation {
590 | 			case "query":
591 | 				queryParams.Add(key, fmt.Sprintf("%v", value))
592 | 				log.Printf("[ExecuteToolCall] Found query parameter %s=%v (from spec)", key, value)
593 | 			case "header":
594 | 				headerParams.Add(key, fmt.Sprintf("%v", value))
595 | 				log.Printf("[ExecuteToolCall] Found header parameter %s=%v (from spec)", key, value)
596 | 			case "cookie":
597 | 				cookieParams = append(cookieParams, &http.Cookie{Name: key, Value: fmt.Sprintf("%v", value)})
598 | 				log.Printf("[ExecuteToolCall] Found cookie parameter %s=%v (from spec)", key, value)
599 | 			// case "formData": // TODO: Handle form data if needed
600 | 			// 	bodyData[key] = value // Or handle differently based on content type
601 | 			// 	log.Printf("[ExecuteToolCall] Found formData parameter %s=%v (from spec)", key, value)
602 | 			default:
603 | 				// Known parameter but location handling is missing or mismatched.
604 | 				if paramLocation == "path" && (operation.Method == "GET" || operation.Method == "DELETE") {
605 | 					// If spec says 'path' but it wasn't in the actual path, and it's a GET/DELETE,
606 | 					// treat it as a query parameter as a fallback.
607 | 					log.Printf("[ExecuteToolCall] Warning: Parameter '%s' is 'path' in spec but not in URL path '%s'. Adding to query parameters as fallback for GET/DELETE.", key, operation.Path)
608 | 					queryParams.Add(key, fmt.Sprintf("%v", value))
609 | 				} else {
610 | 					// Otherwise, log the warning and ignore.
611 | 					log.Printf("[ExecuteToolCall] Warning: Parameter '%s' has unsupported or unhandled location '%s' in spec. Ignoring.", key, paramLocation)
612 | 				}
613 | 			}
614 | 		} else if requestBodyRequired {
615 | 			// If parameter is not in path or defined in spec params, and method expects a body,
616 | 			// assume it belongs in the request body.
617 | 			bodyData[key] = value
618 | 			log.Printf("[ExecuteToolCall] Added body parameter %s=%v (assumed)", key, value)
619 | 		} else {
620 | 			// Parameter not in path, not in spec, and not a body method.
621 | 			// This could be an extraneous parameter like 'explanation'. Log it.
622 | 			log.Printf("[ExecuteToolCall] Ignoring parameter '%s' as it doesn't match path or known parameter location for method %s.", key, operation.Method)
623 | 		}
624 | 	}
625 | 
626 | 	// --- Substitute Path Parameters ---
627 | 	for key, value := range pathParams {
628 | 		path = strings.Replace(path, "{"+key+"}", value, -1)
629 | 	}
630 | 
631 | 	// --- Inject Server API Key (if applicable) ---
632 | 	if hasServerKey {
633 | 		log.Printf("[ExecuteToolCall] Injecting server API key (Name: %s, Location: %s)", apiKeyName, string(apiKeyLocation))
634 | 		switch apiKeyLocation {
635 | 		case config.APIKeyLocationQuery:
636 | 			queryParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value
637 | 			log.Printf("[ExecuteToolCall] Injected API key '%s' into query parameters", apiKeyName)
638 | 		case config.APIKeyLocationHeader:
639 | 			headerParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value
640 | 			log.Printf("[ExecuteToolCall] Injected API key '%s' into headers", apiKeyName)
641 | 		case config.APIKeyLocationPath:
642 | 			pathPlaceholder := "{" + apiKeyName + "}"
643 | 			if strings.Contains(path, pathPlaceholder) {
644 | 				path = strings.Replace(path, pathPlaceholder, resolvedKey, -1)
645 | 				log.Printf("[ExecuteToolCall] Injected API key into path parameter '%s'", apiKeyName)
646 | 			} else {
647 | 				log.Printf("[ExecuteToolCall] Warning: API key location is 'path' but placeholder '%s' not found in final path '%s' for injection.", pathPlaceholder, path)
648 | 			}
649 | 		case config.APIKeyLocationCookie:
650 | 			// Check if cookie already exists from input, replace if so
651 | 			foundCookie := false
652 | 			for i, c := range cookieParams {
653 | 				if c.Name == apiKeyName {
654 | 					log.Printf("[ExecuteToolCall] Replacing existing cookie '%s' with injected API key.", apiKeyName)
655 | 					cookieParams[i] = &http.Cookie{Name: apiKeyName, Value: resolvedKey} // Replace existing
656 | 					foundCookie = true
657 | 					break
658 | 				}
659 | 			}
660 | 			if !foundCookie {
661 | 				log.Printf("[ExecuteToolCall] Adding new cookie '%s' with injected API key.", apiKeyName)
662 | 				cookieParams = append(cookieParams, &http.Cookie{Name: apiKeyName, Value: resolvedKey}) // Append new
663 | 			}
664 | 		default:
665 | 			// Use log.Printf for consistency
666 | 			log.Printf("Warning: Unsupported API key location specified in config: '%s'", apiKeyLocation)
667 | 		}
668 | 	} else {
669 | 		log.Printf("[ExecuteToolCall] Skipping server API key injection (config incomplete or key unresolved).")
670 | 	}
671 | 
672 | 	// --- Final URL Construction ---
673 | 	// Reconstruct query string *after* potential API key injection
674 | 	targetURL := baseURL + path
675 | 	if len(queryParams) > 0 {
676 | 		targetURL += "?" + queryParams.Encode()
677 | 	}
678 | 	log.Printf("[ExecuteToolCall] Final Target URL: %s %s", operation.Method, targetURL)
679 | 
680 | 	// --- Prepare Request Body ---
681 | 	var reqBody io.Reader
682 | 	var bodyBytes []byte // Keep for logging
683 | 	if requestBodyRequired && len(bodyData) > 0 {
684 | 		var err error
685 | 		bodyBytes, err = json.Marshal(bodyData)
686 | 		if err != nil {
687 | 			log.Printf("[ExecuteToolCall] Error marshalling request body: %v", err)
688 | 			return nil, fmt.Errorf("error marshalling request body: %w", err)
689 | 		}
690 | 		reqBody = bytes.NewBuffer(bodyBytes)
691 | 		log.Printf("[ExecuteToolCall] Request body: %s", string(bodyBytes))
692 | 	}
693 | 
694 | 	// --- Create HTTP Request ---
695 | 	req, err := http.NewRequest(operation.Method, targetURL, reqBody)
696 | 	if err != nil {
697 | 		log.Printf("[ExecuteToolCall] Error creating HTTP request: %v", err)
698 | 		return nil, fmt.Errorf("error creating request: %w", err)
699 | 	}
700 | 
701 | 	// --- Set Headers ---
702 | 	// Default headers
703 | 	req.Header.Set("Accept", "application/json") // Assume JSON response typical for APIs
704 | 	if reqBody != nil {
705 | 		req.Header.Set("Content-Type", "application/json") // Assume JSON body if body exists
706 | 	}
707 | 
708 | 	// Add headers collected from input/spec AND potentially injected API key
709 | 	for key, values := range headerParams {
710 | 		// Note: We use Set, assuming single value per header from input typically.
711 | 		// If multi-value headers are needed from spec/input, use Add.
712 | 		if len(values) > 0 {
713 | 			req.Header.Set(key, values[0])
714 | 		}
715 | 	}
716 | 
717 | 	// Add custom headers from config (comma-separated)
718 | 	if cfg.CustomHeaders != "" {
719 | 		headers := strings.Split(cfg.CustomHeaders, ",")
720 | 		for _, h := range headers {
721 | 			parts := strings.SplitN(h, ":", 2)
722 | 			if len(parts) == 2 {
723 | 				headerName := strings.TrimSpace(parts[0])
724 | 				headerValue := strings.TrimSpace(parts[1])
725 | 				if headerName != "" {
726 | 					req.Header.Set(headerName, headerValue) // Set overrides potential input
727 | 					log.Printf("[ExecuteToolCall] Added custom header from config: %s", headerName)
728 | 				}
729 | 			}
730 | 		}
731 | 	}
732 | 
733 | 	// --- Add Cookies ---
734 | 	for _, cookie := range cookieParams {
735 | 		req.AddCookie(cookie)
736 | 	}
737 | 
738 | 	log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header)
739 | 	if len(req.Cookies()) > 0 {
740 | 		log.Printf("[ExecuteToolCall] Sending request with cookies: %+v", req.Cookies())
741 | 	}
742 | 
743 | 	// --- Execute HTTP Request ---
744 | 	log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header)
745 | 	client := &http.Client{Timeout: 30 * time.Second}
746 | 	resp, err := client.Do(req)
747 | 	if err != nil {
748 | 		log.Printf("[ExecuteToolCall] Error executing HTTP request: %v", err)
749 | 		return nil, fmt.Errorf("error executing request: %w", err)
750 | 	}
751 | 
752 | 	log.Printf("[ExecuteToolCall] Request executed. Status Code: %d", resp.StatusCode)
753 | 	// Note: Don't close resp.Body here, the caller (handleToolCallJSONRPC) needs it.
754 | 	return resp, nil
755 | }
756 | 
757 | func handleToolCallJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.ToolSet, cfg *config.Config) jsonRPCResponse {
758 | 	// req.Params is interface{}, but should contain json.RawMessage for tools/call
759 | 	rawParams, ok := req.Params.(json.RawMessage)
760 | 	if !ok {
761 | 		// If it's not RawMessage, maybe it was already decoded to a map? Handle that case too.
762 | 		if paramsMap, mapOk := req.Params.(map[string]interface{}); mapOk {
763 | 			// Attempt to marshal the map back to JSON bytes
764 | 			var marshalErr error
765 | 			rawParams, marshalErr = json.Marshal(paramsMap)
766 | 			if marshalErr != nil {
767 | 				log.Printf("Error marshalling params map for %s: %v", connID, marshalErr)
768 | 				return createJSONRPCError(req.ID, -32602, "Invalid parameters format (map marshal failed)", marshalErr.Error())
769 | 			}
770 | 			log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from map)", connID, string(rawParams))
771 | 		} else {
772 | 			log.Printf("Invalid parameters format for tools/call (not json.RawMessage or map[string]interface{}): %T", req.Params)
773 | 			return createJSONRPCError(req.ID, -32602, "Invalid parameters format (expected JSON object)", nil)
774 | 		}
775 | 	} else {
776 | 		log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from RawMessage)", connID, string(rawParams))
777 | 	}
778 | 
779 | 	// Now, unmarshal the rawParams ([]byte) into ToolCallParams
780 | 	var params ToolCallParams
781 | 	if err := json.Unmarshal(rawParams, &params); err != nil {
782 | 		log.Printf("Error unmarshalling tools/call params for %s: %v", connID, err)
783 | 		return createJSONRPCError(req.ID, -32602, "Invalid parameters structure (unmarshal)", err.Error())
784 | 	}
785 | 
786 | 	log.Printf("Executing tool '%s' for %s with input: %+v", params.ToolName, connID, params.Input)
787 | 
788 | 	// --- Execute the actual tool call ---
789 | 	httpResp, execErr := executeToolCall(&params, toolSet, cfg)
790 | 
791 | 	// --- Process Response ---
792 | 	var resultPayload ToolResultPayload
793 | 	if execErr != nil {
794 | 		log.Printf("Error executing tool call '%s': %v", params.ToolName, execErr)
795 | 		resultPayload = ToolResultPayload{
796 | 			IsError: true,
797 | 			Error: &MCPError{
798 | 				Message: fmt.Sprintf("Failed to execute tool '%s': %v", params.ToolName, execErr),
799 | 			},
800 | 			ToolCallID: fmt.Sprintf("%v", req.ID),
801 | 		}
802 | 	} else {
803 | 		defer httpResp.Body.Close() // Ensure body is closed
804 | 		bodyBytes, readErr := io.ReadAll(httpResp.Body)
805 | 		if readErr != nil {
806 | 			log.Printf("Error reading response body for tool '%s': %v", params.ToolName, readErr)
807 | 			resultPayload = ToolResultPayload{
808 | 				IsError: true,
809 | 				Error: &MCPError{
810 | 					Message: fmt.Sprintf("Failed to read response from tool '%s': %v", params.ToolName, readErr),
811 | 				},
812 | 				ToolCallID: fmt.Sprintf("%v", req.ID),
813 | 			}
814 | 		} else {
815 | 			log.Printf("Received response body for tool '%s': %s", params.ToolName, string(bodyBytes))
816 | 			// Check status code for API-level errors
817 | 			if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
818 | 				resultPayload = ToolResultPayload{
819 | 					IsError: true,
820 | 					Error: &MCPError{
821 | 						Code:    httpResp.StatusCode,
822 | 						Message: fmt.Sprintf("Tool '%s' API call failed with status %s", params.ToolName, httpResp.Status),
823 | 						Data:    string(bodyBytes), // Include response body in error data
824 | 					},
825 | 					ToolCallID: fmt.Sprintf("%v", req.ID),
826 | 				}
827 | 			} else {
828 | 				// Successful execution
829 | 				resultContent := []ToolResultContent{
830 | 					{
831 | 						Type: "text", // TODO: Handle JSON responses properly if Content-Type indicates it
832 | 						Text: string(bodyBytes),
833 | 					},
834 | 				}
835 | 				resultPayload = ToolResultPayload{
836 | 					Content:    resultContent,
837 | 					IsError:    false,
838 | 					ToolCallID: fmt.Sprintf("%v", req.ID),
839 | 				}
840 | 			}
841 | 		}
842 | 	}
843 | 
844 | 	// --- Send Response ---
845 | 	return jsonRPCResponse{
846 | 		Jsonrpc: "2.0",
847 | 		ID:      req.ID,        // Match request ID
848 | 		Result:  resultPayload, // Use the actual result payload
849 | 	}
850 | }
851 | 
852 | // --- Helper Functions (Updated for JSON-RPC) ---
853 | 
854 | // sendJSONRPCResponse sends a JSON-RPC response *synchronously*.
855 | // Keep this for now for sending synchronous errors on POST decode/read failures.
856 | func sendJSONRPCResponse(w http.ResponseWriter, resp jsonRPCResponse) {
857 | 	w.Header().Set("Content-Type", "application/json")
858 | 	if err := json.NewEncoder(w).Encode(resp); err != nil {
859 | 		log.Printf("Error encoding JSON-RPC response (ID: %v) for ConnID %v: %v", resp.ID, resp.Error, err)
860 | 		// Attempt to send a plain text error if JSON encoding fails
861 | 		tryWriteHTTPError(w, http.StatusInternalServerError, "Internal Server Error encoding JSON-RPC response")
862 | 	}
863 | 	log.Printf("Sent JSON-RPC response: Method=%s, ID=%v", getMethodFromResponse(resp), resp.ID)
864 | }
865 | 
866 | // createJSONRPCError creates a JSON-RPC error response.
867 | func createJSONRPCError(id interface{}, code int, message string, data interface{}) jsonRPCResponse {
868 | 	jsonErr := &jsonError{Code: code, Message: message, Data: data}
869 | 	return jsonRPCResponse{
870 | 		Jsonrpc: "2.0",
871 | 		ID:      id, // Error response should echo the request ID
872 | 		Error:   jsonErr,
873 | 	}
874 | }
875 | 
876 | // sendJSONRPCError sends a JSON-RPC error response.
877 | func sendJSONRPCError(w http.ResponseWriter, connID string, id interface{}, code int, message string, data interface{}) {
878 | 	resp := createJSONRPCError(id, code, message, data)
879 | 	log.Printf("Sending JSON-RPC Error for ConnID %s, ID %v: Code=%d, Message='%s'", connID, id, code, message)
880 | 	sendJSONRPCResponse(w, resp)
881 | }
882 | 
883 | // Helper to get the method name for logging purposes (from the result/error structure if possible)
884 | func getMethodFromResponse(resp jsonRPCResponse) string {
885 | 	if resp.Result != nil {
886 | 		// Attempt to infer method from result structure if it has a type field
887 | 		if resMap, ok := resp.Result.(map[string]interface{}); ok {
888 | 			if methodType, typeOk := resMap["type"].(string); typeOk {
889 | 				return methodType + "_result"
890 | 			}
891 | 		}
892 | 		// Infer based on known result types if possible
893 | 		if _, ok := resp.Result.(map[string]interface{}); ok && resp.Result.(map[string]interface{})["tools"] != nil {
894 | 			return "tool_set"
895 | 		}
896 | 		// If not easily identifiable, just indicate success
897 | 		return "success"
898 | 	} else if resp.Error != nil {
899 | 		return "error"
900 | 	}
901 | 	return "unknown"
902 | }
903 | 
904 | // tryWriteHTTPError attempts to write an HTTP error, ignoring failures.
905 | func tryWriteHTTPError(w http.ResponseWriter, code int, message string) {
906 | 	if _, err := w.Write([]byte(message)); err != nil {
907 | 		log.Printf("Error writing plain HTTP error response: %v", err)
908 | 	}
909 | 	log.Printf("Sent plain HTTP error: %s (Code: %d)", message, code)
910 | }
911 | 
```

--------------------------------------------------------------------------------
/pkg/server/server_test.go:
--------------------------------------------------------------------------------

```go
   1 | package server
   2 | 
   3 | import (
   4 | 	"bytes"
   5 | 	"context"
   6 | 	"encoding/json"
   7 | 	"fmt"
   8 | 	"io"
   9 | 	"log"
  10 | 	"net/http"
  11 | 	"net/http/httptest"
  12 | 	"strings"
  13 | 	"sync"
  14 | 	"testing"
  15 | 	"time"
  16 | 
  17 | 	"github.com/ckanthony/openapi-mcp/pkg/config"
  18 | 	"github.com/ckanthony/openapi-mcp/pkg/mcp"
  19 | 	"github.com/google/uuid"
  20 | 	"github.com/stretchr/testify/assert"
  21 | 	"github.com/stretchr/testify/require"
  22 | )
  23 | 
  24 | // --- Re-added Helper Functions ---
  25 | 
  26 | // Helper function to create a simple ToolSet for testing tool calls
  27 | func createTestToolSetForCall() *mcp.ToolSet {
  28 | 	return &mcp.ToolSet{
  29 | 		Name: "Call Test API",
  30 | 		Tools: []mcp.Tool{
  31 | 			{
  32 | 				Name:        "get_user",
  33 | 				Description: "Get user details",
  34 | 				InputSchema: mcp.Schema{
  35 | 					Type: "object",
  36 | 					Properties: map[string]mcp.Schema{
  37 | 						"user_id": {Type: "string"},
  38 | 					},
  39 | 					Required: []string{"user_id"},
  40 | 				},
  41 | 			},
  42 | 			{
  43 | 				Name:        "post_data",
  44 | 				Description: "Post some data",
  45 | 				InputSchema: mcp.Schema{
  46 | 					Type: "object",
  47 | 					Properties: map[string]mcp.Schema{
  48 | 						"data": {Type: "string"},
  49 | 					},
  50 | 					Required: []string{"data"},
  51 | 				},
  52 | 			},
  53 | 		},
  54 | 		Operations: map[string]mcp.OperationDetail{
  55 | 			"get_user": {
  56 | 				Method: "GET",
  57 | 				Path:   "/users/{user_id}",
  58 | 				Parameters: []mcp.ParameterDetail{
  59 | 					{Name: "user_id", In: "path"},
  60 | 				},
  61 | 			},
  62 | 			"post_data": {
  63 | 				Method:     "POST",
  64 | 				Path:       "/data",
  65 | 				Parameters: []mcp.ParameterDetail{}, // Body params assumed
  66 | 			},
  67 | 		},
  68 | 	}
  69 | }
  70 | 
  71 | // Helper to safely manage activeConnections for tests
  72 | func setupTestConnection(connID string) chan jsonRPCResponse {
  73 | 	msgChan := make(chan jsonRPCResponse, 1) // Buffer of 1 sufficient for most tests
  74 | 	connMutex.Lock()
  75 | 	activeConnections[connID] = msgChan
  76 | 	connMutex.Unlock()
  77 | 	return msgChan
  78 | }
  79 | 
  80 | func cleanupTestConnection(connID string) {
  81 | 	connMutex.Lock()
  82 | 	msgChan, exists := activeConnections[connID]
  83 | 	if exists {
  84 | 		delete(activeConnections, connID)
  85 | 		close(msgChan)
  86 | 	}
  87 | 	connMutex.Unlock()
  88 | }
  89 | 
  90 | // --- End Re-added Helper Functions ---
  91 | 
  92 | func TestHttpMethodPostHandler(t *testing.T) {
  93 | 	// --- Setup common test items ---
  94 | 	toolSet := createTestToolSetForCall() // Use the helper
  95 | 	cfg := &config.Config{}               // Basic config
  96 | 	// NOTE: connID is now generated within each subtest to ensure isolation
  97 | 
  98 | 	// --- Define Test Cases ---
  99 | 	tests := []struct {
 100 | 		name                 string
 101 | 		requestBodyFn        func(connID string) string               // Function to generate body with dynamic connID
 102 | 		expectedSyncStatus   int                                      // Expected status code for the immediate POST response
 103 | 		expectedSyncBody     string                                   // Expected body for the immediate POST response
 104 | 		checkAsyncResponse   func(t *testing.T, resp jsonRPCResponse) // Function to check async response
 105 | 		mockBackend          http.HandlerFunc                         // Optional mock backend for tool calls
 106 | 		setupChannelDirectly func(connID string) chan jsonRPCResponse // Optional: For specific channel setups
 107 | 	}{
 108 | 		{
 109 | 			name: "Valid Initialize Request",
 110 | 			requestBodyFn: func(connID string) string {
 111 | 				return fmt.Sprintf(`{
 112 | 					"jsonrpc": "2.0", 
 113 | 					"method": "initialize", 
 114 | 					"id": "init-post-1", 
 115 | 					"params": {"connectionId": "%s"}
 116 | 				}`, connID)
 117 | 			},
 118 | 			expectedSyncStatus: http.StatusAccepted,
 119 | 			expectedSyncBody:   "Request accepted, response will be sent via SSE.\n",
 120 | 			checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
 121 | 				assert.Equal(t, "init-post-1", resp.ID)
 122 | 				assert.Nil(t, resp.Error)
 123 | 				resultMap, ok := resp.Result.(map[string]interface{})
 124 | 				require.True(t, ok)
 125 | 				assert.Contains(t, resultMap, "connectionId") // Check existence, actual ID checked separately
 126 | 				assert.Equal(t, "2024-11-05", resultMap["protocolVersion"])
 127 | 			},
 128 | 		},
 129 | 		{
 130 | 			name: "Valid Tools List Request",
 131 | 			requestBodyFn: func(connID string) string {
 132 | 				return `{
 133 | 					"jsonrpc": "2.0", 
 134 | 					"method": "tools/list", 
 135 | 					"id": "list-post-1"
 136 | 				}`
 137 | 			},
 138 | 			expectedSyncStatus: http.StatusAccepted,
 139 | 			expectedSyncBody:   "Request accepted, response will be sent via SSE.\n",
 140 | 			checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
 141 | 				assert.Equal(t, "list-post-1", resp.ID)
 142 | 				assert.Nil(t, resp.Error)
 143 | 				resultMap, ok := resp.Result.(map[string]interface{})
 144 | 				require.True(t, ok)
 145 | 				assert.Contains(t, resultMap, "metadata")
 146 | 				assert.Contains(t, resultMap, "tools")
 147 | 				metadata, _ := resultMap["metadata"].(map[string]interface{})
 148 | 				assert.Equal(t, 2, metadata["count"]) // Corrected: Expect int(2)
 149 | 			},
 150 | 		},
 151 | 		{
 152 | 			name: "Valid Tool Call Request (Success)",
 153 | 			requestBodyFn: func(connID string) string {
 154 | 				return `{
 155 | 					"jsonrpc": "2.0", 
 156 | 					"method": "tools/call", 
 157 | 					"id": "call-post-1", 
 158 | 					"params": {"name": "get_user", "arguments": {"user_id": "postUser"}}
 159 | 				}`
 160 | 			},
 161 | 			expectedSyncStatus: http.StatusAccepted,
 162 | 			expectedSyncBody:   "Request accepted, response will be sent via SSE.\n",
 163 | 			checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
 164 | 				assert.Equal(t, "call-post-1", resp.ID)
 165 | 				assert.Nil(t, resp.Error)
 166 | 				resultPayload, ok := resp.Result.(ToolResultPayload)
 167 | 				require.True(t, ok)
 168 | 				assert.False(t, resultPayload.IsError)
 169 | 				require.Len(t, resultPayload.Content, 1)
 170 | 				assert.JSONEq(t, `{"id":"postUser"}`, resultPayload.Content[0].Text)
 171 | 			},
 172 | 			mockBackend: func(w http.ResponseWriter, r *http.Request) {
 173 | 				w.WriteHeader(http.StatusOK)
 174 | 				fmt.Fprintln(w, `{"id":"postUser"}`)
 175 | 			},
 176 | 		},
 177 | 		{
 178 | 			name: "Valid Tool Call Request (Tool Not Found)",
 179 | 			requestBodyFn: func(connID string) string {
 180 | 				return `{
 181 | 					"jsonrpc": "2.0", 
 182 | 					"method": "tools/call", 
 183 | 					"id": "call-post-err-1", 
 184 | 					"params": {"name": "nonexistent_tool", "arguments": {}}
 185 | 				}`
 186 | 			},
 187 | 			expectedSyncStatus: http.StatusAccepted,
 188 | 			expectedSyncBody:   "Request accepted, response will be sent via SSE.\n",
 189 | 			checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
 190 | 				assert.Equal(t, "call-post-err-1", resp.ID)
 191 | 				assert.Nil(t, resp.Error)
 192 | 				resultPayload, ok := resp.Result.(ToolResultPayload)
 193 | 				require.True(t, ok)
 194 | 				assert.True(t, resultPayload.IsError)
 195 | 				require.NotNil(t, resultPayload.Error)
 196 | 				assert.Contains(t, resultPayload.Error.Message, "operation details for tool 'nonexistent_tool' not found")
 197 | 			},
 198 | 		},
 199 | 		{
 200 | 			name: "Malformed JSON Request",
 201 | 			requestBodyFn: func(connID string) string {
 202 | 				return `{"jsonrpc": "2.0", "method": "initialize"`
 203 | 			},
 204 | 			expectedSyncStatus: http.StatusAccepted, // Even decode errors return 202, error is sent async
 205 | 			expectedSyncBody:   "Request accepted (with decode error), response will be sent via SSE.\n",
 206 | 			checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
 207 | 				assert.Nil(t, resp.ID) // ID might be nil if request parsing failed early
 208 | 				require.NotNil(t, resp.Error)
 209 | 				assert.Equal(t, -32700, resp.Error.Code)                                 // Parse Error
 210 | 				assert.Equal(t, "Parse error decoding JSON request", resp.Error.Message) // Corrected assertion
 211 | 			},
 212 | 		},
 213 | 		{
 214 | 			name: "Missing JSON-RPC Version",
 215 | 			requestBodyFn: func(connID string) string {
 216 | 				return `{
 217 | 					"method": "initialize", 
 218 | 					"id": "rpc-err-1"
 219 | 				}`
 220 | 			},
 221 | 			expectedSyncStatus: http.StatusAccepted,
 222 | 			expectedSyncBody:   "Request accepted, response will be sent via SSE.\n",
 223 | 			checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
 224 | 				assert.Equal(t, "rpc-err-1", resp.ID)
 225 | 				require.NotNil(t, resp.Error)
 226 | 				assert.Equal(t, -32600, resp.Error.Code) // Invalid Request
 227 | 				assert.Contains(t, resp.Error.Message, "jsonrpc field must be \"2.0\"")
 228 | 			},
 229 | 		},
 230 | 		{
 231 | 			name: "Unknown Method",
 232 | 			requestBodyFn: func(connID string) string {
 233 | 				return `{
 234 | 					"jsonrpc": "2.0", 
 235 | 					"method": "unknown/method", 
 236 | 					"id": "rpc-err-2"
 237 | 				}`
 238 | 			},
 239 | 			expectedSyncStatus: http.StatusAccepted,
 240 | 			expectedSyncBody:   "Request accepted, response will be sent via SSE.\n",
 241 | 			checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
 242 | 				assert.Equal(t, "rpc-err-2", resp.ID)
 243 | 				require.NotNil(t, resp.Error)
 244 | 				assert.Equal(t, -32601, resp.Error.Code) // Method not found
 245 | 				assert.Contains(t, resp.Error.Message, "Method not found")
 246 | 			},
 247 | 		},
 248 | 		{
 249 | 			name: "Missing Method",
 250 | 			requestBodyFn: func(connID string) string {
 251 | 				return `{
 252 | 					"jsonrpc": "2.0", 
 253 | 					"id": "rpc-err-3"
 254 | 				}`
 255 | 			},
 256 | 			expectedSyncStatus: http.StatusAccepted,
 257 | 			expectedSyncBody:   "Request accepted, response will be sent via SSE.\n",
 258 | 			checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
 259 | 				assert.Equal(t, "rpc-err-3", resp.ID)
 260 | 				require.NotNil(t, resp.Error)
 261 | 				assert.Equal(t, -32600, resp.Error.Code)                                                 // Invalid Request
 262 | 				assert.Equal(t, "Invalid Request: method field is missing or empty", resp.Error.Message) // Corrected assertion
 263 | 			},
 264 | 		},
 265 | 		{
 266 | 			name: "Error Queuing Response To SSE",
 267 | 			requestBodyFn: func(connID string) string { // Use a simple valid request like tools/list
 268 | 				return `{
 269 | 					"jsonrpc": "2.0", 
 270 | 					"method": "tools/list", 
 271 | 					"id": "list-post-err-queue"
 272 | 				}`
 273 | 			},
 274 | 			expectedSyncStatus: http.StatusInternalServerError,               // Expect 500 when channel is blocked
 275 | 			expectedSyncBody:   "Failed to queue response for SSE channel\n", // Specific error message expected
 276 | 			setupChannelDirectly: func(connID string) chan jsonRPCResponse {
 277 | 				// Create a NON-BUFFERED channel to simulate blocking/full channel
 278 | 				msgChan := make(chan jsonRPCResponse) // No buffer size!
 279 | 				connMutex.Lock()
 280 | 				activeConnections[connID] = msgChan
 281 | 				connMutex.Unlock()
 282 | 				// Important: Do NOT start a reader for this channel
 283 | 				return msgChan
 284 | 			},
 285 | 			checkAsyncResponse: nil, // No async response should be successfully sent
 286 | 		},
 287 | 	}
 288 | 
 289 | 	// --- Run Test Cases ---
 290 | 	for _, tc := range tests {
 291 | 		t.Run(tc.name, func(t *testing.T) {
 292 | 			connID := uuid.NewString() // Generate unique connID for each subtest
 293 | 
 294 | 			// Setup mock backend if needed for this test case
 295 | 			var backendServer *httptest.Server
 296 | 			// --- Add Connection ID before test ---
 297 | 			var msgChan chan jsonRPCResponse
 298 | 			if tc.setupChannelDirectly != nil {
 299 | 				// Use custom setup if provided (e.g., for blocking channel test)
 300 | 				msgChan = tc.setupChannelDirectly(connID)
 301 | 			} else {
 302 | 				// Default setup using the helper with buffered channel
 303 | 				msgChan = setupTestConnection(connID)
 304 | 			}
 305 | 			defer cleanupTestConnection(connID) // Ensure cleanup after test
 306 | 
 307 | 			if tc.mockBackend != nil {
 308 | 				backendServer = httptest.NewServer(tc.mockBackend)
 309 | 				defer backendServer.Close()
 310 | 				// IMPORTANT: Update the toolset's BaseURL for the relevant operation
 311 | 				if strings.Contains(tc.requestBodyFn(connID), "get_user") { // Simple check based on request
 312 | 					op := toolSet.Operations["get_user"]
 313 | 					op.BaseURL = backendServer.URL
 314 | 					toolSet.Operations["get_user"] = op
 315 | 				}
 316 | 				// Update post_data BaseURL if needed
 317 | 				if strings.Contains(tc.requestBodyFn(connID), "post_data") {
 318 | 					op := toolSet.Operations["post_data"]
 319 | 					op.BaseURL = backendServer.URL
 320 | 					toolSet.Operations["post_data"] = op
 321 | 				}
 322 | 			}
 323 | 
 324 | 			reqBody := tc.requestBodyFn(connID) // Generate request body
 325 | 			req := httptest.NewRequest(http.MethodPost, "/mcp", strings.NewReader(reqBody))
 326 | 			req.Header.Set("Content-Type", "application/json")
 327 | 			req.Header.Set("X-Connection-ID", connID) // Use the generated connID
 328 | 			rr := httptest.NewRecorder()
 329 | 
 330 | 			httpMethodPostHandler(rr, req, toolSet, cfg)
 331 | 
 332 | 			// 1. Check synchronous response
 333 | 			assert.Equal(t, tc.expectedSyncStatus, rr.Code, "Unexpected status code for sync response")
 334 | 			// Trim space for comparison as http.Error might add a newline our literal doesn't have
 335 | 			assert.Equal(t, strings.TrimSpace(tc.expectedSyncBody), strings.TrimSpace(rr.Body.String()), "Unexpected body for sync response")
 336 | 
 337 | 			// 2. Check asynchronous response (sent via SSE channel)
 338 | 			if tc.checkAsyncResponse != nil {
 339 | 				select {
 340 | 				case asyncResp := <-msgChan:
 341 | 					tc.checkAsyncResponse(t, asyncResp)
 342 | 				case <-time.After(100 * time.Millisecond): // Add a timeout
 343 | 					t.Fatal("Timeout waiting for async response on SSE channel")
 344 | 				}
 345 | 			} else {
 346 | 				// If no async check is defined, ensure nothing was sent (e.g., for queue error test)
 347 | 				select {
 348 | 				case unexpectedResp, ok := <-msgChan:
 349 | 					if ok { // Only fail if the channel wasn't closed AND we got a message
 350 | 						t.Errorf("Received unexpected async response when none was expected: %+v", unexpectedResp)
 351 | 					}
 352 | 					// If !ok, channel was closed, which is fine/expected after cleanup
 353 | 				case <-time.After(50 * time.Millisecond):
 354 | 					// Success - no message received quickly, channel likely blocked as expected
 355 | 				}
 356 | 			}
 357 | 		})
 358 | 	}
 359 | }
 360 | 
 361 | func TestHttpMethodGetHandler(t *testing.T) {
 362 | 	// --- Setup ---
 363 | 	// Reset global state for this test
 364 | 	connMutex.Lock()
 365 | 	originalConnections := activeConnections
 366 | 	activeConnections = make(map[string]chan jsonRPCResponse)
 367 | 	connMutex.Unlock()
 368 | 
 369 | 	req, err := http.NewRequest("GET", "/mcp", nil)
 370 | 	require.NoError(t, err, "Failed to create request")
 371 | 
 372 | 	rr := httptest.NewRecorder()
 373 | 
 374 | 	// Ensure cleanup happens regardless of test outcome
 375 | 	defer func() {
 376 | 		connMutex.Lock()
 377 | 		// Clean up any connections potentially left by the test
 378 | 		for id, ch := range activeConnections {
 379 | 			close(ch)
 380 | 			delete(activeConnections, id)
 381 | 			log.Printf("[DEFER Cleanup] Closed channel and removed connection %s", id)
 382 | 		}
 383 | 		activeConnections = originalConnections // Restore the original map
 384 | 		connMutex.Unlock()
 385 | 	}()
 386 | 
 387 | 	// --- Execute Handler (in a goroutine as it blocks waiting for context) ---
 388 | 	ctx, cancel := context.WithCancel(context.Background())
 389 | 	req = req.WithContext(ctx)
 390 | 
 391 | 	hwg := sync.WaitGroup{}
 392 | 	hwg.Add(1)
 393 | 	go func() {
 394 | 		defer hwg.Done()
 395 | 		// Simulate some work before handler returns
 396 | 		// In a real scenario, this would block on ctx.Done() or keepAliveTicker
 397 | 		// For the test, we just call cancel() after a short delay
 398 | 		// to simulate the connection ending gracefully.
 399 | 		time.AfterFunc(100*time.Millisecond, cancel) // Allow handler to start and write initial data
 400 | 		httpMethodGetHandler(rr, req)
 401 | 	}()
 402 | 
 403 | 	// Wait for the handler goroutine to finish.
 404 | 	// This ensures all writes to rr are complete before we read.
 405 | 	if !waitTimeout(&hwg, 2*time.Second) { // Use a reasonable timeout
 406 | 		t.Fatal("Handler goroutine did not exit cleanly after context cancellation")
 407 | 	}
 408 | 
 409 | 	// --- Assertions (Performed *after* handler completion) ---
 410 | 	assert.Equal(t, http.StatusOK, rr.Code, "Status code should be OK")
 411 | 
 412 | 	// Check headers are set correctly
 413 | 	assert.Equal(t, "text/event-stream", rr.Header().Get("Content-Type"))
 414 | 	assert.Equal(t, "no-cache", rr.Header().Get("Cache-Control"))
 415 | 	assert.Equal(t, "keep-alive", rr.Header().Get("Connection"))
 416 | 	connID := rr.Header().Get("X-Connection-ID")
 417 | 	assert.NotEmpty(t, connID, "X-Connection-ID header should be set")
 418 | 
 419 | 	// Check connection was registered and then cleaned up
 420 | 	connMutex.RLock()
 421 | 	_, exists := originalConnections[connID] // Check original map after cleanup
 422 | 	connMutex.RUnlock()
 423 | 	assert.False(t, exists, "Connection ID should be removed from map after handler exits")
 424 | 
 425 | 	// Check initial body content is present
 426 | 	bodyContent := rr.Body.String()
 427 | 	assert.Contains(t, bodyContent, ":ok\n\n", "Body should contain :ok preamble")
 428 | 	// Construct the expected endpoint data string accurately
 429 | 	expectedEndpointData := "data: /mcp?sessionId=" + connID + "\n\n"
 430 | 	assert.Contains(t, bodyContent, "event: endpoint\n"+expectedEndpointData, "Body should contain endpoint event")
 431 | 	assert.Contains(t, bodyContent, "event: message\ndata: {", "Body should contain start of a message event (e.g., mcp-ready)")
 432 | 	// Check if connectionId is present in the ready message (adjust based on actual JSON structure)
 433 | 	assert.Contains(t, bodyContent, `"connectionId":"`+connID+`"`, "Body should contain mcp-ready event with correct connection ID")
 434 | 
 435 | 	// The explicit cleanupTestConnection call is not needed because the handler's defer and the test's defer handle it.
 436 | }
 437 | 
 438 | func TestExecuteToolCall(t *testing.T) {
 439 | 	tests := []struct {
 440 | 		name              string
 441 | 		params            ToolCallParams
 442 | 		opDetail          mcp.OperationDetail
 443 | 		cfg               *config.Config
 444 | 		expectError       bool
 445 | 		containsError     string
 446 | 		requestAsserter   func(t *testing.T, r *http.Request) // Function to assert details of the received HTTP request
 447 | 		backendResponse   string                              // Response body from mock backend
 448 | 		backendStatusCode int                                 // Status code from mock backend
 449 | 	}{
 450 | 		// --- Basic GET with Path Param ---
 451 | 		{
 452 | 			name: "GET with path parameter",
 453 | 			params: ToolCallParams{
 454 | 				ToolName: "get_item",
 455 | 				Input:    map[string]interface{}{"item_id": "item123"},
 456 | 			},
 457 | 			opDetail: mcp.OperationDetail{
 458 | 				Method:     "GET",
 459 | 				Path:       "/items/{item_id}",
 460 | 				Parameters: []mcp.ParameterDetail{{Name: "item_id", In: "path"}},
 461 | 			},
 462 | 			cfg:               &config.Config{},
 463 | 			expectError:       false,
 464 | 			backendStatusCode: http.StatusOK,
 465 | 			backendResponse:   `{"status":"ok"}`,
 466 | 			requestAsserter: func(t *testing.T, r *http.Request) {
 467 | 				assert.Equal(t, http.MethodGet, r.Method)
 468 | 				assert.Equal(t, "/items/item123", r.URL.Path)
 469 | 				assert.Empty(t, r.URL.RawQuery)
 470 | 			},
 471 | 		},
 472 | 		// --- POST with Query, Header, Cookie, and Body Params ---
 473 | 		{
 474 | 			name: "POST with various params",
 475 | 			params: ToolCallParams{
 476 | 				ToolName: "create_resource",
 477 | 				Input: map[string]interface{}{
 478 | 					"queryArg":     "value1",
 479 | 					"X-Custom-Hdr": "headerValue",
 480 | 					"sessionToken": "cookieValue",
 481 | 					"bodyFieldA":   "A",
 482 | 					"bodyFieldB":   123,
 483 | 				},
 484 | 			},
 485 | 			opDetail: mcp.OperationDetail{
 486 | 				Method: "POST",
 487 | 				Path:   "/resources",
 488 | 				Parameters: []mcp.ParameterDetail{
 489 | 					{Name: "queryArg", In: "query"},
 490 | 					{Name: "X-Custom-Hdr", In: "header"},
 491 | 					{Name: "sessionToken", In: "cookie"},
 492 | 					// Body fields are implicitly handled
 493 | 				},
 494 | 			},
 495 | 			cfg:               &config.Config{},
 496 | 			expectError:       false,
 497 | 			backendStatusCode: http.StatusCreated,
 498 | 			backendResponse:   `{"id":"res456"}`,
 499 | 			requestAsserter: func(t *testing.T, r *http.Request) {
 500 | 				assert.Equal(t, http.MethodPost, r.Method)
 501 | 				assert.Equal(t, "/resources", r.URL.Path)
 502 | 				assert.Equal(t, "value1", r.URL.Query().Get("queryArg"))
 503 | 				assert.Equal(t, "headerValue", r.Header.Get("X-Custom-Hdr"))
 504 | 				cookie, err := r.Cookie("sessionToken")
 505 | 				require.NoError(t, err)
 506 | 				assert.Equal(t, "cookieValue", cookie.Value)
 507 | 				bodyBytes, _ := io.ReadAll(r.Body)
 508 | 				assert.JSONEq(t, `{"bodyFieldA":"A", "bodyFieldB":123}`, string(bodyBytes))
 509 | 			},
 510 | 		},
 511 | 		// --- API Key Injection (Header) ---
 512 | 		{
 513 | 			name: "API Key Injection (Header)",
 514 | 			params: ToolCallParams{
 515 | 				ToolName: "get_secure",
 516 | 				Input:    map[string]interface{}{}, // No client key provided
 517 | 			},
 518 | 			opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure"},
 519 | 			cfg: &config.Config{
 520 | 				APIKey:         "secret-server-key",
 521 | 				APIKeyName:     "Authorization",
 522 | 				APIKeyLocation: config.APIKeyLocationHeader,
 523 | 			},
 524 | 			expectError:       false,
 525 | 			backendStatusCode: http.StatusOK,
 526 | 			requestAsserter: func(t *testing.T, r *http.Request) {
 527 | 				assert.Equal(t, "secret-server-key", r.Header.Get("Authorization"))
 528 | 			},
 529 | 		},
 530 | 		// --- API Key Injection (Query) ---
 531 | 		{
 532 | 			name: "API Key Injection (Query)",
 533 | 			params: ToolCallParams{
 534 | 				ToolName: "get_secure",
 535 | 				Input:    map[string]interface{}{"otherParam": "abc"},
 536 | 			},
 537 | 			opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure", Parameters: []mcp.ParameterDetail{{Name: "otherParam", In: "query"}}},
 538 | 			cfg: &config.Config{
 539 | 				APIKey:         "secret-server-key-q",
 540 | 				APIKeyName:     "api_key",
 541 | 				APIKeyLocation: config.APIKeyLocationQuery,
 542 | 			},
 543 | 			expectError:       false,
 544 | 			backendStatusCode: http.StatusOK,
 545 | 			requestAsserter: func(t *testing.T, r *http.Request) {
 546 | 				assert.Equal(t, "secret-server-key-q", r.URL.Query().Get("api_key"))
 547 | 				assert.Equal(t, "abc", r.URL.Query().Get("otherParam")) // Ensure other params are preserved
 548 | 			},
 549 | 		},
 550 | 		// --- API Key Injection (Path) ---
 551 | 		{
 552 | 			name: "API Key Injection (Path)",
 553 | 			params: ToolCallParams{
 554 | 				ToolName: "get_secure_path",
 555 | 				Input:    map[string]interface{}{}, // Key comes from config
 556 | 			},
 557 | 			opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure/{apiKey}/data"},
 558 | 			cfg: &config.Config{
 559 | 				APIKey:         "path-key-123",
 560 | 				APIKeyName:     "apiKey", // Matches the placeholder name
 561 | 				APIKeyLocation: config.APIKeyLocationPath,
 562 | 			},
 563 | 			expectError:       false,
 564 | 			backendStatusCode: http.StatusOK,
 565 | 			requestAsserter: func(t *testing.T, r *http.Request) {
 566 | 				assert.Equal(t, "/secure/path-key-123/data", r.URL.Path)
 567 | 			},
 568 | 		},
 569 | 		// --- API Key Injection (Cookie) ---
 570 | 		{
 571 | 			name: "API Key Injection (Cookie)",
 572 | 			params: ToolCallParams{
 573 | 				ToolName: "get_secure_cookie",
 574 | 				Input:    map[string]interface{}{}, // Key comes from config
 575 | 			},
 576 | 			opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure_cookie"},
 577 | 			cfg: &config.Config{
 578 | 				APIKey:         "cookie-key-abc",
 579 | 				APIKeyName:     "AuthToken",
 580 | 				APIKeyLocation: config.APIKeyLocationCookie,
 581 | 			},
 582 | 			expectError:       false,
 583 | 			backendStatusCode: http.StatusOK,
 584 | 			requestAsserter: func(t *testing.T, r *http.Request) {
 585 | 				cookie, err := r.Cookie("AuthToken")
 586 | 				require.NoError(t, err)
 587 | 				assert.Equal(t, "cookie-key-abc", cookie.Value)
 588 | 			},
 589 | 		},
 590 | 		// --- Base URL Handling Tests ---
 591 | 		{
 592 | 			name:              "Base URL from Default (Mock Server)",
 593 | 			params:            ToolCallParams{ToolName: "get_default_url", Input: map[string]interface{}{}},
 594 | 			opDetail:          mcp.OperationDetail{Method: "GET", Path: "/path1"}, // No BaseURL here
 595 | 			cfg:               &config.Config{},                                   // No global override
 596 | 			expectError:       false,
 597 | 			backendStatusCode: http.StatusOK,
 598 | 			requestAsserter: func(t *testing.T, r *http.Request) {
 599 | 				// Should hit the mock server at the correct path
 600 | 				assert.Equal(t, "/path1", r.URL.Path)
 601 | 			},
 602 | 		},
 603 | 		{
 604 | 			name:     "Base URL from Global Config Override",
 605 | 			params:   ToolCallParams{ToolName: "get_global_url", Input: map[string]interface{}{}},
 606 | 			opDetail: mcp.OperationDetail{Method: "GET", Path: "/path2", BaseURL: "http://should-be-ignored.com"},
 607 | 			// cfg will be updated in test loop to point ServerBaseURL to mock server
 608 | 			cfg:               &config.Config{},
 609 | 			expectError:       false,
 610 | 			backendStatusCode: http.StatusOK,
 611 | 			requestAsserter: func(t *testing.T, r *http.Request) {
 612 | 				// Should hit the mock server (set via cfg override) at the correct path
 613 | 				assert.Equal(t, "/path2", r.URL.Path)
 614 | 			},
 615 | 		},
 616 | 		// --- Error Case (Tool Not Found in ToolSet) ---
 617 | 		{
 618 | 			name: "Error - Tool Not Found",
 619 | 			params: ToolCallParams{
 620 | 				ToolName: "nonexistent",
 621 | 				Input:    map[string]interface{}{},
 622 | 			},
 623 | 			opDetail:          mcp.OperationDetail{}, // Not used, error occurs before this
 624 | 			cfg:               &config.Config{},
 625 | 			expectError:       true,
 626 | 			containsError:     "operation details for tool 'nonexistent' not found",
 627 | 			requestAsserter:   nil, // No request should be made
 628 | 			backendStatusCode: 0,   // Not applicable
 629 | 		},
 630 | 	}
 631 | 
 632 | 	for _, tc := range tests {
 633 | 		t.Run(tc.name, func(t *testing.T) {
 634 | 			// --- Mock Backend Setup ---
 635 | 			var backendServer *httptest.Server
 636 | 			backendServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 637 | 				if tc.requestAsserter != nil {
 638 | 					tc.requestAsserter(t, r)
 639 | 				}
 640 | 				w.WriteHeader(tc.backendStatusCode)
 641 | 				fmt.Fprint(w, tc.backendResponse)
 642 | 			}))
 643 | 			defer backendServer.Close()
 644 | 
 645 | 			// --- Prepare ToolSet (using mock server URL if needed) ---
 646 | 			toolSet := &mcp.ToolSet{
 647 | 				Operations: make(map[string]mcp.OperationDetail),
 648 | 			}
 649 | 
 650 | 			// Clone config to avoid modifying the template test case config
 651 | 			testCfg := *tc.cfg
 652 | 
 653 | 			// Special handling for the global override test case
 654 | 			if tc.name == "Base URL from Global Config Override" {
 655 | 				testCfg.ServerBaseURL = backendServer.URL // Point global override to mock server
 656 | 			}
 657 | 
 658 | 			// If the opDetail needs a BaseURL, set it to the mock server ONLY if it wasn't
 659 | 			// already set in the test case definition AND the global override isn't being used.
 660 | 			if tc.opDetail.Method != "" { // Only add if it's a valid detail for the test
 661 | 				if tc.opDetail.BaseURL == "" && testCfg.ServerBaseURL == "" {
 662 | 					tc.opDetail.BaseURL = backendServer.URL
 663 | 				}
 664 | 				toolSet.Operations[tc.params.ToolName] = tc.opDetail
 665 | 			}
 666 | 
 667 | 			// --- Execute Function ---
 668 | 			httpResp, err := executeToolCall(&tc.params, toolSet, &testCfg) // Use the potentially modified testCfg
 669 | 
 670 | 			// --- Assertions ---
 671 | 			if tc.expectError {
 672 | 				assert.Error(t, err)
 673 | 				if tc.containsError != "" {
 674 | 					assert.Contains(t, err.Error(), tc.containsError)
 675 | 				}
 676 | 				assert.Nil(t, httpResp)
 677 | 			} else {
 678 | 				assert.NoError(t, err)
 679 | 				require.NotNil(t, httpResp)
 680 | 				defer httpResp.Body.Close()
 681 | 				assert.Equal(t, tc.backendStatusCode, httpResp.StatusCode)
 682 | 				bodyBytes, _ := io.ReadAll(httpResp.Body)
 683 | 				assert.Equal(t, tc.backendResponse, string(bodyBytes))
 684 | 			}
 685 | 		})
 686 | 	}
 687 | }
 688 | 
 689 | func TestWriteSSEEvent(t *testing.T) {
 690 | 	tests := []struct {
 691 | 		name        string
 692 | 		eventName   string
 693 | 		data        interface{}
 694 | 		expectedOut string
 695 | 		expectError bool
 696 | 	}{
 697 | 		{
 698 | 			name:        "Simple String Data",
 699 | 			eventName:   "endpoint",
 700 | 			data:        "/mcp?sessionId=123",
 701 | 			expectedOut: "event: endpoint\ndata: /mcp?sessionId=123\n\n",
 702 | 			expectError: false,
 703 | 		},
 704 | 		{
 705 | 			name:      "Struct Data (JSON-RPC Request)",
 706 | 			eventName: "message",
 707 | 			data: jsonRPCRequest{
 708 | 				Jsonrpc: "2.0",
 709 | 				Method:  "mcp-ready",
 710 | 				Params:  map[string]interface{}{"connectionId": "abc"},
 711 | 			},
 712 | 			// Note: JSON marshaling order isn't guaranteed, so use JSONEq or check fields
 713 | 			expectedOut: "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"mcp-ready\",\"params\":{\"connectionId\":\"abc\"}}\n\n",
 714 | 			expectError: false,
 715 | 		},
 716 | 		{
 717 | 			name:      "Struct Data (JSON-RPC Response)",
 718 | 			eventName: "message",
 719 | 			data: jsonRPCResponse{
 720 | 				Jsonrpc: "2.0",
 721 | 				Result:  map[string]interface{}{"status": "ok"},
 722 | 				ID:      "req-1",
 723 | 			},
 724 | 			expectedOut: "event: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{\"status\":\"ok\"},\"id\":\"req-1\"}\n\n",
 725 | 			expectError: false,
 726 | 		},
 727 | 		{
 728 | 			name:        "Error - Unmarshalable Data",
 729 | 			eventName:   "error",
 730 | 			data:        make(chan int), // Channels cannot be marshaled to JSON
 731 | 			expectError: true,
 732 | 		},
 733 | 	}
 734 | 
 735 | 	for _, tc := range tests {
 736 | 		t.Run(tc.name, func(t *testing.T) {
 737 | 			rr := httptest.NewRecorder()
 738 | 			err := writeSSEEvent(rr, tc.eventName, tc.data)
 739 | 
 740 | 			if tc.expectError {
 741 | 				assert.Error(t, err)
 742 | 			} else {
 743 | 				assert.NoError(t, err)
 744 | 				// For struct data, use JSONEq for robust comparison
 745 | 				if _, isStruct := tc.data.(jsonRPCRequest); isStruct {
 746 | 					prefix := fmt.Sprintf("event: %s\ndata: ", tc.eventName)
 747 | 					suffix := "\n\n"
 748 | 					require.True(t, strings.HasPrefix(rr.Body.String(), prefix))
 749 | 					require.True(t, strings.HasSuffix(rr.Body.String(), suffix))
 750 | 					actualJSON := strings.TrimSuffix(strings.TrimPrefix(rr.Body.String(), prefix), suffix)
 751 | 					expectedJSONBytes, _ := json.Marshal(tc.data)
 752 | 					assert.JSONEq(t, string(expectedJSONBytes), actualJSON)
 753 | 				} else if _, isStruct := tc.data.(jsonRPCResponse); isStruct {
 754 | 					prefix := fmt.Sprintf("event: %s\ndata: ", tc.eventName)
 755 | 					suffix := "\n\n"
 756 | 					require.True(t, strings.HasPrefix(rr.Body.String(), prefix))
 757 | 					require.True(t, strings.HasSuffix(rr.Body.String(), suffix))
 758 | 					actualJSON := strings.TrimSuffix(strings.TrimPrefix(rr.Body.String(), prefix), suffix)
 759 | 					expectedJSONBytes, _ := json.Marshal(tc.data)
 760 | 					assert.JSONEq(t, string(expectedJSONBytes), actualJSON)
 761 | 				} else {
 762 | 					// For simple types, direct string comparison is fine
 763 | 					assert.Equal(t, tc.expectedOut, rr.Body.String())
 764 | 				}
 765 | 			}
 766 | 		})
 767 | 	}
 768 | }
 769 | 
 770 | func TestTryWriteHTTPError(t *testing.T) {
 771 | 	rr := httptest.NewRecorder()
 772 | 	message := "Test Error Message"
 773 | 	code := http.StatusInternalServerError
 774 | 
 775 | 	tryWriteHTTPError(rr, code, message)
 776 | 
 777 | 	// Note: tryWriteHTTPError doesn't set the status code, it only writes the body.
 778 | 	// The calling function is expected to have set the code earlier.
 779 | 	// So, we only check the body content here.
 780 | 	assert.Equal(t, message, rr.Body.String())
 781 | }
 782 | 
 783 | func TestGetMethodFromResponse(t *testing.T) {
 784 | 	tests := []struct {
 785 | 		name     string
 786 | 		response jsonRPCResponse
 787 | 		expected string
 788 | 	}{
 789 | 		{
 790 | 			name: "Error Response",
 791 | 			response: jsonRPCResponse{
 792 | 				Error: &jsonError{Code: -32600, Message: "..."},
 793 | 			},
 794 | 			expected: "error",
 795 | 		},
 796 | 		{
 797 | 			name: "Tool List Response",
 798 | 			response: jsonRPCResponse{
 799 | 				Result: map[string]interface{}{"tools": []interface{}{}, "metadata": map[string]interface{}{}},
 800 | 			},
 801 | 			expected: "tool_set",
 802 | 		},
 803 | 		{
 804 | 			name: "Initialize Response (Result is Map)",
 805 | 			response: jsonRPCResponse{
 806 | 				Result: map[string]interface{}{"protocolVersion": "...", "capabilities": map[string]interface{}{}},
 807 | 			},
 808 | 			expected: "success", // Falls back to 'success' as type isn't explicitly set
 809 | 		},
 810 | 		{
 811 | 			name: "Tool Call Response (Result is ToolResultPayload)",
 812 | 			response: jsonRPCResponse{
 813 | 				Result: ToolResultPayload{Content: []ToolResultContent{{Type: "text", Text: "..."}}},
 814 | 			},
 815 | 			expected: "success", // Falls back to 'success'
 816 | 		},
 817 | 		{
 818 | 			name:     "Empty Response",
 819 | 			response: jsonRPCResponse{},
 820 | 			expected: "unknown",
 821 | 		},
 822 | 	}
 823 | 
 824 | 	for _, tc := range tests {
 825 | 		t.Run(tc.name, func(t *testing.T) {
 826 | 			actual := getMethodFromResponse(tc.response)
 827 | 			assert.Equal(t, tc.expected, actual)
 828 | 		})
 829 | 	}
 830 | }
 831 | 
 832 | // --- Mock ResponseWriter for error simulation ---
 833 | 
 834 | // mockResponseWriter implements http.ResponseWriter and http.Flusher for testing SSE.
 835 | type sseMockResponseWriter struct {
 836 | 	hdr              http.Header // Internal map for headers
 837 | 	statusCode       int
 838 | 	body             *bytes.Buffer
 839 | 	flushed          bool
 840 | 	forceError       error // If set, Write and Flush will return this error
 841 | 	failAfterNWrites int   // Start failing after this many writes (-1 = disable)
 842 | 	writesMade       int   // Counter for writes made
 843 | }
 844 | 
 845 | // Renamed constructor
 846 | func newSseMockResponseWriter() *sseMockResponseWriter {
 847 | 	return &sseMockResponseWriter{
 848 | 		hdr:              make(http.Header), // Initialize internal map
 849 | 		body:             &bytes.Buffer{},
 850 | 		failAfterNWrites: -1, // Default to disabled
 851 | 	}
 852 | }
 853 | 
 854 | // Implement http.ResponseWriter interface
 855 | func (m *sseMockResponseWriter) Header() http.Header {
 856 | 	return m.hdr // Return the internal map
 857 | }
 858 | 
 859 | func (m *sseMockResponseWriter) WriteHeader(statusCode int) {
 860 | 	m.statusCode = statusCode
 861 | }
 862 | 
 863 | func (m *sseMockResponseWriter) Write(p []byte) (int, error) {
 864 | 	// Check if already forced error
 865 | 	if m.forceError != nil {
 866 | 		return 0, m.forceError
 867 | 	}
 868 | 
 869 | 	// Increment write count
 870 | 	m.writesMade++
 871 | 
 872 | 	// Check if write count triggers failure
 873 | 	if m.failAfterNWrites >= 0 && m.writesMade >= m.failAfterNWrites {
 874 | 		m.forceError = fmt.Errorf("forced write error after %d writes", m.failAfterNWrites)
 875 | 		log.Printf("DEBUG: sseMockResponseWriter triggering error: %v", m.forceError) // Debug log
 876 | 		return 0, m.forceError
 877 | 	}
 878 | 
 879 | 	// Proceed with normal write
 880 | 	return m.body.Write(p)
 881 | }
 882 | 
 883 | // Implement http.Flusher interface
 884 | func (m *sseMockResponseWriter) Flush() {
 885 | 	// Check if already forced error
 886 | 	if m.forceError != nil {
 887 | 		// Optional: log or handle repeated flush attempts after error
 888 | 		return
 889 | 	}
 890 | 
 891 | 	// Check if flush count triggers failure (less common to fail on flush, but possible)
 892 | 	// We are primarily testing Write failures, so we might skip count check here for simplicity
 893 | 	// or use a separate failAfterNFlushes counter if needed.
 894 | 
 895 | 	m.flushed = true
 896 | }
 897 | 
 898 | // Helper to get body content
 899 | func (m *sseMockResponseWriter) String() string {
 900 | 	return m.body.String()
 901 | }
 902 | 
 903 | // --- End Mock ResponseWriter ---
 904 | 
 905 | func TestHttpMethodGetHandler_WriteErrors(t *testing.T) {
 906 | 	tests := []struct {
 907 | 		name              string
 908 | 		errorOnStage      string // "preamble", "endpoint", "ready", "ping", "message"
 909 | 		forceError        error  // Error to set on the mock writer *before* handler runs
 910 | 		expectConnRemoved bool
 911 | 	}{
 912 | 		{"Error on Preamble (:ok)", "preamble", fmt.Errorf("forced write error during preamble"), true},
 913 | 		// Removed: {"Error on Endpoint Event", "endpoint", nil, true}, // Hard to simulate reliably without patching
 914 | 		// Removed: {"Error on MCP-Ready Event", "ready", nil, true},    // Hard to simulate reliably without patching
 915 | 		// TODO: Add test for error during keep-alive ping
 916 | 		// TODO: Add test for error during message write from channel
 917 | 	}
 918 | 
 919 | 	for _, tc := range tests {
 920 | 		t.Run(tc.name, func(t *testing.T) {
 921 | 			// Use renamed mock writer
 922 | 			mockWriter := newSseMockResponseWriter()
 923 | 			req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
 924 | 			var connID string // Variable to capture assigned ID
 925 | 
 926 | 			// Set the error on the writer *before* calling the handler
 927 | 			if tc.forceError != nil {
 928 | 				mockWriter.forceError = tc.forceError
 929 | 			}
 930 | 
 931 | 			// Need to capture connID *if* headers get written before error
 932 | 			// We can check mockWriter.Header() after the handler potentially runs
 933 | 
 934 | 			// Inject error based on the test stage - REMOVED FUNCTION PATCHING
 935 | 			/*
 936 | 				originalWriteSSE := writeSSEEvent
 937 | 				defer func() { writeSSEEvent = originalWriteSSE }() // Restore original
 938 | 
 939 | 				writeSSEEvent = func(w http.ResponseWriter, eventName string, data interface{}) error {
 940 | 				    // ... removed patching logic ...
 941 | 				}
 942 | 			*/
 943 | 
 944 | 			// Execute handler in goroutine as it might block briefly before erroring
 945 | 			done := make(chan struct{})
 946 | 			go func() {
 947 | 				defer close(done)
 948 | 				httpMethodGetHandler(mockWriter, req)
 949 | 			}()
 950 | 
 951 | 			// Wait for the handler goroutine to finish or timeout
 952 | 			select {
 953 | 			case <-done:
 954 | 				// Handler finished (presumably due to error)
 955 | 			case <-time.After(200 * time.Millisecond): // Generous timeout
 956 | 				t.Fatal("Timeout waiting for httpMethodGetHandler goroutine to exit after injected error")
 957 | 			}
 958 | 
 959 | 			// Capture ConnID *after* handler exit, in case headers were set before error
 960 | 			connID = mockWriter.Header().Get("X-Connection-ID")
 961 | 
 962 | 			// Assert connection removal
 963 | 			if tc.expectConnRemoved && connID != "" {
 964 | 				connMutex.RLock()
 965 | 				_, exists := activeConnections[connID]
 966 | 				connMutex.RUnlock()
 967 | 				assert.False(t, exists, "Connection %s should have been removed from activeConnections after write error", connID)
 968 | 			} else if tc.expectConnRemoved && connID == "" {
 969 | 				t.Log("Cannot assert connection removal as ConnID was not captured before error")
 970 | 			}
 971 | 		})
 972 | 	}
 973 | }
 974 | 
 975 | func TestHttpMethodGetHandler_GoroutineErrors(t *testing.T) {
 976 | 	t.Run("Error_on_Message_Write", func(t *testing.T) {
 977 | 		// Estimate writes before first message: :ok(1), endpoint(1), ready(1) = 3 writes
 978 | 		// Target failure on the 4th write (first write of the actual message event line)
 979 | 		mockWriter := newSseMockResponseWriter()
 980 | 		mockWriter.failAfterNWrites = 4 // Fail on the 4th write overall
 981 | 
 982 | 		req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
 983 | 		var connID string
 984 | 		var msgChan chan jsonRPCResponse
 985 | 
 986 | 		// Clean connections before test
 987 | 		connMutex.Lock()
 988 | 		activeConnections = make(map[string]chan jsonRPCResponse)
 989 | 		connMutex.Unlock()
 990 | 		defer func() {
 991 | 			// Clean up after test, ensure channel is closed if exists
 992 | 			connMutex.Lock()
 993 | 			if msgChan != nil {
 994 | 				// Only delete from map, handler is responsible for closing channel
 995 | 				delete(activeConnections, connID)
 996 | 			}
 997 | 			activeConnections = make(map[string]chan jsonRPCResponse) // Reset for other tests
 998 | 			connMutex.Unlock()
 999 | 		}()
1000 | 
1001 | 		done := make(chan struct{})
1002 | 		go func() {
1003 | 			defer close(done)
1004 | 			httpMethodGetHandler(mockWriter, req)
1005 | 			log.Println("DEBUG: httpMethodGetHandler goroutine exited")
1006 | 		}()
1007 | 
1008 | 		// Wait for the connection to be established
1009 | 		assert.Eventually(t, func() bool {
1010 | 			connMutex.RLock()
1011 | 			defer connMutex.RUnlock()
1012 | 			for id, ch := range activeConnections {
1013 | 				connID = id
1014 | 				msgChan = ch
1015 | 				log.Printf("DEBUG: Connection established: %s", connID)
1016 | 				return true
1017 | 			}
1018 | 			return false
1019 | 		}, 200*time.Millisecond, 20*time.Millisecond, "Connection not established in time")
1020 | 
1021 | 		require.NotEmpty(t, connID, "connID should have been captured")
1022 | 		require.NotNil(t, msgChan, "msgChan should have been captured")
1023 | 
1024 | 		// Send a message that should trigger the write error
1025 | 		testResp := jsonRPCResponse{Jsonrpc: "2.0", ID: "test-msg-1", Result: "test data"}
1026 | 		log.Printf("DEBUG: Sending test message to channel for %s", connID)
1027 | 		select {
1028 | 		case msgChan <- testResp:
1029 | 			log.Printf("DEBUG: Test message sent to channel for %s", connID)
1030 | 		case <-time.After(100 * time.Millisecond):
1031 | 			t.Fatal("Timeout sending message to channel")
1032 | 		}
1033 | 
1034 | 		// Wait for the handler goroutine to finish due to the write error
1035 | 		select {
1036 | 		case <-done:
1037 | 			log.Printf("DEBUG: Handler goroutine finished as expected after message write error")
1038 | 			// Handler finished (presumably due to write error)
1039 | 		case <-time.After(1000 * time.Millisecond): // Increased timeout to 1 second
1040 | 			t.Fatal("Timeout waiting for httpMethodGetHandler goroutine to exit after message write error")
1041 | 		}
1042 | 
1043 | 		// Assert connection removal
1044 | 		connMutex.RLock()
1045 | 		_, exists := activeConnections[connID]
1046 | 		connMutex.RUnlock()
1047 | 		assert.False(t, exists, "Connection %s should have been removed after message write error", connID)
1048 | 	})
1049 | 
1050 | 	// TODO: Add sub-test for Error_on_Ping_Write
1051 | }
1052 | 
1053 | // Helper function to wait for a WaitGroup with a timeout
1054 | func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
1055 | 	c := make(chan struct{})
1056 | 	go func() {
1057 | 		defer close(c)
1058 | 		wg.Wait()
1059 | 	}()
1060 | 	select {
1061 | 	case <-c:
1062 | 		return true // Completed normally
1063 | 	case <-time.After(timeout):
1064 | 		return false // Timed out
1065 | 	}
1066 | }
1067 | 
```
Page 2/3FirstPrevNextLast