This is page 1 of 2. Use http://codebase.md/ckanthony/openapi-mcp?page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .github
│ └── workflows
│ ├── ci.yml
│ └── publish.yml
├── .gitignore
├── cmd
│ └── openapi-mcp
│ └── main.go
├── Dockerfile
├── example
│ ├── agent_demo.png
│ ├── docker-compose.yml
│ └── weather
│ ├── .env.example
│ └── weatherbitio-swagger.json
├── go.mod
├── go.sum
├── openapi-mcp.png
├── pkg
│ ├── config
│ │ ├── config_test.go
│ │ └── config.go
│ ├── mcp
│ │ └── types.go
│ ├── parser
│ │ ├── parser_test.go
│ │ └── parser.go
│ └── server
│ ├── manager_test.go
│ ├── manager.go
│ ├── server_test.go
│ └── server.go
└── README.md
```
# Files
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
```
# Git files
.git
.gitignore
# Docker files
.dockerignore
Dockerfile
# Documentation
*.md
# Environment files (except example)
.env
*.env
!.env.example
# Go cache and modules (handled in multi-stage build)
vendor/
# Local build artifacts
openapi-mcp
*.exe
*.test
*.out
# OS generated files
.DS_Store
*~
```
--------------------------------------------------------------------------------
/example/weather/.env.example:
--------------------------------------------------------------------------------
```
# Example environment variables for the Weatherbit API example.
# Copy this file to .env in the same directory (example/weather/.env)
# and replace placeholders with your actual values.
# Required: Your Weatherbit API Key
API_KEY=YOUR_WEATHERBIT_API_KEY_HERE
# Optional: Custom headers (JSON format)
# REQUEST_HEADERS='{"X-Client-ID": "MyTestClient"}'
```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
# Go workspace file
go.work
go.work.sum
# Environment configuration files
.env
*.env
!.env.example
```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
# OpenAPI-MCP: Dockerized MCP Server to allow your AI agent to access any API with existing api docs
[](https://pkg.go.dev/github.com/ckanthony/openapi-mcp)
[](https://github.com/ckanthony/openapi-mcp/actions/workflows/ci.yml)
[](https://codecov.io/gh/ckanthony/openapi-mcp)


**Generate MCP tool definitions directly from a Swagger/OpenAPI specification file.**
OpenAPI-MCP is a dockerized MCP server that reads a `swagger.json` or `openapi.yaml` file and generates a corresponding [Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) toolset. This allows MCP-compatible clients like [Cursor](https://cursor.sh/) to interact with APIs described by standard OpenAPI specifications. Now you can enable your AI agent to access any API by simply providing its OpenAPI/Swagger specification - no additional coding required.
## Table of Contents
- [Why OpenAPI-MCP?](#why-openapi-mcp)
- [Features](#features)
- [Installation](#installation)
- [Using the Pre-built Docker Hub Image (Recommended)](#using-the-pre-built-docker-hub-image-recommended)
- [Building Locally (Optional)](#building-locally-optional)
- [Running the Weatherbit Example (Step-by-Step)](#running-the-weatherbit-example-step-by-step)
- [Command-Line Options](#command-line-options)
- [Environment Variables](#environment-variables)
## Demo
Run the demo yourself: [Running the Weatherbit Example (Step-by-Step)](#running-the-weatherbit-example-step-by-step)

## Why OpenAPI-MCP?
- **Standard Compliance:** Leverage your existing OpenAPI/Swagger documentation.
- **Automatic Tool Generation:** Create MCP tools without manual configuration for each endpoint.
- **Flexible API Key Handling:** Securely manage API key authentication for the proxied API without exposing keys to the MCP client.
- **Local & Remote Specs:** Works with local specification files or remote URLs.
- **Dockerized Tool:** Easily deploy and run as a containerized service with Docker.
## Features
- **OpenAPI v2 (Swagger) & v3 Support:** Parses standard specification formats.
- **Schema Generation:** Creates MCP tool schemas from OpenAPI operation parameters and request/response definitions.
- **Secure API Key Management:**
- Injects API keys into requests (`header`, `query`, `path`, `cookie`) based on command-line configuration.
- Loads API keys directly from flags (`--api-key`), environment variables (`--api-key-env`), or `.env` files located alongside local specs.
- Keeps API keys hidden from the end MCP client (e.g., the AI assistant).
- **Server URL Detection:** Uses server URLs from the spec as the base for tool interactions (can be overridden).
- **Filtering:** Options to include/exclude specific operations or tags (`--include-tag`, `--exclude-tag`, `--include-op`, `--exclude-op`).
- **Request Header Injection:** Pass custom headers (e.g., for additional auth, tracing) via the `REQUEST_HEADERS` environment variable.
## Installation
### Docker
The recommended way to run this tool is via [Docker](https://hub.docker.com/r/ckanthony/openapi-mcp).
#### Using the Pre-built Docker Hub Image (Recommended)
Alternatively, you can use the pre-built image available on [Docker Hub](https://hub.docker.com/r/ckanthony/openapi-mcp).
1. **Pull the Image:**
```bash
docker pull ckanthony/openapi-mcp:latest
```
2. **Run the Container:**
Follow the `docker run` examples above, but replace `openapi-mcp:latest` with `ckanthony/openapi-mcp:latest`.
#### Building Locally (Optional)
1. **Build the Docker Image Locally:**
```bash
# Navigate to the repository root
cd openapi-mcp
# Build the Docker image (tag it as you like, e.g., openapi-mcp:latest)
docker build -t openapi-mcp:latest .
```
2. **Run the Container:**
You need to provide the OpenAPI specification and any necessary API key configuration when running the container.
* **Example 1: Using a local spec file and `.env` file:**
- Create a directory (e.g., `./my-api`) containing your `openapi.json` or `swagger.yaml`.
- If the API requires a key, create a `.env` file in the *same directory* (e.g., `./my-api/.env`) with `API_KEY=your_actual_key` (replace `API_KEY` if your `--api-key-env` flag is different).
```bash
docker run -p 8080:8080 --rm \\
-v $(pwd)/my-api:/app/spec \\
--env-file $(pwd)/my-api/.env \\
openapi-mcp:latest \\
--spec /app/spec/openapi.json \\
--api-key-env API_KEY \\
--api-key-name X-API-Key \\
--api-key-loc header
```
*(Adjust `--spec`, `--api-key-env`, `--api-key-name`, `--api-key-loc`, and `-p` as needed.)*
* **Example 2: Using a remote spec URL and direct environment variable:**
```bash
docker run -p 8080:8080 --rm \\
-e SOME_API_KEY="your_actual_key" \\
openapi-mcp:latest \\
--spec https://petstore.swagger.io/v2/swagger.json \\
--api-key-env SOME_API_KEY \\
--api-key-name api_key \\
--api-key-loc header
```
* **Key Docker Run Options:**
* `-p <host_port>:8080`: Map a port on your host to the container's default port 8080.
* `--rm`: Automatically remove the container when it exits.
* `-v <host_path>:<container_path>`: Mount a local directory containing your spec into the container. Use absolute paths or `$(pwd)/...`. Common container path: `/app/spec`.
* `--env-file <path_to_host_env_file>`: Load environment variables from a local file (for API keys, etc.). Path is on the host.
* `-e <VAR_NAME>="<value>"`: Pass a single environment variable directly.
* `openapi-mcp:latest`: The name of the image you built locally.
* `--spec ...`: **Required.** Path to the spec file *inside the container* (e.g., `/app/spec/openapi.json`) or a public URL.
* `--port 8080`: (Optional) Change the internal port the server listens on (must match the container port in `-p`).
* `--api-key-env`, `--api-key-name`, `--api-key-loc`: Required if the target API needs an API key.
* (See `--help` for all command-line options by running `docker run --rm openapi-mcp:latest --help`)
## Running the Weatherbit Example (Step-by-Step)
This repository includes an example using the [Weatherbit API](https://www.weatherbit.io/). Here's how to run it using the public Docker image:
1. **Find OpenAPI Specs (Optional Knowledge):**
Many public APIs have their OpenAPI/Swagger specifications available online. A great resource for discovering them is [APIs.guru](https://apis.guru/). The Weatherbit specification used in this example (`weatherbitio-swagger.json`) was sourced from there.
2. **Get a Weatherbit API Key:**
* Go to [Weatherbit.io](https://www.weatherbit.io/) and sign up for an account (they offer a free tier).
* Find your API key in your Weatherbit account dashboard.
3. **Clone this Repository:**
You need the example files from this repository.
```bash
git clone https://github.com/ckanthony/openapi-mcp.git
cd openapi-mcp
```
4. **Prepare Environment File:**
* Navigate to the example directory: `cd example/weather`
* Copy the example environment file: `cp .env.example .env`
* Edit the new `.env` file and replace `YOUR_WEATHERBIT_API_KEY_HERE` with the actual API key you obtained from Weatherbit.
5. **Run the Docker Container:**
From the `openapi-mcp` **root directory** (the one containing the `example` folder), run the following command:
```bash
docker run -p 8080:8080 --rm \\
-v $(pwd)/example/weather:/app/spec \\
--env-file $(pwd)/example/weather/.env \\
ckanthony/openapi-mcp:latest \\
--spec /app/spec/weatherbitio-swagger.json \\
--api-key-env API_KEY \\
--api-key-name key \\
--api-key-loc query
```
* `-v $(pwd)/example/weather:/app/spec`: Mounts the local `example/weather` directory (containing the spec and `.env` file) to `/app/spec` inside the container.
* `--env-file $(pwd)/example/weather/.env`: Tells Docker to load environment variables (specifically `API_KEY`) from your `.env` file.
* `ckanthony/openapi-mcp:latest`: Uses the public Docker image.
* `--spec /app/spec/weatherbitio-swagger.json`: Points to the spec file inside the container.
* The `--api-key-*` flags configure how the tool should inject the API key (read from the `API_KEY` env var, named `key`, placed in the `query` string).
6. **Access the MCP Server:**
The MCP server should now be running and accessible at `http://localhost:8080` for compatible clients.
**Using Docker Compose (Example):**
A `docker-compose.yml` file is provided in the `example/` directory to demonstrate running the Weatherbit API example using the *locally built* image.
1. **Prepare Environment File:** Copy `example/weather/.env.example` to `example/weather/.env` and add your actual Weatherbit API key:
```dotenv
# example/weather/.env
API_KEY=YOUR_ACTUAL_WEATHERBIT_KEY
```
2. **Run with Docker Compose:** Navigate to the `example` directory and run:
```bash
cd example
# This builds the image locally based on ../Dockerfile
# It does NOT use the public Docker Hub image
docker-compose up --build
```
* `--build`: Forces Docker Compose to build the image using the `Dockerfile` in the project root before starting the service.
* Compose will read `example/docker-compose.yml`, build the image, mount `./weather`, read `./weather/.env`, and start the `openapi-mcp` container with the specified command-line arguments.
* The MCP server will be available at `http://localhost:8080`.
3. **Stop the service:** Press `Ctrl+C` in the terminal where Compose is running, or run `docker-compose down` from the `example` directory in another terminal.
## Command-Line Options
The `openapi-mcp` command accepts the following flags:
| Flag | Description | Type | Default |
|----------------------|---------------------------------------------------------------------------------------------------------------------|---------------|----------------------------------|
| `--spec` | **Required.** Path or URL to the OpenAPI specification file. | `string` | (none) |
| `--port` | Port to run the MCP server on. | `int` | `8080` |
| `--api-key` | Direct API key value (use `--api-key-env` or `.env` file instead for security). | `string` | (none) |
| `--api-key-env` | Environment variable name containing the API key. If spec is local, also checks `.env` file in the spec's directory. | `string` | (none) |
| `--api-key-name` | **Required if key used.** Name of the API key parameter (header, query, path, or cookie name). | `string` | (none) |
| `--api-key-loc` | **Required if key used.** Location of API key: `header`, `query`, `path`, or `cookie`. | `string` | (none) |
| `--include-tag` | Tag to include (can be repeated). If include flags are used, only included items are exposed. | `string slice`| (none) |
| `--exclude-tag` | Tag to exclude (can be repeated). Exclusions apply after inclusions. | `string slice`| (none) |
| `--include-op` | Operation ID to include (can be repeated). | `string slice`| (none) |
| `--exclude-op` | Operation ID to exclude (can be repeated). | `string slice`| (none) |
| `--base-url` | Manually override the target API server base URL detected from the spec. | `string` | (none) |
| `--name` | Default name for the generated MCP toolset (used if spec has no title). | `string` | "OpenAPI-MCP Tools" |
| `--desc` | Default description for the generated MCP toolset (used if spec has no description). | `string` | "Tools generated from OpenAPI spec" |
**Note:** You can get this list by running the tool with the `--help` flag (e.g., `docker run --rm ckanthony/openapi-mcp:latest --help`).
### Environment Variables
* `REQUEST_HEADERS`: Set this environment variable to a JSON string (e.g., `'{"X-Custom": "Value"}'`) to add custom headers to *all* outgoing requests to the target API.
```
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
```yaml
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
test:
name: Test
environment: CI
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.21'
cache: true
- name: Install dependencies
run: go mod tidy
working-directory: .
- name: Run tests with coverage
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
working-directory: .
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: ckanthony/openapi-mcp
```
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
```yaml
name: Publish Docker image
on:
push:
tags:
- 'v*.*.*' # Trigger on version tags like v1.0.0
jobs:
push_to_registry:
name: Build and push Docker image to Docker Hub
environment: CI
runs-on: ubuntu-latest
steps:
- name: Check out the repo
uses: actions/checkout@v4
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ckanthony
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
with:
images: ckanthony/openapi-mcp
# Add git tag as Docker tag
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
```
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
```dockerfile
# --- Build Stage ---
ARG GO_VERSION=1.22
FROM golang:${GO_VERSION}-alpine AS builder
WORKDIR /app
# Copy Go modules and download dependencies first
# This layer is cached unless go.mod or go.sum changes
COPY go.mod go.sum ./
RUN go mod download
# Copy the rest of the application source code
COPY . .
# Build the static binary for the command-line tool
# CGO_ENABLED=0 produces a static binary, important for distroless/scratch images
# -ldflags="-s -w" strips debug symbols and DWARF info, reducing binary size
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -o /openapi-mcp ./cmd/openapi-mcp/main.go
# --- Final Stage ---
# Use a minimal base image. distroless/static is very small and secure.
# alpine is another good option if you need a shell for debugging.
# FROM alpine:latest
FROM gcr.io/distroless/static-debian12 AS final
# Copy the static binary from the builder stage
COPY --from=builder /openapi-mcp /openapi-mcp
# Copy example files (optional, but useful for demonstrating)
COPY example /app/example
WORKDIR /app
# Define the default command to run when the container starts
# Users can override this command or provide arguments like --spec, --port etc.
ENTRYPOINT ["/openapi-mcp"]
# Expose the default port (optional, good documentation)
EXPOSE 8080
```
--------------------------------------------------------------------------------
/example/docker-compose.yml:
--------------------------------------------------------------------------------
```yaml
version: '3.8' # Specifies the Docker Compose file version
services:
openapi-mcp:
# Build the image using the Dockerfile located in the parent directory
build:
context: .. # The context is the parent directory (project root)
dockerfile: Dockerfile # Explicitly points to the Dockerfile
image: openapi-mcp-example-weather-compose:latest # Optional: Name the image built by compose
container_name: openapi-mcp-example-weather-service # Sets a specific name for the container
ports:
# Map port 8080 on the host to port 8080 in the container
- "8080:8080"
volumes:
# Mount the local './weather' directory (relative to this compose file)
# to '/app/example/weather' inside the container.
# This makes the spec file accessible to the application.
- ./weather:/app/example/weather
# Load environment variables from the .env file located in ./weather
# This is the recommended way to handle secrets like API keys.
# Ensure 'example/weather/.env' exists and defines API_KEY.
env_file:
- ./weather/.env
# Define the command to run inside the container, overriding the Dockerfile's CMD/ENTRYPOINT args
# Uses the variables loaded from the env_file.
# Make sure the --spec path matches the volume mount point.
command: >
--spec /app/example/weather/weatherbitio-swagger.json
--api-key-env API_KEY
--api-key-name key
--api-key-loc query
--port 8080 # The port the app listens on inside the container
# Restart policy: Automatically restart the container unless it was manually stopped.
restart: unless-stopped
```
--------------------------------------------------------------------------------
/pkg/server/manager.go:
--------------------------------------------------------------------------------
```go
package server
import (
"fmt"
"log"
"net/http"
"sync"
)
// client holds information about a connected SSE client.
type client struct {
writer http.ResponseWriter
flusher http.Flusher
// channel chan []byte // Could be used later for broadcasting updates
}
// connectionManager manages active client connections.
type connectionManager struct {
clients map[*http.Request]*client // Use request ptr as key
mu sync.RWMutex
toolSet []byte // Pre-encoded toolset JSON
}
// newConnectionManager creates a manager.
func newConnectionManager(toolSetJSON []byte) *connectionManager {
return &connectionManager{
clients: make(map[*http.Request]*client),
toolSet: toolSetJSON,
}
}
// addClient registers a new client and sends the initial toolset.
func (m *connectionManager) addClient(r *http.Request, w http.ResponseWriter, f http.Flusher) {
newClient := &client{writer: w, flusher: f}
m.mu.Lock()
m.clients[r] = newClient
m.mu.Unlock()
log.Printf("Client connected: %s (Total: %d)", r.RemoteAddr, m.getClientCount())
// Send initial toolset immediately
go m.sendToolset(newClient) // Send in a goroutine to avoid blocking registration?
}
// removeClient removes a client.
func (m *connectionManager) removeClient(r *http.Request) {
m.mu.Lock()
_, ok := m.clients[r]
if ok {
delete(m.clients, r)
log.Printf("Client disconnected: %s (Total: %d)", r.RemoteAddr, len(m.clients))
} else {
log.Printf("Attempted to remove already disconnected client: %s", r.RemoteAddr)
}
m.mu.Unlock()
}
// getClientCount returns the number of active clients.
func (m *connectionManager) getClientCount() int {
m.mu.RLock()
count := len(m.clients)
m.mu.RUnlock()
return count
}
// sendToolset sends the pre-encoded toolset to a specific client.
func (m *connectionManager) sendToolset(c *client) {
if c == nil {
return
}
log.Printf("Attempting to send toolset to client...")
_, err := fmt.Fprintf(c.writer, "event: tool_set\ndata: %s\n\n", string(m.toolSet))
if err != nil {
// This error often happens if the client disconnected before/during the write
log.Printf("Error sending toolset data to client: %v (client likely disconnected)", err)
// Optionally trigger removal here if possible, though context done in handler is primary mechanism
return
}
// Flush the data
c.flusher.Flush()
log.Println("Sent tool_set event and flushed.")
}
```
--------------------------------------------------------------------------------
/pkg/config/config.go:
--------------------------------------------------------------------------------
```go
package config
import (
"log"
"os"
)
// APIKeyLocation specifies where the API key is located for requests.
type APIKeyLocation string
const (
APIKeyLocationHeader APIKeyLocation = "header"
APIKeyLocationQuery APIKeyLocation = "query"
APIKeyLocationPath APIKeyLocation = "path"
APIKeyLocationCookie APIKeyLocation = "cookie"
// APIKeyLocationCookie APIKeyLocation = "cookie" // Add if needed
)
// Config holds the configuration for generating the MCP toolset.
type Config struct {
SpecPath string // Path or URL to the OpenAPI specification file.
// API Key details (optional, inferred from spec if possible)
APIKey string // The actual API key value.
APIKeyName string // Name of the header or query parameter for the API key (e.g., "X-API-Key", "api_key").
APIKeyLocation APIKeyLocation // Where the API key should be placed (header, query, path, or cookie).
APIKeyFromEnvVar string // Environment variable name to read the API key from.
// Filtering (optional)
IncludeTags []string // Only include operations with these tags.
ExcludeTags []string // Exclude operations with these tags.
IncludeOperations []string // Only include operations with these IDs.
ExcludeOperations []string // Exclude operations with these IDs.
// Overrides (optional)
ServerBaseURL string // Manually override the base URL for API calls, ignoring the spec's servers field.
DefaultToolName string // Name for the toolset if not specified in the spec's info section.
DefaultToolDesc string // Description for the toolset if not specified in the spec's info section.
// Server-side request modification
CustomHeaders string // Comma-separated list of headers (e.g., "Header1:Value1,Header2:Value2") to add to outgoing requests.
}
// GetAPIKey resolves the API key value, prioritizing the environment variable over the direct flag.
func (c *Config) GetAPIKey() string {
log.Println("GetAPIKey: Attempting to resolve API key...")
// 1. Check environment variable specified by --api-key-env
if c.APIKeyFromEnvVar != "" {
log.Printf("GetAPIKey: Checking environment variable specified by --api-key-env: %s", c.APIKeyFromEnvVar)
val := os.Getenv(c.APIKeyFromEnvVar)
if val != "" {
log.Printf("GetAPIKey: Found key in environment variable %s.", c.APIKeyFromEnvVar)
return val
}
log.Printf("GetAPIKey: Environment variable %s not found or empty.", c.APIKeyFromEnvVar)
} else {
log.Println("GetAPIKey: No --api-key-env variable specified.")
}
// 2. Check direct flag --api-key
if c.APIKey != "" {
log.Println("GetAPIKey: Found key provided directly via --api-key flag.")
return c.APIKey
}
// 3. No key found
log.Println("GetAPIKey: No API key found from config (env var or direct flag).")
return ""
}
```
--------------------------------------------------------------------------------
/pkg/config/config_test.go:
--------------------------------------------------------------------------------
```go
package config
import (
"os"
"testing"
)
func TestConfig_GetAPIKey(t *testing.T) {
tests := []struct {
name string
config Config
envKey string // Environment variable name to set
envValue string // Value to set for the env var
expectedKey string
cleanupEnv bool // Flag to indicate if env var needs cleanup
}{
{
name: "No key set",
config: Config{}, // Empty config
expectedKey: "",
},
{
name: "Direct key set only",
config: Config{
APIKey: "direct-key-123",
},
expectedKey: "direct-key-123",
},
{
name: "Env var set only",
config: Config{
APIKeyFromEnvVar: "TEST_API_KEY_ENV_ONLY",
},
envKey: "TEST_API_KEY_ENV_ONLY",
envValue: "env-key-456",
expectedKey: "env-key-456",
cleanupEnv: true,
},
{
name: "Both direct and env var set (env takes precedence)",
config: Config{
APIKey: "direct-key-789",
APIKeyFromEnvVar: "TEST_API_KEY_BOTH",
},
envKey: "TEST_API_KEY_BOTH",
envValue: "env-key-abc",
expectedKey: "env-key-abc",
cleanupEnv: true,
},
{
name: "Direct key set, env var specified but not set",
config: Config{
APIKey: "direct-key-xyz",
APIKeyFromEnvVar: "TEST_API_KEY_UNSET",
},
envKey: "TEST_API_KEY_UNSET", // Ensure this is not set
envValue: "",
expectedKey: "direct-key-xyz", // Should fall back to direct key
cleanupEnv: true, // Cleanup in case it was set previously
},
{
name: "Env var specified but empty string value",
config: Config{
APIKeyFromEnvVar: "TEST_API_KEY_EMPTY",
},
envKey: "TEST_API_KEY_EMPTY",
envValue: "", // Explicitly set to empty string
expectedKey: "", // Empty env var should result in empty key
cleanupEnv: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Set environment variable if needed for this test case
if tc.envKey != "" {
originalValue, wasSet := os.LookupEnv(tc.envKey)
err := os.Setenv(tc.envKey, tc.envValue)
if err != nil {
t.Fatalf("Failed to set environment variable %s: %v", tc.envKey, err)
}
// Schedule cleanup
if tc.cleanupEnv {
t.Cleanup(func() {
if wasSet {
os.Setenv(tc.envKey, originalValue)
} else {
os.Unsetenv(tc.envKey)
}
})
}
} else {
// Ensure env var is unset if tc.envKey is empty (for tests like "Direct key set only")
// This prevents interference from previous tests if not cleaned up properly.
os.Unsetenv(tc.config.APIKeyFromEnvVar) // Unset based on config field if relevant
}
// Call the method under test
actualKey := tc.config.GetAPIKey()
// Assert the result
if actualKey != tc.expectedKey {
t.Errorf("Expected API key %q, but got %q", tc.expectedKey, actualKey)
}
})
}
}
```
--------------------------------------------------------------------------------
/pkg/mcp/types.go:
--------------------------------------------------------------------------------
```go
package mcp
// Based on the MCP specification: https://modelcontextprotocol.io/spec/
// ParameterDetail describes a single parameter for an operation.
type ParameterDetail struct {
Name string `json:"name"`
In string `json:"in"` // Location (query, header, path, cookie)
// Add other details if needed, e.g., required, type
}
// OperationDetail holds the necessary information to execute a specific API operation.
type OperationDetail struct {
Method string `json:"method"`
Path string `json:"path"` // Path template (e.g., /users/{id})
BaseURL string `json:"baseUrl"`
Parameters []ParameterDetail `json:"parameters,omitempty"`
// Add RequestBody schema if needed
}
// ToolSet represents the collection of tools provided by an MCP server.
type ToolSet struct {
MCPVersion string `json:"mcp_version"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
// Auth *AuthInfo `json:"auth,omitempty"` // Removed authentication info
Tools []Tool `json:"tools"`
// Operations maps Tool.Name (operationId) to its execution details.
// This is internal to the server and not part of the standard MCP JSON response.
Operations map[string]OperationDetail `json:"-"` // Use json:"-" to exclude from JSON
// Internal fields for server-side auth handling (not exposed in JSON)
apiKeyName string // e.g., "key", "X-API-Key"
apiKeyIn string // e.g., "query", "header"
}
// SetAPIKeyDetails allows the parser to set internal API key info.
func (ts *ToolSet) SetAPIKeyDetails(name, in string) {
ts.apiKeyName = name
ts.apiKeyIn = in
}
// GetAPIKeyDetails allows the server to retrieve internal API key info.
// We might need this later when making the request.
func (ts *ToolSet) GetAPIKeyDetails() (name, in string) {
return ts.apiKeyName, ts.apiKeyIn
}
// Tool represents a single function or capability exposed via MCP.
type Tool struct {
Name string `json:"name"` // Corresponds to OpenAPI operationId or generated name
Description string `json:"description,omitempty"`
InputSchema Schema `json:"inputSchema"` // Renamed from Parameters, consolidate parameters/body here
// Entrypoint string `json:"entrypoint"` // Removed for simplicity, schema should contain enough info?
// RequestBody RequestBody `json:"request_body,omitempty"` // Removed, info should be part of InputSchema
// HTTPMethod string `json:"http_method"` // Removed for simplicity
// TODO: Add Response handling if needed by spec/client
}
// RequestBody describes the expected request body for a tool.
// This might become redundant if all info is in InputSchema.
// Keeping it for now as the parser might still use it internally.
type RequestBody struct {
Description string `json:"description,omitempty"`
Required bool `json:"required,omitempty"`
Content map[string]Schema `json:"content"` // Keyed by media type (e.g., "application/json")
}
// Schema defines the structure and constraints of data (parameters or request/response bodies).
// This mirrors a subset of JSON Schema properties.
type Schema struct {
Type string `json:"type,omitempty"` // e.g., "object", "string", "integer", "array"
Description string `json:"description,omitempty"`
Properties map[string]Schema `json:"properties,omitempty"` // For type "object"
Required []string `json:"required,omitempty"` // For type "object"
Items *Schema `json:"items,omitempty"` // For type "array"
Format string `json:"format,omitempty"` // e.g., "int32", "date-time"
Enum []interface{} `json:"enum,omitempty"`
// Add other relevant JSON Schema fields as needed (e.g., minimum, maximum, pattern)
}
```
--------------------------------------------------------------------------------
/cmd/openapi-mcp/main.go:
--------------------------------------------------------------------------------
```go
package main
import (
"flag"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/ckanthony/openapi-mcp/pkg/config"
"github.com/ckanthony/openapi-mcp/pkg/parser"
"github.com/ckanthony/openapi-mcp/pkg/server"
"github.com/joho/godotenv"
)
// stringSliceFlag allows defining a flag that can be repeated to collect multiple string values.
type stringSliceFlag []string
func (i *stringSliceFlag) String() string {
return strings.Join(*i, ", ")
}
func (i *stringSliceFlag) Set(value string) error {
*i = append(*i, value)
return nil
}
func main() {
// --- Flag Definitions First ---
// Define specPath early so we can use it for .env loading
specPath := flag.String("spec", "", "Path or URL to the OpenAPI specification file (required)")
port := flag.Int("port", 8080, "Port to run the MCP server on")
apiKey := flag.String("api-key", "", "Direct API key value")
apiKeyEnv := flag.String("api-key-env", "", "Environment variable name containing the API key")
apiKeyName := flag.String("api-key-name", "", "Name of the API key header, query parameter, path parameter, or cookie (required if api-key or api-key-env is set)")
apiKeyLocStr := flag.String("api-key-loc", "", "Location of API key: 'header', 'query', 'path', or 'cookie' (required if api-key or api-key-env is set)")
var includeTags stringSliceFlag
flag.Var(&includeTags, "include-tag", "Tag to include (can be repeated)")
var excludeTags stringSliceFlag
flag.Var(&excludeTags, "exclude-tag", "Tag to exclude (can be repeated)")
var includeOps stringSliceFlag
flag.Var(&includeOps, "include-op", "Operation ID to include (can be repeated)")
var excludeOps stringSliceFlag
flag.Var(&excludeOps, "exclude-op", "Operation ID to exclude (can be repeated)")
serverBaseURL := flag.String("base-url", "", "Manually override the server base URL")
defaultToolName := flag.String("name", "OpenAPI-MCP Tools", "Default name for the toolset")
defaultToolDesc := flag.String("desc", "Tools generated from OpenAPI spec", "Default description for the toolset")
// Parse flags *after* defining them all
flag.Parse()
// --- Load .env after parsing flags ---
if *specPath != "" && !strings.HasPrefix(*specPath, "http://") && !strings.HasPrefix(*specPath, "https://") {
envPath := filepath.Join(filepath.Dir(*specPath), ".env")
log.Printf("Attempting to load .env file from spec directory: %s", envPath)
err := godotenv.Load(envPath)
if err != nil {
// It's okay if the file doesn't exist, log other errors.
if !os.IsNotExist(err) {
log.Printf("Warning: Error loading .env file from %s: %v", envPath, err)
} else {
log.Printf("Info: No .env file found at %s, proceeding without it.", envPath)
}
} else {
log.Printf("Successfully loaded .env file from %s", envPath)
}
} else if *specPath == "" {
log.Println("Skipping .env load because --spec is missing.")
} else {
log.Println("Skipping .env load because spec path appears to be a URL.")
}
// --- Read REQUEST_HEADERS env var ---
customHeadersEnv := os.Getenv("REQUEST_HEADERS")
if customHeadersEnv != "" {
log.Printf("Found REQUEST_HEADERS environment variable: %s", customHeadersEnv)
}
// --- Input Validation ---
if *specPath == "" {
log.Println("Error: --spec flag is required.")
flag.Usage()
os.Exit(1)
}
var apiKeyLocation config.APIKeyLocation
if *apiKeyLocStr != "" {
switch *apiKeyLocStr {
case string(config.APIKeyLocationHeader):
apiKeyLocation = config.APIKeyLocationHeader
case string(config.APIKeyLocationQuery):
apiKeyLocation = config.APIKeyLocationQuery
case string(config.APIKeyLocationPath):
apiKeyLocation = config.APIKeyLocationPath
case string(config.APIKeyLocationCookie):
apiKeyLocation = config.APIKeyLocationCookie
default:
log.Fatalf("Error: invalid --api-key-loc value: %s. Must be 'header', 'query', 'path', or 'cookie'.", *apiKeyLocStr)
}
}
// --- Configuration Population ---
cfg := &config.Config{
SpecPath: *specPath,
APIKey: *apiKey,
APIKeyFromEnvVar: *apiKeyEnv,
APIKeyName: *apiKeyName,
APIKeyLocation: apiKeyLocation,
IncludeTags: includeTags,
ExcludeTags: excludeTags,
IncludeOperations: includeOps,
ExcludeOperations: excludeOps,
ServerBaseURL: *serverBaseURL,
DefaultToolName: *defaultToolName,
DefaultToolDesc: *defaultToolDesc,
CustomHeaders: customHeadersEnv,
}
log.Printf("Configuration loaded: %+v\n", cfg)
log.Println("API Key (resolved):", cfg.GetAPIKey())
// --- Call Parser ---
specDoc, version, err := parser.LoadSwagger(cfg.SpecPath)
if err != nil {
log.Fatalf("Failed to load OpenAPI/Swagger spec: %v", err)
}
log.Printf("Spec type %s loaded successfully from %s.\n", version, cfg.SpecPath)
toolSet, err := parser.GenerateToolSet(specDoc, version, cfg)
if err != nil {
log.Fatalf("Failed to generate MCP toolset: %v", err)
}
log.Printf("MCP toolset generated with %d tools.\n", len(toolSet.Tools))
// --- Start Server ---
addr := fmt.Sprintf(":%d", *port)
log.Printf("Starting MCP server on %s...", addr)
err = server.ServeMCP(addr, toolSet, cfg) // Pass cfg to ServeMCP
if err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}
```
--------------------------------------------------------------------------------
/pkg/server/manager_test.go:
--------------------------------------------------------------------------------
```go
package server
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// mockResponseWriter implements http.ResponseWriter and http.Flusher for testing SSE.
type mockResponseWriter struct {
*httptest.ResponseRecorder // Embed to get ResponseWriter behavior
flushed bool // Track if Flush was called
forceError error // Added for testing error handling
}
// NewMockResponseWriter creates a new mock response writer.
func NewMockResponseWriter() *mockResponseWriter {
return &mockResponseWriter{
ResponseRecorder: httptest.NewRecorder(),
}
}
// Write method for mockResponseWriter (ensure it handles forceError)
func (m *mockResponseWriter) Write(p []byte) (int, error) {
if m.forceError != nil {
return 0, m.forceError
}
return m.ResponseRecorder.Write(p) // Use embedded writer
}
// Flush method for mockResponseWriter
func (m *mockResponseWriter) Flush() {
if m.forceError != nil { // Don't flush if write failed
return
}
m.flushed = true
// We don't actually flush the embedded recorder in this mock
}
// --- Simple Mock Flusher ---
type mockFlusher struct {
flushed bool
}
func (f *mockFlusher) Flush() {
f.flushed = true
}
// --- End Mock Flusher ---
func TestManager_Run_Stop(t *testing.T) {
// Basic test to ensure the manager can start and stop.
// More comprehensive tests involving resource handling would be needed.
// Dummy tool set JSON for initialization
dummyToolSet := []byte(`{"tools": []}`)
m := newConnectionManager(dummyToolSet)
// Basic run/stop test - might need refinement depending on Run() implementation
// We need a way to observe if Run() is actually doing something or blocking.
// For now, just test start and stop signals.
stopChan := make(chan struct{})
go func() {
// Need to figure out what Run expects or does.
// If Run is intended to block, this test structure needs adjustment.
// For now, assume Run might just start background tasks and doesn't block indefinitely.
// If it expects specific input or state, that needs mocking.
// Placeholder: Simulate Run behavior relevant to Stop.
// If Run blocks, this goroutine might hang.
<-stopChan // Simulate Run blocking until Stop is called
}()
// Simulate adding a client to test remove logic
req := httptest.NewRequest(http.MethodGet, "/events", nil)
mrr := NewMockResponseWriter() // Use the mock
m.addClient(req, mrr, mrr) // Pass the mock which implements both interfaces
if m.getClientCount() != 1 {
t.Errorf("Expected 1 client after add, got %d", m.getClientCount())
}
time.Sleep(100 * time.Millisecond) // Give time for potential background tasks
// Test removing the client
m.removeClient(req)
if m.getClientCount() != 0 {
t.Errorf("Expected 0 clients after remove, got %d", m.getClientCount())
}
// Simulate stopping the manager
close(stopChan) // Signal the placeholder Run goroutine to exit
// Need a way to verify Stop() worked. If it closes internal channels,
// we could potentially check that. Without knowing Stop's implementation,
// this is a basic check.
// Maybe add a dedicated Stop() method to connectionManager if Run blocks?
// Or check internal state if possible.
// Example: If Stop closes a known channel:
// select {
// case <-m.internalStopChan: // Assuming internalStopChan exists and is closed by Stop()
// // Expected behavior
// case <-time.After(1 * time.Second):
// t.Fatal("Manager did not signal stop within the expected time")
// }
}
// Define a dummy non-flusher if needed
type nonFlusher struct {
http.ResponseWriter
}
func (nf *nonFlusher) Flush() { /* Do nothing */ }
func TestManager_AddRemoveClient(t *testing.T) {
dummyToolSet := []byte(`{"tools": []}`)
m := newConnectionManager(dummyToolSet)
req1 := httptest.NewRequest(http.MethodGet, "/events?id=1", nil)
mrr1 := NewMockResponseWriter() // Use mock
req2 := httptest.NewRequest(http.MethodGet, "/events?id=2", nil)
mrr2 := NewMockResponseWriter() // Use mock
m.addClient(req1, mrr1, mrr1) // Pass mock
if count := m.getClientCount(); count != 1 {
t.Errorf("Expected 1 client, got %d", count)
}
m.addClient(req2, mrr2, mrr2) // Pass mock
if count := m.getClientCount(); count != 2 {
t.Errorf("Expected 2 clients, got %d", count)
}
m.removeClient(req1)
if count := m.getClientCount(); count != 1 {
t.Errorf("Expected 1 client after removing req1, got %d", count)
}
// Ensure the correct client was removed
m.mu.RLock()
_, exists := m.clients[req1]
m.mu.RUnlock()
if exists {
t.Error("req1 should have been removed but still exists in map")
}
m.removeClient(req2)
if count := m.getClientCount(); count != 0 {
t.Errorf("Expected 0 clients after removing req2, got %d", count)
}
// Test removing non-existent client
m.removeClient(req1) // Remove again
if count := m.getClientCount(); count != 0 {
t.Errorf("Expected 0 clients after removing non-existent, got %d", count)
}
}
// Test for sendToolset needs a way to capture output sent to the client.
// httptest.ResponseRecorder can capture the body.
func TestManager_SendToolset(t *testing.T) {
toolSetData := `{"tools": ["tool1", "tool2"]}`
m := newConnectionManager([]byte(toolSetData))
mrr := NewMockResponseWriter() // Use mock
// Directly create a client struct instance for testing sendToolset specifically
// Note: This bypasses addClient logic for focused testing of sendToolset.
testClient := &client{writer: mrr, flusher: mrr} // Use mock for both
m.sendToolset(testClient)
// Use strings.TrimSpace for comparison to avoid issues with subtle whitespace differences
// Escape inner quotes
expectedOutputPattern := "event: tool_set\ndata: {\"tools\": [\"tool1\", \"tool2\"]}\n\n"
actualOutput := mrr.Body.String()
if strings.TrimSpace(actualOutput) != strings.TrimSpace(expectedOutputPattern) {
// Use %q to quote strings, making whitespace visible
t.Errorf("Expected toolset output matching pattern %q, got %q", expectedOutputPattern, actualOutput)
}
if !mrr.flushed { // Check if flush was called
t.Error("Expected Flush() to be called on the writer, but it wasn't")
}
// Test sending to nil client
m.sendToolset(nil) // Should not panic
}
// Test case for when writing the toolset fails (e.g., client disconnected)
func TestConnectionManager_SendToolset_WriteError(t *testing.T) {
mgr := newConnectionManager([]byte(`{"tool":"set"}`))
// Create a mock writer that always returns an error
mockWriter := &mockResponseWriter{
ResponseRecorder: httptest.NewRecorder(), // Initialize embedded recorder
forceError: fmt.Errorf("simulated write error"),
}
mockFlusher := &mockFlusher{}
// Create a client with the erroring writer
mockClient := &client{
writer: mockWriter,
flusher: mockFlusher,
}
// Call sendToolset - we expect it to log the error and return early
// We don't easily assert the log, but we run it for coverage.
mgr.sendToolset(mockClient)
// Assert that Flush was NOT called because the function should have returned early
assert.False(t, mockFlusher.flushed, "Flush should not be called when Write fails")
// Assert that Write was attempted (optional, depends on mock capabilities)
// If mockResponseWriter tracks calls, assert Write was called once.
}
```
--------------------------------------------------------------------------------
/pkg/parser/parser_test.go:
--------------------------------------------------------------------------------
```go
package parser
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sort"
"strings"
"testing"
"github.com/getkin/kin-openapi/openapi3"
"github.com/go-openapi/spec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ckanthony/openapi-mcp/pkg/config"
"github.com/ckanthony/openapi-mcp/pkg/mcp"
)
// Minimal valid OpenAPI V3 spec (JSON string)
const minimalV3SpecJSON = `{
"openapi": "3.0.0",
"info": {
"title": "Minimal V3 API",
"version": "1.0.0"
},
"paths": {
"/ping": {
"get": {
"summary": "Simple ping endpoint",
"operationId": "getPing",
"responses": {
"200": {
"description": "OK"
}
}
}
}
}
}`
// Minimal valid Swagger V2 spec (JSON string)
const minimalV2SpecJSON = `{
"swagger": "2.0",
"info": {
"title": "Minimal V2 API",
"version": "1.0.0"
},
"paths": {
"/health": {
"get": {
"summary": "Simple health check",
"operationId": "getHealth",
"produces": ["application/json"],
"responses": {
"200": {
"description": "OK"
}
}
}
}
}
}`
// Malformed JSON
const malformedJSON = `{
"openapi": "3.0.0",
"info": {
"title": "Missing Version",
}
}`
// JSON without version key
const noVersionKeyJSON = `{
"info": {
"title": "No Version Key",
"version": "1.0"
},
"paths": {}
}`
// V3 Spec with tags and multiple operations
const complexV3SpecJSON = `{
"openapi": "3.0.0",
"info": {
"title": "Complex V3 API",
"version": "1.1.0"
},
"tags": [
{"name": "tag1", "description": "First Tag"},
{"name": "tag2", "description": "Second Tag"}
],
"paths": {
"/items": {
"get": {
"summary": "List Items",
"operationId": "listItems",
"tags": ["tag1"],
"responses": {"200": {"description": "OK"}}
},
"post": {
"summary": "Create Item",
"operationId": "createItem",
"tags": ["tag1", "tag2"],
"responses": {"201": {"description": "Created"}}
}
},
"/users": {
"get": {
"summary": "List Users",
"operationId": "listUsers",
"tags": ["tag2"],
"responses": {"200": {"description": "OK"}}
}
},
"/ping": {
"get": {
"summary": "Simple ping",
"operationId": "getPing",
"responses": {"200": {"description": "OK"}}
}
}
}
}`
// V2 Spec with tags and multiple operations
const complexV2SpecJSON = `{
"swagger": "2.0",
"info": {
"title": "Complex V2 API",
"version": "1.1.0"
},
"tags": [
{"name": "tag1", "description": "First Tag"},
{"name": "tag2", "description": "Second Tag"}
],
"paths": {
"/items": {
"get": {
"summary": "List Items",
"operationId": "listItems",
"tags": ["tag1"],
"produces": ["application/json"],
"responses": {"200": {"description": "OK"}}
},
"post": {
"summary": "Create Item",
"operationId": "createItem",
"tags": ["tag1", "tag2"],
"produces": ["application/json"],
"responses": {"201": {"description": "Created"}}
}
},
"/users": {
"get": {
"summary": "List Users",
"operationId": "listUsers",
"tags": ["tag2"],
"produces": ["application/json"],
"responses": {"200": {"description": "OK"}}
}
},
"/ping": {
"get": {
"summary": "Simple ping",
"operationId": "getPing",
"produces": ["application/json"],
"responses": {"200": {"description": "OK"}}
}
}
}
}`
// V3 Spec with various parameter types and request body
const paramsV3SpecJSON = `{
"openapi": "3.0.0",
"info": {
"title": "Params V3 API",
"version": "1.0.0"
},
"paths": {
"/test/{path_param}": {
"post": {
"summary": "Test various params",
"operationId": "testParams",
"parameters": [
{
"name": "path_param",
"in": "path",
"required": true,
"schema": {"type": "integer", "format": "int32"}
},
{
"name": "query_param",
"in": "query",
"required": true,
"schema": {"type": "string", "enum": ["A", "B"]}
},
{
"name": "optional_query",
"in": "query",
"schema": {"type": "boolean"}
},
{
"name": "X-Header-Param",
"in": "header",
"required": true,
"schema": {"type": "string"}
},
{
"name": "CookieParam",
"in": "cookie",
"schema": {"type": "number"}
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"id": {"type": "string"},
"value": {"type": "number"}
},
"required": ["id"]
}
}
}
},
"responses": {
"200": {"description": "OK"}
}
}
}
}
}`
// V2 Spec with various parameter types and $ref
const paramsV2SpecJSON = `{
"swagger": "2.0",
"info": {
"title": "Params V2 API",
"version": "1.0.0"
},
"definitions": {
"Item": {
"type": "object",
"properties": {
"id": {"type": "string", "format": "uuid"},
"name": {"type": "string"}
},
"required": ["id"]
}
},
"paths": {
"/test/{path_id}": {
"put": {
"summary": "Test V2 params and ref",
"operationId": "testV2Params",
"consumes": ["application/json"],
"produces": ["application/json"],
"parameters": [
{
"name": "path_id",
"in": "path",
"required": true,
"type": "string"
},
{
"name": "query_flag",
"in": "query",
"type": "boolean",
"required": true
},
{
"name": "X-Request-ID",
"in": "header",
"type": "string",
"required": false
},
{
"name": "body_param",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/Item"
}
}
],
"responses": {
"200": {"description": "OK"}
}
}
}
}
}`
// V3 Spec with array types
const arraysV3SpecJSON = `{
"openapi": "3.0.0",
"info": {"title": "Arrays V3 API", "version": "1.0.0"},
"paths": {
"/process": {
"post": {
"summary": "Process arrays",
"operationId": "processArrays",
"parameters": [
{
"name": "string_array_query",
"in": "query",
"schema": {
"type": "array",
"items": {"type": "string"}
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"int_array_body": {
"type": "array",
"items": {"type": "integer", "format": "int64"}
}
}
}
}
}
},
"responses": {"200": {"description": "OK"}}
}
}
}
}`
// V2 Spec with array types
const arraysV2SpecJSON = `{
"swagger": "2.0",
"info": {"title": "Arrays V2 API", "version": "1.0.0"},
"paths": {
"/process": {
"get": {
"summary": "Get arrays",
"operationId": "getArrays",
"parameters": [
{
"name": "string_array_query",
"in": "query",
"type": "array",
"items": {"type": "string"},
"collectionFormat": "csv"
},
{
"name": "int_array_form",
"in": "formData",
"type": "array",
"items": {"type": "integer", "format": "int32"}
}
],
"responses": {"200": {"description": "OK"}}
}
}
}
}`
// V2 Spec with file parameter
const fileV2SpecJSON = `{
"swagger": "2.0",
"info": {"title": "File V2 API", "version": "1.0.0"},
"paths": {
"/upload": {
"post": {
"summary": "Upload file",
"operationId": "uploadFile",
"consumes": ["multipart/form-data"],
"parameters": [
{
"name": "description",
"in": "formData",
"type": "string"
},
{
"name": "file_upload",
"in": "formData",
"required": true,
"type": "file"
}
],
"responses": {"200": {"description": "OK"}}
}
}
}
}`
func TestLoadSwagger(t *testing.T) {
tests := []struct {
name string
content string
fileName string
expectError bool
expectVersion string
containsError string // Substring to check in error message
isURLTest bool // Flag to indicate if the test uses a URL
handler http.HandlerFunc // Handler for mock HTTP server
}{
{
name: "Valid V3 JSON file",
content: minimalV3SpecJSON,
fileName: "valid_v3.json",
expectError: false,
expectVersion: VersionV3,
},
{
name: "Valid V2 JSON file",
content: minimalV2SpecJSON,
fileName: "valid_v2.json",
expectError: false,
expectVersion: VersionV2,
},
{
name: "Malformed JSON file",
content: malformedJSON,
fileName: "malformed.json",
expectError: true,
containsError: "failed to parse JSON",
},
{
name: "No version key JSON file",
content: noVersionKeyJSON,
fileName: "no_version.json",
expectError: true,
containsError: "missing 'openapi' or 'swagger' key",
},
{
name: "Non-existent file",
content: "", // No content needed
fileName: "non_existent.json",
expectError: true,
containsError: "failed reading file path",
},
// --- URL Tests ---
{
name: "Valid V3 JSON URL",
content: minimalV3SpecJSON,
expectError: false,
expectVersion: VersionV3,
isURLTest: true,
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(minimalV3SpecJSON))
},
},
{
name: "Valid V2 JSON URL",
content: minimalV2SpecJSON, // Content used by handler
expectError: false,
expectVersion: VersionV2,
isURLTest: true,
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(minimalV2SpecJSON))
},
},
{
name: "Malformed JSON URL",
content: malformedJSON,
expectError: true,
containsError: "failed to parse JSON",
isURLTest: true,
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(malformedJSON))
},
},
{
name: "No version key JSON URL",
content: noVersionKeyJSON,
expectError: true,
containsError: "missing 'openapi' or 'swagger' key",
isURLTest: true,
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(noVersionKeyJSON))
},
},
{
name: "URL Not Found (404)",
expectError: true,
containsError: "failed to fetch URL", // Check for fetch error
isURLTest: true,
handler: func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r) // Use standard http.NotFound
},
},
{
name: "URL Internal Server Error (500)",
expectError: true,
containsError: "failed to fetch URL", // Check for fetch error
isURLTest: true,
handler: func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal Server Error", http.StatusInternalServerError) // Use standard http.Error
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var location string
var server *httptest.Server // Declare server variable
if tc.isURLTest {
// Set up mock HTTP server
require.NotNil(t, tc.handler, "URL test case must provide a handler")
server = httptest.NewServer(tc.handler)
defer server.Close()
location = server.URL // Use the mock server's URL
} else {
// Existing file path logic
tempDir := t.TempDir()
filePath := filepath.Join(tempDir, tc.fileName)
// Create the file only if content is provided
if tc.content != "" {
err := os.WriteFile(filePath, []byte(tc.content), 0644)
require.NoError(t, err, "Failed to write temp spec file")
}
// For the non-existent file case, ensure it really doesn't exist
if tc.name == "Non-existent file" {
filePath = filepath.Join(tempDir, "definitely_not_here.json")
}
location = filePath
}
specDoc, version, err := LoadSwagger(location)
if tc.expectError {
assert.Error(t, err)
if tc.containsError != "" {
assert.True(t, strings.Contains(err.Error(), tc.containsError),
"Error message %q does not contain expected substring %q", err.Error(), tc.containsError)
}
assert.Nil(t, specDoc)
assert.Empty(t, version)
} else {
assert.NoError(t, err)
assert.NotNil(t, specDoc)
assert.Equal(t, tc.expectVersion, version)
// Basic type assertion based on expected version
if version == VersionV3 {
assert.IsType(t, &openapi3.T{}, specDoc) // Expecting a pointer
} else if version == VersionV2 {
assert.IsType(t, &spec.Swagger{}, specDoc) // Expecting a pointer
}
}
})
}
}
// TODO: Add tests for GenerateToolSet
func TestGenerateToolSet(t *testing.T) {
// --- Load Specs Once ---
// Load V3 spec (error checked in TestLoadSwagger)
tempDirV3 := t.TempDir()
filePathV3 := filepath.Join(tempDirV3, "minimal_v3.json")
err := os.WriteFile(filePathV3, []byte(minimalV3SpecJSON), 0644)
require.NoError(t, err)
docV3, versionV3, err := LoadSwagger(filePathV3)
require.NoError(t, err)
require.Equal(t, VersionV3, versionV3)
specV3 := docV3.(*openapi3.T)
// Load V2 spec (error checked in TestLoadSwagger)
tempDirV2 := t.TempDir()
filePathV2 := filepath.Join(tempDirV2, "minimal_v2.json")
err = os.WriteFile(filePathV2, []byte(minimalV2SpecJSON), 0644)
require.NoError(t, err)
docV2, versionV2, err := LoadSwagger(filePathV2)
require.NoError(t, err)
require.Equal(t, VersionV2, versionV2)
specV2 := docV2.(*spec.Swagger)
// Load Complex V3 spec
tempDirComplexV3 := t.TempDir()
filePathComplexV3 := filepath.Join(tempDirComplexV3, "complex_v3.json")
err = os.WriteFile(filePathComplexV3, []byte(complexV3SpecJSON), 0644)
require.NoError(t, err)
docComplexV3, versionComplexV3, err := LoadSwagger(filePathComplexV3)
require.NoError(t, err)
require.Equal(t, VersionV3, versionComplexV3)
specComplexV3 := docComplexV3.(*openapi3.T)
// Load Complex V2 spec
tempDirComplexV2 := t.TempDir()
filePathComplexV2 := filepath.Join(tempDirComplexV2, "complex_v2.json")
err = os.WriteFile(filePathComplexV2, []byte(complexV2SpecJSON), 0644)
require.NoError(t, err)
docComplexV2, versionComplexV2, err := LoadSwagger(filePathComplexV2)
require.NoError(t, err)
require.Equal(t, VersionV2, versionComplexV2)
specComplexV2 := docComplexV2.(*spec.Swagger)
// Load Params V3 spec
tempDirParamsV3 := t.TempDir()
filePathParamsV3 := filepath.Join(tempDirParamsV3, "params_v3.json")
err = os.WriteFile(filePathParamsV3, []byte(paramsV3SpecJSON), 0644)
require.NoError(t, err)
docParamsV3, versionParamsV3, err := LoadSwagger(filePathParamsV3)
require.NoError(t, err)
require.Equal(t, VersionV3, versionParamsV3)
specParamsV3 := docParamsV3.(*openapi3.T)
// Load Params V2 spec
tempDirParamsV2 := t.TempDir()
filePathParamsV2 := filepath.Join(tempDirParamsV2, "params_v2.json")
err = os.WriteFile(filePathParamsV2, []byte(paramsV2SpecJSON), 0644)
require.NoError(t, err)
docParamsV2, versionParamsV2, err := LoadSwagger(filePathParamsV2)
require.NoError(t, err)
require.Equal(t, VersionV2, versionParamsV2)
specParamsV2 := docParamsV2.(*spec.Swagger)
// Load Arrays V3 spec
tempDirArraysV3 := t.TempDir()
filePathArraysV3 := filepath.Join(tempDirArraysV3, "arrays_v3.json")
err = os.WriteFile(filePathArraysV3, []byte(arraysV3SpecJSON), 0644)
require.NoError(t, err)
docArraysV3, versionArraysV3, err := LoadSwagger(filePathArraysV3)
require.NoError(t, err)
require.Equal(t, VersionV3, versionArraysV3)
specArraysV3 := docArraysV3.(*openapi3.T)
// Load Arrays V2 spec
tempDirArraysV2 := t.TempDir()
filePathArraysV2 := filepath.Join(tempDirArraysV2, "arrays_v2.json")
err = os.WriteFile(filePathArraysV2, []byte(arraysV2SpecJSON), 0644)
require.NoError(t, err)
docArraysV2, versionArraysV2, err := LoadSwagger(filePathArraysV2)
require.NoError(t, err)
require.Equal(t, VersionV2, versionArraysV2)
specArraysV2 := docArraysV2.(*spec.Swagger)
// Load File V2 spec
tempDirFileV2 := t.TempDir()
filePathFileV2 := filepath.Join(tempDirFileV2, "file_v2.json")
err = os.WriteFile(filePathFileV2, []byte(fileV2SpecJSON), 0644)
require.NoError(t, err)
docFileV2, versionFileV2, err := LoadSwagger(filePathFileV2)
require.NoError(t, err)
require.Equal(t, VersionV2, versionFileV2)
specFileV2 := docFileV2.(*spec.Swagger)
// --- Test Cases ---
tests := []struct {
name string
spec interface{}
version string
cfg *config.Config
expectError bool
expectedToolSet *mcp.ToolSet // Define expected basic structure
}{
{
name: "V3 Minimal Spec - Default Config",
spec: specV3,
version: VersionV3,
cfg: &config.Config{}, // Default empty config
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Minimal V3 API",
Description: "",
Tools: []mcp.Tool{
{
Name: "getPing",
Description: "Note: The API key is handled by the server, no need to provide it. Simple ping endpoint",
InputSchema: mcp.Schema{Type: "object", Properties: map[string]mcp.Schema{}, Required: []string{}},
},
},
Operations: map[string]mcp.OperationDetail{
"getPing": {
Method: "GET",
Path: "/ping",
BaseURL: "", // No server defined
Parameters: []mcp.ParameterDetail{}, // Expect empty slice
},
},
},
},
{
name: "V2 Minimal Spec - Default Config",
spec: specV2,
version: VersionV2,
cfg: &config.Config{}, // Default empty config
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Minimal V2 API",
Description: "",
Tools: []mcp.Tool{
{
Name: "getHealth",
Description: "Note: The API key is handled by the server, no need to provide it. Simple health check",
InputSchema: mcp.Schema{Type: "object", Properties: map[string]mcp.Schema{}, Required: []string{}},
},
},
Operations: map[string]mcp.OperationDetail{
"getHealth": {
Method: "GET",
Path: "/health",
BaseURL: "", // No host/schemes/basePath
Parameters: []mcp.ParameterDetail{}, // Expect empty slice
},
},
},
},
{
name: "V3 Minimal Spec - Config Overrides",
spec: specV3,
version: VersionV3,
cfg: &config.Config{
ServerBaseURL: "http://override.com/v1",
DefaultToolName: "Override Name",
DefaultToolDesc: "Override Desc",
},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Override Name", // Uses override
Description: "Override Desc", // Uses override
Tools: []mcp.Tool{
{
Name: "getPing",
Description: "Note: The API key is handled by the server, no need to provide it. Simple ping endpoint",
InputSchema: mcp.Schema{Type: "object", Properties: map[string]mcp.Schema{}, Required: []string{}},
},
},
Operations: map[string]mcp.OperationDetail{
"getPing": {
Method: "GET",
Path: "/ping",
BaseURL: "http://override.com/v1", // Uses override
Parameters: []mcp.ParameterDetail{}, // Expect empty slice
},
},
},
},
// --- Filtering Tests (Using Complex Specs) ---
{
name: "V3 Complex - Include Tag1",
spec: specComplexV3,
version: VersionV3,
cfg: &config.Config{IncludeTags: []string{"tag1"}},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Complex V3 API", Description: "", // Should only include listItems and createItem
Tools: []mcp.Tool{{Name: "listItems"}, {Name: "createItem"}}, // Simplified for length check
Operations: map[string]mcp.OperationDetail{"listItems": {}, "createItem": {}}, // Simplified for length check
},
},
{
name: "V3 Complex - Exclude Tag2",
spec: specComplexV3,
version: VersionV3,
cfg: &config.Config{ExcludeTags: []string{"tag2"}},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Complex V3 API", Description: "", // Should include listItems and getPing
Tools: []mcp.Tool{{Name: "listItems"}, {Name: "getPing"}}, // Simplified for length check
Operations: map[string]mcp.OperationDetail{"listItems": {}, "getPing": {}}, // Simplified for length check
},
},
{
name: "V3 Complex - Include Operation listItems",
spec: specComplexV3,
version: VersionV3,
cfg: &config.Config{IncludeOperations: []string{"listItems"}},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Complex V3 API", Description: "", // Should include only listItems
Tools: []mcp.Tool{{Name: "listItems"}}, // Simplified for length check
Operations: map[string]mcp.OperationDetail{"listItems": {}}, // Simplified for length check
},
},
{
name: "V3 Complex - Exclude Operation createItem, getPing",
spec: specComplexV3,
version: VersionV3,
cfg: &config.Config{ExcludeOperations: []string{"createItem", "getPing"}},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Complex V3 API", Description: "", // Should include listItems and listUsers
Tools: []mcp.Tool{{Name: "listItems"}, {Name: "listUsers"}}, // Simplified for length check
Operations: map[string]mcp.OperationDetail{"listItems": {}, "listUsers": {}}, // Simplified for length check
},
},
{
name: "V2 Complex - Include Tag1",
spec: specComplexV2,
version: VersionV2,
cfg: &config.Config{IncludeTags: []string{"tag1"}},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Complex V2 API", Description: "", // Should only include listItems and createItem
Tools: []mcp.Tool{{Name: "listItems"}, {Name: "createItem"}}, // Simplified for length check
Operations: map[string]mcp.OperationDetail{"listItems": {}, "createItem": {}}, // Simplified for length check
},
},
{
name: "V2 Complex - Exclude Tag2",
spec: specComplexV2,
version: VersionV2,
cfg: &config.Config{ExcludeTags: []string{"tag2"}},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Complex V2 API", Description: "", // Should include listItems and getPing
Tools: []mcp.Tool{{Name: "listItems"}, {Name: "getPing"}}, // Simplified for length check
Operations: map[string]mcp.OperationDetail{"listItems": {}, "getPing": {}}, // Simplified for length check
},
},
// --- Parameter/Schema Tests ---
{
name: "V3 Params and Request Body",
spec: specParamsV3,
version: VersionV3,
cfg: &config.Config{},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Params V3 API",
Description: "", // Updated: No description in spec info
Tools: []mcp.Tool{
{
Name: "testParams",
Description: "Note: The API key is handled by the server, no need to provide it. Test various params",
InputSchema: mcp.Schema{
Type: "object",
Properties: map[string]mcp.Schema{
// Parameters merged with Request Body properties
"path_param": {Type: "integer", Format: "int32"},
"query_param": {Type: "string", Enum: []interface{}{"A", "B"}},
"optional_query": {Type: "boolean"},
"X-Header-Param": {Type: "string"},
"CookieParam": {Type: "number"},
"id": {Type: "string"},
"value": {Type: "number"},
},
Required: []string{"path_param", "query_param", "X-Header-Param", "id"}, // Order might differ, will sort before assert
},
},
},
Operations: map[string]mcp.OperationDetail{
"testParams": {
Method: "POST",
Path: "/test/{path_param}",
BaseURL: "", // No server
Parameters: []mcp.ParameterDetail{
{Name: "path_param", In: "path"},
{Name: "query_param", In: "query"},
{Name: "optional_query", In: "query"},
{Name: "X-Header-Param", In: "header"},
{Name: "CookieParam", In: "cookie"},
},
},
},
},
},
{
name: "V2 Params and Ref",
spec: specParamsV2,
version: VersionV2,
cfg: &config.Config{},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Params V2 API",
Description: "", // Corrected: No description in spec info
Tools: []mcp.Tool{
{
Name: "testV2Params",
Description: "Note: The API key is handled by the server, no need to provide it. Test V2 params and ref",
InputSchema: mcp.Schema{
Type: "object",
Properties: map[string]mcp.Schema{
// Path, Query, Header params first
"path_id": {Type: "string"},
"query_flag": {Type: "boolean"},
"X-Request-ID": {Type: "string"},
// Body param ($ref to Item) merged
"id": {Type: "string", Format: "uuid"},
"name": {Type: "string"},
},
Required: []string{"path_id", "query_flag", "id"}, // Required params + required definition props
},
},
},
Operations: map[string]mcp.OperationDetail{
"testV2Params": {
Method: "PUT",
Path: "/test/{path_id}",
BaseURL: "", // No server
Parameters: []mcp.ParameterDetail{
{Name: "path_id", In: "path"},
{Name: "query_flag", In: "query"},
{Name: "X-Request-ID", In: "header"},
{Name: "body_param", In: "body"}, // Body param listed here
},
},
},
},
},
// --- Array Tests ---
{
name: "V3 Arrays",
spec: specArraysV3,
version: VersionV3,
cfg: &config.Config{},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Arrays V3 API", Description: "",
Tools: []mcp.Tool{
{
Name: "processArrays",
Description: "Note: The API key is handled by the server, no need to provide it. Process arrays",
InputSchema: mcp.Schema{
Type: "object",
Properties: map[string]mcp.Schema{
"string_array_query": {Type: "array", Items: &mcp.Schema{Type: "string"}},
"int_array_body": {Type: "array", Items: &mcp.Schema{Type: "integer", Format: "int64"}},
},
Required: []string{}, // No required fields specified
},
},
},
Operations: map[string]mcp.OperationDetail{
"processArrays": {
Method: "POST",
Path: "/process",
BaseURL: "",
Parameters: []mcp.ParameterDetail{
{Name: "string_array_query", In: "query"},
// Body param details are not explicitly listed in V3 op details
},
},
},
},
},
{
name: "V2 Arrays",
spec: specArraysV2,
version: VersionV2,
cfg: &config.Config{},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "Arrays V2 API", Description: "",
Tools: []mcp.Tool{
{
Name: "getArrays",
Description: "Note: The API key is handled by the server, no need to provide it. Get arrays",
InputSchema: mcp.Schema{
Type: "object",
Properties: map[string]mcp.Schema{
"string_array_query": {Type: "array", Items: &mcp.Schema{Type: "string"}},
"int_array_form": {Type: "array", Items: &mcp.Schema{Type: "integer", Format: "int32"}},
},
Required: []string{}, // No required fields specified
},
},
},
Operations: map[string]mcp.OperationDetail{
"getArrays": {
Method: "GET",
Path: "/process",
BaseURL: "",
Parameters: []mcp.ParameterDetail{
{Name: "string_array_query", In: "query"},
{Name: "int_array_form", In: "formData"},
},
},
},
},
},
{
name: "V2 File Param",
spec: specFileV2,
version: VersionV2,
cfg: &config.Config{},
expectError: false,
expectedToolSet: &mcp.ToolSet{
Name: "File V2 API", Description: "",
Tools: []mcp.Tool{
{
Name: "uploadFile",
Description: "Note: The API key is handled by the server, no need to provide it. Upload file",
InputSchema: mcp.Schema{
Type: "object",
Properties: map[string]mcp.Schema{
"description": {Type: "string"},
"file_upload": {Type: "string"}, // file type maps to string
},
Required: []string{"file_upload"}, // file_upload is required
},
},
},
Operations: map[string]mcp.OperationDetail{
"uploadFile": {
Method: "POST",
Path: "/upload",
BaseURL: "",
Parameters: []mcp.ParameterDetail{
{Name: "description", In: "formData"},
{Name: "file_upload", In: "formData"},
},
},
},
},
},
// TODO: Add V3/V2 tests for refs
// TODO: Add V3/V2 tests for file types (V2)
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
toolSet, err := GenerateToolSet(tc.spec, tc.version, tc.cfg)
if tc.expectError {
assert.Error(t, err)
assert.Nil(t, toolSet)
} else {
assert.NoError(t, err)
require.NotNil(t, toolSet)
// Compare basic ToolSet fields
assert.Equal(t, tc.expectedToolSet.Name, toolSet.Name, "ToolSet Name mismatch")
assert.Equal(t, tc.expectedToolSet.Description, toolSet.Description, "ToolSet Description mismatch")
// Compare Tool/Operation counts first for filtering tests
assert.Equal(t, len(tc.expectedToolSet.Tools), len(toolSet.Tools), "Tool count mismatch")
assert.Equal(t, len(tc.expectedToolSet.Operations), len(toolSet.Operations), "Operation count mismatch")
// If counts match, check specific tool names exist (more robust for filtering tests)
if len(tc.expectedToolSet.Tools) == len(toolSet.Tools) {
actualToolNames := make(map[string]bool)
for _, actualTool := range toolSet.Tools {
actualToolNames[actualTool.Name] = true
}
for _, expectedTool := range tc.expectedToolSet.Tools {
assert.Contains(t, actualToolNames, expectedTool.Name, "Expected tool %s not found in actual tools", expectedTool.Name)
}
}
// If counts match, check specific operation IDs exist (more robust for filtering tests)
if len(tc.expectedToolSet.Operations) == len(toolSet.Operations) {
for opID := range tc.expectedToolSet.Operations {
assert.Contains(t, toolSet.Operations, opID, "Expected operation detail %s not found", opID)
}
}
// Full comparison only for non-filtering tests for now (can be expanded)
if !strings.Contains(tc.name, "Complex") {
// Compare Tools slice fully
for i, expectedTool := range tc.expectedToolSet.Tools {
if i < len(toolSet.Tools) { // Bounds check
actualTool := toolSet.Tools[i]
assert.Equal(t, expectedTool.Name, actualTool.Name, "Tool[%d] Name mismatch", i)
assert.Equal(t, expectedTool.Description, actualTool.Description, "Tool[%d] Description mismatch", i)
// Sort Required slices before comparing Schemas
expectedSchema := expectedTool.InputSchema
actualSchema := actualTool.InputSchema
sort.Strings(expectedSchema.Required)
sort.Strings(actualSchema.Required)
assert.Equal(t, expectedSchema, actualSchema, "Tool[%d] InputSchema mismatch", i)
}
}
// Compare Operations map fully
for opID, expectedOpDetail := range tc.expectedToolSet.Operations {
if actualOpDetail, ok := toolSet.Operations[opID]; ok {
assert.Equal(t, expectedOpDetail.Method, actualOpDetail.Method, "OpDetail %s Method mismatch", opID)
assert.Equal(t, expectedOpDetail.Path, actualOpDetail.Path, "OpDetail %s Path mismatch", opID)
assert.Equal(t, expectedOpDetail.BaseURL, actualOpDetail.BaseURL, "OpDetail %s BaseURL mismatch", opID)
assert.Equal(t, expectedOpDetail.Parameters, actualOpDetail.Parameters, "OpDetail %s Parameters mismatch", opID)
}
}
}
}
})
}
}
```
--------------------------------------------------------------------------------
/pkg/parser/parser.go:
--------------------------------------------------------------------------------
```go
package parser
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"sort"
"strings"
"github.com/ckanthony/openapi-mcp/pkg/config"
"github.com/ckanthony/openapi-mcp/pkg/mcp"
"github.com/getkin/kin-openapi/openapi3"
"github.com/go-openapi/loads"
"github.com/go-openapi/spec"
)
const (
VersionV2 = "v2"
VersionV3 = "v3"
)
// LoadSwagger detects the version and loads an OpenAPI/Swagger specification
// from a local file path or a remote URL.
// It returns the loaded spec document (as interface{}), the detected version (string), and an error.
func LoadSwagger(location string) (interface{}, string, error) {
// Determine if location is URL or file path
locationURL, urlErr := url.ParseRequestURI(location)
isURL := urlErr == nil && locationURL != nil && (locationURL.Scheme == "http" || locationURL.Scheme == "https")
var data []byte
var err error
var absPath string // Store absolute path if it's a file
if !isURL {
log.Printf("Detected file path location: %s", location)
absPath, err = filepath.Abs(location)
if err != nil {
return nil, "", fmt.Errorf("failed to get absolute path for '%s': %w", location, err)
}
// Read data first for version detection
data, err = os.ReadFile(absPath)
if err != nil {
return nil, "", fmt.Errorf("failed reading file path '%s': %w", absPath, err)
}
} else {
log.Printf("Detected URL location: %s", location)
// Read data first for version detection
resp, err := http.Get(location)
if err != nil {
return nil, "", fmt.Errorf("failed to fetch URL '%s': %w", location, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body) // Attempt to read body for error context
return nil, "", fmt.Errorf("failed to fetch URL '%s': status code %d, body: %s", location, resp.StatusCode, string(bodyBytes))
}
data, err = io.ReadAll(resp.Body)
if err != nil {
return nil, "", fmt.Errorf("failed to read response body from URL '%s': %w", location, err)
}
}
// Detect version from data
var detector map[string]interface{}
if err := json.Unmarshal(data, &detector); err != nil {
return nil, "", fmt.Errorf("failed to parse JSON from '%s' for version detection: %w", location, err)
}
if _, ok := detector["openapi"]; ok {
// OpenAPI 3.x
loader := openapi3.NewLoader()
loader.IsExternalRefsAllowed = true
var doc *openapi3.T
var loadErr error
if !isURL {
// Use LoadFromFile for local files
log.Printf("Loading V3 spec using LoadFromFile: %s", absPath)
doc, loadErr = loader.LoadFromFile(absPath)
} else {
// Use LoadFromURI for URLs
log.Printf("Loading V3 spec using LoadFromURI: %s", location)
doc, loadErr = loader.LoadFromURI(locationURL)
}
if loadErr != nil {
return nil, "", fmt.Errorf("failed to load OpenAPI v3 spec from '%s': %w", location, loadErr)
}
if err := doc.Validate(context.Background()); err != nil {
return nil, "", fmt.Errorf("OpenAPI v3 spec validation failed for '%s': %w", location, err)
}
return doc, VersionV3, nil
} else if _, ok := detector["swagger"]; ok {
// Swagger 2.0 - Still load from data as loads.Analyzed expects bytes
log.Printf("Loading V2 spec using loads.Analyzed from data (source: %s)", location)
doc, err := loads.Analyzed(data, "2.0")
if err != nil {
return nil, "", fmt.Errorf("failed to load or validate Swagger v2 spec from '%s': %w", location, err)
}
return doc.Spec(), VersionV2, nil
} else {
return nil, "", fmt.Errorf("failed to detect OpenAPI/Swagger version in '%s': missing 'openapi' or 'swagger' key", location)
}
}
// GenerateToolSet converts a loaded spec (v2 or v3) into an MCP ToolSet.
func GenerateToolSet(specDoc interface{}, version string, cfg *config.Config) (*mcp.ToolSet, error) {
switch version {
case VersionV3:
docV3, ok := specDoc.(*openapi3.T)
if !ok {
return nil, fmt.Errorf("internal error: expected *openapi3.T for v3 spec, got %T", specDoc)
}
return generateToolSetV3(docV3, cfg)
case VersionV2:
docV2, ok := specDoc.(*spec.Swagger)
if !ok {
return nil, fmt.Errorf("internal error: expected *spec.Swagger for v2 spec, got %T", specDoc)
}
return generateToolSetV2(docV2, cfg)
default:
return nil, fmt.Errorf("unsupported specification version: %s", version)
}
}
// --- V3 Specific Implementation ---
func generateToolSetV3(doc *openapi3.T, cfg *config.Config) (*mcp.ToolSet, error) {
toolSet := createBaseToolSet(doc.Info.Title, doc.Info.Description, cfg)
toolSet.Operations = make(map[string]mcp.OperationDetail) // Initialize the map
// Determine Base URL once
baseURL, err := determineBaseURLV3(doc, cfg)
if err != nil {
log.Printf("Warning: Could not determine base URL for V3 spec: %v. Operations might fail if base URL override is not set.", err)
baseURL = "" // Allow proceeding if override is set
}
// // V3 Handles security differently (Components.SecuritySchemes). Rely on config flags for server-side injection.
// apiKeyName := cfg.APIKeyName
// apiKeyIn := string(cfg.APIKeyLocation)
// // Store detected/configured key details internally - Let config handle this
// toolSet.SetAPIKeyDetails(apiKeyName, apiKeyIn)
paths := getSortedPathsV3(doc.Paths)
for _, rawPath := range paths { // Rename loop var to rawPath
pathItem := doc.Paths.Value(rawPath)
for method, op := range pathItem.Operations() {
if op == nil || !shouldIncludeOperationV3(op, cfg) {
continue
}
// Clean the path
cleanPath := rawPath
if queryIndex := strings.Index(rawPath, "?"); queryIndex != -1 {
cleanPath = rawPath[:queryIndex]
}
toolName := generateToolNameV3(op, method, rawPath) // Still generate name from raw path
toolDesc := getOperationDescriptionV3(op)
// Convert parameters (query, header, path, cookie)
parametersSchema, opParams, err := parametersToMCPSchemaAndDetailsV3(op.Parameters, cfg)
if err != nil {
return nil, fmt.Errorf("error processing v3 parameters for %s %s: %w", method, rawPath, err)
}
// Handle request body
requestBody, err := requestBodyToMCPV3(op.RequestBody)
if err != nil {
log.Printf("Warning: skipping request body for %s %s due to error: %v", method, rawPath, err)
} else {
// Merge request body schema into the main parameter schema
if requestBody.Content != nil {
if parametersSchema.Properties == nil {
parametersSchema.Properties = make(map[string]mcp.Schema)
}
for _, mediaTypeSchema := range requestBody.Content {
if mediaTypeSchema.Type == "object" && mediaTypeSchema.Properties != nil {
for propName, propSchema := range mediaTypeSchema.Properties {
parametersSchema.Properties[propName] = propSchema
}
} else {
// If body is not an object, represent as 'requestBody'
log.Printf("Warning: V3 request body for %s %s is not an object schema. Representing as 'requestBody' field.", method, rawPath)
parametersSchema.Properties["requestBody"] = mediaTypeSchema
}
break // Only process the first content type
}
// Merge required fields from the body *schema* (not the requestBody boolean)
var bodySchemaRequired []string
for _, mediaTypeSchema := range requestBody.Content {
if len(mediaTypeSchema.Required) > 0 {
bodySchemaRequired = mediaTypeSchema.Required
break // Use required from the first content type with a schema
}
}
if len(bodySchemaRequired) > 0 {
if parametersSchema.Required == nil {
parametersSchema.Required = make([]string, 0)
}
for _, r := range bodySchemaRequired { // Range over the correct schema required list
if !sliceContains(parametersSchema.Required, r) {
parametersSchema.Required = append(parametersSchema.Required, r)
}
}
sort.Strings(parametersSchema.Required)
}
// Optionally, add a note if the requestBody itself was marked as required
if requestBody.Required { // Check the boolean field
// How to indicate this? Maybe add to description?
log.Printf("Note: Request body for %s %s is marked as required.", method, rawPath)
// Or add all top-level body props to required? Needs decision.
}
}
}
// Prepend note about API key handling
finalToolDesc := "Note: The API key is handled by the server, no need to provide it. " + toolDesc
tool := mcp.Tool{
Name: toolName,
Description: finalToolDesc, // Use potentially modified description
InputSchema: parametersSchema, // Use InputSchema, assuming it contains combined params/body
}
toolSet.Tools = append(toolSet.Tools, tool)
// Store operation details for execution
toolSet.Operations[toolName] = mcp.OperationDetail{
Method: method,
Path: cleanPath, // Use the cleaned path here
BaseURL: baseURL,
Parameters: opParams,
}
}
}
return toolSet, nil
}
func determineBaseURLV3(doc *openapi3.T, cfg *config.Config) (string, error) {
if cfg.ServerBaseURL != "" {
return strings.TrimSuffix(cfg.ServerBaseURL, "/"), nil
}
if len(doc.Servers) > 0 {
baseURL := ""
for _, server := range doc.Servers {
if baseURL == "" {
baseURL = server.URL
}
if strings.HasPrefix(strings.ToLower(server.URL), "https://") {
baseURL = server.URL
break
}
if strings.HasPrefix(strings.ToLower(server.URL), "http://") {
baseURL = server.URL
}
}
if baseURL == "" {
return "", fmt.Errorf("v3: could not determine a suitable base URL from servers list")
}
return strings.TrimSuffix(baseURL, "/"), nil
}
return "", fmt.Errorf("v3: no server base URL specified in config or OpenAPI spec servers list")
}
func getSortedPathsV3(paths *openapi3.Paths) []string {
if paths == nil {
return []string{}
}
keys := make([]string, 0, len(paths.Map()))
for k := range paths.Map() {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func generateToolNameV3(op *openapi3.Operation, method, path string) string {
if op.OperationID != "" {
return op.OperationID
}
return generateDefaultToolName(method, path)
}
func getOperationDescriptionV3(op *openapi3.Operation) string {
if op.Summary != "" {
return op.Summary
}
return op.Description
}
func shouldIncludeOperationV3(op *openapi3.Operation, cfg *config.Config) bool {
return shouldInclude(op.OperationID, op.Tags, cfg)
}
// parametersToMCPSchemaAndDetailsV3 converts parameters and also returns the parameter details.
func parametersToMCPSchemaAndDetailsV3(params openapi3.Parameters, cfg *config.Config) (mcp.Schema, []mcp.ParameterDetail, error) {
mcpSchema := mcp.Schema{Type: "object", Properties: make(map[string]mcp.Schema), Required: []string{}}
opParams := []mcp.ParameterDetail{}
for _, paramRef := range params {
if paramRef.Value == nil {
log.Printf("Warning: Skipping parameter with nil value.")
continue
}
param := paramRef.Value
if param.Schema == nil {
log.Printf("Warning: Skipping parameter '%s' with nil schema.", param.Name)
continue
}
// Skip the API key parameter if configured
if cfg.APIKeyName != "" && param.Name == cfg.APIKeyName && param.In == string(cfg.APIKeyLocation) {
log.Printf("Parser V3: Skipping API key parameter '%s' ('%s') from input schema generation.", param.Name, param.In)
continue
}
// Store parameter detail (even if skipped for schema, needed for execution?)
// Decision: Keep storing *all* params in opParams for potential server-side use,
// but skip adding the API key to the mcpSchema exposed to the client.
opParams = append(opParams, mcp.ParameterDetail{
Name: param.Name,
In: param.In,
})
propSchema, err := openapiSchemaToMCPSchemaV3(param.Schema)
if err != nil {
return mcp.Schema{}, nil, fmt.Errorf("v3 param '%s': %w", param.Name, err)
}
propSchema.Description = param.Description
mcpSchema.Properties[param.Name] = propSchema
if param.Required {
mcpSchema.Required = append(mcpSchema.Required, param.Name)
}
}
if len(mcpSchema.Required) > 1 {
sort.Strings(mcpSchema.Required)
}
return mcpSchema, opParams, nil
}
func requestBodyToMCPV3(rbRef *openapi3.RequestBodyRef) (mcp.RequestBody, error) {
mcpRB := mcp.RequestBody{Content: make(map[string]mcp.Schema)}
if rbRef == nil || rbRef.Value == nil {
return mcpRB, nil
}
rb := rbRef.Value
mcpRB.Description = rb.Description
mcpRB.Required = rb.Required
var mediaType *openapi3.MediaType
var chosenMediaTypeKey string
if mt, ok := rb.Content["application/json"]; ok {
mediaType, chosenMediaTypeKey = mt, "application/json"
} else {
for key, mt := range rb.Content {
mediaType, chosenMediaTypeKey = mt, key
break
}
}
if mediaType != nil && mediaType.Schema != nil {
contentSchema, err := openapiSchemaToMCPSchemaV3(mediaType.Schema)
if err != nil {
return mcp.RequestBody{}, fmt.Errorf("v3 request body (media type: %s): %w", chosenMediaTypeKey, err)
}
mcpRB.Content["application/json"] = contentSchema
} else if mediaType != nil {
mcpRB.Content["application/json"] = mcp.Schema{Type: "string", Description: fmt.Sprintf("Request body with media type %s (no specific schema defined)", chosenMediaTypeKey)}
}
return mcpRB, nil
}
func openapiSchemaToMCPSchemaV3(oapiSchemaRef *openapi3.SchemaRef) (mcp.Schema, error) {
if oapiSchemaRef == nil {
return mcp.Schema{Type: "string", Description: "Schema reference was nil"}, nil
}
if oapiSchemaRef.Value == nil {
return mcp.Schema{Type: "string", Description: fmt.Sprintf("Schema reference value was nil (ref: %s)", oapiSchemaRef.Ref)}, nil
}
oapiSchema := oapiSchemaRef.Value
var primaryType string
if oapiSchema.Type != nil && len(*oapiSchema.Type) > 0 {
primaryType = (*oapiSchema.Type)[0]
}
mcpSchema := mcp.Schema{
Type: mapJSONSchemaType(primaryType),
Description: oapiSchema.Description,
Format: oapiSchema.Format,
Enum: oapiSchema.Enum,
}
switch mcpSchema.Type {
case "object":
mcpSchema.Properties = make(map[string]mcp.Schema)
mcpSchema.Required = oapiSchema.Required
for name, propRef := range oapiSchema.Properties {
propSchema, err := openapiSchemaToMCPSchemaV3(propRef)
if err != nil {
return mcp.Schema{}, fmt.Errorf("v3 object property '%s': %w", name, err)
}
mcpSchema.Properties[name] = propSchema
}
if len(mcpSchema.Required) > 1 {
sort.Strings(mcpSchema.Required)
}
case "array":
if oapiSchema.Items != nil {
itemsSchema, err := openapiSchemaToMCPSchemaV3(oapiSchema.Items)
if err != nil {
return mcp.Schema{}, fmt.Errorf("v3 array items: %w", err)
}
mcpSchema.Items = &itemsSchema
}
case "string", "number", "integer", "boolean", "null":
// Basic types mapped
default:
if mcpSchema.Type == "string" && primaryType != "" && primaryType != "string" {
mcpSchema.Description += fmt.Sprintf(" (Original type '%s' unknown or unsupported)", primaryType)
}
}
return mcpSchema, nil
}
// --- V2 Specific Implementation ---
func generateToolSetV2(doc *spec.Swagger, cfg *config.Config) (*mcp.ToolSet, error) {
toolSet := createBaseToolSet(doc.Info.Title, doc.Info.Description, cfg)
toolSet.Operations = make(map[string]mcp.OperationDetail) // Initialize map
// Determine Base URL once
baseURL, err := determineBaseURLV2(doc, cfg)
if err != nil {
log.Printf("Warning: Could not determine base URL for V2 spec: %v. Operations might fail if base URL override is not set.", err)
baseURL = "" // Allow proceeding if override is set
}
// Detect API Key (Security Definitions)
apiKeyName := cfg.APIKeyName
apiKeyIn := string(cfg.APIKeyLocation)
if apiKeyName == "" && apiKeyIn == "" { // Only infer if not provided by config
for name, secDef := range doc.SecurityDefinitions {
if secDef.Type == "apiKey" {
apiKeyName = secDef.Name
apiKeyIn = secDef.In // "query" or "header"
log.Printf("Parser V2: Detected API key from security definition '%s': Name='%s', In='%s'", name, apiKeyName, apiKeyIn)
break // Assume only one apiKey definition for simplicity
}
}
}
// Store detected/configured key details internally
toolSet.SetAPIKeyDetails(apiKeyName, apiKeyIn)
// --- Iterate through Paths ---
paths := getSortedPathsV2(doc.Paths)
for _, rawPath := range paths { // Rename loop var to rawPath
pathItem := doc.Paths.Paths[rawPath]
ops := map[string]*spec.Operation{
"GET": pathItem.Get,
"PUT": pathItem.Put,
"POST": pathItem.Post,
"DELETE": pathItem.Delete,
"OPTIONS": pathItem.Options,
"HEAD": pathItem.Head,
"PATCH": pathItem.Patch,
}
for method, op := range ops {
if op == nil || !shouldIncludeOperationV2(op, cfg) {
continue
}
// Clean the path
cleanPath := rawPath
if queryIndex := strings.Index(rawPath, "?"); queryIndex != -1 {
cleanPath = rawPath[:queryIndex]
}
toolName := generateToolNameV2(op, method, rawPath) // Still generate name from raw path
toolDesc := getOperationDescriptionV2(op)
// Convert parameters and potential body schema
parametersSchema, bodySchema, opParams, err := parametersToMCPSchemaAndDetailsV2(op.Parameters, doc.Definitions, apiKeyName)
if err != nil {
return nil, fmt.Errorf("error processing v2 parameters for %s %s: %w", method, rawPath, err)
}
// Combine request body into parameters schema if it exists
if bodySchema.Type != "" { // Check if bodySchema was actually populated
if bodySchema.Type == "object" && bodySchema.Properties != nil {
if parametersSchema.Properties == nil {
parametersSchema.Properties = make(map[string]mcp.Schema)
}
for propName, propSchema := range bodySchema.Properties {
parametersSchema.Properties[propName] = propSchema
}
if len(bodySchema.Required) > 0 {
if parametersSchema.Required == nil {
parametersSchema.Required = make([]string, 0)
}
for _, r := range bodySchema.Required {
if !sliceContains(parametersSchema.Required, r) {
parametersSchema.Required = append(parametersSchema.Required, r)
}
}
sort.Strings(parametersSchema.Required)
}
} else {
// If body is not an object, represent as 'requestBody'
log.Printf("Warning: V2 request body for %s %s is not an object schema. Representing as 'requestBody' field.", method, rawPath)
if parametersSchema.Properties == nil {
parametersSchema.Properties = make(map[string]mcp.Schema)
}
parametersSchema.Properties["requestBody"] = bodySchema
}
}
// Prepend note about API key handling
finalToolDesc := "Note: The API key is handled by the server, no need to provide it. " + toolDesc
tool := mcp.Tool{
Name: toolName,
Description: finalToolDesc, // Use potentially modified description
InputSchema: parametersSchema, // Use InputSchema, assuming it contains combined params/body
}
toolSet.Tools = append(toolSet.Tools, tool)
// Store operation details for execution
toolSet.Operations[toolName] = mcp.OperationDetail{
Method: method,
Path: cleanPath, // Use the cleaned path here
BaseURL: baseURL,
Parameters: opParams,
}
}
}
return toolSet, nil
}
func determineBaseURLV2(doc *spec.Swagger, cfg *config.Config) (string, error) {
if cfg.ServerBaseURL != "" {
return strings.TrimSuffix(cfg.ServerBaseURL, "/"), nil
}
host := doc.Host
if host == "" {
return "", fmt.Errorf("v2: missing 'host' in spec")
}
scheme := "https"
if len(doc.Schemes) > 0 {
// Prefer https, then http, then first
preferred := []string{"https", "http"}
found := false
for _, p := range preferred {
for _, s := range doc.Schemes {
if s == p {
scheme = s
found = true
break
}
}
if found {
break
}
}
if !found {
scheme = doc.Schemes[0]
} // fallback to first scheme
} // else default to https
basePath := doc.BasePath
return strings.TrimSuffix(scheme+"://"+host+basePath, "/"), nil
}
func getSortedPathsV2(paths *spec.Paths) []string {
if paths == nil {
return []string{}
}
keys := make([]string, 0, len(paths.Paths))
for k := range paths.Paths {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func generateToolNameV2(op *spec.Operation, method, path string) string {
if op.ID != "" {
return op.ID
}
return generateDefaultToolName(method, path)
}
func getOperationDescriptionV2(op *spec.Operation) string {
if op.Summary != "" {
return op.Summary
}
return op.Description
}
func shouldIncludeOperationV2(op *spec.Operation, cfg *config.Config) bool {
return shouldInclude(op.ID, op.Tags, cfg)
}
// parametersToMCPSchemaAndDetailsV2 converts V2 parameters and also returns details and request body.
func parametersToMCPSchemaAndDetailsV2(params []spec.Parameter, definitions spec.Definitions, apiKeyName string) (mcp.Schema, mcp.Schema, []mcp.ParameterDetail, error) {
mcpSchema := mcp.Schema{Type: "object", Properties: make(map[string]mcp.Schema), Required: []string{}}
bodySchema := mcp.Schema{} // Initialize empty
opParams := []mcp.ParameterDetail{}
hasBodyParam := false
var bodyParam *spec.Parameter // Declare bodyParam here to be accessible later
// First pass: Separate body param, process others
for _, param := range params {
// Skip the API key parameter if it's configured/detected
if apiKeyName != "" && param.Name == apiKeyName && (param.In == "query" || param.In == "header") {
log.Printf("Parser V2: Skipping API key parameter '%s' ('%s') from input schema generation.", param.Name, param.In)
continue
}
if param.In == "body" {
if hasBodyParam {
return mcp.Schema{}, mcp.Schema{}, nil, fmt.Errorf("v2: multiple 'body' parameters found")
}
hasBodyParam = true
bodyParam = ¶m // Assign to outer scope variable
continue // Don't process body param further in this loop
}
if param.In != "query" && param.In != "path" && param.In != "header" && param.In != "formData" {
log.Printf("Parser V2: Skipping unsupported parameter type '%s' for parameter '%s'", param.In, param.Name)
continue
}
// Add non-body param detail
opParams = append(opParams, mcp.ParameterDetail{
Name: param.Name,
In: param.In, // query, header, path, formData
})
// Convert non-body param schema and add to mcpSchema
propSchema, err := swaggerParamToMCPSchema(¶m, definitions)
if err != nil {
return mcp.Schema{}, mcp.Schema{}, nil, fmt.Errorf("v2 param '%s': %w", param.Name, err)
}
mcpSchema.Properties[param.Name] = propSchema
if param.Required {
mcpSchema.Required = append(mcpSchema.Required, param.Name)
}
}
// Second pass: Process the body parameter if found
if bodyParam != nil {
bodySchema.Description = bodyParam.Description
if bodyParam.Schema != nil {
// Convert the body schema (resolving $refs)
bodySchemaFields, err := swaggerSchemaToMCPSchemaV2(bodyParam.Schema, definitions)
if err != nil {
return mcp.Schema{}, mcp.Schema{}, nil, fmt.Errorf("v2 request body schema: %w", err)
}
// Update our local bodySchema with the converted fields
bodySchema.Type = bodySchemaFields.Type
bodySchema.Properties = bodySchemaFields.Properties
bodySchema.Items = bodySchemaFields.Items
bodySchema.Format = bodySchemaFields.Format
bodySchema.Enum = bodySchemaFields.Enum
bodySchema.Required = bodySchemaFields.Required // Required fields from the *schema* itself
// Merge bodySchema properties into the main mcpSchema
if bodySchema.Type == "object" && bodySchema.Properties != nil {
for propName, propSchema := range bodySchema.Properties {
mcpSchema.Properties[propName] = propSchema
}
// Merge required fields from the body's schema into the main required list
if len(bodySchema.Required) > 0 {
mcpSchema.Required = append(mcpSchema.Required, bodySchema.Required...)
}
} else {
// Handle non-object body schema (e.g., array, string)
// Add a single property named after the body parameter
mcpSchema.Properties[bodyParam.Name] = bodySchemaFields // Use the converted schema
if bodyParam.Required { // Check the parameter's required status
mcpSchema.Required = append(mcpSchema.Required, bodyParam.Name)
}
}
} else {
// Body param defined without a schema? Treat as simple string.
log.Printf("Warning: V2 body parameter '%s' defined without a schema. Treating as string.", bodyParam.Name)
bodySchema.Type = "string"
mcpSchema.Properties[bodyParam.Name] = bodySchema
if bodyParam.Required {
mcpSchema.Required = append(mcpSchema.Required, bodyParam.Name)
}
}
// Always add the body parameter to the OperationDetail list
opParams = append(opParams, mcp.ParameterDetail{
Name: bodyParam.Name,
In: bodyParam.In,
})
}
// Sort and deduplicate the final required list
if len(mcpSchema.Required) > 1 {
sort.Strings(mcpSchema.Required)
seen := make(map[string]struct{}, len(mcpSchema.Required))
j := 0
for _, r := range mcpSchema.Required {
if _, ok := seen[r]; !ok {
seen[r] = struct{}{}
mcpSchema.Required[j] = r
j++
}
}
mcpSchema.Required = mcpSchema.Required[:j]
}
return mcpSchema, bodySchema, opParams, nil
}
// swaggerParamToMCPSchema converts a V2 Parameter (non-body) to an MCP Schema.
func swaggerParamToMCPSchema(param *spec.Parameter, definitions spec.Definitions) (mcp.Schema, error) {
// This needs to handle types like string, integer, array based on param.Type, param.Format, param.Items
// Simplified version:
mcpSchema := mcp.Schema{
Type: mapJSONSchemaType(param.Type), // Use the same mapping
Description: param.Description,
Format: param.Format,
Enum: param.Enum,
// TODO: Map items for array type, map constraints (maximum, etc.)
}
if param.Type == "array" && param.Items != nil {
// Need to convert param.Items (which is *spec.Items) to MCP schema
itemsSchema, err := swaggerItemsToMCPSchema(param.Items, definitions)
if err != nil {
return mcp.Schema{}, fmt.Errorf("v2 array param '%s' items: %w", param.Name, err)
}
mcpSchema.Items = &itemsSchema
}
return mcpSchema, nil
}
// swaggerItemsToMCPSchema converts V2 Items object
func swaggerItemsToMCPSchema(items *spec.Items, definitions spec.Definitions) (mcp.Schema, error) {
if items == nil {
return mcp.Schema{Type: "string", Description: "nil items"}, nil
}
// Similar logic to swaggerParamToMCPSchema but for Items structure
mcpSchema := mcp.Schema{
Type: mapJSONSchemaType(items.Type),
Description: "", // Items don't have descriptions typically
Format: items.Format,
Enum: items.Enum,
}
if items.Type == "array" && items.Items != nil {
subItemsSchema, err := swaggerItemsToMCPSchema(items.Items, definitions)
if err != nil {
return mcp.Schema{}, fmt.Errorf("v2 nested array items: %w", err)
}
mcpSchema.Items = &subItemsSchema
}
// TODO: Handle $ref within items? Not directly supported by spec.Items
return mcpSchema, nil
}
// swaggerSchemaToMCPSchemaV2 converts a Swagger v2 schema (from definitions or body param) to mcp.Schema
func swaggerSchemaToMCPSchemaV2(oapiSchema *spec.Schema, definitions spec.Definitions) (mcp.Schema, error) {
if oapiSchema == nil {
return mcp.Schema{Type: "string", Description: "Schema was nil"}, nil
}
// Handle $ref
if oapiSchema.Ref.String() != "" {
refSchema, err := resolveRefV2(oapiSchema.Ref, definitions)
if err != nil {
return mcp.Schema{}, err
}
// Recursively convert the resolved schema, careful with cycles
return swaggerSchemaToMCPSchemaV2(refSchema, definitions)
}
var primaryType string
if len(oapiSchema.Type) > 0 {
primaryType = oapiSchema.Type[0]
}
mcpSchema := mcp.Schema{
Type: mapJSONSchemaType(primaryType),
Description: oapiSchema.Description,
Format: oapiSchema.Format,
Enum: oapiSchema.Enum,
// TODO: Map V2 constraints (Maximum, Minimum, etc.)
}
switch mcpSchema.Type {
case "object":
mcpSchema.Properties = make(map[string]mcp.Schema)
mcpSchema.Required = oapiSchema.Required
for name, propSchema := range oapiSchema.Properties {
// propSchema here is spec.Schema, need recursive call
propMCPSchema, err := swaggerSchemaToMCPSchemaV2(&propSchema, definitions)
if err != nil {
return mcp.Schema{}, fmt.Errorf("v2 object property '%s': %w", name, err)
}
mcpSchema.Properties[name] = propMCPSchema
}
if len(mcpSchema.Required) > 1 {
sort.Strings(mcpSchema.Required)
}
case "array":
if oapiSchema.Items != nil && oapiSchema.Items.Schema != nil {
// V2 Items has a single Schema field
itemsSchema, err := swaggerSchemaToMCPSchemaV2(oapiSchema.Items.Schema, definitions)
if err != nil {
return mcp.Schema{}, fmt.Errorf("v2 array items: %w", err)
}
mcpSchema.Items = &itemsSchema
} else if oapiSchema.Items != nil && len(oapiSchema.Items.Schemas) > 0 {
// Handle tuple-like arrays (less common, maybe simplify to single type?)
// For now, take the first schema
itemsSchema, err := swaggerSchemaToMCPSchemaV2(&oapiSchema.Items.Schemas[0], definitions)
if err != nil {
return mcp.Schema{}, fmt.Errorf("v2 tuple array items: %w", err)
}
mcpSchema.Items = &itemsSchema
mcpSchema.Description += " (Note: original was tuple-like array, showing first type)"
}
case "string", "number", "integer", "boolean", "null":
// Basic types mapped
default:
if mcpSchema.Type == "string" && primaryType != "" && primaryType != "string" {
mcpSchema.Description += fmt.Sprintf(" (Original type '%s' unknown or unsupported)", primaryType)
}
}
return mcpSchema, nil
}
func resolveRefV2(ref spec.Ref, definitions spec.Definitions) (*spec.Schema, error) {
// Simple local definition resolution
refStr := ref.String()
if !strings.HasPrefix(refStr, "#/definitions/") {
return nil, fmt.Errorf("unsupported $ref format: %s", refStr)
}
defName := strings.TrimPrefix(refStr, "#/definitions/")
schema, ok := definitions[defName]
if !ok {
return nil, fmt.Errorf("$ref '%s' not found in definitions", refStr)
}
return &schema, nil
}
// --- Common Helper Functions ---
func createBaseToolSet(title, desc string, cfg *config.Config) *mcp.ToolSet {
// Prioritize config overrides if they are set
toolSetName := title // Default to spec title
if cfg.DefaultToolName != "" {
toolSetName = cfg.DefaultToolName // Use config override if provided
}
toolSetDesc := desc // Default to spec description
if cfg.DefaultToolDesc != "" {
toolSetDesc = cfg.DefaultToolDesc // Use config override if provided
}
toolSet := &mcp.ToolSet{
MCPVersion: "0.1.0",
Name: toolSetName, // Use determined name
Description: toolSetDesc, // Use determined description
Tools: []mcp.Tool{},
Operations: make(map[string]mcp.OperationDetail), // Initialize map
}
// The old overwrite logic is removed as it's handled above
// if title != "" {
// toolSet.Name = title
// }
// if desc != "" {
// toolSet.Description = desc
// }
return toolSet
}
// generateDefaultToolName creates a name if operationId is missing.
func generateDefaultToolName(method, path string) string {
pathParts := strings.Split(strings.Trim(path, "/"), "/")
var nameParts []string
nameParts = append(nameParts, strings.ToUpper(method[:1])+strings.ToLower(method[1:]))
for _, part := range pathParts {
if part == "" {
continue
}
if strings.HasPrefix(part, "{") && strings.HasSuffix(part, "}") {
paramName := strings.Trim(part, "{}")
nameParts = append(nameParts, "By"+strings.ToUpper(paramName[:1])+paramName[1:])
} else {
sanitizedPart := strings.ReplaceAll(part, "-", "_")
sanitizedPart = strings.Title(sanitizedPart) // Basic capitalization
nameParts = append(nameParts, sanitizedPart)
}
}
return strings.Join(nameParts, "")
}
// shouldInclude determines if an operation should be included based on config filters.
func shouldInclude(opID string, opTags []string, cfg *config.Config) bool {
// Exclusion rules take precedence
if len(cfg.ExcludeOperations) > 0 && opID != "" && sliceContains(cfg.ExcludeOperations, opID) {
return false
}
if len(cfg.ExcludeTags) > 0 {
for _, tag := range opTags {
if sliceContains(cfg.ExcludeTags, tag) {
return false
}
}
}
// Inclusion rules
hasInclusionRule := len(cfg.IncludeOperations) > 0 || len(cfg.IncludeTags) > 0
if !hasInclusionRule {
return true
} // No inclusion rules, include by default
if len(cfg.IncludeOperations) > 0 {
if opID != "" && sliceContains(cfg.IncludeOperations, opID) {
return true
}
} else if len(cfg.IncludeTags) > 0 {
for _, tag := range opTags {
if sliceContains(cfg.IncludeTags, tag) {
return true
}
}
}
return false // Did not match any inclusion rule
}
// mapJSONSchemaType ensures the type is one recognized by JSON Schema / MCP.
func mapJSONSchemaType(oapiType string) string {
switch strings.ToLower(oapiType) { // Normalize type
case "integer", "number", "string", "boolean", "array", "object":
return strings.ToLower(oapiType)
case "null":
return "string" // Represent null as string for MCP?
case "file": // Swagger 2.0 specific type
return "string" // Represent file uploads as string (e.g., path or content)?
default:
return "string"
}
}
// sliceContains checks if a string slice contains a specific string.
func sliceContains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
```
--------------------------------------------------------------------------------
/pkg/server/server.go:
--------------------------------------------------------------------------------
```go
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"sync"
"time"
// "fmt" // No longer needed here
// "sync" // No longer needed here
"github.com/ckanthony/openapi-mcp/pkg/config"
"github.com/ckanthony/openapi-mcp/pkg/mcp"
"github.com/google/uuid" // Import UUID package
)
// --- JSON-RPC Structures (Re-introduced for Handshake/Messages) ---
type jsonRPCRequest struct {
Jsonrpc string `json:"jsonrpc"`
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
ID interface{} `json:"id,omitempty"` // Can be string, number, or null
}
type jsonRPCResponse struct {
Jsonrpc string `json:"jsonrpc"`
Result interface{} `json:"result,omitempty"`
Error *jsonError `json:"error,omitempty"`
ID interface{} `json:"id"` // ID should match the request ID
}
type jsonError struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// --- MCP Message Structures (Kept for clarity on expected payloads) ---
// MCPMessage represents a generic message exchanged over the transport.
// Note: Adapt this structure based on the exact MCP spec requirements if needed.
// This structure is now more for understanding the *payloads* within JSON-RPC.
type MCPMessage struct {
Type string `json:"type"` // e.g., "initialize", "tools/list", "tools/call", "tool_result", "error"
ID string `json:"id,omitempty"` // Unique message ID (less relevant for JSON-RPC wrapper)
Payload json.RawMessage `json:"payload,omitempty"` // Content specific to the message type
ConnID string `json:"connectionId,omitempty"` // Included in responses related to a connection
}
// MCPError defines a structured error for MCP responses.
// This will be used within the 'Error.Data' field of a jsonRPCResponse.
type MCPError struct {
Code int `json:"code,omitempty"` // Optional error code
Message string `json:"message"`
Data interface{} `json:"data,omitempty"` // Optional additional data
}
// ToolCallParams represents the expected payload for a tools/call request.
// This will be the structure within the 'params' field of a jsonRPCRequest.
type ToolCallParams struct {
ToolName string `json:"name"` // Aligning with gin-mcp JSON-RPC 'name'
Input map[string]interface{} `json:"arguments"` // Aligning with gin-mcp JSON-RPC 'arguments'
}
// ToolResultContent represents an item in the 'content' array of a tool_result.
type ToolResultContent struct {
Type string `json:"type"`
Text string `json:"text"` // Assuming text/JSON string result
// Add other content types if needed
}
// ToolResultPayload represents the structure for the 'result' of a 'tool_result' JSON-RPC response.
type ToolResultPayload struct {
Content []ToolResultContent `json:"content"` // Array of content items
IsError bool `json:"isError"` // Aligning with gin-mcp
Error *MCPError `json:"error,omitempty"` // Detailed error info if IsError is true
ToolCallID string `json:"tool_call_id,omitempty"` // Optional: Can be helpful
}
// --- Server State ---
// activeConnections stores channels for sending messages back to active SSE clients.
var activeConnections = make(map[string]chan jsonRPCResponse) // Changed value type
var connMutex sync.RWMutex
// Channel buffer size
const messageChannelBufferSize = 10
// --- Server Implementation ---
// ServeMCP starts an HTTP server handling MCP communication.
func ServeMCP(addr string, toolSet *mcp.ToolSet, cfg *config.Config) error {
log.Printf("Preparing ToolSet for MCP...")
// --- Handler Functions ---
mcpHandler := func(w http.ResponseWriter, r *http.Request) {
// CORS Headers (Apply to all relevant requests)
w.Header().Set("Access-Control-Allow-Origin", "*") // Be more specific in production
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-Connection-ID")
w.Header().Set("Access-Control-Expose-Headers", "X-Connection-ID")
if r.Method == http.MethodOptions {
log.Println("Responding to OPTIONS request")
w.WriteHeader(http.StatusNoContent) // Use 204 No Content for OPTIONS
return
}
if r.Method == http.MethodGet {
httpMethodGetHandler(w, r) // Handle SSE connection setup
} else if r.Method == http.MethodPost {
httpMethodPostHandler(w, r, toolSet, cfg) // Pass the cfg object here
} else {
log.Printf("Method Not Allowed: %s", r.Method)
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
}
}
// Setup server mux
mux := http.NewServeMux()
mux.HandleFunc("/mcp", mcpHandler) // Single endpoint for GET/POST/OPTIONS
log.Printf("MCP server listening on %s/mcp", addr)
return http.ListenAndServe(addr, mux)
}
// httpMethodGetHandler handles the initial GET request to establish the SSE connection.
func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) {
connectionID := uuid.New().String()
log.Printf("SSE client connecting: %s (Assigning ID: %s)", r.RemoteAddr, connectionID)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
log.Println("Error: Client connection does not support flushing")
return
}
// --- Set headers FIRST ---
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// CORS headers are set in the main handler
w.Header().Set("X-Connection-ID", connectionID)
w.Header().Set("X-Accel-Buffering", "no") // Useful for proxies like Nginx
w.WriteHeader(http.StatusOK) // Write headers and status code
flusher.Flush() // Ensure headers are sent immediately
// --- Send initial :ok --- (Must happen *after* headers)
if _, err := fmt.Fprintf(w, ":ok\n\n"); err != nil {
log.Printf("Error sending SSE preamble to %s (ID: %s): %v", r.RemoteAddr, connectionID, err)
return // Cannot proceed if preamble fails
}
flusher.Flush()
log.Printf("Sent :ok preamble to %s (ID: %s)", r.RemoteAddr, connectionID)
// --- Send initial SSE events --- (endpoint, mcp-ready)
endpointURL := fmt.Sprintf("/mcp?sessionId=%s", connectionID) // Assuming /mcp is the mount path
if err := writeSSEEvent(w, "endpoint", endpointURL); err != nil {
log.Printf("Error sending SSE endpoint event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err)
return
}
flusher.Flush()
log.Printf("Sent endpoint event to %s (ID: %s)", r.RemoteAddr, connectionID)
readyMsg := jsonRPCRequest{ // Use request struct for notification format
Jsonrpc: "2.0",
Method: "mcp-ready",
Params: map[string]interface{}{ // Put data in params
"connectionId": connectionID,
"status": "connected",
"protocol": "2.0",
},
}
if err := writeSSEEvent(w, "message", readyMsg); err != nil {
log.Printf("Error sending SSE mcp-ready event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err)
return
}
flusher.Flush()
log.Printf("Sent mcp-ready event to %s (ID: %s)", r.RemoteAddr, connectionID)
// --- Setup message channel and store connection ---
msgChan := make(chan jsonRPCResponse, messageChannelBufferSize) // Channel for responses
connMutex.Lock()
activeConnections[connectionID] = msgChan
connMutex.Unlock()
log.Printf("Registered channel for connection %s. Active connections: %d", connectionID, len(activeConnections))
// --- Cleanup function ---
cleanup := func() {
connMutex.Lock()
delete(activeConnections, connectionID)
connMutex.Unlock()
close(msgChan) // Close channel when connection ends
log.Printf("Removed connection %s. Active connections: %d", connectionID, len(activeConnections))
}
defer cleanup()
// --- Goroutine to write messages from channel to SSE stream ---
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
go func() {
log.Printf("[SSE Writer %s] Starting message writer goroutine", connectionID)
defer log.Printf("[SSE Writer %s] Exiting message writer goroutine", connectionID)
for {
select {
case <-ctx.Done():
return // Exit if main context is cancelled
case resp, ok := <-msgChan:
if !ok {
log.Printf("[SSE Writer %s] Message channel closed.", connectionID)
return // Exit if channel is closed
}
log.Printf("[SSE Writer %s] Sending message (ID: %v) via SSE", connectionID, resp.ID)
if err := writeSSEEvent(w, "message", resp); err != nil {
log.Printf("[SSE Writer %s] Error writing message to SSE stream: %v. Cancelling context.", connectionID, err)
cancel() // Signal main loop to exit on write error
return
}
flusher.Flush() // Flush after writing message
}
}
}()
// --- Keep connection alive (main loop) ---
keepAliveTicker := time.NewTicker(20 * time.Second)
defer keepAliveTicker.Stop()
log.Printf("[SSE %s] Entering keep-alive loop", connectionID)
for {
select {
case <-ctx.Done():
log.Printf("[SSE %s] Context done. Exiting keep-alive loop.", connectionID)
return // Exit loop if context cancelled (client disconnect or write error)
case <-keepAliveTicker.C:
// Send JSON-RPC ping notification instead of SSE comment
pingMsg := jsonRPCRequest{ // Use request struct for notification format
Jsonrpc: "2.0",
Method: "ping",
Params: map[string]interface{}{ // Include timestamp like gin-mcp
"timestamp": time.Now().Unix(),
},
}
if err := writeSSEEvent(w, "message", pingMsg); err != nil {
log.Printf("[SSE %s] Error sending ping notification: %v. Closing connection.", connectionID, err)
cancel() // Signal writer goroutine and exit
return
}
flusher.Flush()
}
}
}
// writeSSEEvent formats and writes data as a Server-Sent Event.
func writeSSEEvent(w http.ResponseWriter, eventName string, data interface{}) error {
buffer := bytes.Buffer{}
if eventName != "" {
buffer.WriteString(fmt.Sprintf("event: %s\n", eventName))
}
// Marshal data to JSON if it's not a simple string already
var dataStr string
if strData, ok := data.(string); ok && eventName == "endpoint" { // Special case for endpoint URL
dataStr = strData
} else {
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal data for SSE event '%s': %w", eventName, err)
}
dataStr = string(jsonData)
}
// Write data line(s). Split multiline JSON for proper SSE formatting.
lines := strings.Split(dataStr, "\n")
for _, line := range lines {
buffer.WriteString(fmt.Sprintf("data: %s\n", line))
}
// Add final newline
buffer.WriteString("\n")
// Write to the response writer
_, err := w.Write(buffer.Bytes())
if err != nil {
return fmt.Errorf("failed to write SSE event '%s': %w", eventName, err)
}
return nil
}
// httpMethodPostHandler handles incoming POST requests containing MCP messages.
func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp.ToolSet, cfg *config.Config) {
// --- Original Logic (Restored) ---
connID := r.Header.Get("X-Connection-ID") // Try header first
if connID == "" {
connID = r.URL.Query().Get("sessionId") // Fallback to query parameter
log.Printf("X-Connection-ID header missing, checking sessionId query param: found='%s'", connID)
}
if connID == "" {
log.Println("Error: POST request received without X-Connection-ID header or sessionId query parameter")
http.Error(w, "Missing X-Connection-ID header or sessionId query parameter", http.StatusBadRequest)
return
}
// Find the corresponding message channel for this connection
connMutex.RLock()
msgChan, isActive := activeConnections[connID]
connMutex.RUnlock()
if !isActive {
log.Printf("Error: POST request received for inactive/unknown connection ID: %s", connID)
// Still send sync error here, as we don't have a channel
tryWriteHTTPError(w, http.StatusNotFound, "Invalid or expired connection ID")
return
}
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
log.Printf("Error reading POST request body for %s: %v", connID, err)
// Create error response in the ToolResultPayload format
errPayload := ToolResultPayload{
IsError: true,
Error: &MCPError{
Code: -32700, // JSON-RPC Parse Error Code
Message: "Parse error reading request body",
},
// ToolCallID doesn't really apply here, maybe use connID or leave empty?
// ToolCallID: connID,
}
errResp := jsonRPCResponse{
Jsonrpc: "2.0",
ID: nil, // ID is unknown if we can't read the body
Result: errPayload,
Error: nil, // Ensure top-level error is nil
}
// Attempt to send via SSE channel
select {
case msgChan <- errResp:
log.Printf("Queued read error response (ID: %v) for %s onto SSE channel (as Result)", errResp.ID, connID)
// Send HTTP 202 Accepted back to the POST request
w.WriteHeader(http.StatusAccepted)
fmt.Fprintln(w, "Request accepted (with parse error), response will be sent via SSE.")
default:
log.Printf("Error: Failed to queue read error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID)
// Send an error back on the POST request if channel fails
tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel")
}
return // Stop processing
}
// No defer r.Body.Close() needed here as io.ReadAll reads to EOF
log.Printf("Received POST data for %s: %s", connID, string(bodyBytes))
// Attempt to unmarshal into a temporary map first to extract ID if possible
var rawReq map[string]interface{}
var reqID interface{} // Keep track of ID even if full unmarshal fails
// Try unmarshalling into raw map
if err := json.Unmarshal(bodyBytes, &rawReq); err == nil {
// Ensure reqID is treated as a string or number if possible, handle potential null
if idVal, idExists := rawReq["id"]; idExists && idVal != nil {
reqID = idVal
} else {
reqID = nil // Explicitly set to nil if missing or JSON null
}
} else {
// Full unmarshal failed, log it but continue to try specific struct
log.Printf("Warning: Initial unmarshal into map failed for %s: %v. Will attempt specific struct unmarshal.", connID, err)
reqID = nil // ID is unknown
}
var req jsonRPCRequest // Expect JSON-RPC request
if err := json.Unmarshal(bodyBytes, &req); err != nil {
log.Printf("Error decoding JSON-RPC request for %s: %v", connID, err)
// Use createJSONRPCError to correctly format the error response
errResp := createJSONRPCError(reqID, -32700, "Parse error decoding JSON request", err.Error())
// Attempt to send via SSE channel
select {
case msgChan <- errResp:
log.Printf("Queued decode error response (ID: %v) for %s onto SSE channel", errResp.ID, connID)
// Send HTTP 202 Accepted back to the POST request
w.WriteHeader(http.StatusAccepted)
// Use a specific message for decode errors
fmt.Fprintln(w, "Request accepted (with decode error), response will be sent via SSE.")
default:
log.Printf("Error: Failed to queue decode error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID)
// Send an error back on the POST request if channel fails
tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel")
}
return // Stop processing
}
// If we successfully unmarshalled 'req', ensure reqID matches req.ID
if req.ID != nil {
reqID = req.ID
} else {
reqID = nil
}
// --- Variable to hold the final response to be sent via SSE ---
var respToSend jsonRPCResponse
// --- Validate JSON-RPC Request ---
if req.Jsonrpc != "2.0" {
log.Printf("Invalid JSON-RPC version ('%s') for %s, ID: %v", req.Jsonrpc, connID, reqID)
respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: jsonrpc field must be \"2.0\"", nil)
} else if req.Method == "" {
log.Printf("Missing JSON-RPC method for %s, ID: %v", connID, reqID)
respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: method field is missing or empty", nil)
} else {
// --- Process the valid request ---
log.Printf("Processing JSON-RPC message for %s: Method=%s, ID=%v", connID, req.Method, reqID)
switch req.Method {
case "initialize":
incomingInitializeJSON, _ := json.Marshal(req)
log.Printf("DEBUG: Handling 'initialize' for %s. Incoming request: %s", connID, string(incomingInitializeJSON))
respToSend = handleInitializeJSONRPC(connID, &req)
outgoingInitializeJSON, _ := json.Marshal(respToSend)
log.Printf("DEBUG: Prepared 'initialize' response for %s. Outgoing response: %s", connID, string(outgoingInitializeJSON))
case "notifications/initialized":
log.Printf("Received 'notifications/initialized' notification for %s. Ignoring.", connID)
w.WriteHeader(http.StatusAccepted)
fmt.Fprintln(w, "Notification received.")
return // Return early, do not send anything on SSE channel
case "tools/list":
respToSend = handleToolsListJSONRPC(connID, &req, toolSet)
case "tools/call":
respToSend = handleToolCallJSONRPC(connID, &req, toolSet, cfg)
default:
log.Printf("Received unknown JSON-RPC method '%s' for %s", req.Method, connID)
respToSend = createJSONRPCError(reqID, -32601, fmt.Sprintf("Method not found: %s", req.Method), nil)
}
}
// --- Send response ASYNCHRONOUSLY via SSE channel (unless handled earlier) ---
select {
case msgChan <- respToSend:
log.Printf("Queued response (ID: %v) for %s onto SSE channel", respToSend.ID, connID)
// Send HTTP 202 Accepted back to the POST request
w.WriteHeader(http.StatusAccepted)
// Use the standard message for successfully queued responses
fmt.Fprintln(w, "Request accepted, response will be sent via SSE.")
default:
log.Printf("Error: Failed to queue response (ID: %v) for %s - SSE channel likely full or closed.", respToSend.ID, connID)
http.Error(w, "Failed to queue response for SSE channel", http.StatusInternalServerError)
}
}
// --- JSON-RPC Message Handlers --- // Implementations returning jsonRPCResponse
func handleInitializeJSONRPC(connID string, req *jsonRPCRequest) jsonRPCResponse {
log.Printf("Handling 'initialize' (JSON-RPC) for %s", connID)
// Construct the result payload based on gin-mcp's structure using map[string]interface{}
resultPayload := map[string]interface{}{
"protocolVersion": "2024-11-05", // Aligning with gin-mcp
"capabilities": map[string]interface{}{
"tools": map[string]interface{}{
"enabled": true,
"config": map[string]interface{}{
"listChanged": false,
},
},
"prompts": map[string]interface{}{
"enabled": false,
},
"resources": map[string]interface{}{
"enabled": true,
},
"logging": map[string]interface{}{
"enabled": false,
},
"roots": map[string]interface{}{
"listChanged": false,
},
},
"serverInfo": map[string]interface{}{
"name": "OpenAPI-MCP", // Or use config name if available
"version": "openapi-mcp-0.1.0", // Your server version
"apiVersion": "2024-11-05", // MCP API version
},
"connectionId": connID, // Include the connection ID
}
return jsonRPCResponse{
Jsonrpc: "2.0",
ID: req.ID, // Match request ID
Result: resultPayload,
}
}
func handleToolsListJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.ToolSet) jsonRPCResponse {
log.Printf("Handling 'tools/list' (JSON-RPC) for %s", connID)
// Construct the result payload based on gin-mcp's structure
resultPayload := map[string]interface{}{
"tools": toolSet.Tools,
"metadata": map[string]interface{}{
"version": "2024-11-05", // Align with gin-mcp if possible
"count": len(toolSet.Tools),
},
}
return jsonRPCResponse{
Jsonrpc: "2.0",
ID: req.ID, // Match request ID
Result: resultPayload,
}
}
// executeToolCall performs the actual HTTP request based on the resolved operation and parameters.
// It now correctly handles API key injection based on the *cfg* parameter.
func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.Config) (*http.Response, error) {
toolName := params.ToolName
toolInput := params.Input // This is the map[string]interface{} from the client
log.Printf("[ExecuteToolCall] Looking up details for tool: %s", toolName)
operation, ok := toolSet.Operations[toolName]
if !ok {
log.Printf("[ExecuteToolCall] Error: Operation details not found for tool '%s'", toolName)
return nil, fmt.Errorf("operation details for tool '%s' not found", toolName)
}
log.Printf("[ExecuteToolCall] Found operation: Method=%s, Path=%s", operation.Method, operation.Path)
// --- Resolve API Key (using cfg passed from main) ---
resolvedKey := cfg.GetAPIKey()
apiKeyName := cfg.APIKeyName
apiKeyLocation := cfg.APIKeyLocation
hasServerKey := resolvedKey != "" && apiKeyName != "" && apiKeyLocation != ""
log.Printf("[ExecuteToolCall] API Key Details: Name='%s', In='%s', HasServerValue=%t", apiKeyName, apiKeyLocation, resolvedKey != "")
// --- Prepare Request Components ---
baseURL := operation.BaseURL // Use BaseURL from the specific operation
if cfg.ServerBaseURL != "" {
baseURL = cfg.ServerBaseURL // Override if global base URL is set
log.Printf("[ExecuteToolCall] Overriding base URL with global config: %s", baseURL)
}
if baseURL == "" {
log.Printf("[ExecuteToolCall] Warning: No base URL found for operation %s and no global override set.", toolName)
// For now, assume relative if empty.
}
path := operation.Path
queryParams := url.Values{}
pathParams := make(map[string]string)
headerParams := make(http.Header) // For headers to add
cookieParams := []*http.Cookie{} // For cookies to add
bodyData := make(map[string]interface{}) // For building the request body
requestBodyRequired := operation.Method == "POST" || operation.Method == "PUT" || operation.Method == "PATCH"
// Create a map of expected parameters from the operation details for easier lookup
expectedParams := make(map[string]string) // Map param name to its location ('in')
for _, p := range operation.Parameters {
expectedParams[p.Name] = p.In
}
// --- Process Input Parameters (Separating and Handling API Key Override) ---
log.Printf("[ExecuteToolCall] Processing %d input parameters...", len(toolInput))
for key, value := range toolInput {
// --- API Key Override Check ---
// If this input param is the API key AND we have a valid server key config,
// skip processing the client's value entirely.
if hasServerKey && key == apiKeyName {
log.Printf("[ExecuteToolCall] Skipping client-provided param '%s' due to server API key override.", key)
continue
}
// --- End API Key Override ---
paramLocation, knownParam := expectedParams[key]
pathPlaceholder := "{" + key + "}" // OpenAPI uses {param}
if strings.Contains(path, pathPlaceholder) {
// Handle path parameter substitution
pathParams[key] = fmt.Sprintf("%v", value)
log.Printf("[ExecuteToolCall] Found path parameter %s=%v", key, value)
} else if knownParam {
// Handle parameters defined in the spec (query, header, cookie)
switch paramLocation {
case "query":
queryParams.Add(key, fmt.Sprintf("%v", value))
log.Printf("[ExecuteToolCall] Found query parameter %s=%v (from spec)", key, value)
case "header":
headerParams.Add(key, fmt.Sprintf("%v", value))
log.Printf("[ExecuteToolCall] Found header parameter %s=%v (from spec)", key, value)
case "cookie":
cookieParams = append(cookieParams, &http.Cookie{Name: key, Value: fmt.Sprintf("%v", value)})
log.Printf("[ExecuteToolCall] Found cookie parameter %s=%v (from spec)", key, value)
// case "formData": // TODO: Handle form data if needed
// bodyData[key] = value // Or handle differently based on content type
// log.Printf("[ExecuteToolCall] Found formData parameter %s=%v (from spec)", key, value)
default:
// Known parameter but location handling is missing or mismatched.
if paramLocation == "path" && (operation.Method == "GET" || operation.Method == "DELETE") {
// If spec says 'path' but it wasn't in the actual path, and it's a GET/DELETE,
// treat it as a query parameter as a fallback.
log.Printf("[ExecuteToolCall] Warning: Parameter '%s' is 'path' in spec but not in URL path '%s'. Adding to query parameters as fallback for GET/DELETE.", key, operation.Path)
queryParams.Add(key, fmt.Sprintf("%v", value))
} else {
// Otherwise, log the warning and ignore.
log.Printf("[ExecuteToolCall] Warning: Parameter '%s' has unsupported or unhandled location '%s' in spec. Ignoring.", key, paramLocation)
}
}
} else if requestBodyRequired {
// If parameter is not in path or defined in spec params, and method expects a body,
// assume it belongs in the request body.
bodyData[key] = value
log.Printf("[ExecuteToolCall] Added body parameter %s=%v (assumed)", key, value)
} else {
// Parameter not in path, not in spec, and not a body method.
// This could be an extraneous parameter like 'explanation'. Log it.
log.Printf("[ExecuteToolCall] Ignoring parameter '%s' as it doesn't match path or known parameter location for method %s.", key, operation.Method)
}
}
// --- Substitute Path Parameters ---
for key, value := range pathParams {
path = strings.Replace(path, "{"+key+"}", value, -1)
}
// --- Inject Server API Key (if applicable) ---
if hasServerKey {
log.Printf("[ExecuteToolCall] Injecting server API key (Name: %s, Location: %s)", apiKeyName, string(apiKeyLocation))
switch apiKeyLocation {
case config.APIKeyLocationQuery:
queryParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value
log.Printf("[ExecuteToolCall] Injected API key '%s' into query parameters", apiKeyName)
case config.APIKeyLocationHeader:
headerParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value
log.Printf("[ExecuteToolCall] Injected API key '%s' into headers", apiKeyName)
case config.APIKeyLocationPath:
pathPlaceholder := "{" + apiKeyName + "}"
if strings.Contains(path, pathPlaceholder) {
path = strings.Replace(path, pathPlaceholder, resolvedKey, -1)
log.Printf("[ExecuteToolCall] Injected API key into path parameter '%s'", apiKeyName)
} else {
log.Printf("[ExecuteToolCall] Warning: API key location is 'path' but placeholder '%s' not found in final path '%s' for injection.", pathPlaceholder, path)
}
case config.APIKeyLocationCookie:
// Check if cookie already exists from input, replace if so
foundCookie := false
for i, c := range cookieParams {
if c.Name == apiKeyName {
log.Printf("[ExecuteToolCall] Replacing existing cookie '%s' with injected API key.", apiKeyName)
cookieParams[i] = &http.Cookie{Name: apiKeyName, Value: resolvedKey} // Replace existing
foundCookie = true
break
}
}
if !foundCookie {
log.Printf("[ExecuteToolCall] Adding new cookie '%s' with injected API key.", apiKeyName)
cookieParams = append(cookieParams, &http.Cookie{Name: apiKeyName, Value: resolvedKey}) // Append new
}
default:
// Use log.Printf for consistency
log.Printf("Warning: Unsupported API key location specified in config: '%s'", apiKeyLocation)
}
} else {
log.Printf("[ExecuteToolCall] Skipping server API key injection (config incomplete or key unresolved).")
}
// --- Final URL Construction ---
// Reconstruct query string *after* potential API key injection
targetURL := baseURL + path
if len(queryParams) > 0 {
targetURL += "?" + queryParams.Encode()
}
log.Printf("[ExecuteToolCall] Final Target URL: %s %s", operation.Method, targetURL)
// --- Prepare Request Body ---
var reqBody io.Reader
var bodyBytes []byte // Keep for logging
if requestBodyRequired && len(bodyData) > 0 {
var err error
bodyBytes, err = json.Marshal(bodyData)
if err != nil {
log.Printf("[ExecuteToolCall] Error marshalling request body: %v", err)
return nil, fmt.Errorf("error marshalling request body: %w", err)
}
reqBody = bytes.NewBuffer(bodyBytes)
log.Printf("[ExecuteToolCall] Request body: %s", string(bodyBytes))
}
// --- Create HTTP Request ---
req, err := http.NewRequest(operation.Method, targetURL, reqBody)
if err != nil {
log.Printf("[ExecuteToolCall] Error creating HTTP request: %v", err)
return nil, fmt.Errorf("error creating request: %w", err)
}
// --- Set Headers ---
// Default headers
req.Header.Set("Accept", "application/json") // Assume JSON response typical for APIs
if reqBody != nil {
req.Header.Set("Content-Type", "application/json") // Assume JSON body if body exists
}
// Add headers collected from input/spec AND potentially injected API key
for key, values := range headerParams {
// Note: We use Set, assuming single value per header from input typically.
// If multi-value headers are needed from spec/input, use Add.
if len(values) > 0 {
req.Header.Set(key, values[0])
}
}
// Add custom headers from config (comma-separated)
if cfg.CustomHeaders != "" {
headers := strings.Split(cfg.CustomHeaders, ",")
for _, h := range headers {
parts := strings.SplitN(h, ":", 2)
if len(parts) == 2 {
headerName := strings.TrimSpace(parts[0])
headerValue := strings.TrimSpace(parts[1])
if headerName != "" {
req.Header.Set(headerName, headerValue) // Set overrides potential input
log.Printf("[ExecuteToolCall] Added custom header from config: %s", headerName)
}
}
}
}
// --- Add Cookies ---
for _, cookie := range cookieParams {
req.AddCookie(cookie)
}
log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header)
if len(req.Cookies()) > 0 {
log.Printf("[ExecuteToolCall] Sending request with cookies: %+v", req.Cookies())
}
// --- Execute HTTP Request ---
log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
log.Printf("[ExecuteToolCall] Error executing HTTP request: %v", err)
return nil, fmt.Errorf("error executing request: %w", err)
}
log.Printf("[ExecuteToolCall] Request executed. Status Code: %d", resp.StatusCode)
// Note: Don't close resp.Body here, the caller (handleToolCallJSONRPC) needs it.
return resp, nil
}
func handleToolCallJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.ToolSet, cfg *config.Config) jsonRPCResponse {
// req.Params is interface{}, but should contain json.RawMessage for tools/call
rawParams, ok := req.Params.(json.RawMessage)
if !ok {
// If it's not RawMessage, maybe it was already decoded to a map? Handle that case too.
if paramsMap, mapOk := req.Params.(map[string]interface{}); mapOk {
// Attempt to marshal the map back to JSON bytes
var marshalErr error
rawParams, marshalErr = json.Marshal(paramsMap)
if marshalErr != nil {
log.Printf("Error marshalling params map for %s: %v", connID, marshalErr)
return createJSONRPCError(req.ID, -32602, "Invalid parameters format (map marshal failed)", marshalErr.Error())
}
log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from map)", connID, string(rawParams))
} else {
log.Printf("Invalid parameters format for tools/call (not json.RawMessage or map[string]interface{}): %T", req.Params)
return createJSONRPCError(req.ID, -32602, "Invalid parameters format (expected JSON object)", nil)
}
} else {
log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from RawMessage)", connID, string(rawParams))
}
// Now, unmarshal the rawParams ([]byte) into ToolCallParams
var params ToolCallParams
if err := json.Unmarshal(rawParams, ¶ms); err != nil {
log.Printf("Error unmarshalling tools/call params for %s: %v", connID, err)
return createJSONRPCError(req.ID, -32602, "Invalid parameters structure (unmarshal)", err.Error())
}
log.Printf("Executing tool '%s' for %s with input: %+v", params.ToolName, connID, params.Input)
// --- Execute the actual tool call ---
httpResp, execErr := executeToolCall(¶ms, toolSet, cfg)
// --- Process Response ---
var resultPayload ToolResultPayload
if execErr != nil {
log.Printf("Error executing tool call '%s': %v", params.ToolName, execErr)
resultPayload = ToolResultPayload{
IsError: true,
Error: &MCPError{
Message: fmt.Sprintf("Failed to execute tool '%s': %v", params.ToolName, execErr),
},
ToolCallID: fmt.Sprintf("%v", req.ID),
}
} else {
defer httpResp.Body.Close() // Ensure body is closed
bodyBytes, readErr := io.ReadAll(httpResp.Body)
if readErr != nil {
log.Printf("Error reading response body for tool '%s': %v", params.ToolName, readErr)
resultPayload = ToolResultPayload{
IsError: true,
Error: &MCPError{
Message: fmt.Sprintf("Failed to read response from tool '%s': %v", params.ToolName, readErr),
},
ToolCallID: fmt.Sprintf("%v", req.ID),
}
} else {
log.Printf("Received response body for tool '%s': %s", params.ToolName, string(bodyBytes))
// Check status code for API-level errors
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
resultPayload = ToolResultPayload{
IsError: true,
Error: &MCPError{
Code: httpResp.StatusCode,
Message: fmt.Sprintf("Tool '%s' API call failed with status %s", params.ToolName, httpResp.Status),
Data: string(bodyBytes), // Include response body in error data
},
ToolCallID: fmt.Sprintf("%v", req.ID),
}
} else {
// Successful execution
resultContent := []ToolResultContent{
{
Type: "text", // TODO: Handle JSON responses properly if Content-Type indicates it
Text: string(bodyBytes),
},
}
resultPayload = ToolResultPayload{
Content: resultContent,
IsError: false,
ToolCallID: fmt.Sprintf("%v", req.ID),
}
}
}
}
// --- Send Response ---
return jsonRPCResponse{
Jsonrpc: "2.0",
ID: req.ID, // Match request ID
Result: resultPayload, // Use the actual result payload
}
}
// --- Helper Functions (Updated for JSON-RPC) ---
// sendJSONRPCResponse sends a JSON-RPC response *synchronously*.
// Keep this for now for sending synchronous errors on POST decode/read failures.
func sendJSONRPCResponse(w http.ResponseWriter, resp jsonRPCResponse) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Error encoding JSON-RPC response (ID: %v) for ConnID %v: %v", resp.ID, resp.Error, err)
// Attempt to send a plain text error if JSON encoding fails
tryWriteHTTPError(w, http.StatusInternalServerError, "Internal Server Error encoding JSON-RPC response")
}
log.Printf("Sent JSON-RPC response: Method=%s, ID=%v", getMethodFromResponse(resp), resp.ID)
}
// createJSONRPCError creates a JSON-RPC error response.
func createJSONRPCError(id interface{}, code int, message string, data interface{}) jsonRPCResponse {
jsonErr := &jsonError{Code: code, Message: message, Data: data}
return jsonRPCResponse{
Jsonrpc: "2.0",
ID: id, // Error response should echo the request ID
Error: jsonErr,
}
}
// sendJSONRPCError sends a JSON-RPC error response.
func sendJSONRPCError(w http.ResponseWriter, connID string, id interface{}, code int, message string, data interface{}) {
resp := createJSONRPCError(id, code, message, data)
log.Printf("Sending JSON-RPC Error for ConnID %s, ID %v: Code=%d, Message='%s'", connID, id, code, message)
sendJSONRPCResponse(w, resp)
}
// Helper to get the method name for logging purposes (from the result/error structure if possible)
func getMethodFromResponse(resp jsonRPCResponse) string {
if resp.Result != nil {
// Attempt to infer method from result structure if it has a type field
if resMap, ok := resp.Result.(map[string]interface{}); ok {
if methodType, typeOk := resMap["type"].(string); typeOk {
return methodType + "_result"
}
}
// Infer based on known result types if possible
if _, ok := resp.Result.(map[string]interface{}); ok && resp.Result.(map[string]interface{})["tools"] != nil {
return "tool_set"
}
// If not easily identifiable, just indicate success
return "success"
} else if resp.Error != nil {
return "error"
}
return "unknown"
}
// tryWriteHTTPError attempts to write an HTTP error, ignoring failures.
func tryWriteHTTPError(w http.ResponseWriter, code int, message string) {
if _, err := w.Write([]byte(message)); err != nil {
log.Printf("Error writing plain HTTP error response: %v", err)
}
log.Printf("Sent plain HTTP error: %s (Code: %d)", message, code)
}
```
--------------------------------------------------------------------------------
/pkg/server/server_test.go:
--------------------------------------------------------------------------------
```go
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/ckanthony/openapi-mcp/pkg/config"
"github.com/ckanthony/openapi-mcp/pkg/mcp"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Re-added Helper Functions ---
// Helper function to create a simple ToolSet for testing tool calls
func createTestToolSetForCall() *mcp.ToolSet {
return &mcp.ToolSet{
Name: "Call Test API",
Tools: []mcp.Tool{
{
Name: "get_user",
Description: "Get user details",
InputSchema: mcp.Schema{
Type: "object",
Properties: map[string]mcp.Schema{
"user_id": {Type: "string"},
},
Required: []string{"user_id"},
},
},
{
Name: "post_data",
Description: "Post some data",
InputSchema: mcp.Schema{
Type: "object",
Properties: map[string]mcp.Schema{
"data": {Type: "string"},
},
Required: []string{"data"},
},
},
},
Operations: map[string]mcp.OperationDetail{
"get_user": {
Method: "GET",
Path: "/users/{user_id}",
Parameters: []mcp.ParameterDetail{
{Name: "user_id", In: "path"},
},
},
"post_data": {
Method: "POST",
Path: "/data",
Parameters: []mcp.ParameterDetail{}, // Body params assumed
},
},
}
}
// Helper to safely manage activeConnections for tests
func setupTestConnection(connID string) chan jsonRPCResponse {
msgChan := make(chan jsonRPCResponse, 1) // Buffer of 1 sufficient for most tests
connMutex.Lock()
activeConnections[connID] = msgChan
connMutex.Unlock()
return msgChan
}
func cleanupTestConnection(connID string) {
connMutex.Lock()
msgChan, exists := activeConnections[connID]
if exists {
delete(activeConnections, connID)
close(msgChan)
}
connMutex.Unlock()
}
// --- End Re-added Helper Functions ---
func TestHttpMethodPostHandler(t *testing.T) {
// --- Setup common test items ---
toolSet := createTestToolSetForCall() // Use the helper
cfg := &config.Config{} // Basic config
// NOTE: connID is now generated within each subtest to ensure isolation
// --- Define Test Cases ---
tests := []struct {
name string
requestBodyFn func(connID string) string // Function to generate body with dynamic connID
expectedSyncStatus int // Expected status code for the immediate POST response
expectedSyncBody string // Expected body for the immediate POST response
checkAsyncResponse func(t *testing.T, resp jsonRPCResponse) // Function to check async response
mockBackend http.HandlerFunc // Optional mock backend for tool calls
setupChannelDirectly func(connID string) chan jsonRPCResponse // Optional: For specific channel setups
}{
{
name: "Valid Initialize Request",
requestBodyFn: func(connID string) string {
return fmt.Sprintf(`{
"jsonrpc": "2.0",
"method": "initialize",
"id": "init-post-1",
"params": {"connectionId": "%s"}
}`, connID)
},
expectedSyncStatus: http.StatusAccepted,
expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
assert.Equal(t, "init-post-1", resp.ID)
assert.Nil(t, resp.Error)
resultMap, ok := resp.Result.(map[string]interface{})
require.True(t, ok)
assert.Contains(t, resultMap, "connectionId") // Check existence, actual ID checked separately
assert.Equal(t, "2024-11-05", resultMap["protocolVersion"])
},
},
{
name: "Valid Tools List Request",
requestBodyFn: func(connID string) string {
return `{
"jsonrpc": "2.0",
"method": "tools/list",
"id": "list-post-1"
}`
},
expectedSyncStatus: http.StatusAccepted,
expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
assert.Equal(t, "list-post-1", resp.ID)
assert.Nil(t, resp.Error)
resultMap, ok := resp.Result.(map[string]interface{})
require.True(t, ok)
assert.Contains(t, resultMap, "metadata")
assert.Contains(t, resultMap, "tools")
metadata, _ := resultMap["metadata"].(map[string]interface{})
assert.Equal(t, 2, metadata["count"]) // Corrected: Expect int(2)
},
},
{
name: "Valid Tool Call Request (Success)",
requestBodyFn: func(connID string) string {
return `{
"jsonrpc": "2.0",
"method": "tools/call",
"id": "call-post-1",
"params": {"name": "get_user", "arguments": {"user_id": "postUser"}}
}`
},
expectedSyncStatus: http.StatusAccepted,
expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
assert.Equal(t, "call-post-1", resp.ID)
assert.Nil(t, resp.Error)
resultPayload, ok := resp.Result.(ToolResultPayload)
require.True(t, ok)
assert.False(t, resultPayload.IsError)
require.Len(t, resultPayload.Content, 1)
assert.JSONEq(t, `{"id":"postUser"}`, resultPayload.Content[0].Text)
},
mockBackend: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, `{"id":"postUser"}`)
},
},
{
name: "Valid Tool Call Request (Tool Not Found)",
requestBodyFn: func(connID string) string {
return `{
"jsonrpc": "2.0",
"method": "tools/call",
"id": "call-post-err-1",
"params": {"name": "nonexistent_tool", "arguments": {}}
}`
},
expectedSyncStatus: http.StatusAccepted,
expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
assert.Equal(t, "call-post-err-1", resp.ID)
assert.Nil(t, resp.Error)
resultPayload, ok := resp.Result.(ToolResultPayload)
require.True(t, ok)
assert.True(t, resultPayload.IsError)
require.NotNil(t, resultPayload.Error)
assert.Contains(t, resultPayload.Error.Message, "operation details for tool 'nonexistent_tool' not found")
},
},
{
name: "Malformed JSON Request",
requestBodyFn: func(connID string) string {
return `{"jsonrpc": "2.0", "method": "initialize"`
},
expectedSyncStatus: http.StatusAccepted, // Even decode errors return 202, error is sent async
expectedSyncBody: "Request accepted (with decode error), response will be sent via SSE.\n",
checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
assert.Nil(t, resp.ID) // ID might be nil if request parsing failed early
require.NotNil(t, resp.Error)
assert.Equal(t, -32700, resp.Error.Code) // Parse Error
assert.Equal(t, "Parse error decoding JSON request", resp.Error.Message) // Corrected assertion
},
},
{
name: "Missing JSON-RPC Version",
requestBodyFn: func(connID string) string {
return `{
"method": "initialize",
"id": "rpc-err-1"
}`
},
expectedSyncStatus: http.StatusAccepted,
expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
assert.Equal(t, "rpc-err-1", resp.ID)
require.NotNil(t, resp.Error)
assert.Equal(t, -32600, resp.Error.Code) // Invalid Request
assert.Contains(t, resp.Error.Message, "jsonrpc field must be \"2.0\"")
},
},
{
name: "Unknown Method",
requestBodyFn: func(connID string) string {
return `{
"jsonrpc": "2.0",
"method": "unknown/method",
"id": "rpc-err-2"
}`
},
expectedSyncStatus: http.StatusAccepted,
expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
assert.Equal(t, "rpc-err-2", resp.ID)
require.NotNil(t, resp.Error)
assert.Equal(t, -32601, resp.Error.Code) // Method not found
assert.Contains(t, resp.Error.Message, "Method not found")
},
},
{
name: "Missing Method",
requestBodyFn: func(connID string) string {
return `{
"jsonrpc": "2.0",
"id": "rpc-err-3"
}`
},
expectedSyncStatus: http.StatusAccepted,
expectedSyncBody: "Request accepted, response will be sent via SSE.\n",
checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) {
assert.Equal(t, "rpc-err-3", resp.ID)
require.NotNil(t, resp.Error)
assert.Equal(t, -32600, resp.Error.Code) // Invalid Request
assert.Equal(t, "Invalid Request: method field is missing or empty", resp.Error.Message) // Corrected assertion
},
},
{
name: "Error Queuing Response To SSE",
requestBodyFn: func(connID string) string { // Use a simple valid request like tools/list
return `{
"jsonrpc": "2.0",
"method": "tools/list",
"id": "list-post-err-queue"
}`
},
expectedSyncStatus: http.StatusInternalServerError, // Expect 500 when channel is blocked
expectedSyncBody: "Failed to queue response for SSE channel\n", // Specific error message expected
setupChannelDirectly: func(connID string) chan jsonRPCResponse {
// Create a NON-BUFFERED channel to simulate blocking/full channel
msgChan := make(chan jsonRPCResponse) // No buffer size!
connMutex.Lock()
activeConnections[connID] = msgChan
connMutex.Unlock()
// Important: Do NOT start a reader for this channel
return msgChan
},
checkAsyncResponse: nil, // No async response should be successfully sent
},
}
// --- Run Test Cases ---
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
connID := uuid.NewString() // Generate unique connID for each subtest
// Setup mock backend if needed for this test case
var backendServer *httptest.Server
// --- Add Connection ID before test ---
var msgChan chan jsonRPCResponse
if tc.setupChannelDirectly != nil {
// Use custom setup if provided (e.g., for blocking channel test)
msgChan = tc.setupChannelDirectly(connID)
} else {
// Default setup using the helper with buffered channel
msgChan = setupTestConnection(connID)
}
defer cleanupTestConnection(connID) // Ensure cleanup after test
if tc.mockBackend != nil {
backendServer = httptest.NewServer(tc.mockBackend)
defer backendServer.Close()
// IMPORTANT: Update the toolset's BaseURL for the relevant operation
if strings.Contains(tc.requestBodyFn(connID), "get_user") { // Simple check based on request
op := toolSet.Operations["get_user"]
op.BaseURL = backendServer.URL
toolSet.Operations["get_user"] = op
}
// Update post_data BaseURL if needed
if strings.Contains(tc.requestBodyFn(connID), "post_data") {
op := toolSet.Operations["post_data"]
op.BaseURL = backendServer.URL
toolSet.Operations["post_data"] = op
}
}
reqBody := tc.requestBodyFn(connID) // Generate request body
req := httptest.NewRequest(http.MethodPost, "/mcp", strings.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Connection-ID", connID) // Use the generated connID
rr := httptest.NewRecorder()
httpMethodPostHandler(rr, req, toolSet, cfg)
// 1. Check synchronous response
assert.Equal(t, tc.expectedSyncStatus, rr.Code, "Unexpected status code for sync response")
// Trim space for comparison as http.Error might add a newline our literal doesn't have
assert.Equal(t, strings.TrimSpace(tc.expectedSyncBody), strings.TrimSpace(rr.Body.String()), "Unexpected body for sync response")
// 2. Check asynchronous response (sent via SSE channel)
if tc.checkAsyncResponse != nil {
select {
case asyncResp := <-msgChan:
tc.checkAsyncResponse(t, asyncResp)
case <-time.After(100 * time.Millisecond): // Add a timeout
t.Fatal("Timeout waiting for async response on SSE channel")
}
} else {
// If no async check is defined, ensure nothing was sent (e.g., for queue error test)
select {
case unexpectedResp, ok := <-msgChan:
if ok { // Only fail if the channel wasn't closed AND we got a message
t.Errorf("Received unexpected async response when none was expected: %+v", unexpectedResp)
}
// If !ok, channel was closed, which is fine/expected after cleanup
case <-time.After(50 * time.Millisecond):
// Success - no message received quickly, channel likely blocked as expected
}
}
})
}
}
func TestHttpMethodGetHandler(t *testing.T) {
// --- Setup ---
// Reset global state for this test
connMutex.Lock()
originalConnections := activeConnections
activeConnections = make(map[string]chan jsonRPCResponse)
connMutex.Unlock()
req, err := http.NewRequest("GET", "/mcp", nil)
require.NoError(t, err, "Failed to create request")
rr := httptest.NewRecorder()
// Ensure cleanup happens regardless of test outcome
defer func() {
connMutex.Lock()
// Clean up any connections potentially left by the test
for id, ch := range activeConnections {
close(ch)
delete(activeConnections, id)
log.Printf("[DEFER Cleanup] Closed channel and removed connection %s", id)
}
activeConnections = originalConnections // Restore the original map
connMutex.Unlock()
}()
// --- Execute Handler (in a goroutine as it blocks waiting for context) ---
ctx, cancel := context.WithCancel(context.Background())
req = req.WithContext(ctx)
hwg := sync.WaitGroup{}
hwg.Add(1)
go func() {
defer hwg.Done()
// Simulate some work before handler returns
// In a real scenario, this would block on ctx.Done() or keepAliveTicker
// For the test, we just call cancel() after a short delay
// to simulate the connection ending gracefully.
time.AfterFunc(100*time.Millisecond, cancel) // Allow handler to start and write initial data
httpMethodGetHandler(rr, req)
}()
// Wait for the handler goroutine to finish.
// This ensures all writes to rr are complete before we read.
if !waitTimeout(&hwg, 2*time.Second) { // Use a reasonable timeout
t.Fatal("Handler goroutine did not exit cleanly after context cancellation")
}
// --- Assertions (Performed *after* handler completion) ---
assert.Equal(t, http.StatusOK, rr.Code, "Status code should be OK")
// Check headers are set correctly
assert.Equal(t, "text/event-stream", rr.Header().Get("Content-Type"))
assert.Equal(t, "no-cache", rr.Header().Get("Cache-Control"))
assert.Equal(t, "keep-alive", rr.Header().Get("Connection"))
connID := rr.Header().Get("X-Connection-ID")
assert.NotEmpty(t, connID, "X-Connection-ID header should be set")
// Check connection was registered and then cleaned up
connMutex.RLock()
_, exists := originalConnections[connID] // Check original map after cleanup
connMutex.RUnlock()
assert.False(t, exists, "Connection ID should be removed from map after handler exits")
// Check initial body content is present
bodyContent := rr.Body.String()
assert.Contains(t, bodyContent, ":ok\n\n", "Body should contain :ok preamble")
// Construct the expected endpoint data string accurately
expectedEndpointData := "data: /mcp?sessionId=" + connID + "\n\n"
assert.Contains(t, bodyContent, "event: endpoint\n"+expectedEndpointData, "Body should contain endpoint event")
assert.Contains(t, bodyContent, "event: message\ndata: {", "Body should contain start of a message event (e.g., mcp-ready)")
// Check if connectionId is present in the ready message (adjust based on actual JSON structure)
assert.Contains(t, bodyContent, `"connectionId":"`+connID+`"`, "Body should contain mcp-ready event with correct connection ID")
// The explicit cleanupTestConnection call is not needed because the handler's defer and the test's defer handle it.
}
func TestExecuteToolCall(t *testing.T) {
tests := []struct {
name string
params ToolCallParams
opDetail mcp.OperationDetail
cfg *config.Config
expectError bool
containsError string
requestAsserter func(t *testing.T, r *http.Request) // Function to assert details of the received HTTP request
backendResponse string // Response body from mock backend
backendStatusCode int // Status code from mock backend
}{
// --- Basic GET with Path Param ---
{
name: "GET with path parameter",
params: ToolCallParams{
ToolName: "get_item",
Input: map[string]interface{}{"item_id": "item123"},
},
opDetail: mcp.OperationDetail{
Method: "GET",
Path: "/items/{item_id}",
Parameters: []mcp.ParameterDetail{{Name: "item_id", In: "path"}},
},
cfg: &config.Config{},
expectError: false,
backendStatusCode: http.StatusOK,
backendResponse: `{"status":"ok"}`,
requestAsserter: func(t *testing.T, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/items/item123", r.URL.Path)
assert.Empty(t, r.URL.RawQuery)
},
},
// --- POST with Query, Header, Cookie, and Body Params ---
{
name: "POST with various params",
params: ToolCallParams{
ToolName: "create_resource",
Input: map[string]interface{}{
"queryArg": "value1",
"X-Custom-Hdr": "headerValue",
"sessionToken": "cookieValue",
"bodyFieldA": "A",
"bodyFieldB": 123,
},
},
opDetail: mcp.OperationDetail{
Method: "POST",
Path: "/resources",
Parameters: []mcp.ParameterDetail{
{Name: "queryArg", In: "query"},
{Name: "X-Custom-Hdr", In: "header"},
{Name: "sessionToken", In: "cookie"},
// Body fields are implicitly handled
},
},
cfg: &config.Config{},
expectError: false,
backendStatusCode: http.StatusCreated,
backendResponse: `{"id":"res456"}`,
requestAsserter: func(t *testing.T, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/resources", r.URL.Path)
assert.Equal(t, "value1", r.URL.Query().Get("queryArg"))
assert.Equal(t, "headerValue", r.Header.Get("X-Custom-Hdr"))
cookie, err := r.Cookie("sessionToken")
require.NoError(t, err)
assert.Equal(t, "cookieValue", cookie.Value)
bodyBytes, _ := io.ReadAll(r.Body)
assert.JSONEq(t, `{"bodyFieldA":"A", "bodyFieldB":123}`, string(bodyBytes))
},
},
// --- API Key Injection (Header) ---
{
name: "API Key Injection (Header)",
params: ToolCallParams{
ToolName: "get_secure",
Input: map[string]interface{}{}, // No client key provided
},
opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure"},
cfg: &config.Config{
APIKey: "secret-server-key",
APIKeyName: "Authorization",
APIKeyLocation: config.APIKeyLocationHeader,
},
expectError: false,
backendStatusCode: http.StatusOK,
requestAsserter: func(t *testing.T, r *http.Request) {
assert.Equal(t, "secret-server-key", r.Header.Get("Authorization"))
},
},
// --- API Key Injection (Query) ---
{
name: "API Key Injection (Query)",
params: ToolCallParams{
ToolName: "get_secure",
Input: map[string]interface{}{"otherParam": "abc"},
},
opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure", Parameters: []mcp.ParameterDetail{{Name: "otherParam", In: "query"}}},
cfg: &config.Config{
APIKey: "secret-server-key-q",
APIKeyName: "api_key",
APIKeyLocation: config.APIKeyLocationQuery,
},
expectError: false,
backendStatusCode: http.StatusOK,
requestAsserter: func(t *testing.T, r *http.Request) {
assert.Equal(t, "secret-server-key-q", r.URL.Query().Get("api_key"))
assert.Equal(t, "abc", r.URL.Query().Get("otherParam")) // Ensure other params are preserved
},
},
// --- API Key Injection (Path) ---
{
name: "API Key Injection (Path)",
params: ToolCallParams{
ToolName: "get_secure_path",
Input: map[string]interface{}{}, // Key comes from config
},
opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure/{apiKey}/data"},
cfg: &config.Config{
APIKey: "path-key-123",
APIKeyName: "apiKey", // Matches the placeholder name
APIKeyLocation: config.APIKeyLocationPath,
},
expectError: false,
backendStatusCode: http.StatusOK,
requestAsserter: func(t *testing.T, r *http.Request) {
assert.Equal(t, "/secure/path-key-123/data", r.URL.Path)
},
},
// --- API Key Injection (Cookie) ---
{
name: "API Key Injection (Cookie)",
params: ToolCallParams{
ToolName: "get_secure_cookie",
Input: map[string]interface{}{}, // Key comes from config
},
opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure_cookie"},
cfg: &config.Config{
APIKey: "cookie-key-abc",
APIKeyName: "AuthToken",
APIKeyLocation: config.APIKeyLocationCookie,
},
expectError: false,
backendStatusCode: http.StatusOK,
requestAsserter: func(t *testing.T, r *http.Request) {
cookie, err := r.Cookie("AuthToken")
require.NoError(t, err)
assert.Equal(t, "cookie-key-abc", cookie.Value)
},
},
// --- Base URL Handling Tests ---
{
name: "Base URL from Default (Mock Server)",
params: ToolCallParams{ToolName: "get_default_url", Input: map[string]interface{}{}},
opDetail: mcp.OperationDetail{Method: "GET", Path: "/path1"}, // No BaseURL here
cfg: &config.Config{}, // No global override
expectError: false,
backendStatusCode: http.StatusOK,
requestAsserter: func(t *testing.T, r *http.Request) {
// Should hit the mock server at the correct path
assert.Equal(t, "/path1", r.URL.Path)
},
},
{
name: "Base URL from Global Config Override",
params: ToolCallParams{ToolName: "get_global_url", Input: map[string]interface{}{}},
opDetail: mcp.OperationDetail{Method: "GET", Path: "/path2", BaseURL: "http://should-be-ignored.com"},
// cfg will be updated in test loop to point ServerBaseURL to mock server
cfg: &config.Config{},
expectError: false,
backendStatusCode: http.StatusOK,
requestAsserter: func(t *testing.T, r *http.Request) {
// Should hit the mock server (set via cfg override) at the correct path
assert.Equal(t, "/path2", r.URL.Path)
},
},
// --- Error Case (Tool Not Found in ToolSet) ---
{
name: "Error - Tool Not Found",
params: ToolCallParams{
ToolName: "nonexistent",
Input: map[string]interface{}{},
},
opDetail: mcp.OperationDetail{}, // Not used, error occurs before this
cfg: &config.Config{},
expectError: true,
containsError: "operation details for tool 'nonexistent' not found",
requestAsserter: nil, // No request should be made
backendStatusCode: 0, // Not applicable
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// --- Mock Backend Setup ---
var backendServer *httptest.Server
backendServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if tc.requestAsserter != nil {
tc.requestAsserter(t, r)
}
w.WriteHeader(tc.backendStatusCode)
fmt.Fprint(w, tc.backendResponse)
}))
defer backendServer.Close()
// --- Prepare ToolSet (using mock server URL if needed) ---
toolSet := &mcp.ToolSet{
Operations: make(map[string]mcp.OperationDetail),
}
// Clone config to avoid modifying the template test case config
testCfg := *tc.cfg
// Special handling for the global override test case
if tc.name == "Base URL from Global Config Override" {
testCfg.ServerBaseURL = backendServer.URL // Point global override to mock server
}
// If the opDetail needs a BaseURL, set it to the mock server ONLY if it wasn't
// already set in the test case definition AND the global override isn't being used.
if tc.opDetail.Method != "" { // Only add if it's a valid detail for the test
if tc.opDetail.BaseURL == "" && testCfg.ServerBaseURL == "" {
tc.opDetail.BaseURL = backendServer.URL
}
toolSet.Operations[tc.params.ToolName] = tc.opDetail
}
// --- Execute Function ---
httpResp, err := executeToolCall(&tc.params, toolSet, &testCfg) // Use the potentially modified testCfg
// --- Assertions ---
if tc.expectError {
assert.Error(t, err)
if tc.containsError != "" {
assert.Contains(t, err.Error(), tc.containsError)
}
assert.Nil(t, httpResp)
} else {
assert.NoError(t, err)
require.NotNil(t, httpResp)
defer httpResp.Body.Close()
assert.Equal(t, tc.backendStatusCode, httpResp.StatusCode)
bodyBytes, _ := io.ReadAll(httpResp.Body)
assert.Equal(t, tc.backendResponse, string(bodyBytes))
}
})
}
}
func TestWriteSSEEvent(t *testing.T) {
tests := []struct {
name string
eventName string
data interface{}
expectedOut string
expectError bool
}{
{
name: "Simple String Data",
eventName: "endpoint",
data: "/mcp?sessionId=123",
expectedOut: "event: endpoint\ndata: /mcp?sessionId=123\n\n",
expectError: false,
},
{
name: "Struct Data (JSON-RPC Request)",
eventName: "message",
data: jsonRPCRequest{
Jsonrpc: "2.0",
Method: "mcp-ready",
Params: map[string]interface{}{"connectionId": "abc"},
},
// Note: JSON marshaling order isn't guaranteed, so use JSONEq or check fields
expectedOut: "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"mcp-ready\",\"params\":{\"connectionId\":\"abc\"}}\n\n",
expectError: false,
},
{
name: "Struct Data (JSON-RPC Response)",
eventName: "message",
data: jsonRPCResponse{
Jsonrpc: "2.0",
Result: map[string]interface{}{"status": "ok"},
ID: "req-1",
},
expectedOut: "event: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{\"status\":\"ok\"},\"id\":\"req-1\"}\n\n",
expectError: false,
},
{
name: "Error - Unmarshalable Data",
eventName: "error",
data: make(chan int), // Channels cannot be marshaled to JSON
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
rr := httptest.NewRecorder()
err := writeSSEEvent(rr, tc.eventName, tc.data)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
// For struct data, use JSONEq for robust comparison
if _, isStruct := tc.data.(jsonRPCRequest); isStruct {
prefix := fmt.Sprintf("event: %s\ndata: ", tc.eventName)
suffix := "\n\n"
require.True(t, strings.HasPrefix(rr.Body.String(), prefix))
require.True(t, strings.HasSuffix(rr.Body.String(), suffix))
actualJSON := strings.TrimSuffix(strings.TrimPrefix(rr.Body.String(), prefix), suffix)
expectedJSONBytes, _ := json.Marshal(tc.data)
assert.JSONEq(t, string(expectedJSONBytes), actualJSON)
} else if _, isStruct := tc.data.(jsonRPCResponse); isStruct {
prefix := fmt.Sprintf("event: %s\ndata: ", tc.eventName)
suffix := "\n\n"
require.True(t, strings.HasPrefix(rr.Body.String(), prefix))
require.True(t, strings.HasSuffix(rr.Body.String(), suffix))
actualJSON := strings.TrimSuffix(strings.TrimPrefix(rr.Body.String(), prefix), suffix)
expectedJSONBytes, _ := json.Marshal(tc.data)
assert.JSONEq(t, string(expectedJSONBytes), actualJSON)
} else {
// For simple types, direct string comparison is fine
assert.Equal(t, tc.expectedOut, rr.Body.String())
}
}
})
}
}
func TestTryWriteHTTPError(t *testing.T) {
rr := httptest.NewRecorder()
message := "Test Error Message"
code := http.StatusInternalServerError
tryWriteHTTPError(rr, code, message)
// Note: tryWriteHTTPError doesn't set the status code, it only writes the body.
// The calling function is expected to have set the code earlier.
// So, we only check the body content here.
assert.Equal(t, message, rr.Body.String())
}
func TestGetMethodFromResponse(t *testing.T) {
tests := []struct {
name string
response jsonRPCResponse
expected string
}{
{
name: "Error Response",
response: jsonRPCResponse{
Error: &jsonError{Code: -32600, Message: "..."},
},
expected: "error",
},
{
name: "Tool List Response",
response: jsonRPCResponse{
Result: map[string]interface{}{"tools": []interface{}{}, "metadata": map[string]interface{}{}},
},
expected: "tool_set",
},
{
name: "Initialize Response (Result is Map)",
response: jsonRPCResponse{
Result: map[string]interface{}{"protocolVersion": "...", "capabilities": map[string]interface{}{}},
},
expected: "success", // Falls back to 'success' as type isn't explicitly set
},
{
name: "Tool Call Response (Result is ToolResultPayload)",
response: jsonRPCResponse{
Result: ToolResultPayload{Content: []ToolResultContent{{Type: "text", Text: "..."}}},
},
expected: "success", // Falls back to 'success'
},
{
name: "Empty Response",
response: jsonRPCResponse{},
expected: "unknown",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
actual := getMethodFromResponse(tc.response)
assert.Equal(t, tc.expected, actual)
})
}
}
// --- Mock ResponseWriter for error simulation ---
// mockResponseWriter implements http.ResponseWriter and http.Flusher for testing SSE.
type sseMockResponseWriter struct {
hdr http.Header // Internal map for headers
statusCode int
body *bytes.Buffer
flushed bool
forceError error // If set, Write and Flush will return this error
failAfterNWrites int // Start failing after this many writes (-1 = disable)
writesMade int // Counter for writes made
}
// Renamed constructor
func newSseMockResponseWriter() *sseMockResponseWriter {
return &sseMockResponseWriter{
hdr: make(http.Header), // Initialize internal map
body: &bytes.Buffer{},
failAfterNWrites: -1, // Default to disabled
}
}
// Implement http.ResponseWriter interface
func (m *sseMockResponseWriter) Header() http.Header {
return m.hdr // Return the internal map
}
func (m *sseMockResponseWriter) WriteHeader(statusCode int) {
m.statusCode = statusCode
}
func (m *sseMockResponseWriter) Write(p []byte) (int, error) {
// Check if already forced error
if m.forceError != nil {
return 0, m.forceError
}
// Increment write count
m.writesMade++
// Check if write count triggers failure
if m.failAfterNWrites >= 0 && m.writesMade >= m.failAfterNWrites {
m.forceError = fmt.Errorf("forced write error after %d writes", m.failAfterNWrites)
log.Printf("DEBUG: sseMockResponseWriter triggering error: %v", m.forceError) // Debug log
return 0, m.forceError
}
// Proceed with normal write
return m.body.Write(p)
}
// Implement http.Flusher interface
func (m *sseMockResponseWriter) Flush() {
// Check if already forced error
if m.forceError != nil {
// Optional: log or handle repeated flush attempts after error
return
}
// Check if flush count triggers failure (less common to fail on flush, but possible)
// We are primarily testing Write failures, so we might skip count check here for simplicity
// or use a separate failAfterNFlushes counter if needed.
m.flushed = true
}
// Helper to get body content
func (m *sseMockResponseWriter) String() string {
return m.body.String()
}
// --- End Mock ResponseWriter ---
func TestHttpMethodGetHandler_WriteErrors(t *testing.T) {
tests := []struct {
name string
errorOnStage string // "preamble", "endpoint", "ready", "ping", "message"
forceError error // Error to set on the mock writer *before* handler runs
expectConnRemoved bool
}{
{"Error on Preamble (:ok)", "preamble", fmt.Errorf("forced write error during preamble"), true},
// Removed: {"Error on Endpoint Event", "endpoint", nil, true}, // Hard to simulate reliably without patching
// Removed: {"Error on MCP-Ready Event", "ready", nil, true}, // Hard to simulate reliably without patching
// TODO: Add test for error during keep-alive ping
// TODO: Add test for error during message write from channel
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Use renamed mock writer
mockWriter := newSseMockResponseWriter()
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
var connID string // Variable to capture assigned ID
// Set the error on the writer *before* calling the handler
if tc.forceError != nil {
mockWriter.forceError = tc.forceError
}
// Need to capture connID *if* headers get written before error
// We can check mockWriter.Header() after the handler potentially runs
// Inject error based on the test stage - REMOVED FUNCTION PATCHING
/*
originalWriteSSE := writeSSEEvent
defer func() { writeSSEEvent = originalWriteSSE }() // Restore original
writeSSEEvent = func(w http.ResponseWriter, eventName string, data interface{}) error {
// ... removed patching logic ...
}
*/
// Execute handler in goroutine as it might block briefly before erroring
done := make(chan struct{})
go func() {
defer close(done)
httpMethodGetHandler(mockWriter, req)
}()
// Wait for the handler goroutine to finish or timeout
select {
case <-done:
// Handler finished (presumably due to error)
case <-time.After(200 * time.Millisecond): // Generous timeout
t.Fatal("Timeout waiting for httpMethodGetHandler goroutine to exit after injected error")
}
// Capture ConnID *after* handler exit, in case headers were set before error
connID = mockWriter.Header().Get("X-Connection-ID")
// Assert connection removal
if tc.expectConnRemoved && connID != "" {
connMutex.RLock()
_, exists := activeConnections[connID]
connMutex.RUnlock()
assert.False(t, exists, "Connection %s should have been removed from activeConnections after write error", connID)
} else if tc.expectConnRemoved && connID == "" {
t.Log("Cannot assert connection removal as ConnID was not captured before error")
}
})
}
}
func TestHttpMethodGetHandler_GoroutineErrors(t *testing.T) {
t.Run("Error_on_Message_Write", func(t *testing.T) {
// Estimate writes before first message: :ok(1), endpoint(1), ready(1) = 3 writes
// Target failure on the 4th write (first write of the actual message event line)
mockWriter := newSseMockResponseWriter()
mockWriter.failAfterNWrites = 4 // Fail on the 4th write overall
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
var connID string
var msgChan chan jsonRPCResponse
// Clean connections before test
connMutex.Lock()
activeConnections = make(map[string]chan jsonRPCResponse)
connMutex.Unlock()
defer func() {
// Clean up after test, ensure channel is closed if exists
connMutex.Lock()
if msgChan != nil {
// Only delete from map, handler is responsible for closing channel
delete(activeConnections, connID)
}
activeConnections = make(map[string]chan jsonRPCResponse) // Reset for other tests
connMutex.Unlock()
}()
done := make(chan struct{})
go func() {
defer close(done)
httpMethodGetHandler(mockWriter, req)
log.Println("DEBUG: httpMethodGetHandler goroutine exited")
}()
// Wait for the connection to be established
assert.Eventually(t, func() bool {
connMutex.RLock()
defer connMutex.RUnlock()
for id, ch := range activeConnections {
connID = id
msgChan = ch
log.Printf("DEBUG: Connection established: %s", connID)
return true
}
return false
}, 200*time.Millisecond, 20*time.Millisecond, "Connection not established in time")
require.NotEmpty(t, connID, "connID should have been captured")
require.NotNil(t, msgChan, "msgChan should have been captured")
// Send a message that should trigger the write error
testResp := jsonRPCResponse{Jsonrpc: "2.0", ID: "test-msg-1", Result: "test data"}
log.Printf("DEBUG: Sending test message to channel for %s", connID)
select {
case msgChan <- testResp:
log.Printf("DEBUG: Test message sent to channel for %s", connID)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timeout sending message to channel")
}
// Wait for the handler goroutine to finish due to the write error
select {
case <-done:
log.Printf("DEBUG: Handler goroutine finished as expected after message write error")
// Handler finished (presumably due to write error)
case <-time.After(1000 * time.Millisecond): // Increased timeout to 1 second
t.Fatal("Timeout waiting for httpMethodGetHandler goroutine to exit after message write error")
}
// Assert connection removal
connMutex.RLock()
_, exists := activeConnections[connID]
connMutex.RUnlock()
assert.False(t, exists, "Connection %s should have been removed after message write error", connID)
})
// TODO: Add sub-test for Error_on_Ping_Write
}
// Helper function to wait for a WaitGroup with a timeout
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
return true // Completed normally
case <-time.After(timeout):
return false // Timed out
}
}
```