This is page 3 of 5. Use http://codebase.md/grafana/mcp-grafana?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .github
│ ├── dependabot.yml
│ └── workflows
│ ├── docker.yml
│ ├── e2e.yml
│ ├── integration.yml
│ ├── release.yml
│ └── unit.yml
├── .gitignore
├── .golangci.yaml
├── .goreleaser.yaml
├── cmd
│ ├── linters
│ │ └── jsonschema
│ │ └── main.go
│ └── mcp-grafana
│ └── main.go
├── CODEOWNERS
├── docker-compose.yaml
├── Dockerfile
├── examples
│ └── tls_example.go
├── gemini-extension.json
├── go.mod
├── go.sum
├── image-tag
├── internal
│ └── linter
│ └── jsonschema
│ ├── jsonschema_lint_test.go
│ ├── jsonschema_lint.go
│ └── README.md
├── LICENSE
├── Makefile
├── mcpgrafana_test.go
├── mcpgrafana.go
├── proxied_client.go
├── proxied_handler.go
├── proxied_tools_test.go
├── proxied_tools.go
├── README.md
├── renovate.json
├── server.json
├── session_test.go
├── session.go
├── testdata
│ ├── dashboards
│ │ └── demo.json
│ ├── loki-config.yml
│ ├── prometheus-entrypoint.sh
│ ├── prometheus-seed.yml
│ ├── prometheus.yml
│ ├── promtail-config.yml
│ ├── provisioning
│ │ ├── alerting
│ │ │ ├── alert_rules.yaml
│ │ │ └── contact_points.yaml
│ │ ├── dashboards
│ │ │ └── dashboards.yaml
│ │ └── datasources
│ │ └── datasources.yaml
│ ├── tempo-config-2.yaml
│ └── tempo-config.yaml
├── tests
│ ├── .gitignore
│ ├── .python-version
│ ├── admin_test.py
│ ├── conftest.py
│ ├── dashboards_test.py
│ ├── disable_write_test.py
│ ├── health_test.py
│ ├── loki_test.py
│ ├── navigation_test.py
│ ├── pyproject.toml
│ ├── README.md
│ ├── tempo_test.py
│ ├── utils.py
│ └── uv.lock
├── tls_test.go
├── tools
│ ├── admin_test.go
│ ├── admin.go
│ ├── alerting_client_test.go
│ ├── alerting_client.go
│ ├── alerting_test.go
│ ├── alerting_unit_test.go
│ ├── alerting.go
│ ├── annotations_integration_test.go
│ ├── annotations_unit_test.go
│ ├── annotations.go
│ ├── asserts_cloud_test.go
│ ├── asserts_test.go
│ ├── asserts.go
│ ├── cloud_testing_utils.go
│ ├── dashboard_test.go
│ ├── dashboard.go
│ ├── datasources_test.go
│ ├── datasources.go
│ ├── folder.go
│ ├── incident_integration_test.go
│ ├── incident_test.go
│ ├── incident.go
│ ├── loki_test.go
│ ├── loki.go
│ ├── navigation_test.go
│ ├── navigation.go
│ ├── oncall_cloud_test.go
│ ├── oncall.go
│ ├── prometheus_test.go
│ ├── prometheus_unit_test.go
│ ├── prometheus.go
│ ├── pyroscope_test.go
│ ├── pyroscope.go
│ ├── search_test.go
│ ├── search.go
│ ├── sift_cloud_test.go
│ ├── sift.go
│ └── testcontext_test.go
├── tools_test.go
└── tools.go
```
# Files
--------------------------------------------------------------------------------
/tools/oncall_cloud_test.go:
--------------------------------------------------------------------------------
```go
1 | //go:build cloud
2 | // +build cloud
3 |
4 | // This file contains cloud integration tests that run against a dedicated test instance
5 | // at mcptests.grafana-dev.net. This instance is configured with a minimal setup on the OnCall side:
6 | // - One team
7 | // - Two schedules (only one has a team assigned)
8 | // - One shift in the schedule with a team assigned
9 | // - One user
10 | // These tests expect this configuration to exist and will skip if the required
11 | // environment variables (GRAFANA_URL, GRAFANA_SERVICE_ACCOUNT_TOKEN or GRAFANA_API_KEY) are not set.
12 | // The GRAFANA_API_KEY variable is deprecated.
13 |
14 | package tools
15 |
16 | import (
17 | "testing"
18 |
19 | "github.com/stretchr/testify/assert"
20 | "github.com/stretchr/testify/require"
21 | )
22 |
23 | func TestCloudOnCallSchedules(t *testing.T) {
24 | ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
25 |
26 | // Test listing all schedules
27 | t.Run("list all schedules", func(t *testing.T) {
28 | result, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{})
29 | require.NoError(t, err, "Should not error when listing schedules")
30 | assert.NotNil(t, result, "Result should not be nil")
31 | })
32 |
33 | // Test pagination
34 | t.Run("list schedules with pagination", func(t *testing.T) {
35 | // Get first page
36 | page1, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{Page: 1})
37 | require.NoError(t, err, "Should not error when listing schedules page 1")
38 | assert.NotNil(t, page1, "Page 1 should not be nil")
39 |
40 | // Get second page
41 | page2, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{Page: 2})
42 | require.NoError(t, err, "Should not error when listing schedules page 2")
43 | assert.NotNil(t, page2, "Page 2 should not be nil")
44 | })
45 |
46 | // Get a team ID from an existing schedule to test filtering
47 | schedules, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{})
48 | require.NoError(t, err, "Should not error when listing schedules")
49 |
50 | if len(schedules) > 0 && schedules[0].TeamID != "" {
51 | teamID := schedules[0].TeamID
52 |
53 | // Test filtering by team ID
54 | t.Run("list schedules by team ID", func(t *testing.T) {
55 | result, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{
56 | TeamID: teamID,
57 | })
58 | require.NoError(t, err, "Should not error when listing schedules by team")
59 | assert.NotEmpty(t, result, "Should return at least one schedule")
60 | for _, schedule := range result {
61 | assert.Equal(t, teamID, schedule.TeamID, "All schedules should belong to the specified team")
62 | }
63 | })
64 | }
65 |
66 | // Test getting a specific schedule
67 | if len(schedules) > 0 {
68 | scheduleID := schedules[0].ID
69 | t.Run("get specific schedule", func(t *testing.T) {
70 | result, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{
71 | ScheduleID: scheduleID,
72 | })
73 | require.NoError(t, err, "Should not error when getting specific schedule")
74 | assert.Len(t, result, 1, "Should return exactly one schedule")
75 | assert.Equal(t, scheduleID, result[0].ID, "Should return the correct schedule")
76 |
77 | // Verify all summary fields are present
78 | schedule := result[0]
79 | assert.NotEmpty(t, schedule.Name, "Schedule should have a name")
80 | assert.NotEmpty(t, schedule.Timezone, "Schedule should have a timezone")
81 | assert.NotNil(t, schedule.Shifts, "Schedule should have a shifts field")
82 | })
83 | }
84 | }
85 |
86 | func TestCloudOnCallShift(t *testing.T) {
87 | ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
88 |
89 | // First get a schedule to find a valid shift
90 | schedules, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{})
91 | require.NoError(t, err, "Should not error when listing schedules")
92 | require.NotEmpty(t, schedules, "Should have at least one schedule to test with")
93 | require.NotEmpty(t, schedules[0].Shifts, "Schedule should have at least one shift")
94 |
95 | shifts := schedules[0].Shifts
96 | shiftID := shifts[0]
97 |
98 | // Test getting shift details with valid ID
99 | t.Run("get shift details", func(t *testing.T) {
100 | result, err := getOnCallShift(ctx, GetOnCallShiftParams{
101 | ShiftID: shiftID,
102 | })
103 | require.NoError(t, err, "Should not error when getting shift details")
104 | assert.NotNil(t, result, "Result should not be nil")
105 | assert.Equal(t, shiftID, result.ID, "Should return the correct shift")
106 | })
107 |
108 | t.Run("get shift with invalid ID", func(t *testing.T) {
109 | _, err := getOnCallShift(ctx, GetOnCallShiftParams{
110 | ShiftID: "invalid-shift-id",
111 | })
112 | assert.Error(t, err, "Should error when getting shift with invalid ID")
113 | })
114 | }
115 |
116 | func TestCloudGetCurrentOnCallUsers(t *testing.T) {
117 | ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
118 |
119 | // First get a schedule to use for testing
120 | schedules, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{})
121 | require.NoError(t, err, "Should not error when listing schedules")
122 | require.NotEmpty(t, schedules, "Should have at least one schedule to test with")
123 |
124 | scheduleID := schedules[0].ID
125 |
126 | // Test getting current on-call users
127 | t.Run("get current on-call users", func(t *testing.T) {
128 | result, err := getCurrentOnCallUsers(ctx, GetCurrentOnCallUsersParams{
129 | ScheduleID: scheduleID,
130 | })
131 | require.NoError(t, err, "Should not error when getting current on-call users")
132 | assert.NotNil(t, result, "Result should not be nil")
133 | assert.Equal(t, scheduleID, result.ScheduleID, "Should return the correct schedule")
134 | assert.NotEmpty(t, result.ScheduleName, "Schedule should have a name")
135 | assert.NotNil(t, result.Users, "Users field should be present")
136 |
137 | // Assert that Users is of type []*aapi.User
138 | if len(result.Users) > 0 {
139 | user := result.Users[0]
140 | assert.NotEmpty(t, user.ID, "User should have an ID")
141 | assert.NotEmpty(t, user.Username, "User should have a username")
142 | }
143 | })
144 |
145 | t.Run("get current on-call users with invalid schedule ID", func(t *testing.T) {
146 | _, err := getCurrentOnCallUsers(ctx, GetCurrentOnCallUsersParams{
147 | ScheduleID: "invalid-schedule-id",
148 | })
149 | assert.Error(t, err, "Should error when getting current on-call users with invalid schedule ID")
150 | })
151 | }
152 |
153 | func TestCloudOnCallTeams(t *testing.T) {
154 | ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
155 |
156 | t.Run("list teams", func(t *testing.T) {
157 | result, err := listOnCallTeams(ctx, ListOnCallTeamsParams{})
158 | require.NoError(t, err, "Should not error when listing teams")
159 | assert.NotNil(t, result, "Result should not be nil")
160 |
161 | if len(result) > 0 {
162 | team := result[0]
163 | assert.NotEmpty(t, team.ID, "Team should have an ID")
164 | assert.NotEmpty(t, team.Name, "Team should have a name")
165 | }
166 | })
167 |
168 | // Test pagination
169 | t.Run("list teams with pagination", func(t *testing.T) {
170 | // Get first page
171 | page1, err := listOnCallTeams(ctx, ListOnCallTeamsParams{Page: 1})
172 | require.NoError(t, err, "Should not error when listing teams page 1")
173 | assert.NotNil(t, page1, "Page 1 should not be nil")
174 |
175 | // Get second page
176 | page2, err := listOnCallTeams(ctx, ListOnCallTeamsParams{Page: 2})
177 | require.NoError(t, err, "Should not error when listing teams page 2")
178 | assert.NotNil(t, page2, "Page 2 should not be nil")
179 | })
180 | }
181 |
182 | func TestCloudOnCallUsers(t *testing.T) {
183 | ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
184 |
185 | t.Run("list all users", func(t *testing.T) {
186 | result, err := listOnCallUsers(ctx, ListOnCallUsersParams{})
187 | require.NoError(t, err, "Should not error when listing users")
188 | assert.NotNil(t, result, "Result should not be nil")
189 |
190 | if len(result) > 0 {
191 | user := result[0]
192 | assert.NotEmpty(t, user.ID, "User should have an ID")
193 | assert.NotEmpty(t, user.Username, "User should have a username")
194 | }
195 | })
196 |
197 | // Test pagination
198 | t.Run("list users with pagination", func(t *testing.T) {
199 | // Get first page
200 | page1, err := listOnCallUsers(ctx, ListOnCallUsersParams{Page: 1})
201 | require.NoError(t, err, "Should not error when listing users page 1")
202 | assert.NotNil(t, page1, "Page 1 should not be nil")
203 |
204 | // Get second page
205 | page2, err := listOnCallUsers(ctx, ListOnCallUsersParams{Page: 2})
206 | require.NoError(t, err, "Should not error when listing users page 2")
207 | assert.NotNil(t, page2, "Page 2 should not be nil")
208 | })
209 |
210 | // Get a user ID and username from the list to test filtering
211 | users, err := listOnCallUsers(ctx, ListOnCallUsersParams{})
212 | require.NoError(t, err, "Should not error when listing users")
213 | require.NotEmpty(t, users, "Should have at least one user to test with")
214 |
215 | userID := users[0].ID
216 | username := users[0].Username
217 |
218 | t.Run("get user by ID", func(t *testing.T) {
219 | result, err := listOnCallUsers(ctx, ListOnCallUsersParams{
220 | UserID: userID,
221 | })
222 | require.NoError(t, err, "Should not error when getting user by ID")
223 | assert.NotNil(t, result, "Result should not be nil")
224 | assert.Len(t, result, 1, "Should return exactly one user")
225 | assert.Equal(t, userID, result[0].ID, "Should return the correct user")
226 | assert.NotEmpty(t, result[0].Username, "User should have a username")
227 | })
228 |
229 | t.Run("get user by username", func(t *testing.T) {
230 | result, err := listOnCallUsers(ctx, ListOnCallUsersParams{
231 | Username: username,
232 | })
233 | require.NoError(t, err, "Should not error when getting user by username")
234 | assert.NotNil(t, result, "Result should not be nil")
235 | assert.Len(t, result, 1, "Should return exactly one user")
236 | assert.Equal(t, username, result[0].Username, "Should return the correct user")
237 | assert.NotEmpty(t, result[0].ID, "User should have an ID")
238 | })
239 |
240 | t.Run("get user with invalid ID", func(t *testing.T) {
241 | _, err := listOnCallUsers(ctx, ListOnCallUsersParams{
242 | UserID: "invalid-user-id",
243 | })
244 | assert.Error(t, err, "Should error when getting user with invalid ID")
245 | })
246 |
247 | t.Run("get user with invalid username", func(t *testing.T) {
248 | result, err := listOnCallUsers(ctx, ListOnCallUsersParams{
249 | Username: "invalid-username",
250 | })
251 | require.NoError(t, err, "Should not error when getting user with invalid username")
252 | assert.Empty(t, result, "Should return empty result set for invalid username")
253 | })
254 | }
255 |
256 | func TestCloudGetAlertGroup(t *testing.T) {
257 | ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
258 |
259 | // First, get a list of alert groups to find a valid ID to test with
260 | alertGroups, err := listAlertGroups(ctx, ListAlertGroupsParams{})
261 | require.NoError(t, err, "Should not error when listing alert groups")
262 | require.NotEmpty(t, alertGroups, "Should have at least one alert group to test with")
263 |
264 | alertGroupID := alertGroups[0].ID
265 |
266 | t.Run("get alert group by ID", func(t *testing.T) {
267 | result, err := getAlertGroup(ctx, GetAlertGroupParams{
268 | AlertGroupID: alertGroupID,
269 | })
270 | require.NoError(t, err, "Should not error when getting alert group by ID")
271 | assert.NotNil(t, result, "Result should not be nil")
272 | assert.Equal(t, alertGroupID, result.ID, "Should return the correct alert group")
273 | assert.NotEmpty(t, result.Title, "Alert group should have a title")
274 | assert.NotEmpty(t, result.State, "Alert group should have a state")
275 | })
276 |
277 | t.Run("get alert group with invalid ID", func(t *testing.T) {
278 | _, err := getAlertGroup(ctx, GetAlertGroupParams{
279 | AlertGroupID: "invalid-alert-group-id",
280 | })
281 | assert.Error(t, err, "Should error when getting alert group with invalid ID")
282 | })
283 | }
284 |
```
--------------------------------------------------------------------------------
/proxied_tools.go:
--------------------------------------------------------------------------------
```go
1 | package mcpgrafana
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log/slog"
7 | "net/http"
8 | "strings"
9 | "sync"
10 |
11 | "github.com/go-openapi/runtime"
12 | "github.com/mark3labs/mcp-go/mcp"
13 | "github.com/mark3labs/mcp-go/server"
14 | )
15 |
16 | // MCPDatasourceConfig defines configuration for a datasource type that supports MCP
17 | type MCPDatasourceConfig struct {
18 | Type string
19 | EndpointPath string // e.g., "/api/mcp"
20 | }
21 |
22 | // mcpEnabledDatasources is a registry of datasource types that support MCP
23 | var mcpEnabledDatasources = map[string]MCPDatasourceConfig{
24 | "tempo": {Type: "tempo", EndpointPath: "/api/mcp"},
25 | // Future: add other datasource types here
26 | }
27 |
28 | // DiscoveredDatasource represents a datasource that supports MCP
29 | type DiscoveredDatasource struct {
30 | UID string
31 | Name string
32 | Type string
33 | MCPURL string // The MCP endpoint URL
34 | }
35 |
36 | // discoverMCPDatasources discovers datasources that support MCP
37 | // Returns a list of datasources with MCP endpoints
38 | func discoverMCPDatasources(ctx context.Context) ([]DiscoveredDatasource, error) {
39 | gc := GrafanaClientFromContext(ctx)
40 | if gc == nil {
41 | return nil, fmt.Errorf("grafana client not found in context")
42 | }
43 |
44 | var discovered []DiscoveredDatasource
45 |
46 | // List all datasources
47 | resp, err := gc.Datasources.GetDataSources()
48 | if err != nil {
49 | return nil, fmt.Errorf("failed to list datasources: %w", err)
50 | }
51 |
52 | // Get the Grafana base URL from context
53 | config := GrafanaConfigFromContext(ctx)
54 | if config.URL == "" {
55 | return nil, fmt.Errorf("grafana url not found in context")
56 | }
57 | grafanaBaseURL := config.URL
58 |
59 | // Filter for datasources that support MCP
60 | for _, ds := range resp.Payload {
61 | // Check if this datasource type supports MCP
62 | dsConfig, supported := mcpEnabledDatasources[ds.Type]
63 | if !supported {
64 | continue
65 | }
66 |
67 | // Check if the datasource instance has MCP enabled
68 | // We use a DELETE request to probe the MCP endpoint since:
69 | // - GET would start an event stream and hang
70 | // - POST doesn't work with the Grafana OpenAPI client
71 | // - DELETE returns 200 if MCP is enabled, 404 if not
72 | _, err := gc.Datasources.DatasourceProxyDELETEByUIDcalls(ds.UID, strings.TrimPrefix(dsConfig.EndpointPath, "/"))
73 | if err == nil {
74 | // Something strange happened - the server should never return a 202 for this really. Skip.
75 | continue
76 | }
77 | if apiErr, ok := err.(*runtime.APIError); !ok || (ok && !apiErr.IsCode(http.StatusOK)) {
78 | // Not a 200 response, MCP not enabled
79 | continue
80 | }
81 |
82 | // Build the MCP endpoint URL using Grafana's datasource proxy API
83 | // Format: <grafana URL>/api/datasources/proxy/uid/<uid><endpoint_path>
84 | mcpURL := fmt.Sprintf("%s/api/datasources/proxy/uid/%s%s", grafanaBaseURL, ds.UID, dsConfig.EndpointPath)
85 |
86 | discovered = append(discovered, DiscoveredDatasource{
87 | UID: ds.UID,
88 | Name: ds.Name,
89 | Type: ds.Type,
90 | MCPURL: mcpURL,
91 | })
92 | }
93 |
94 | slog.DebugContext(ctx, "discovered MCP datasources", "count", len(discovered))
95 | return discovered, nil
96 | }
97 |
98 | // addDatasourceUidParameter adds a required datasourceUid parameter to a tool's input schema
99 | func addDatasourceUidParameter(tool mcp.Tool, datasourceType string) mcp.Tool {
100 | modifiedTool := tool
101 | // Prefix tool name with datasource type (e.g., "tempo_traceql-search")
102 | modifiedTool.Name = datasourceType + "_" + tool.Name
103 |
104 | // Add datasourceUid to the input schema
105 | if modifiedTool.InputSchema.Properties == nil {
106 | modifiedTool.InputSchema.Properties = make(map[string]any)
107 | }
108 |
109 | modifiedTool.InputSchema.Properties["datasourceUid"] = map[string]any{
110 | "type": "string",
111 | "description": "UID of the " + datasourceType + " datasource to query",
112 | }
113 |
114 | // Add to required fields
115 | modifiedTool.InputSchema.Required = append(modifiedTool.InputSchema.Required, "datasourceUid")
116 |
117 | return modifiedTool
118 | }
119 |
120 | // parseProxiedToolName extracts datasource type and original tool name from a proxied tool name
121 | // Format: <datasource_type>_<original_tool_name>
122 | // Returns: datasourceType, originalToolName, error
123 | func parseProxiedToolName(toolName string) (string, string, error) {
124 | parts := strings.SplitN(toolName, "_", 2)
125 | if len(parts) != 2 {
126 | return "", "", fmt.Errorf("invalid proxied tool name format: %s", toolName)
127 | }
128 | return parts[0], parts[1], nil
129 | }
130 |
131 | // ToolManager manages proxied tools (either per-session or server-wide)
132 | type ToolManager struct {
133 | sm *SessionManager
134 | server *server.MCPServer
135 |
136 | // Whether to enable proxied tools.
137 | enableProxiedTools bool
138 |
139 | // For stdio transport: store clients at manager level (single-tenant).
140 | // These will be unused for HTTP/SSE transports.
141 | serverMode bool // true if using server-wide tools (stdio), false for per-session (HTTP/SSE)
142 | serverClients map[string]*ProxiedClient
143 | clientsMutex sync.RWMutex
144 | }
145 |
146 | // NewToolManager creates a new ToolManager
147 | func NewToolManager(sm *SessionManager, mcpServer *server.MCPServer, opts ...toolManagerOption) *ToolManager {
148 | tm := &ToolManager{
149 | sm: sm,
150 | server: mcpServer,
151 | serverClients: make(map[string]*ProxiedClient),
152 | }
153 | for _, opt := range opts {
154 | opt(tm)
155 | }
156 | return tm
157 | }
158 |
159 | type toolManagerOption func(*ToolManager)
160 |
161 | // WithProxiedTools sets whether proxied tools are enabled
162 | func WithProxiedTools(enabled bool) toolManagerOption {
163 | return func(tm *ToolManager) {
164 | tm.enableProxiedTools = enabled
165 | }
166 | }
167 |
168 | // InitializeAndRegisterServerTools discovers datasources and registers tools on the server (for stdio transport)
169 | // This should be called once at server startup for single-tenant stdio servers
170 | func (tm *ToolManager) InitializeAndRegisterServerTools(ctx context.Context) error {
171 | if !tm.enableProxiedTools {
172 | return nil
173 | }
174 |
175 | // Mark as server mode (stdio transport)
176 | tm.serverMode = true
177 |
178 | // Discover datasources with MCP support
179 | discovered, err := discoverMCPDatasources(ctx)
180 | if err != nil {
181 | return fmt.Errorf("failed to discover MCP datasources: %w", err)
182 | }
183 |
184 | if len(discovered) == 0 {
185 | slog.Info("no MCP datasources discovered")
186 | return nil
187 | }
188 |
189 | // Connect to each datasource and store in manager
190 | tm.clientsMutex.Lock()
191 | for _, ds := range discovered {
192 | client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
193 | if err != nil {
194 | slog.Error("failed to create proxied client", "datasource", ds.UID, "error", err)
195 | continue
196 | }
197 | key := ds.Type + "_" + ds.UID
198 | tm.serverClients[key] = client
199 | }
200 | clientCount := len(tm.serverClients)
201 | tm.clientsMutex.Unlock()
202 |
203 | if clientCount == 0 {
204 | slog.Warn("no proxied clients created")
205 | return nil
206 | }
207 |
208 | slog.Info("connected to proxied MCP servers", "datasources", clientCount)
209 |
210 | // Collect and register all unique tools
211 | tm.clientsMutex.RLock()
212 | toolMap := make(map[string]mcp.Tool)
213 | for _, client := range tm.serverClients {
214 | for _, tool := range client.ListTools() {
215 | toolName := client.DatasourceType + "_" + tool.Name
216 | if _, exists := toolMap[toolName]; !exists {
217 | modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType)
218 | toolMap[toolName] = modifiedTool
219 | }
220 | }
221 | }
222 | tm.clientsMutex.RUnlock()
223 |
224 | // Register tools on the server (not per-session)
225 | for toolName, tool := range toolMap {
226 | handler := NewProxiedToolHandler(tm.sm, tm, toolName)
227 | tm.server.AddTool(tool, handler.Handle)
228 | }
229 |
230 | slog.Info("registered proxied tools on server", "tools", len(toolMap))
231 | return nil
232 | }
233 |
234 | // InitializeAndRegisterProxiedTools discovers datasources, creates clients, and registers tools per-session
235 | // This should be called in OnBeforeListTools and OnBeforeCallTool hooks for HTTP/SSE transports
236 | func (tm *ToolManager) InitializeAndRegisterProxiedTools(ctx context.Context, session server.ClientSession) {
237 | if !tm.enableProxiedTools {
238 | return
239 | }
240 |
241 | sessionID := session.SessionID()
242 | state, exists := tm.sm.GetSession(sessionID)
243 | if !exists {
244 | // Session exists in server context but not in our SessionManager yet
245 | tm.sm.CreateSession(ctx, session)
246 | state, exists = tm.sm.GetSession(sessionID)
247 | if !exists {
248 | slog.Error("failed to create session in SessionManager", "sessionID", sessionID)
249 | return
250 | }
251 | }
252 |
253 | // Step 1: Discover and connect (guaranteed to run exactly once per session)
254 | state.initOnce.Do(func() {
255 | // Discover datasources with MCP support
256 | discovered, err := discoverMCPDatasources(ctx)
257 | if err != nil {
258 | slog.Error("failed to discover MCP datasources", "error", err)
259 | state.mutex.Lock()
260 | state.proxiedToolsInitialized = true
261 | state.mutex.Unlock()
262 | return
263 | }
264 |
265 | state.mutex.Lock()
266 | // For each discovered datasource, create a proxied client
267 | for _, ds := range discovered {
268 | client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
269 | if err != nil {
270 | slog.Error("failed to create proxied client", "datasource", ds.UID, "error", err)
271 | continue
272 | }
273 |
274 | // Store the client
275 | key := ds.Type + "_" + ds.UID
276 | state.proxiedClients[key] = client
277 | }
278 | state.proxiedToolsInitialized = true
279 | state.mutex.Unlock()
280 |
281 | slog.Info("connected to proxied MCP servers", "session", sessionID, "datasources", len(state.proxiedClients))
282 | })
283 |
284 | // Step 2: Register tools with the MCP server
285 | state.mutex.Lock()
286 | defer state.mutex.Unlock()
287 |
288 | // Check if tools already registered
289 | if len(state.proxiedTools) > 0 {
290 | return
291 | }
292 |
293 | // Check if we have any clients (discovery should have happened above)
294 | if len(state.proxiedClients) == 0 {
295 | return
296 | }
297 |
298 | // First pass: collect all unique tools and track which datasources support them
299 | toolMap := make(map[string]mcp.Tool) // unique tools by name
300 |
301 | for key, client := range state.proxiedClients {
302 | remoteTools := client.ListTools()
303 |
304 | for _, tool := range remoteTools {
305 | // Tool name format: datasourceType_originalToolName (e.g., "tempo_traceql-search")
306 | toolName := client.DatasourceType + "_" + tool.Name
307 |
308 | // Store the tool if we haven't seen it yet
309 | if _, exists := toolMap[toolName]; !exists {
310 | // Add datasourceUid parameter to the tool
311 | modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType)
312 | toolMap[toolName] = modifiedTool
313 | }
314 |
315 | // Track which datasources support this tool
316 | state.toolToDatasources[toolName] = append(state.toolToDatasources[toolName], key)
317 | }
318 | }
319 |
320 | // Second pass: register all unique tools at once (reduces listChanged notifications)
321 | var serverTools []server.ServerTool
322 | for toolName, tool := range toolMap {
323 | handler := NewProxiedToolHandler(tm.sm, tm, toolName)
324 | serverTools = append(serverTools, server.ServerTool{
325 | Tool: tool,
326 | Handler: handler.Handle,
327 | })
328 | state.proxiedTools = append(state.proxiedTools, tool)
329 | }
330 |
331 | if err := tm.server.AddSessionTools(sessionID, serverTools...); err != nil {
332 | slog.Warn("failed to add session tools", "session", sessionID, "error", err)
333 | } else {
334 | slog.Info("registered proxied tools", "session", sessionID, "tools", len(state.proxiedTools))
335 | }
336 | }
337 |
338 | // GetServerClient retrieves a proxied client from server-level storage (for stdio transport)
339 | func (tm *ToolManager) GetServerClient(datasourceType, datasourceUID string) (*ProxiedClient, error) {
340 | tm.clientsMutex.RLock()
341 | defer tm.clientsMutex.RUnlock()
342 |
343 | key := datasourceType + "_" + datasourceUID
344 | client, exists := tm.serverClients[key]
345 | if !exists {
346 | // List available datasources to help with debugging
347 | var availableUIDs []string
348 | for _, c := range tm.serverClients {
349 | if c.DatasourceType == datasourceType {
350 | availableUIDs = append(availableUIDs, c.DatasourceUID)
351 | }
352 | }
353 |
354 | if len(availableUIDs) > 0 {
355 | return nil, fmt.Errorf("datasource '%s' not found. Available %s datasources: %v", datasourceUID, datasourceType, availableUIDs)
356 | }
357 | return nil, fmt.Errorf("datasource '%s' not found. No %s datasources with MCP support are configured", datasourceUID, datasourceType)
358 | }
359 |
360 | return client, nil
361 | }
362 |
```
--------------------------------------------------------------------------------
/proxied_tools_test.go:
--------------------------------------------------------------------------------
```go
1 | package mcpgrafana
2 |
3 | import (
4 | "context"
5 | "sync"
6 | "sync/atomic"
7 | "testing"
8 | "time"
9 |
10 | "github.com/mark3labs/mcp-go/mcp"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestSessionStateRaceConditions(t *testing.T) {
15 | t.Run("concurrent initialization with sync.Once is safe", func(t *testing.T) {
16 | state := newSessionState()
17 |
18 | var initCounter int32
19 | var wg sync.WaitGroup
20 |
21 | // Launch 100 goroutines that all try to initialize at once
22 | const numGoroutines = 100
23 | wg.Add(numGoroutines)
24 |
25 | for i := 0; i < numGoroutines; i++ {
26 | go func() {
27 | defer wg.Done()
28 | state.initOnce.Do(func() {
29 | // Simulate initialization work
30 | atomic.AddInt32(&initCounter, 1)
31 | time.Sleep(10 * time.Millisecond) // Simulate some work
32 | state.mutex.Lock()
33 | state.proxiedToolsInitialized = true
34 | state.mutex.Unlock()
35 | })
36 | }()
37 | }
38 |
39 | wg.Wait()
40 |
41 | // Verify initialization happened exactly once
42 | assert.Equal(t, int32(1), atomic.LoadInt32(&initCounter),
43 | "Initialization should run exactly once despite 100 concurrent calls")
44 | assert.True(t, state.proxiedToolsInitialized)
45 | })
46 |
47 | t.Run("concurrent reads and writes with mutex protection", func(t *testing.T) {
48 | state := newSessionState()
49 | var wg sync.WaitGroup
50 |
51 | // Writer goroutines
52 | for i := 0; i < 10; i++ {
53 | wg.Add(1)
54 | go func(id int) {
55 | defer wg.Done()
56 | state.mutex.Lock()
57 | key := "tempo_" + string(rune('a'+id))
58 | state.proxiedClients[key] = &ProxiedClient{
59 | DatasourceUID: key,
60 | DatasourceName: "Test " + key,
61 | DatasourceType: "tempo",
62 | }
63 | state.mutex.Unlock()
64 | }(i)
65 | }
66 |
67 | // Reader goroutines
68 | for i := 0; i < 10; i++ {
69 | wg.Add(1)
70 | go func() {
71 | defer wg.Done()
72 | state.mutex.RLock()
73 | _ = len(state.proxiedClients)
74 | state.mutex.RUnlock()
75 | }()
76 | }
77 |
78 | wg.Wait()
79 |
80 | // Verify all writes succeeded
81 | state.mutex.RLock()
82 | count := len(state.proxiedClients)
83 | state.mutex.RUnlock()
84 |
85 | assert.Equal(t, 10, count, "All 10 clients should be stored")
86 | })
87 |
88 | t.Run("concurrent tool registration is safe", func(t *testing.T) {
89 | state := newSessionState()
90 | var wg sync.WaitGroup
91 |
92 | // Multiple goroutines trying to register tools
93 | const numGoroutines = 50
94 | wg.Add(numGoroutines)
95 |
96 | for i := 0; i < numGoroutines; i++ {
97 | go func(id int) {
98 | defer wg.Done()
99 | state.mutex.Lock()
100 | toolName := "tempo_tool-" + string(rune('a'+id%26))
101 | if state.toolToDatasources[toolName] == nil {
102 | state.toolToDatasources[toolName] = []string{}
103 | }
104 | state.toolToDatasources[toolName] = append(
105 | state.toolToDatasources[toolName],
106 | "datasource_"+string(rune('a'+id%26)),
107 | )
108 | state.mutex.Unlock()
109 | }(i)
110 | }
111 |
112 | wg.Wait()
113 |
114 | // Verify the tool mappings exist
115 | state.mutex.RLock()
116 | defer state.mutex.RUnlock()
117 | assert.Greater(t, len(state.toolToDatasources), 0, "Should have tool mappings")
118 | })
119 | }
120 |
121 | func TestSessionManagerConcurrency(t *testing.T) {
122 | t.Run("concurrent session creation is safe", func(t *testing.T) {
123 | sm := NewSessionManager()
124 | var wg sync.WaitGroup
125 |
126 | // Create many sessions concurrently
127 | const numSessions = 100
128 | wg.Add(numSessions)
129 |
130 | for i := 0; i < numSessions; i++ {
131 | go func(id int) {
132 | defer wg.Done()
133 | sessionID := "session-" + string(rune('a'+id%26)) + "-" + string(rune('0'+id/26))
134 | mockSession := &mockClientSession{id: sessionID}
135 | sm.CreateSession(context.Background(), mockSession)
136 | }(i)
137 | }
138 |
139 | wg.Wait()
140 |
141 | // Verify all sessions were created
142 | sm.mutex.RLock()
143 | count := len(sm.sessions)
144 | sm.mutex.RUnlock()
145 |
146 | assert.Equal(t, numSessions, count, "All sessions should be created")
147 | })
148 |
149 | t.Run("concurrent get and remove is safe", func(t *testing.T) {
150 | sm := NewSessionManager()
151 |
152 | // Pre-populate sessions
153 | for i := 0; i < 50; i++ {
154 | sessionID := "session-" + string(rune('a'+i%26))
155 | mockSession := &mockClientSession{id: sessionID}
156 | sm.CreateSession(context.Background(), mockSession)
157 | }
158 |
159 | var wg sync.WaitGroup
160 |
161 | // Readers
162 | for i := 0; i < 50; i++ {
163 | wg.Add(1)
164 | go func(id int) {
165 | defer wg.Done()
166 | sessionID := "session-" + string(rune('a'+id%26))
167 | _, _ = sm.GetSession(sessionID)
168 | }(i)
169 | }
170 |
171 | // Writers (removers)
172 | for i := 0; i < 25; i++ {
173 | wg.Add(1)
174 | go func(id int) {
175 | defer wg.Done()
176 | sessionID := "session-" + string(rune('a'+id%26))
177 | mockSession := &mockClientSession{id: sessionID}
178 | sm.RemoveSession(context.Background(), mockSession)
179 | }(i)
180 | }
181 |
182 | wg.Wait()
183 |
184 | // Test passed if no race conditions occurred
185 | })
186 | }
187 |
188 | func TestInitOncePattern(t *testing.T) {
189 | t.Run("verify sync.Once guarantees single execution", func(t *testing.T) {
190 | var once sync.Once
191 | var counter int32
192 | var wg sync.WaitGroup
193 |
194 | // Simulate what happens in InitializeAndRegisterProxiedTools
195 | initFunc := func() {
196 | atomic.AddInt32(&counter, 1)
197 | // Simulate expensive initialization
198 | time.Sleep(50 * time.Millisecond)
199 | }
200 |
201 | // Launch many concurrent calls
202 | for i := 0; i < 1000; i++ {
203 | wg.Add(1)
204 | go func() {
205 | defer wg.Done()
206 | once.Do(initFunc)
207 | }()
208 | }
209 |
210 | wg.Wait()
211 |
212 | assert.Equal(t, int32(1), atomic.LoadInt32(&counter),
213 | "sync.Once should guarantee function runs exactly once")
214 | })
215 |
216 | t.Run("sync.Once with different functions only runs first", func(t *testing.T) {
217 | var once sync.Once
218 | var result string
219 | var mu sync.Mutex
220 |
221 | once.Do(func() {
222 | mu.Lock()
223 | result = "first"
224 | mu.Unlock()
225 | })
226 |
227 | once.Do(func() {
228 | mu.Lock()
229 | result = "second"
230 | mu.Unlock()
231 | })
232 |
233 | mu.Lock()
234 | finalResult := result
235 | mu.Unlock()
236 |
237 | assert.Equal(t, "first", finalResult, "Only first function should execute")
238 | })
239 | }
240 |
241 | func TestProxiedToolsInitializationFlow(t *testing.T) {
242 | t.Run("initialization state transitions are correct", func(t *testing.T) {
243 | state := newSessionState()
244 |
245 | // Initial state
246 | assert.False(t, state.proxiedToolsInitialized)
247 | assert.Empty(t, state.proxiedClients)
248 | assert.Empty(t, state.proxiedTools)
249 |
250 | // Simulate initialization
251 | state.initOnce.Do(func() {
252 | state.mutex.Lock()
253 | state.proxiedToolsInitialized = true
254 | state.proxiedClients["tempo_test"] = &ProxiedClient{
255 | DatasourceUID: "test",
256 | DatasourceName: "Test",
257 | DatasourceType: "tempo",
258 | }
259 | state.mutex.Unlock()
260 | })
261 |
262 | // Verify state after initialization
263 | state.mutex.RLock()
264 | initialized := state.proxiedToolsInitialized
265 | clientCount := len(state.proxiedClients)
266 | state.mutex.RUnlock()
267 |
268 | assert.True(t, initialized)
269 | assert.Equal(t, 1, clientCount)
270 | })
271 |
272 | t.Run("multiple sessions maintain separate state", func(t *testing.T) {
273 | sm := NewSessionManager()
274 |
275 | // Create two sessions
276 | session1 := &mockClientSession{id: "session-1"}
277 | session2 := &mockClientSession{id: "session-2"}
278 |
279 | sm.CreateSession(context.Background(), session1)
280 | sm.CreateSession(context.Background(), session2)
281 |
282 | state1, _ := sm.GetSession("session-1")
283 | state2, _ := sm.GetSession("session-2")
284 |
285 | // Initialize only session1
286 | state1.initOnce.Do(func() {
287 | state1.mutex.Lock()
288 | state1.proxiedToolsInitialized = true
289 | state1.mutex.Unlock()
290 | })
291 |
292 | // Verify states are independent
293 | assert.True(t, state1.proxiedToolsInitialized)
294 | assert.False(t, state2.proxiedToolsInitialized)
295 | assert.NotSame(t, state1, state2)
296 | })
297 | }
298 |
299 | func TestRaceConditionDemonstration(t *testing.T) {
300 | t.Run("old pattern WITHOUT sync.Once would have race condition", func(t *testing.T) {
301 | // This test demonstrates what WOULD happen with the old mutex-check pattern
302 | state := newSessionState()
303 |
304 | var discoveryCallCount int32
305 | var wg sync.WaitGroup
306 |
307 | // Simulate the OLD pattern (mutex check, unlock, then do work)
308 | oldPatternInitialize := func() {
309 | state.mutex.Lock()
310 | // Check if already initialized
311 | if state.proxiedToolsInitialized {
312 | state.mutex.Unlock()
313 | return
314 | }
315 | alreadyDiscovered := state.proxiedToolsInitialized
316 | state.mutex.Unlock() // ❌ OLD PATTERN: Unlock before expensive work
317 |
318 | if !alreadyDiscovered {
319 | // Simulate discovery work that should only happen once
320 | atomic.AddInt32(&discoveryCallCount, 1)
321 | time.Sleep(10 * time.Millisecond) // Simulate expensive operation
322 |
323 | state.mutex.Lock()
324 | state.proxiedToolsInitialized = true
325 | state.mutex.Unlock()
326 | }
327 | }
328 |
329 | // Launch concurrent initializations
330 | const numGoroutines = 10
331 | wg.Add(numGoroutines)
332 | for i := 0; i < numGoroutines; i++ {
333 | go func() {
334 | defer wg.Done()
335 | oldPatternInitialize()
336 | }()
337 | }
338 | wg.Wait()
339 |
340 | // With the old pattern, multiple goroutines can get past the check
341 | // and call discovery multiple times
342 | count := atomic.LoadInt32(&discoveryCallCount)
343 | if count > 1 {
344 | t.Logf("OLD PATTERN: Discovery called %d times (race condition!)", count)
345 | }
346 | // We can't assert > 1 reliably because timing matters, but this demonstrates the problem
347 | })
348 |
349 | t.Run("new pattern WITH sync.Once prevents race condition", func(t *testing.T) {
350 | // This test demonstrates the NEW pattern with sync.Once
351 | state := newSessionState()
352 |
353 | var discoveryCallCount int32
354 | var wg sync.WaitGroup
355 |
356 | // NEW pattern: sync.Once guarantees single execution
357 | newPatternInitialize := func() {
358 | state.initOnce.Do(func() {
359 | // Simulate discovery work that should only happen once
360 | atomic.AddInt32(&discoveryCallCount, 1)
361 | time.Sleep(10 * time.Millisecond) // Simulate expensive operation
362 |
363 | state.mutex.Lock()
364 | state.proxiedToolsInitialized = true
365 | state.mutex.Unlock()
366 | })
367 | }
368 |
369 | // Launch concurrent initializations
370 | const numGoroutines = 10
371 | wg.Add(numGoroutines)
372 | for i := 0; i < numGoroutines; i++ {
373 | go func() {
374 | defer wg.Done()
375 | newPatternInitialize()
376 | }()
377 | }
378 | wg.Wait()
379 |
380 | // With sync.Once, discovery is guaranteed to run exactly once
381 | count := atomic.LoadInt32(&discoveryCallCount)
382 | assert.Equal(t, int32(1), count, "NEW PATTERN: Discovery must be called exactly once")
383 | })
384 | }
385 |
386 | func TestRaceDetector(t *testing.T) {
387 | // This test is primarily valuable when run with -race flag
388 | t.Run("stress test with race detector", func(t *testing.T) {
389 |
390 | sm := NewSessionManager()
391 | var wg sync.WaitGroup
392 |
393 | // Create a mix of operations happening concurrently
394 | for i := 0; i < 20; i++ {
395 | sessionID := "stress-session-" + string(rune('a'+i%10))
396 |
397 | // Create session
398 | wg.Add(1)
399 | go func(sid string) {
400 | defer wg.Done()
401 | mockSession := &mockClientSession{id: sid}
402 | sm.CreateSession(context.Background(), mockSession)
403 | }(sessionID)
404 |
405 | // Initialize session state
406 | wg.Add(1)
407 | go func(sid string) {
408 | defer wg.Done()
409 | time.Sleep(time.Millisecond) // Let creation happen first
410 | state, exists := sm.GetSession(sid)
411 | if exists {
412 | state.initOnce.Do(func() {
413 | state.mutex.Lock()
414 | state.proxiedToolsInitialized = true
415 | state.mutex.Unlock()
416 | })
417 | }
418 | }(sessionID)
419 |
420 | // Read session state
421 | wg.Add(1)
422 | go func(sid string) {
423 | defer wg.Done()
424 | time.Sleep(2 * time.Millisecond)
425 | state, exists := sm.GetSession(sid)
426 | if exists {
427 | state.mutex.RLock()
428 | _ = state.proxiedToolsInitialized
429 | state.mutex.RUnlock()
430 | }
431 | }(sessionID)
432 | }
433 |
434 | wg.Wait()
435 |
436 | // If we get here without race detector warnings, we're good
437 | t.Log("Stress test completed without race conditions")
438 | })
439 | }
440 |
441 | // mockClientSession implements server.ClientSession for testing
442 | type mockClientSession struct {
443 | id string
444 | notifChannel chan mcp.JSONRPCNotification
445 | isInitialized bool
446 | }
447 |
448 | func (m *mockClientSession) SessionID() string {
449 | return m.id
450 | }
451 |
452 | func (m *mockClientSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
453 | if m.notifChannel == nil {
454 | m.notifChannel = make(chan mcp.JSONRPCNotification, 10)
455 | }
456 | return m.notifChannel
457 | }
458 |
459 | func (m *mockClientSession) Initialize() {
460 | m.isInitialized = true
461 | }
462 |
463 | func (m *mockClientSession) Initialized() bool {
464 | return m.isInitialized
465 | }
466 |
```
--------------------------------------------------------------------------------
/tools/dashboard_test.go:
--------------------------------------------------------------------------------
```go
1 | // Requires a Grafana instance running on localhost:3000,
2 | // with a dashboard provisioned.
3 | // Run with `go test -tags integration`.
4 | //go:build integration
5 |
6 | package tools
7 |
8 | import (
9 | "context"
10 | "testing"
11 |
12 | "github.com/grafana/grafana-openapi-client-go/models"
13 | "github.com/stretchr/testify/assert"
14 | "github.com/stretchr/testify/require"
15 | )
16 |
17 | const (
18 | newTestDashboardName = "Integration Test"
19 | )
20 |
21 | // getExistingDashboardUID will fetch an existing dashboard for test purposes
22 | // It will search for exisiting dashboards and return the first, otherwise
23 | // will trigger a test error
24 | func getExistingTestDashboard(t *testing.T, ctx context.Context, dashboardName string) *models.Hit {
25 | // Make sure we query for the existing dashboard, not a folder
26 | if dashboardName == "" {
27 | dashboardName = "Demo"
28 | }
29 | searchResults, err := searchDashboards(ctx, SearchDashboardsParams{
30 | Query: dashboardName,
31 | })
32 | require.NoError(t, err)
33 | require.Greater(t, len(searchResults), 0, "No dashboards found")
34 | return searchResults[0]
35 | }
36 |
37 | // getExistingTestDashboardJSON will fetch the JSON map for an existing
38 | // dashboard in the test environment
39 | func getTestDashboardJSON(t *testing.T, ctx context.Context, dashboard *models.Hit) map[string]interface{} {
40 | result, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
41 | UID: dashboard.UID,
42 | })
43 | require.NoError(t, err)
44 | dashboardMap, ok := result.Dashboard.(map[string]interface{})
45 | require.True(t, ok, "Dashboard should be a map")
46 | return dashboardMap
47 | }
48 |
49 | func TestDashboardTools(t *testing.T) {
50 | t.Run("get dashboard by uid", func(t *testing.T) {
51 | ctx := newTestContext()
52 |
53 | // First, let's search for a dashboard to get its UID
54 | dashboard := getExistingTestDashboard(t, ctx, "")
55 |
56 | // Now test the get dashboard by uid functionality
57 | result, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
58 | UID: dashboard.UID,
59 | })
60 | require.NoError(t, err)
61 | dashboardMap, ok := result.Dashboard.(map[string]interface{})
62 | require.True(t, ok, "Dashboard should be a map")
63 | assert.Equal(t, dashboard.UID, dashboardMap["uid"])
64 | assert.NotNil(t, result.Meta)
65 | })
66 |
67 | t.Run("get dashboard by uid - invalid uid", func(t *testing.T) {
68 | ctx := newTestContext()
69 |
70 | _, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
71 | UID: "non-existent-uid",
72 | })
73 | require.Error(t, err)
74 | })
75 |
76 | t.Run("update dashboard - create new", func(t *testing.T) {
77 | ctx := newTestContext()
78 |
79 | // Get the dashboard JSON
80 | // In this case, we will create a new dashboard with the same
81 | // content but different Title, and disable "overwrite"
82 | dashboard := getExistingTestDashboard(t, ctx, "")
83 | dashboardMap := getTestDashboardJSON(t, ctx, dashboard)
84 |
85 | // Avoid a clash by unsetting the existing IDs
86 | delete(dashboardMap, "uid")
87 | delete(dashboardMap, "id")
88 |
89 | // Set a new title and tag
90 | dashboardMap["title"] = newTestDashboardName
91 | dashboardMap["tags"] = []string{"integration-test"}
92 |
93 | params := UpdateDashboardParams{
94 | Dashboard: dashboardMap,
95 | Message: "creating a new dashboard",
96 | Overwrite: false,
97 | UserID: 1,
98 | }
99 |
100 | // Only pass in the Folder UID if it exists
101 | if dashboard.FolderUID != "" {
102 | params.FolderUID = dashboard.FolderUID
103 | }
104 |
105 | // create the dashboard
106 | _, err := updateDashboard(ctx, params)
107 | require.NoError(t, err)
108 | })
109 |
110 | t.Run("update dashboard - overwrite existing", func(t *testing.T) {
111 | ctx := newTestContext()
112 |
113 | // Get the dashboard JSON for the non-provisioned dashboard we've created
114 | dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
115 | dashboardMap := getTestDashboardJSON(t, ctx, dashboard)
116 |
117 | params := UpdateDashboardParams{
118 | Dashboard: dashboardMap,
119 | Message: "updating existing dashboard",
120 | Overwrite: true,
121 | UserID: 1,
122 | }
123 |
124 | // Only pass in the Folder UID if it exists
125 | if dashboard.FolderUID != "" {
126 | params.FolderUID = dashboard.FolderUID
127 | }
128 |
129 | // update the dashboard
130 | _, err := updateDashboard(ctx, params)
131 | require.NoError(t, err)
132 | })
133 |
134 | t.Run("get dashboard panel queries", func(t *testing.T) {
135 | ctx := newTestContext()
136 |
137 | // Get the test dashboard
138 | dashboard := getExistingTestDashboard(t, ctx, "")
139 |
140 | result, err := GetDashboardPanelQueriesTool(ctx, DashboardPanelQueriesParams{
141 | UID: dashboard.UID,
142 | })
143 | require.NoError(t, err)
144 | assert.Greater(t, len(result), 0, "Should return at least one panel query")
145 |
146 | // The initial demo dashboard plus for all dashboards created by the integration tests,
147 | // every panel should have identical title and query values.
148 | // Datasource UID may differ. Datasource type can be an empty string as well but on the demo and test dashboards it should be "prometheus".
149 | for _, panelQuery := range result {
150 | assert.Equal(t, panelQuery.Title, "Node Load")
151 | assert.Equal(t, panelQuery.Query, "node_load1")
152 | assert.NotEmpty(t, panelQuery.Datasource.UID)
153 | assert.Equal(t, panelQuery.Datasource.Type, "prometheus")
154 | }
155 | })
156 |
157 | // Tests for new Issue #101 context window management tools
158 | t.Run("get dashboard summary", func(t *testing.T) {
159 | ctx := newTestContext()
160 |
161 | // Get the test dashboard
162 | dashboard := getExistingTestDashboard(t, ctx, "")
163 |
164 | result, err := getDashboardSummary(ctx, GetDashboardSummaryParams{
165 | UID: dashboard.UID,
166 | })
167 | require.NoError(t, err)
168 |
169 | assert.Equal(t, dashboard.UID, result.UID)
170 | assert.NotEmpty(t, result.Title)
171 | assert.Greater(t, result.PanelCount, 0, "Should have at least one panel")
172 | assert.Len(t, result.Panels, result.PanelCount, "Panel count should match panels array length")
173 | assert.NotNil(t, result.Meta)
174 |
175 | // Check that panels have expected structure
176 | for _, panel := range result.Panels {
177 | assert.NotEmpty(t, panel.Title)
178 | assert.NotEmpty(t, panel.Type)
179 | assert.GreaterOrEqual(t, panel.QueryCount, 0)
180 | }
181 | })
182 |
183 | t.Run("get dashboard property - title", func(t *testing.T) {
184 | ctx := newTestContext()
185 |
186 | dashboard := getExistingTestDashboard(t, ctx, "")
187 |
188 | result, err := getDashboardProperty(ctx, GetDashboardPropertyParams{
189 | UID: dashboard.UID,
190 | JSONPath: "$.title",
191 | })
192 | require.NoError(t, err)
193 |
194 | title, ok := result.(string)
195 | require.True(t, ok, "Title should be a string")
196 | assert.NotEmpty(t, title)
197 | })
198 |
199 | t.Run("get dashboard property - panel titles", func(t *testing.T) {
200 | ctx := newTestContext()
201 |
202 | dashboard := getExistingTestDashboard(t, ctx, "")
203 |
204 | result, err := getDashboardProperty(ctx, GetDashboardPropertyParams{
205 | UID: dashboard.UID,
206 | JSONPath: "$.panels[*].title",
207 | })
208 | require.NoError(t, err)
209 |
210 | titles, ok := result.([]interface{})
211 | require.True(t, ok, "Panel titles should be an array")
212 | assert.Greater(t, len(titles), 0, "Should have at least one panel title")
213 |
214 | for _, title := range titles {
215 | titleStr, ok := title.(string)
216 | require.True(t, ok, "Each title should be a string")
217 | assert.NotEmpty(t, titleStr)
218 | }
219 | })
220 |
221 | t.Run("get dashboard property - invalid path", func(t *testing.T) {
222 | ctx := newTestContext()
223 |
224 | dashboard := getExistingTestDashboard(t, ctx, "")
225 |
226 | _, err := getDashboardProperty(ctx, GetDashboardPropertyParams{
227 | UID: dashboard.UID,
228 | JSONPath: "$.nonexistent.path",
229 | })
230 | require.Error(t, err, "Should fail for non-existent path")
231 | })
232 |
233 | t.Run("update dashboard - patch title", func(t *testing.T) {
234 | ctx := newTestContext()
235 |
236 | // Get our test dashboard (not the provisioned one)
237 | dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
238 |
239 | newTitle := "Updated Integration Test Dashboard"
240 |
241 | result, err := updateDashboard(ctx, UpdateDashboardParams{
242 | UID: dashboard.UID,
243 | Operations: []PatchOperation{
244 | {
245 | Op: "replace",
246 | Path: "$.title",
247 | Value: newTitle,
248 | },
249 | },
250 | Message: "Updated title via patch",
251 | })
252 | require.NoError(t, err)
253 | assert.NotNil(t, result)
254 |
255 | // Verify the change was applied
256 | updatedDashboard, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
257 | UID: dashboard.UID,
258 | })
259 | require.NoError(t, err)
260 |
261 | dashboardMap, ok := updatedDashboard.Dashboard.(map[string]interface{})
262 | require.True(t, ok, "Dashboard should be a map")
263 | assert.Equal(t, newTitle, dashboardMap["title"])
264 | })
265 |
266 | t.Run("update dashboard - patch add description", func(t *testing.T) {
267 | ctx := newTestContext()
268 |
269 | dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
270 |
271 | description := "This is a test description added via patch"
272 |
273 | _, err := updateDashboard(ctx, UpdateDashboardParams{
274 | UID: dashboard.UID,
275 | Operations: []PatchOperation{
276 | {
277 | Op: "add",
278 | Path: "$.description",
279 | Value: description,
280 | },
281 | },
282 | Message: "Added description via patch",
283 | })
284 | require.NoError(t, err)
285 |
286 | // Verify the description was added
287 | updatedDashboard, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
288 | UID: dashboard.UID,
289 | })
290 | require.NoError(t, err)
291 |
292 | dashboardMap, ok := updatedDashboard.Dashboard.(map[string]interface{})
293 | require.True(t, ok, "Dashboard should be a map")
294 | assert.Equal(t, description, dashboardMap["description"])
295 | })
296 |
297 | t.Run("update dashboard - patch remove description", func(t *testing.T) {
298 | ctx := newTestContext()
299 |
300 | dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
301 |
302 | _, err := updateDashboard(ctx, UpdateDashboardParams{
303 | UID: dashboard.UID,
304 | Operations: []PatchOperation{
305 | {
306 | Op: "remove",
307 | Path: "$.description",
308 | },
309 | },
310 | Message: "Removed description via patch",
311 | })
312 | require.NoError(t, err)
313 |
314 | // Verify the description was removed
315 | updatedDashboard, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
316 | UID: dashboard.UID,
317 | })
318 | require.NoError(t, err)
319 |
320 | dashboardMap, ok := updatedDashboard.Dashboard.(map[string]interface{})
321 | require.True(t, ok, "Dashboard should be a map")
322 | _, hasDescription := dashboardMap["description"]
323 | assert.False(t, hasDescription, "Description should be removed")
324 | })
325 |
326 | t.Run("update dashboard - unsupported operation", func(t *testing.T) {
327 | ctx := newTestContext()
328 |
329 | dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
330 |
331 | _, err := updateDashboard(ctx, UpdateDashboardParams{
332 | UID: dashboard.UID,
333 | Operations: []PatchOperation{
334 | {
335 | Op: "copy", // Unsupported operation
336 | Path: "$.title",
337 | Value: "New Title",
338 | },
339 | },
340 | })
341 | require.Error(t, err, "Should fail for unsupported operation")
342 | })
343 |
344 | t.Run("update dashboard - invalid parameters", func(t *testing.T) {
345 | ctx := newTestContext()
346 |
347 | _, err := updateDashboard(ctx, UpdateDashboardParams{
348 | // Neither dashboard nor (uid + operations) provided
349 | })
350 | require.Error(t, err, "Should fail when no valid parameters provided")
351 | })
352 |
353 | t.Run("update dashboard - append to panels array", func(t *testing.T) {
354 | ctx := newTestContext()
355 |
356 | // Get our test dashboard
357 | dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
358 |
359 | // Create a new panel to append
360 | newPanel := map[string]interface{}{
361 | "id": 999,
362 | "title": "New Appended Panel",
363 | "type": "stat",
364 | "targets": []interface{}{
365 | map[string]interface{}{
366 | "expr": "up",
367 | },
368 | },
369 | "gridPos": map[string]interface{}{
370 | "h": 8,
371 | "w": 12,
372 | "x": 0,
373 | "y": 8,
374 | },
375 | }
376 |
377 | _, err := updateDashboard(ctx, UpdateDashboardParams{
378 | UID: dashboard.UID,
379 | Operations: []PatchOperation{
380 | {
381 | Op: "add",
382 | Path: "$.panels/-",
383 | Value: newPanel,
384 | },
385 | },
386 | Message: "Appended new panel via /- syntax",
387 | })
388 | require.NoError(t, err)
389 |
390 | // Verify the panel was appended
391 | updatedDashboard, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
392 | UID: dashboard.UID,
393 | })
394 | require.NoError(t, err)
395 |
396 | dashboardMap, ok := updatedDashboard.Dashboard.(map[string]interface{})
397 | require.True(t, ok, "Dashboard should be a map")
398 |
399 | panels, ok := dashboardMap["panels"].([]interface{})
400 | require.True(t, ok, "Panels should be an array")
401 |
402 | // Check that the new panel was appended (should be the last panel)
403 | lastPanel, ok := panels[len(panels)-1].(map[string]interface{})
404 | require.True(t, ok, "Last panel should be an object")
405 | assert.Equal(t, "New Appended Panel", lastPanel["title"])
406 | assert.Equal(t, float64(999), lastPanel["id"]) // JSON unmarshaling converts to float64
407 | })
408 |
409 | t.Run("update dashboard - remove with append syntax should fail", func(t *testing.T) {
410 | ctx := newTestContext()
411 |
412 | dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
413 |
414 | _, err := updateDashboard(ctx, UpdateDashboardParams{
415 | UID: dashboard.UID,
416 | Operations: []PatchOperation{
417 | {
418 | Op: "remove",
419 | Path: "$.panels/-", // Invalid: remove with append syntax
420 | },
421 | },
422 | })
423 | require.Error(t, err, "Should fail when using remove operation with append syntax")
424 | })
425 |
426 | t.Run("update dashboard - append to non-array should fail", func(t *testing.T) {
427 | ctx := newTestContext()
428 |
429 | dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
430 |
431 | _, err := updateDashboard(ctx, UpdateDashboardParams{
432 | UID: dashboard.UID,
433 | Operations: []PatchOperation{
434 | {
435 | Op: "add",
436 | Path: "$.title/-", // Invalid: title is not an array
437 | Value: "Invalid",
438 | },
439 | },
440 | })
441 | require.Error(t, err, "Should fail when trying to append to non-array field")
442 | })
443 | }
444 |
```
--------------------------------------------------------------------------------
/tools_test.go:
--------------------------------------------------------------------------------
```go
1 | //go:build unit
2 | // +build unit
3 |
4 | package mcpgrafana
5 |
6 | import (
7 | "context"
8 | "errors"
9 | "testing"
10 |
11 | "github.com/mark3labs/mcp-go/mcp"
12 | "github.com/stretchr/testify/assert"
13 | "github.com/stretchr/testify/require"
14 | )
15 |
16 | type testToolParams struct {
17 | Name string `json:"name" jsonschema:"required,description=The name parameter"`
18 | Value int `json:"value" jsonschema:"required,description=The value parameter"`
19 | Optional bool `json:"optional,omitempty" jsonschema:"description=An optional parameter"`
20 | }
21 |
22 | func testToolHandler(ctx context.Context, params testToolParams) (*mcp.CallToolResult, error) {
23 | if params.Name == "error" {
24 | return nil, errors.New("test error")
25 | }
26 | return mcp.NewToolResultText(params.Name + ": " + string(rune(params.Value))), nil
27 | }
28 |
29 | type emptyToolParams struct{}
30 |
31 | func emptyToolHandler(ctx context.Context, params emptyToolParams) (*mcp.CallToolResult, error) {
32 | return mcp.NewToolResultText("empty"), nil
33 | }
34 |
35 | // New handlers for different return types
36 | func stringToolHandler(ctx context.Context, params testToolParams) (string, error) {
37 | if params.Name == "error" {
38 | return "", errors.New("test error")
39 | }
40 | if params.Name == "empty" {
41 | return "", nil
42 | }
43 | return params.Name + ": " + string(rune(params.Value)), nil
44 | }
45 |
46 | func stringPtrToolHandler(ctx context.Context, params testToolParams) (*string, error) {
47 | if params.Name == "error" {
48 | return nil, errors.New("test error")
49 | }
50 | if params.Name == "nil" {
51 | return nil, nil
52 | }
53 | if params.Name == "empty" {
54 | empty := ""
55 | return &empty, nil
56 | }
57 | result := params.Name + ": " + string(rune(params.Value))
58 | return &result, nil
59 | }
60 |
61 | type TestResult struct {
62 | Name string `json:"name"`
63 | Value int `json:"value"`
64 | }
65 |
66 | func structToolHandler(ctx context.Context, params testToolParams) (TestResult, error) {
67 | if params.Name == "error" {
68 | return TestResult{}, errors.New("test error")
69 | }
70 | return TestResult{
71 | Name: params.Name,
72 | Value: params.Value,
73 | }, nil
74 | }
75 |
76 | func structPtrToolHandler(ctx context.Context, params testToolParams) (*TestResult, error) {
77 | if params.Name == "error" {
78 | return nil, errors.New("test error")
79 | }
80 | if params.Name == "nil" {
81 | return nil, nil
82 | }
83 | return &TestResult{
84 | Name: params.Name,
85 | Value: params.Value,
86 | }, nil
87 | }
88 |
89 | func TestConvertTool(t *testing.T) {
90 | t.Run("valid handler conversion", func(t *testing.T) {
91 | tool, handler, err := ConvertTool("test_tool", "A test tool", testToolHandler)
92 |
93 | require.NoError(t, err)
94 | require.NotNil(t, tool)
95 | require.NotNil(t, handler)
96 |
97 | // Check tool properties
98 | assert.Equal(t, "test_tool", tool.Name)
99 | assert.Equal(t, "A test tool", tool.Description)
100 |
101 | // Check schema properties
102 | assert.Equal(t, "object", tool.InputSchema.Type)
103 | assert.Contains(t, tool.InputSchema.Properties, "name")
104 | assert.Contains(t, tool.InputSchema.Properties, "value")
105 | assert.Contains(t, tool.InputSchema.Properties, "optional")
106 |
107 | // Test handler execution
108 | ctx := context.Background()
109 | request := mcp.CallToolRequest{
110 | Params: struct {
111 | Name string `json:"name"`
112 | Arguments any `json:"arguments,omitempty"`
113 | Meta *mcp.Meta `json:"_meta,omitempty"`
114 | }{
115 | Name: "test_tool",
116 | Arguments: map[string]any{
117 | "name": "test",
118 | "value": 65, // ASCII 'A'
119 | },
120 | },
121 | }
122 |
123 | result, err := handler(ctx, request)
124 | require.NoError(t, err)
125 | require.Len(t, result.Content, 1)
126 | resultString, ok := result.Content[0].(mcp.TextContent)
127 | require.True(t, ok)
128 | assert.Equal(t, "test: A", resultString.Text)
129 |
130 | // Test error handling
131 | errorRequest := mcp.CallToolRequest{
132 | Params: struct {
133 | Name string `json:"name"`
134 | Arguments any `json:"arguments,omitempty"`
135 | Meta *mcp.Meta `json:"_meta,omitempty"`
136 | }{
137 | Name: "test_tool",
138 | Arguments: map[string]any{
139 | "name": "error",
140 | "value": 66,
141 | },
142 | },
143 | }
144 |
145 | _, err = handler(ctx, errorRequest)
146 | assert.Error(t, err)
147 | assert.Equal(t, "test error", err.Error())
148 | })
149 |
150 | t.Run("empty handler params", func(t *testing.T) {
151 | tool, handler, err := ConvertTool("empty", "description", emptyToolHandler)
152 |
153 | require.NoError(t, err)
154 | require.NotNil(t, tool)
155 | require.NotNil(t, handler)
156 |
157 | // Check tool properties
158 | assert.Equal(t, "empty", tool.Name)
159 | assert.Equal(t, "description", tool.Description)
160 |
161 | // Check schema properties
162 | assert.Equal(t, "object", tool.InputSchema.Type)
163 | assert.Len(t, tool.InputSchema.Properties, 0)
164 |
165 | // Test handler execution
166 | ctx := context.Background()
167 | request := mcp.CallToolRequest{
168 | Params: struct {
169 | Name string `json:"name"`
170 | Arguments any `json:"arguments,omitempty"`
171 | Meta *mcp.Meta `json:"_meta,omitempty"`
172 | }{
173 | Name: "empty",
174 | },
175 | }
176 | result, err := handler(ctx, request)
177 | require.NoError(t, err)
178 | require.Len(t, result.Content, 1)
179 | resultString, ok := result.Content[0].(mcp.TextContent)
180 | require.True(t, ok)
181 | assert.Equal(t, "empty", resultString.Text)
182 | })
183 |
184 | t.Run("string return type", func(t *testing.T) {
185 | _, handler, err := ConvertTool("string_tool", "A string tool", stringToolHandler)
186 | require.NoError(t, err)
187 |
188 | // Test normal string return
189 | ctx := context.Background()
190 | request := mcp.CallToolRequest{
191 | Params: struct {
192 | Name string `json:"name"`
193 | Arguments any `json:"arguments,omitempty"`
194 | Meta *mcp.Meta `json:"_meta,omitempty"`
195 | }{
196 | Name: "string_tool",
197 | Arguments: map[string]any{
198 | "name": "test",
199 | "value": 65, // ASCII 'A'
200 | },
201 | },
202 | }
203 |
204 | result, err := handler(ctx, request)
205 | require.NoError(t, err)
206 | require.NotNil(t, result)
207 | require.Len(t, result.Content, 1)
208 | resultString, ok := result.Content[0].(mcp.TextContent)
209 | require.True(t, ok)
210 | assert.Equal(t, "test: A", resultString.Text)
211 |
212 | // Test empty string return
213 | emptyRequest := mcp.CallToolRequest{
214 | Params: struct {
215 | Name string `json:"name"`
216 | Arguments any `json:"arguments,omitempty"`
217 | Meta *mcp.Meta `json:"_meta,omitempty"`
218 | }{
219 | Name: "string_tool",
220 | Arguments: map[string]any{
221 | "name": "empty",
222 | "value": 65,
223 | },
224 | },
225 | }
226 |
227 | result, err = handler(ctx, emptyRequest)
228 | require.NoError(t, err)
229 | assert.Nil(t, result)
230 |
231 | // Test error return
232 | errorRequest := mcp.CallToolRequest{
233 | Params: struct {
234 | Name string `json:"name"`
235 | Arguments any `json:"arguments,omitempty"`
236 | Meta *mcp.Meta `json:"_meta,omitempty"`
237 | }{
238 | Name: "string_tool",
239 | Arguments: map[string]any{
240 | "name": "error",
241 | "value": 65,
242 | },
243 | },
244 | }
245 |
246 | _, err = handler(ctx, errorRequest)
247 | assert.Error(t, err)
248 | assert.Equal(t, "test error", err.Error())
249 | })
250 |
251 | t.Run("string pointer return type", func(t *testing.T) {
252 | _, handler, err := ConvertTool("string_ptr_tool", "A string pointer tool", stringPtrToolHandler)
253 | require.NoError(t, err)
254 |
255 | // Test normal string pointer return
256 | ctx := context.Background()
257 | request := mcp.CallToolRequest{
258 | Params: struct {
259 | Name string `json:"name"`
260 | Arguments any `json:"arguments,omitempty"`
261 | Meta *mcp.Meta `json:"_meta,omitempty"`
262 | }{
263 | Name: "string_ptr_tool",
264 | Arguments: map[string]any{
265 | "name": "test",
266 | "value": 65, // ASCII 'A'
267 | },
268 | },
269 | }
270 |
271 | result, err := handler(ctx, request)
272 | require.NoError(t, err)
273 | require.NotNil(t, result)
274 | require.Len(t, result.Content, 1)
275 | resultString, ok := result.Content[0].(mcp.TextContent)
276 | require.True(t, ok)
277 | assert.Equal(t, "test: A", resultString.Text)
278 |
279 | // Test nil string pointer return
280 | nilRequest := mcp.CallToolRequest{
281 | Params: struct {
282 | Name string `json:"name"`
283 | Arguments any `json:"arguments,omitempty"`
284 | Meta *mcp.Meta `json:"_meta,omitempty"`
285 | }{
286 | Name: "string_ptr_tool",
287 | Arguments: map[string]any{
288 | "name": "nil",
289 | "value": 65,
290 | },
291 | },
292 | }
293 |
294 | result, err = handler(ctx, nilRequest)
295 | require.NoError(t, err)
296 | assert.Nil(t, result)
297 |
298 | // Test empty string pointer return
299 | emptyRequest := mcp.CallToolRequest{
300 | Params: struct {
301 | Name string `json:"name"`
302 | Arguments any `json:"arguments,omitempty"`
303 | Meta *mcp.Meta `json:"_meta,omitempty"`
304 | }{
305 | Name: "string_ptr_tool",
306 | Arguments: map[string]any{
307 | "name": "empty",
308 | "value": 65,
309 | },
310 | },
311 | }
312 |
313 | result, err = handler(ctx, emptyRequest)
314 | require.NoError(t, err)
315 | assert.Nil(t, result)
316 |
317 | // Test error return
318 | errorRequest := mcp.CallToolRequest{
319 | Params: struct {
320 | Name string `json:"name"`
321 | Arguments any `json:"arguments,omitempty"`
322 | Meta *mcp.Meta `json:"_meta,omitempty"`
323 | }{
324 | Name: "string_ptr_tool",
325 | Arguments: map[string]any{
326 | "name": "error",
327 | "value": 65,
328 | },
329 | },
330 | }
331 |
332 | _, err = handler(ctx, errorRequest)
333 | assert.Error(t, err)
334 | assert.Equal(t, "test error", err.Error())
335 | })
336 |
337 | t.Run("struct return type", func(t *testing.T) {
338 | _, handler, err := ConvertTool("struct_tool", "A struct tool", structToolHandler)
339 | require.NoError(t, err)
340 |
341 | // Test normal struct return
342 | ctx := context.Background()
343 | request := mcp.CallToolRequest{
344 | Params: struct {
345 | Name string `json:"name"`
346 | Arguments any `json:"arguments,omitempty"`
347 | Meta *mcp.Meta `json:"_meta,omitempty"`
348 | }{
349 | Name: "struct_tool",
350 | Arguments: map[string]any{
351 | "name": "test",
352 | "value": 65, // ASCII 'A'
353 | },
354 | },
355 | }
356 |
357 | result, err := handler(ctx, request)
358 | require.NoError(t, err)
359 | require.NotNil(t, result)
360 | require.Len(t, result.Content, 1)
361 | resultString, ok := result.Content[0].(mcp.TextContent)
362 | require.True(t, ok)
363 | assert.Contains(t, resultString.Text, `"name":"test"`)
364 | assert.Contains(t, resultString.Text, `"value":65`)
365 |
366 | // Test error return
367 | errorRequest := mcp.CallToolRequest{
368 | Params: struct {
369 | Name string `json:"name"`
370 | Arguments any `json:"arguments,omitempty"`
371 | Meta *mcp.Meta `json:"_meta,omitempty"`
372 | }{
373 | Name: "struct_tool",
374 | Arguments: map[string]any{
375 | "name": "error",
376 | "value": 65,
377 | },
378 | },
379 | }
380 |
381 | _, err = handler(ctx, errorRequest)
382 | assert.Error(t, err)
383 | assert.Equal(t, "test error", err.Error())
384 | })
385 |
386 | t.Run("struct pointer return type", func(t *testing.T) {
387 | _, handler, err := ConvertTool("struct_ptr_tool", "A struct pointer tool", structPtrToolHandler)
388 | require.NoError(t, err)
389 |
390 | // Test normal struct pointer return
391 | ctx := context.Background()
392 | request := mcp.CallToolRequest{
393 | Params: struct {
394 | Name string `json:"name"`
395 | Arguments any `json:"arguments,omitempty"`
396 | Meta *mcp.Meta `json:"_meta,omitempty"`
397 | }{
398 | Name: "struct_ptr_tool",
399 | Arguments: map[string]any{
400 | "name": "test",
401 | "value": 65, // ASCII 'A'
402 | },
403 | },
404 | }
405 |
406 | result, err := handler(ctx, request)
407 | require.NoError(t, err)
408 | require.NotNil(t, result)
409 | require.Len(t, result.Content, 1)
410 | resultString, ok := result.Content[0].(mcp.TextContent)
411 | require.True(t, ok)
412 | assert.Contains(t, resultString.Text, `"name":"test"`)
413 | assert.Contains(t, resultString.Text, `"value":65`)
414 |
415 | // Test nil struct pointer return
416 | nilRequest := mcp.CallToolRequest{
417 | Params: struct {
418 | Name string `json:"name"`
419 | Arguments any `json:"arguments,omitempty"`
420 | Meta *mcp.Meta `json:"_meta,omitempty"`
421 | }{
422 | Name: "struct_ptr_tool",
423 | Arguments: map[string]any{
424 | "name": "nil",
425 | "value": 65,
426 | },
427 | },
428 | }
429 |
430 | result, err = handler(ctx, nilRequest)
431 | require.NoError(t, err)
432 | assert.Nil(t, result)
433 |
434 | // Test error return
435 | errorRequest := mcp.CallToolRequest{
436 | Params: struct {
437 | Name string `json:"name"`
438 | Arguments any `json:"arguments,omitempty"`
439 | Meta *mcp.Meta `json:"_meta,omitempty"`
440 | }{
441 | Name: "struct_ptr_tool",
442 | Arguments: map[string]any{
443 | "name": "error",
444 | "value": 65,
445 | },
446 | },
447 | }
448 |
449 | _, err = handler(ctx, errorRequest)
450 | assert.Error(t, err)
451 | assert.Equal(t, "test error", err.Error())
452 | })
453 |
454 | t.Run("invalid handler types", func(t *testing.T) {
455 | // Test wrong second argument type (not a struct)
456 | wrongSecondArgFunc := func(ctx context.Context, s string) (*mcp.CallToolResult, error) {
457 | return nil, nil
458 | }
459 | _, _, err := ConvertTool("invalid", "description", wrongSecondArgFunc)
460 | assert.Error(t, err)
461 | assert.Contains(t, err.Error(), "second argument must be a struct")
462 | })
463 |
464 | t.Run("handler execution with invalid arguments", func(t *testing.T) {
465 | _, handler, err := ConvertTool("test_tool", "A test tool", testToolHandler)
466 | require.NoError(t, err)
467 |
468 | // Test with invalid JSON
469 | invalidRequest := mcp.CallToolRequest{
470 | Params: struct {
471 | Name string `json:"name"`
472 | Arguments any `json:"arguments,omitempty"`
473 | Meta *mcp.Meta `json:"_meta,omitempty"`
474 | }{
475 | Arguments: map[string]any{
476 | "name": make(chan int), // Channels can't be marshaled to JSON
477 | },
478 | },
479 | }
480 |
481 | _, err = handler(context.Background(), invalidRequest)
482 | assert.Error(t, err)
483 | assert.Contains(t, err.Error(), "marshal args")
484 |
485 | // Test with type mismatch
486 | mismatchRequest := mcp.CallToolRequest{
487 | Params: struct {
488 | Name string `json:"name"`
489 | Arguments any `json:"arguments,omitempty"`
490 | Meta *mcp.Meta `json:"_meta,omitempty"`
491 | }{
492 | Arguments: map[string]any{
493 | "name": 123, // Should be a string
494 | "value": "not an int",
495 | },
496 | },
497 | }
498 |
499 | _, err = handler(context.Background(), mismatchRequest)
500 | assert.Error(t, err)
501 | assert.Contains(t, err.Error(), "unmarshal args")
502 | })
503 | }
504 |
505 | func TestCreateJSONSchemaFromHandler(t *testing.T) {
506 | schema := createJSONSchemaFromHandler(testToolHandler)
507 |
508 | assert.Equal(t, "object", schema.Type)
509 | assert.Len(t, schema.Required, 2) // name and value are required, optional is not
510 |
511 | // Check properties
512 | nameProperty, ok := schema.Properties.Get("name")
513 | assert.True(t, ok)
514 | assert.Equal(t, "string", nameProperty.Type)
515 | assert.Equal(t, "The name parameter", nameProperty.Description)
516 |
517 | valueProperty, ok := schema.Properties.Get("value")
518 | assert.True(t, ok)
519 | assert.Equal(t, "integer", valueProperty.Type)
520 | assert.Equal(t, "The value parameter", valueProperty.Description)
521 |
522 | optionalProperty, ok := schema.Properties.Get("optional")
523 | assert.True(t, ok)
524 | assert.Equal(t, "boolean", optionalProperty.Type)
525 | assert.Equal(t, "An optional parameter", optionalProperty.Description)
526 | }
527 |
```
--------------------------------------------------------------------------------
/cmd/mcp-grafana/main.go:
--------------------------------------------------------------------------------
```go
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "flag"
7 | "fmt"
8 | "log/slog"
9 | "net/http"
10 | "os"
11 | "os/signal"
12 | "slices"
13 | "strings"
14 | "syscall"
15 | "time"
16 |
17 | "github.com/mark3labs/mcp-go/mcp"
18 | "github.com/mark3labs/mcp-go/server"
19 |
20 | mcpgrafana "github.com/grafana/mcp-grafana"
21 | "github.com/grafana/mcp-grafana/tools"
22 | )
23 |
24 | func maybeAddTools(s *server.MCPServer, tf func(*server.MCPServer), enabledTools []string, disable bool, category string) {
25 | if !slices.Contains(enabledTools, category) {
26 | slog.Debug("Not enabling tools", "category", category)
27 | return
28 | }
29 | if disable {
30 | slog.Info("Disabling tools", "category", category)
31 | return
32 | }
33 | slog.Debug("Enabling tools", "category", category)
34 | tf(s)
35 | }
36 |
37 | // disabledTools indicates whether each category of tools should be disabled.
38 | type disabledTools struct {
39 | enabledTools string
40 |
41 | search, datasource, incident,
42 | prometheus, loki, alerting,
43 | dashboard, folder, oncall, asserts, sift, admin,
44 | pyroscope, navigation, proxied, annotations, write bool
45 | }
46 |
47 | // Configuration for the Grafana client.
48 | type grafanaConfig struct {
49 | // Whether to enable debug mode for the Grafana transport.
50 | debug bool
51 |
52 | // TLS configuration
53 | tlsCertFile string
54 | tlsKeyFile string
55 | tlsCAFile string
56 | tlsSkipVerify bool
57 | }
58 |
59 | func (dt *disabledTools) addFlags() {
60 | flag.StringVar(&dt.enabledTools, "enabled-tools", "search,datasource,incident,prometheus,loki,alerting,dashboard,folder,oncall,asserts,sift,admin,pyroscope,navigation,proxied,annotations", "A comma separated list of tools enabled for this server. Can be overwritten entirely or by disabling specific components, e.g. --disable-search.")
61 | flag.BoolVar(&dt.search, "disable-search", false, "Disable search tools")
62 | flag.BoolVar(&dt.datasource, "disable-datasource", false, "Disable datasource tools")
63 | flag.BoolVar(&dt.incident, "disable-incident", false, "Disable incident tools")
64 | flag.BoolVar(&dt.prometheus, "disable-prometheus", false, "Disable prometheus tools")
65 | flag.BoolVar(&dt.loki, "disable-loki", false, "Disable loki tools")
66 | flag.BoolVar(&dt.alerting, "disable-alerting", false, "Disable alerting tools")
67 | flag.BoolVar(&dt.dashboard, "disable-dashboard", false, "Disable dashboard tools")
68 | flag.BoolVar(&dt.folder, "disable-folder", false, "Disable folder tools")
69 | flag.BoolVar(&dt.oncall, "disable-oncall", false, "Disable oncall tools")
70 | flag.BoolVar(&dt.asserts, "disable-asserts", false, "Disable asserts tools")
71 | flag.BoolVar(&dt.sift, "disable-sift", false, "Disable sift tools")
72 | flag.BoolVar(&dt.admin, "disable-admin", false, "Disable admin tools")
73 | flag.BoolVar(&dt.pyroscope, "disable-pyroscope", false, "Disable pyroscope tools")
74 | flag.BoolVar(&dt.navigation, "disable-navigation", false, "Disable navigation tools")
75 | flag.BoolVar(&dt.proxied, "disable-proxied", false, "Disable proxied tools (tools from external MCP servers)")
76 | flag.BoolVar(&dt.write, "disable-write", false, "Disable write tools (create/update operations)")
77 | flag.BoolVar(&dt.annotations, "disable-annotations", false, "Disable annotation tools")
78 | }
79 |
80 | func (gc *grafanaConfig) addFlags() {
81 | flag.BoolVar(&gc.debug, "debug", false, "Enable debug mode for the Grafana transport")
82 |
83 | // TLS configuration flags
84 | flag.StringVar(&gc.tlsCertFile, "tls-cert-file", "", "Path to TLS certificate file for client authentication")
85 | flag.StringVar(&gc.tlsKeyFile, "tls-key-file", "", "Path to TLS private key file for client authentication")
86 | flag.StringVar(&gc.tlsCAFile, "tls-ca-file", "", "Path to TLS CA certificate file for server verification")
87 | flag.BoolVar(&gc.tlsSkipVerify, "tls-skip-verify", false, "Skip TLS certificate verification (insecure)")
88 | }
89 |
90 | func (dt *disabledTools) addTools(s *server.MCPServer) {
91 | enabledTools := strings.Split(dt.enabledTools, ",")
92 | enableWriteTools := !dt.write
93 | maybeAddTools(s, tools.AddSearchTools, enabledTools, dt.search, "search")
94 | maybeAddTools(s, tools.AddDatasourceTools, enabledTools, dt.datasource, "datasource")
95 | maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddIncidentTools(mcp, enableWriteTools) }, enabledTools, dt.incident, "incident")
96 | maybeAddTools(s, tools.AddPrometheusTools, enabledTools, dt.prometheus, "prometheus")
97 | maybeAddTools(s, tools.AddLokiTools, enabledTools, dt.loki, "loki")
98 | maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddAlertingTools(mcp, enableWriteTools) }, enabledTools, dt.alerting, "alerting")
99 | maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddDashboardTools(mcp, enableWriteTools) }, enabledTools, dt.dashboard, "dashboard")
100 | maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddFolderTools(mcp, enableWriteTools) }, enabledTools, dt.folder, "folder")
101 | maybeAddTools(s, tools.AddOnCallTools, enabledTools, dt.oncall, "oncall")
102 | maybeAddTools(s, tools.AddAssertsTools, enabledTools, dt.asserts, "asserts")
103 | maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddSiftTools(mcp, enableWriteTools) }, enabledTools, dt.sift, "sift")
104 | maybeAddTools(s, tools.AddAdminTools, enabledTools, dt.admin, "admin")
105 | maybeAddTools(s, tools.AddPyroscopeTools, enabledTools, dt.pyroscope, "pyroscope")
106 | maybeAddTools(s, tools.AddNavigationTools, enabledTools, dt.navigation, "navigation")
107 | maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddAnnotationTools(mcp, enableWriteTools) }, enabledTools, dt.annotations, "annotations")
108 | }
109 |
110 | func newServer(transport string, dt disabledTools) (*server.MCPServer, *mcpgrafana.ToolManager) {
111 | sm := mcpgrafana.NewSessionManager()
112 |
113 | // Declare variable for ToolManager that will be initialized after server creation
114 | var stm *mcpgrafana.ToolManager
115 |
116 | // Create hooks
117 | hooks := &server.Hooks{
118 | OnRegisterSession: []server.OnRegisterSessionHookFunc{sm.CreateSession},
119 | OnUnregisterSession: []server.OnUnregisterSessionHookFunc{sm.RemoveSession},
120 | }
121 |
122 | // Add proxied tools hooks if enabled and we're not running in stdio mode.
123 | // (stdio mode is handled by InitializeAndRegisterServerTools; per-session tools
124 | // are not supported).
125 | if transport != "stdio" && !dt.proxied {
126 | // OnBeforeListTools: Discover, connect, and register tools
127 | hooks.OnBeforeListTools = []server.OnBeforeListToolsFunc{
128 | func(ctx context.Context, id any, request *mcp.ListToolsRequest) {
129 | if stm != nil {
130 | if session := server.ClientSessionFromContext(ctx); session != nil {
131 | stm.InitializeAndRegisterProxiedTools(ctx, session)
132 | }
133 | }
134 | },
135 | }
136 |
137 | // OnBeforeCallTool: Fallback in case client calls tool without listing first
138 | hooks.OnBeforeCallTool = []server.OnBeforeCallToolFunc{
139 | func(ctx context.Context, id any, request *mcp.CallToolRequest) {
140 | if stm != nil {
141 | if session := server.ClientSessionFromContext(ctx); session != nil {
142 | stm.InitializeAndRegisterProxiedTools(ctx, session)
143 | }
144 | }
145 | },
146 | }
147 | }
148 | s := server.NewMCPServer("mcp-grafana", mcpgrafana.Version(),
149 | server.WithInstructions(`
150 | This server provides access to your Grafana instance and the surrounding ecosystem.
151 |
152 | Available Capabilities:
153 | - Dashboards: Search, retrieve, update, and create dashboards. Extract panel queries and datasource information.
154 | - Datasources: List and fetch details for datasources.
155 | - Prometheus & Loki: Run PromQL and LogQL queries, retrieve metric/log metadata, and explore label names/values.
156 | - Incidents: Search, create, update, and resolve incidents in Grafana Incident.
157 | - Sift Investigations: Start and manage Sift investigations, analyze logs/traces, find error patterns, and detect slow requests.
158 | - Alerting: List and fetch alert rules and notification contact points.
159 | - OnCall: View and manage on-call schedules, shifts, teams, and users.
160 | - Admin: List teams and perform administrative tasks.
161 | - Pyroscope: Profile applications and fetch profiling data.
162 | - Navigation: Generate deeplink URLs for Grafana resources like dashboards, panels, and Explore queries.
163 | - Proxied Tools: Access tools from external MCP servers (like Tempo) through dynamic discovery.
164 |
165 | Note that some of these capabilities may be disabled. Do not try to use features that are not available via tools.
166 | `),
167 | server.WithHooks(hooks),
168 | )
169 |
170 | // Initialize ToolManager now that server is created
171 | stm = mcpgrafana.NewToolManager(sm, s, mcpgrafana.WithProxiedTools(!dt.proxied))
172 |
173 | dt.addTools(s)
174 | return s, stm
175 | }
176 |
177 | type tlsConfig struct {
178 | certFile, keyFile string
179 | }
180 |
181 | func (tc *tlsConfig) addFlags() {
182 | flag.StringVar(&tc.certFile, "server.tls-cert-file", "", "Path to TLS certificate file for server HTTPS (required for TLS)")
183 | flag.StringVar(&tc.keyFile, "server.tls-key-file", "", "Path to TLS private key file for server HTTPS (required for TLS)")
184 | }
185 |
186 | // httpServer represents a server with Start and Shutdown methods
187 | type httpServer interface {
188 | Start(addr string) error
189 | Shutdown(ctx context.Context) error
190 | }
191 |
192 | // runHTTPServer handles the common logic for running HTTP-based servers
193 | func runHTTPServer(ctx context.Context, srv httpServer, addr, transportName string) error {
194 | // Start server in a goroutine
195 | serverErr := make(chan error, 1)
196 | go func() {
197 | if err := srv.Start(addr); err != nil {
198 | serverErr <- err
199 | }
200 | close(serverErr)
201 | }()
202 |
203 | // Wait for either server error or shutdown signal
204 | select {
205 | case err := <-serverErr:
206 | return err
207 | case <-ctx.Done():
208 | slog.Info(fmt.Sprintf("%s server shutting down...", transportName))
209 |
210 | // Create a timeout context for shutdown
211 | shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
212 | defer shutdownCancel()
213 |
214 | if err := srv.Shutdown(shutdownCtx); err != nil {
215 | return fmt.Errorf("shutdown error: %v", err)
216 | }
217 | slog.Debug("Shutdown called, waiting for connections to close...")
218 |
219 | // Wait for server to finish
220 | select {
221 | case err := <-serverErr:
222 | // http.ErrServerClosed is expected when shutting down
223 | if err != nil && !errors.Is(err, http.ErrServerClosed) {
224 | return fmt.Errorf("server error during shutdown: %v", err)
225 | }
226 | case <-shutdownCtx.Done():
227 | slog.Warn(fmt.Sprintf("%s server did not stop gracefully within timeout", transportName))
228 | }
229 | }
230 |
231 | return nil
232 | }
233 |
234 | func handleHealthz(w http.ResponseWriter, r *http.Request) {
235 | w.WriteHeader(http.StatusOK)
236 | _, _ = w.Write([]byte("ok"))
237 | }
238 |
239 | func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt disabledTools, gc mcpgrafana.GrafanaConfig, tls tlsConfig) error {
240 | slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel})))
241 | s, tm := newServer(transport, dt)
242 |
243 | // Create a context that will be cancelled on shutdown
244 | ctx, cancel := context.WithCancel(context.Background())
245 | defer cancel()
246 |
247 | // Set up signal handling for graceful shutdown
248 | sigChan := make(chan os.Signal, 1)
249 | signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
250 | defer signal.Stop(sigChan)
251 |
252 | // Handle shutdown signals
253 | go func() {
254 | <-sigChan
255 | slog.Info("Received shutdown signal")
256 | cancel()
257 |
258 | // For stdio, close stdin to unblock the Listen call
259 | if transport == "stdio" {
260 | _ = os.Stdin.Close()
261 | }
262 | }()
263 |
264 | // Start the appropriate server based on transport
265 | switch transport {
266 | case "stdio":
267 | srv := server.NewStdioServer(s)
268 | cf := mcpgrafana.ComposedStdioContextFunc(gc)
269 | srv.SetContextFunc(cf)
270 |
271 | // For stdio (single-tenant), initialize proxied tools on the server directly
272 | if !dt.proxied {
273 | stdioCtx := cf(ctx)
274 | if err := tm.InitializeAndRegisterServerTools(stdioCtx); err != nil {
275 | slog.Error("failed to initialize proxied tools for stdio", "error", err)
276 | }
277 | }
278 |
279 | slog.Info("Starting Grafana MCP server using stdio transport", "version", mcpgrafana.Version())
280 |
281 | err := srv.Listen(ctx, os.Stdin, os.Stdout)
282 | if err != nil && err != context.Canceled {
283 | return fmt.Errorf("server error: %v", err)
284 | }
285 | return nil
286 |
287 | case "sse":
288 | httpSrv := &http.Server{Addr: addr}
289 | srv := server.NewSSEServer(s,
290 | server.WithSSEContextFunc(mcpgrafana.ComposedSSEContextFunc(gc)),
291 | server.WithStaticBasePath(basePath),
292 | server.WithHTTPServer(httpSrv),
293 | )
294 | mux := http.NewServeMux()
295 | if basePath == "" {
296 | basePath = "/"
297 | }
298 | mux.Handle(basePath, srv)
299 | mux.HandleFunc("/healthz", handleHealthz)
300 | httpSrv.Handler = mux
301 | slog.Info("Starting Grafana MCP server using SSE transport",
302 | "version", mcpgrafana.Version(), "address", addr, "basePath", basePath)
303 | return runHTTPServer(ctx, srv, addr, "SSE")
304 | case "streamable-http":
305 | httpSrv := &http.Server{Addr: addr}
306 | opts := []server.StreamableHTTPOption{
307 | server.WithHTTPContextFunc(mcpgrafana.ComposedHTTPContextFunc(gc)),
308 | server.WithStateLess(dt.proxied), // Stateful when proxied tools enabled (requires sessions)
309 | server.WithEndpointPath(endpointPath),
310 | server.WithStreamableHTTPServer(httpSrv),
311 | }
312 | if tls.certFile != "" || tls.keyFile != "" {
313 | opts = append(opts, server.WithTLSCert(tls.certFile, tls.keyFile))
314 | }
315 | srv := server.NewStreamableHTTPServer(s, opts...)
316 | mux := http.NewServeMux()
317 | mux.Handle(endpointPath, srv)
318 | mux.HandleFunc("/healthz", handleHealthz)
319 | httpSrv.Handler = mux
320 | slog.Info("Starting Grafana MCP server using StreamableHTTP transport",
321 | "version", mcpgrafana.Version(), "address", addr, "endpointPath", endpointPath)
322 | return runHTTPServer(ctx, srv, addr, "StreamableHTTP")
323 | default:
324 | return fmt.Errorf("invalid transport type: %s. Must be 'stdio', 'sse' or 'streamable-http'", transport)
325 | }
326 | }
327 |
328 | func main() {
329 | var transport string
330 | flag.StringVar(&transport, "t", "stdio", "Transport type (stdio, sse or streamable-http)")
331 | flag.StringVar(
332 | &transport,
333 | "transport",
334 | "stdio",
335 | "Transport type (stdio, sse or streamable-http)",
336 | )
337 | addr := flag.String("address", "localhost:8000", "The host and port to start the sse server on")
338 | basePath := flag.String("base-path", "", "Base path for the sse server")
339 | endpointPath := flag.String("endpoint-path", "/mcp", "Endpoint path for the streamable-http server")
340 | logLevel := flag.String("log-level", "info", "Log level (debug, info, warn, error)")
341 | showVersion := flag.Bool("version", false, "Print the version and exit")
342 | var dt disabledTools
343 | dt.addFlags()
344 | var gc grafanaConfig
345 | gc.addFlags()
346 | var tls tlsConfig
347 | tls.addFlags()
348 | flag.Parse()
349 |
350 | if *showVersion {
351 | fmt.Println(mcpgrafana.Version())
352 | os.Exit(0)
353 | }
354 |
355 | // Convert local grafanaConfig to mcpgrafana.GrafanaConfig
356 | grafanaConfig := mcpgrafana.GrafanaConfig{Debug: gc.debug}
357 | if gc.tlsCertFile != "" || gc.tlsKeyFile != "" || gc.tlsCAFile != "" || gc.tlsSkipVerify {
358 | grafanaConfig.TLSConfig = &mcpgrafana.TLSConfig{
359 | CertFile: gc.tlsCertFile,
360 | KeyFile: gc.tlsKeyFile,
361 | CAFile: gc.tlsCAFile,
362 | SkipVerify: gc.tlsSkipVerify,
363 | }
364 | }
365 |
366 | if err := run(transport, *addr, *basePath, *endpointPath, parseLevel(*logLevel), dt, grafanaConfig, tls); err != nil {
367 | panic(err)
368 | }
369 | }
370 |
371 | func parseLevel(level string) slog.Level {
372 | var l slog.Level
373 | if err := l.UnmarshalText([]byte(level)); err != nil {
374 | return slog.LevelInfo
375 | }
376 | return l
377 | }
378 |
```
--------------------------------------------------------------------------------
/session_test.go:
--------------------------------------------------------------------------------
```go
1 | //go:build integration
2 |
3 | // Integration tests for proxied MCP tools functionality.
4 | // Requires docker-compose to be running with Grafana and Tempo instances.
5 | // Run with: go test -tags integration -v ./...
6 |
7 | package mcpgrafana
8 |
9 | import (
10 | "context"
11 | "fmt"
12 | "net/url"
13 | "os"
14 | "strings"
15 | "sync"
16 | "testing"
17 |
18 | "github.com/go-openapi/strfmt"
19 | grafana_client "github.com/grafana/grafana-openapi-client-go/client"
20 | "github.com/mark3labs/mcp-go/mcp"
21 | "github.com/stretchr/testify/assert"
22 | "github.com/stretchr/testify/require"
23 | )
24 |
25 | // newProxiedToolsTestContext creates a test context with Grafana client and config
26 | func newProxiedToolsTestContext(t *testing.T) context.Context {
27 | cfg := grafana_client.DefaultTransportConfig()
28 | cfg.Host = "localhost:3000"
29 | cfg.Schemes = []string{"http"}
30 |
31 | // Extract transport config from env vars, and set it on the context.
32 | if u, ok := os.LookupEnv("GRAFANA_URL"); ok {
33 | parsedURL, err := url.Parse(u)
34 | require.NoError(t, err, "invalid GRAFANA_URL")
35 | cfg.Host = parsedURL.Host
36 | // The Grafana client will always prefer HTTPS even if the URL is HTTP,
37 | // so we need to limit the schemes to HTTP if the URL is HTTP.
38 | if parsedURL.Scheme == "http" {
39 | cfg.Schemes = []string{"http"}
40 | }
41 | }
42 |
43 | // Check for the new service account token environment variable first
44 | if apiKey := os.Getenv("GRAFANA_SERVICE_ACCOUNT_TOKEN"); apiKey != "" {
45 | cfg.APIKey = apiKey
46 | } else if apiKey := os.Getenv("GRAFANA_API_KEY"); apiKey != "" {
47 | // Fall back to the deprecated API key environment variable
48 | cfg.APIKey = apiKey
49 | } else {
50 | cfg.BasicAuth = url.UserPassword("admin", "admin")
51 | }
52 |
53 | grafanaClient := grafana_client.NewHTTPClientWithConfig(strfmt.Default, cfg)
54 |
55 | grafanaCfg := GrafanaConfig{
56 | Debug: true,
57 | URL: "http://localhost:3000",
58 | APIKey: cfg.APIKey,
59 | BasicAuth: cfg.BasicAuth,
60 | }
61 |
62 | ctx := WithGrafanaConfig(context.Background(), grafanaCfg)
63 | return WithGrafanaClient(ctx, grafanaClient)
64 | }
65 |
66 | func TestDiscoverMCPDatasources(t *testing.T) {
67 | ctx := newProxiedToolsTestContext(t)
68 |
69 | t.Run("discovers tempo datasources", func(t *testing.T) {
70 | discovered, err := discoverMCPDatasources(ctx)
71 | require.NoError(t, err)
72 |
73 | // Should find two Tempo datasources from docker-compose
74 | assert.GreaterOrEqual(t, len(discovered), 2, "Should discover at least 2 Tempo datasources")
75 |
76 | // Check that we found the expected datasources
77 | uids := make([]string, len(discovered))
78 | for i, ds := range discovered {
79 | uids[i] = ds.UID
80 | assert.Equal(t, "tempo", ds.Type, "All discovered datasources should be tempo type")
81 | assert.NotEmpty(t, ds.Name, "Datasource should have a name")
82 | assert.NotEmpty(t, ds.MCPURL, "Datasource should have MCP URL")
83 |
84 | // Verify URL format
85 | expectedURLPattern := fmt.Sprintf("http://localhost:3000/api/datasources/proxy/uid/%s/api/mcp", ds.UID)
86 | assert.Equal(t, expectedURLPattern, ds.MCPURL, "MCP URL should follow proxy pattern")
87 | }
88 |
89 | // Should contain our expected UIDs
90 | assert.Contains(t, uids, "tempo", "Should discover 'tempo' datasource")
91 | assert.Contains(t, uids, "tempo-secondary", "Should discover 'tempo-secondary' datasource")
92 | })
93 |
94 | t.Run("returns error when grafana client not in context", func(t *testing.T) {
95 | emptyCtx := context.Background()
96 | discovered, err := discoverMCPDatasources(emptyCtx)
97 | assert.Error(t, err)
98 | assert.Nil(t, discovered)
99 | assert.Contains(t, err.Error(), "grafana client not found in context")
100 | })
101 |
102 | t.Run("returns error when auth is missing", func(t *testing.T) {
103 | // Context with client but no auth credentials
104 | cfg := grafana_client.DefaultTransportConfig()
105 | cfg.Host = "localhost:3000"
106 | cfg.Schemes = []string{"http"}
107 | grafanaClient := grafana_client.NewHTTPClientWithConfig(strfmt.Default, cfg)
108 |
109 | grafanaCfg := GrafanaConfig{
110 | URL: "http://localhost:3000",
111 | // No APIKey or BasicAuth set
112 | }
113 | ctx := WithGrafanaConfig(context.Background(), grafanaCfg)
114 | ctx = WithGrafanaClient(ctx, grafanaClient)
115 |
116 | discovered, err := discoverMCPDatasources(ctx)
117 | assert.Error(t, err)
118 | assert.Nil(t, discovered)
119 | assert.Contains(t, err.Error(), "Unauthorized")
120 | })
121 | }
122 |
123 | func TestToolNamespacing(t *testing.T) {
124 | t.Run("parse proxied tool name", func(t *testing.T) {
125 | datasourceType, toolName, err := parseProxiedToolName("tempo_traceql-search")
126 | require.NoError(t, err)
127 | assert.Equal(t, "tempo", datasourceType)
128 | assert.Equal(t, "traceql-search", toolName)
129 | })
130 |
131 | t.Run("parse proxied tool name with multiple underscores", func(t *testing.T) {
132 | datasourceType, toolName, err := parseProxiedToolName("tempo_get-attribute-values")
133 | require.NoError(t, err)
134 | assert.Equal(t, "tempo", datasourceType)
135 | assert.Equal(t, "get-attribute-values", toolName)
136 | })
137 |
138 | t.Run("parse proxied tool name with invalid format", func(t *testing.T) {
139 | _, _, err := parseProxiedToolName("invalid")
140 | assert.Error(t, err)
141 | assert.Contains(t, err.Error(), "invalid proxied tool name format")
142 | })
143 |
144 | t.Run("add datasourceUid parameter to tool", func(t *testing.T) {
145 | originalTool := mcp.Tool{
146 | Name: "query_traces",
147 | Description: "Query traces from Tempo",
148 | InputSchema: mcp.ToolInputSchema{
149 | Properties: map[string]any{
150 | "query": map[string]any{
151 | "type": "string",
152 | },
153 | },
154 | Required: []string{"query"},
155 | },
156 | }
157 |
158 | modifiedTool := addDatasourceUidParameter(originalTool, "tempo")
159 |
160 | assert.Equal(t, "tempo_query_traces", modifiedTool.Name)
161 | assert.Equal(t, "Query traces from Tempo", modifiedTool.Description)
162 | assert.NotNil(t, modifiedTool.InputSchema.Properties["datasourceUid"])
163 | assert.Contains(t, modifiedTool.InputSchema.Required, "datasourceUid")
164 | assert.Contains(t, modifiedTool.InputSchema.Required, "query")
165 | })
166 |
167 | t.Run("add datasourceUid parameter with empty description", func(t *testing.T) {
168 | originalTool := mcp.Tool{
169 | Name: "test_tool",
170 | Description: "",
171 | InputSchema: mcp.ToolInputSchema{
172 | Properties: make(map[string]any),
173 | },
174 | }
175 |
176 | modifiedTool := addDatasourceUidParameter(originalTool, "tempo")
177 |
178 | assert.Equal(t, "tempo_test_tool", modifiedTool.Name)
179 | assert.Equal(t, "", modifiedTool.Description, "Should not modify empty description")
180 | assert.NotNil(t, modifiedTool.InputSchema.Properties["datasourceUid"])
181 | })
182 | }
183 |
184 | func TestSessionStateLifecycle(t *testing.T) {
185 | t.Run("create and get session", func(t *testing.T) {
186 | sm := NewSessionManager()
187 |
188 | // Create mock session
189 | mockSession := &mockClientSession{id: "test-session-123"}
190 |
191 | sm.CreateSession(context.Background(), mockSession)
192 |
193 | state, exists := sm.GetSession("test-session-123")
194 | assert.True(t, exists)
195 | assert.NotNil(t, state)
196 | assert.NotNil(t, state.proxiedClients)
197 | assert.False(t, state.proxiedToolsInitialized)
198 | })
199 |
200 | t.Run("remove session cleans up clients", func(t *testing.T) {
201 | sm := NewSessionManager()
202 |
203 | mockSession := &mockClientSession{id: "test-session-456"}
204 | sm.CreateSession(context.Background(), mockSession)
205 |
206 | state, _ := sm.GetSession("test-session-456")
207 |
208 | // Add a mock proxied client
209 | mockClient := &ProxiedClient{
210 | DatasourceUID: "test-uid",
211 | DatasourceName: "Test Datasource",
212 | DatasourceType: "tempo",
213 | }
214 | state.proxiedClients["tempo_test-uid"] = mockClient
215 |
216 | // Remove session
217 | sm.RemoveSession(context.Background(), mockSession)
218 |
219 | // Session should be gone
220 | _, exists := sm.GetSession("test-session-456")
221 | assert.False(t, exists)
222 | })
223 |
224 | t.Run("get non-existent session", func(t *testing.T) {
225 | sm := NewSessionManager()
226 |
227 | state, exists := sm.GetSession("non-existent")
228 | assert.False(t, exists)
229 | assert.Nil(t, state)
230 | })
231 | }
232 |
233 | func TestConcurrentInitializationRaceCondition(t *testing.T) {
234 | t.Run("concurrent initialization calls should be safe", func(t *testing.T) {
235 | sm := NewSessionManager()
236 | mockSession := &mockClientSession{id: "race-test-session"}
237 | sm.CreateSession(context.Background(), mockSession)
238 |
239 | state, exists := sm.GetSession("race-test-session")
240 | require.True(t, exists)
241 |
242 | // Track how many times the initialization logic runs
243 | var initCount int
244 | var initCountMutex sync.Mutex
245 |
246 | // Create a custom initOnce to track calls
247 | state.initOnce = sync.Once{}
248 |
249 | // Simulate the initialization work that should run exactly once
250 | initWork := func() {
251 | initCountMutex.Lock()
252 | initCount++
253 | initCountMutex.Unlock()
254 | // Simulate some work
255 | state.mutex.Lock()
256 | state.proxiedToolsInitialized = true
257 | state.proxiedClients["tempo_test"] = &ProxiedClient{
258 | DatasourceUID: "test",
259 | DatasourceName: "Test",
260 | DatasourceType: "tempo",
261 | }
262 | state.mutex.Unlock()
263 | }
264 |
265 | // Launch multiple goroutines that all try to initialize concurrently
266 | const numGoroutines = 10
267 | var wg sync.WaitGroup
268 | wg.Add(numGoroutines)
269 |
270 | for i := 0; i < numGoroutines; i++ {
271 | go func() {
272 | defer wg.Done()
273 | // This should be the pattern used in InitializeAndRegisterProxiedTools
274 | state.initOnce.Do(initWork)
275 | }()
276 | }
277 |
278 | wg.Wait()
279 |
280 | // Verify initialization ran exactly once
281 | assert.Equal(t, 1, initCount, "Initialization should run exactly once despite concurrent calls")
282 | assert.True(t, state.proxiedToolsInitialized, "State should be initialized")
283 | assert.Len(t, state.proxiedClients, 1, "Should have exactly one client")
284 | })
285 |
286 | t.Run("sync.Once prevents double initialization", func(t *testing.T) {
287 | sm := NewSessionManager()
288 | mockSession := &mockClientSession{id: "double-init-test"}
289 | sm.CreateSession(context.Background(), mockSession)
290 |
291 | state, _ := sm.GetSession("double-init-test")
292 |
293 | callCount := 0
294 |
295 | // First call
296 | state.initOnce.Do(func() {
297 | callCount++
298 | })
299 |
300 | // Second call should not execute
301 | state.initOnce.Do(func() {
302 | callCount++
303 | })
304 |
305 | // Third call should also not execute
306 | state.initOnce.Do(func() {
307 | callCount++
308 | })
309 |
310 | assert.Equal(t, 1, callCount, "sync.Once should ensure function runs exactly once")
311 | })
312 | }
313 |
314 | func TestProxiedClientLifecycle(t *testing.T) {
315 | ctx := newProxiedToolsTestContext(t)
316 |
317 | t.Run("list tools returns copy", func(t *testing.T) {
318 | pc := &ProxiedClient{
319 | DatasourceUID: "test-uid",
320 | DatasourceName: "Test",
321 | DatasourceType: "tempo",
322 | Tools: []mcp.Tool{
323 | {Name: "tool1", Description: "First tool"},
324 | {Name: "tool2", Description: "Second tool"},
325 | },
326 | }
327 |
328 | tools1 := pc.ListTools()
329 | tools2 := pc.ListTools()
330 |
331 | // Should return same content
332 | assert.Equal(t, tools1, tools2)
333 |
334 | // But different slice instances (copy)
335 | assert.NotSame(t, &tools1[0], &tools2[0])
336 | })
337 |
338 | t.Run("call tool validates tool exists", func(t *testing.T) {
339 | pc := &ProxiedClient{
340 | DatasourceUID: "test-uid",
341 | DatasourceName: "Test",
342 | DatasourceType: "tempo",
343 | Tools: []mcp.Tool{
344 | {Name: "valid_tool", Description: "Valid tool"},
345 | },
346 | }
347 |
348 | // Call non-existent tool
349 | result, err := pc.CallTool(ctx, "non_existent_tool", map[string]any{})
350 | assert.Error(t, err)
351 | assert.Nil(t, result)
352 | assert.Contains(t, err.Error(), "not found in remote MCP server")
353 | })
354 | }
355 |
356 | func TestEndToEndProxiedToolsFlow(t *testing.T) {
357 | ctx := newProxiedToolsTestContext(t)
358 |
359 | t.Run("full flow from discovery to tool call", func(t *testing.T) {
360 | // Step 1: Discover MCP datasources
361 | discovered, err := discoverMCPDatasources(ctx)
362 | require.NoError(t, err)
363 | require.GreaterOrEqual(t, len(discovered), 1, "Should discover at least one Tempo datasource")
364 |
365 | // Use the first discovered datasource
366 | ds := discovered[0]
367 | t.Logf("Testing with datasource: %s (UID: %s, URL: %s)", ds.Name, ds.UID, ds.MCPURL)
368 |
369 | // Step 2: Create a proxied client connection
370 | client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
371 | if err != nil {
372 | t.Skipf("Skipping end-to-end test: Tempo MCP endpoint not available: %v", err)
373 | return
374 | }
375 | defer func() {
376 | _ = client.Close()
377 | }()
378 |
379 | // Step 3: Verify we got tools from the remote server
380 | tools := client.ListTools()
381 | require.Greater(t, len(tools), 0, "Should have at least one tool from Tempo MCP server")
382 | t.Logf("Discovered %d tools from Tempo MCP server", len(tools))
383 |
384 | // Log the available tools
385 | for _, tool := range tools {
386 | t.Logf(" - Tool: %s - %s", tool.Name, tool.Description)
387 | }
388 |
389 | // Step 4: Test tool modification with datasourceUid parameter
390 | firstTool := tools[0]
391 | modifiedTool := addDatasourceUidParameter(firstTool, ds.Type)
392 |
393 | expectedName := ds.Type + "_" + firstTool.Name
394 | assert.Equal(t, expectedName, modifiedTool.Name, "Modified tool should have prefixed name")
395 | assert.Contains(t, modifiedTool.InputSchema.Required, "datasourceUid", "Modified tool should require datasourceUid")
396 |
397 | // Step 5: Test session integration
398 | sm := NewSessionManager()
399 | mockSession := &mockClientSession{id: "e2e-test-session"}
400 | sm.CreateSession(ctx, mockSession)
401 |
402 | state, exists := sm.GetSession("e2e-test-session")
403 | require.True(t, exists)
404 |
405 | // Store the proxied client in session state
406 | key := ds.Type + "_" + ds.UID
407 | state.proxiedClients[key] = client
408 |
409 | // Step 6: Verify client is stored correctly in session
410 | retrievedClient, exists := state.proxiedClients[key]
411 | require.True(t, exists, "Client should be stored in session state")
412 | assert.Equal(t, client, retrievedClient, "Should retrieve the same client from session")
413 |
414 | // Step 7: Test ProxiedToolHandler flow
415 | handler := NewProxiedToolHandler(sm, nil, modifiedTool.Name)
416 | assert.NotNil(t, handler)
417 |
418 | // Note: We can't actually call the tool without knowing what arguments it expects
419 | // and without the context having the proper session, but we've validated the setup
420 | t.Logf("Successfully validated end-to-end proxied tools flow")
421 | })
422 |
423 | t.Run("multiple datasources in single session", func(t *testing.T) {
424 | discovered, err := discoverMCPDatasources(ctx)
425 | require.NoError(t, err)
426 |
427 | if len(discovered) < 2 {
428 | t.Skip("Need at least 2 Tempo datasources for this test")
429 | }
430 |
431 | sm := NewSessionManager()
432 | mockSession := &mockClientSession{id: "multi-ds-test-session"}
433 | sm.CreateSession(ctx, mockSession)
434 |
435 | state, _ := sm.GetSession("multi-ds-test-session")
436 |
437 | // Try to connect to multiple datasources
438 | connectedCount := 0
439 | for i, ds := range discovered {
440 | if i >= 2 {
441 | break // Test with first 2 datasources
442 | }
443 |
444 | client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
445 | if err != nil {
446 | t.Logf("Could not connect to datasource %s: %v", ds.UID, err)
447 | continue
448 | }
449 | defer func() {
450 | _ = client.Close()
451 | }()
452 |
453 | key := ds.Type + "_" + ds.UID
454 | state.proxiedClients[key] = client
455 | connectedCount++
456 |
457 | t.Logf("Connected to datasource %s with %d tools", ds.UID, len(client.Tools))
458 | }
459 |
460 | if connectedCount == 0 {
461 | t.Skip("Could not connect to any Tempo datasources")
462 | }
463 |
464 | // Verify each client is stored correctly
465 | for key, client := range state.proxiedClients {
466 | parts := strings.Split(key, "_")
467 | require.Len(t, parts, 2, "Key should have format type_uid")
468 | assert.NotNil(t, client, "Client should not be nil")
469 | assert.Equal(t, parts[0], client.DatasourceType, "Client type should match key")
470 | assert.Equal(t, parts[1], client.DatasourceUID, "Client UID should match key")
471 | }
472 |
473 | t.Logf("Successfully managed %d datasources in single session", connectedCount)
474 | })
475 | }
476 |
```
--------------------------------------------------------------------------------
/tools/prometheus.go:
--------------------------------------------------------------------------------
```go
1 | package tools
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net/http"
7 | "regexp"
8 | "strings"
9 | "time"
10 |
11 | "github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
12 | mcpgrafana "github.com/grafana/mcp-grafana"
13 | "github.com/mark3labs/mcp-go/mcp"
14 | "github.com/mark3labs/mcp-go/server"
15 | "github.com/prometheus/client_golang/api"
16 | promv1 "github.com/prometheus/client_golang/api/prometheus/v1"
17 | "github.com/prometheus/common/config"
18 | "github.com/prometheus/common/model"
19 | "github.com/prometheus/prometheus/model/labels"
20 | )
21 |
22 | var (
23 | matchTypeMap = map[string]labels.MatchType{
24 | "": labels.MatchEqual,
25 | "=": labels.MatchEqual,
26 | "!=": labels.MatchNotEqual,
27 | "=~": labels.MatchRegexp,
28 | "!~": labels.MatchNotRegexp,
29 | }
30 | )
31 |
32 | func promClientFromContext(ctx context.Context, uid string) (promv1.API, error) {
33 | // First check if the datasource exists
34 | _, err := getDatasourceByUID(ctx, GetDatasourceByUIDParams{UID: uid})
35 | if err != nil {
36 | return nil, err
37 | }
38 |
39 | cfg := mcpgrafana.GrafanaConfigFromContext(ctx)
40 | url := fmt.Sprintf("%s/api/datasources/proxy/uid/%s", strings.TrimRight(cfg.URL, "/"), uid)
41 |
42 | // Create custom transport with TLS configuration if available
43 | rt := api.DefaultRoundTripper
44 | if tlsConfig := cfg.TLSConfig; tlsConfig != nil {
45 | customTransport, err := tlsConfig.HTTPTransport(rt.(*http.Transport))
46 | if err != nil {
47 | return nil, fmt.Errorf("failed to create custom transport: %w", err)
48 | }
49 | rt = customTransport
50 | }
51 |
52 | if cfg.AccessToken != "" && cfg.IDToken != "" {
53 | rt = config.NewHeadersRoundTripper(&config.Headers{
54 | Headers: map[string]config.Header{
55 | "X-Access-Token": {
56 | Secrets: []config.Secret{config.Secret(cfg.AccessToken)},
57 | },
58 | "X-Grafana-Id": {
59 | Secrets: []config.Secret{config.Secret(cfg.IDToken)},
60 | },
61 | },
62 | }, rt)
63 | } else if cfg.APIKey != "" {
64 | rt = config.NewAuthorizationCredentialsRoundTripper(
65 | "Bearer", config.NewInlineSecret(cfg.APIKey), rt,
66 | )
67 | } else if cfg.BasicAuth != nil {
68 | password, _ := cfg.BasicAuth.Password()
69 | rt = config.NewBasicAuthRoundTripper(config.NewInlineSecret(cfg.BasicAuth.Username()), config.NewInlineSecret(password), rt)
70 | }
71 |
72 | // Wrap with org ID support
73 | rt = mcpgrafana.NewOrgIDRoundTripper(rt, cfg.OrgID)
74 |
75 | c, err := api.NewClient(api.Config{
76 | Address: url,
77 | RoundTripper: rt,
78 | })
79 | if err != nil {
80 | return nil, fmt.Errorf("creating Prometheus client: %w", err)
81 | }
82 |
83 | return promv1.NewAPI(c), nil
84 | }
85 |
86 | type ListPrometheusMetricMetadataParams struct {
87 | DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
88 | Limit int `json:"limit" jsonschema:"description=The maximum number of metrics to return"`
89 | LimitPerMetric int `json:"limitPerMetric" jsonschema:"description=The maximum number of metrics to return per metric"`
90 | Metric string `json:"metric" jsonschema:"description=The metric to query"`
91 | }
92 |
93 | func listPrometheusMetricMetadata(ctx context.Context, args ListPrometheusMetricMetadataParams) (map[string][]promv1.Metadata, error) {
94 | promClient, err := promClientFromContext(ctx, args.DatasourceUID)
95 | if err != nil {
96 | return nil, fmt.Errorf("getting Prometheus client: %w", err)
97 | }
98 |
99 | limit := args.Limit
100 | if limit == 0 {
101 | limit = 10
102 | }
103 |
104 | metadata, err := promClient.Metadata(ctx, args.Metric, fmt.Sprintf("%d", limit))
105 | if err != nil {
106 | return nil, fmt.Errorf("listing Prometheus metric metadata: %w", err)
107 | }
108 | return metadata, nil
109 | }
110 |
111 | var ListPrometheusMetricMetadata = mcpgrafana.MustTool(
112 | "list_prometheus_metric_metadata",
113 | "List Prometheus metric metadata. Returns metadata about metrics currently scraped from targets. Note: This endpoint is experimental.",
114 | listPrometheusMetricMetadata,
115 | mcp.WithTitleAnnotation("List Prometheus metric metadata"),
116 | mcp.WithIdempotentHintAnnotation(true),
117 | mcp.WithReadOnlyHintAnnotation(true),
118 | )
119 |
120 | type QueryPrometheusParams struct {
121 | DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
122 | Expr string `json:"expr" jsonschema:"required,description=The PromQL expression to query"`
123 | StartTime string `json:"startTime" jsonschema:"required,description=The start time. Supported formats are RFC3339 or relative to now (e.g. 'now'\\, 'now-1.5h'\\, 'now-2h45m'). Valid time units are 'ns'\\, 'us' (or 'µs')\\, 'ms'\\, 's'\\, 'm'\\, 'h'\\, 'd'."`
124 | EndTime string `json:"endTime,omitempty" jsonschema:"description=The end time. Required if queryType is 'range'\\, ignored if queryType is 'instant' Supported formats are RFC3339 or relative to now (e.g. 'now'\\, 'now-1.5h'\\, 'now-2h45m'). Valid time units are 'ns'\\, 'us' (or 'µs')\\, 'ms'\\, 's'\\, 'm'\\, 'h'\\, 'd'."`
125 | StepSeconds int `json:"stepSeconds,omitempty" jsonschema:"description=The time series step size in seconds. Required if queryType is 'range'\\, ignored if queryType is 'instant'"`
126 | QueryType string `json:"queryType,omitempty" jsonschema:"description=The type of query to use. Either 'range' or 'instant'"`
127 | }
128 |
129 | func parseTime(timeStr string) (time.Time, error) {
130 | tr := gtime.TimeRange{
131 | From: timeStr,
132 | Now: time.Now(),
133 | }
134 | return tr.ParseFrom()
135 | }
136 |
137 | func queryPrometheus(ctx context.Context, args QueryPrometheusParams) (model.Value, error) {
138 | promClient, err := promClientFromContext(ctx, args.DatasourceUID)
139 | if err != nil {
140 | return nil, fmt.Errorf("getting Prometheus client: %w", err)
141 | }
142 |
143 | queryType := args.QueryType
144 | if queryType == "" {
145 | queryType = "range"
146 | }
147 |
148 | var startTime time.Time
149 | startTime, err = parseTime(args.StartTime)
150 | if err != nil {
151 | return nil, fmt.Errorf("parsing start time: %w", err)
152 | }
153 |
154 | switch queryType {
155 | case "range":
156 | if args.StepSeconds == 0 {
157 | return nil, fmt.Errorf("stepSeconds must be provided when queryType is 'range'")
158 | }
159 |
160 | var endTime time.Time
161 | endTime, err = parseTime(args.EndTime)
162 | if err != nil {
163 | return nil, fmt.Errorf("parsing end time: %w", err)
164 | }
165 |
166 | step := time.Duration(args.StepSeconds) * time.Second
167 | result, _, err := promClient.QueryRange(ctx, args.Expr, promv1.Range{
168 | Start: startTime,
169 | End: endTime,
170 | Step: step,
171 | })
172 | if err != nil {
173 | return nil, fmt.Errorf("querying Prometheus range: %w", err)
174 | }
175 | return result, nil
176 | case "instant":
177 | result, _, err := promClient.Query(ctx, args.Expr, startTime)
178 | if err != nil {
179 | return nil, fmt.Errorf("querying Prometheus instant: %w", err)
180 | }
181 | return result, nil
182 | }
183 |
184 | return nil, fmt.Errorf("invalid query type: %s", queryType)
185 | }
186 |
187 | var QueryPrometheus = mcpgrafana.MustTool(
188 | "query_prometheus",
189 | "Query Prometheus using a PromQL expression. Supports both instant queries (at a single point in time) and range queries (over a time range). Time can be specified either in RFC3339 format or as relative time expressions like 'now', 'now-1h', 'now-30m', etc.",
190 | queryPrometheus,
191 | mcp.WithTitleAnnotation("Query Prometheus metrics"),
192 | mcp.WithIdempotentHintAnnotation(true),
193 | mcp.WithReadOnlyHintAnnotation(true),
194 | )
195 |
196 | type ListPrometheusMetricNamesParams struct {
197 | DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
198 | Regex string `json:"regex" jsonschema:"description=The regex to match against the metric names"`
199 | Limit int `json:"limit,omitempty" jsonschema:"description=The maximum number of results to return"`
200 | Page int `json:"page,omitempty" jsonschema:"description=The page number to return"`
201 | }
202 |
203 | func listPrometheusMetricNames(ctx context.Context, args ListPrometheusMetricNamesParams) ([]string, error) {
204 | promClient, err := promClientFromContext(ctx, args.DatasourceUID)
205 | if err != nil {
206 | return nil, fmt.Errorf("getting Prometheus client: %w", err)
207 | }
208 |
209 | limit := args.Limit
210 | if limit == 0 {
211 | limit = 10
212 | }
213 |
214 | page := args.Page
215 | if page == 0 {
216 | page = 1
217 | }
218 |
219 | // Get all metric names by querying for __name__ label values
220 | labelValues, _, err := promClient.LabelValues(ctx, "__name__", nil, time.Time{}, time.Time{})
221 | if err != nil {
222 | return nil, fmt.Errorf("listing Prometheus metric names: %w", err)
223 | }
224 |
225 | // Filter by regex if provided
226 | matches := []string{}
227 | if args.Regex != "" {
228 | re, err := regexp.Compile(args.Regex)
229 | if err != nil {
230 | return nil, fmt.Errorf("compiling regex: %w", err)
231 | }
232 | for _, val := range labelValues {
233 | if re.MatchString(string(val)) {
234 | matches = append(matches, string(val))
235 | }
236 | }
237 | } else {
238 | for _, val := range labelValues {
239 | matches = append(matches, string(val))
240 | }
241 | }
242 |
243 | // Apply pagination
244 | start := (page - 1) * limit
245 | end := start + limit
246 | if start >= len(matches) {
247 | matches = []string{}
248 | } else if end > len(matches) {
249 | matches = matches[start:]
250 | } else {
251 | matches = matches[start:end]
252 | }
253 |
254 | return matches, nil
255 | }
256 |
257 | var ListPrometheusMetricNames = mcpgrafana.MustTool(
258 | "list_prometheus_metric_names",
259 | "List metric names in a Prometheus datasource. Retrieves all metric names and then filters them locally using the provided regex. Supports pagination.",
260 | listPrometheusMetricNames,
261 | mcp.WithTitleAnnotation("List Prometheus metric names"),
262 | mcp.WithIdempotentHintAnnotation(true),
263 | mcp.WithReadOnlyHintAnnotation(true),
264 | )
265 |
266 | type LabelMatcher struct {
267 | Name string `json:"name" jsonschema:"required,description=The name of the label to match against"`
268 | Value string `json:"value" jsonschema:"required,description=The value to match against"`
269 | Type string `json:"type" jsonschema:"required,description=One of the '=' or '!=' or '=~' or '!~'"`
270 | }
271 |
272 | type Selector struct {
273 | Filters []LabelMatcher `json:"filters"`
274 | }
275 |
276 | func (s Selector) String() string {
277 | b := strings.Builder{}
278 | b.WriteRune('{')
279 | for i, f := range s.Filters {
280 | if f.Type == "" {
281 | f.Type = "="
282 | }
283 | b.WriteString(fmt.Sprintf(`%s%s'%s'`, f.Name, f.Type, f.Value))
284 | if i < len(s.Filters)-1 {
285 | b.WriteString(", ")
286 | }
287 | }
288 | b.WriteRune('}')
289 | return b.String()
290 | }
291 |
292 | // Matches runs the matchers against the given labels and returns whether they match the selector.
293 | func (s Selector) Matches(lbls labels.Labels) (bool, error) {
294 | matchers := make(labels.Selector, 0, len(s.Filters))
295 |
296 | for _, filter := range s.Filters {
297 | matchType, ok := matchTypeMap[filter.Type]
298 | if !ok {
299 | return false, fmt.Errorf("invalid matcher type: %s", filter.Type)
300 | }
301 |
302 | matcher, err := labels.NewMatcher(matchType, filter.Name, filter.Value)
303 | if err != nil {
304 | return false, fmt.Errorf("creating matcher: %w", err)
305 | }
306 |
307 | matchers = append(matchers, matcher)
308 | }
309 |
310 | return matchers.Matches(lbls), nil
311 | }
312 |
313 | type ListPrometheusLabelNamesParams struct {
314 | DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
315 | Matches []Selector `json:"matches,omitempty" jsonschema:"description=Optionally\\, a list of label matchers to filter the results by"`
316 | StartRFC3339 string `json:"startRfc3339,omitempty" jsonschema:"description=Optionally\\, the start time of the time range to filter the results by"`
317 | EndRFC3339 string `json:"endRfc3339,omitempty" jsonschema:"description=Optionally\\, the end time of the time range to filter the results by"`
318 | Limit int `json:"limit,omitempty" jsonschema:"description=Optionally\\, the maximum number of results to return"`
319 | }
320 |
321 | func listPrometheusLabelNames(ctx context.Context, args ListPrometheusLabelNamesParams) ([]string, error) {
322 | promClient, err := promClientFromContext(ctx, args.DatasourceUID)
323 | if err != nil {
324 | return nil, fmt.Errorf("getting Prometheus client: %w", err)
325 | }
326 |
327 | limit := args.Limit
328 | if limit == 0 {
329 | limit = 100
330 | }
331 |
332 | var startTime, endTime time.Time
333 | if args.StartRFC3339 != "" {
334 | if startTime, err = time.Parse(time.RFC3339, args.StartRFC3339); err != nil {
335 | return nil, fmt.Errorf("parsing start time: %w", err)
336 | }
337 | }
338 | if args.EndRFC3339 != "" {
339 | if endTime, err = time.Parse(time.RFC3339, args.EndRFC3339); err != nil {
340 | return nil, fmt.Errorf("parsing end time: %w", err)
341 | }
342 | }
343 |
344 | var matchers []string
345 | for _, m := range args.Matches {
346 | matchers = append(matchers, m.String())
347 | }
348 |
349 | labelNames, _, err := promClient.LabelNames(ctx, matchers, startTime, endTime)
350 | if err != nil {
351 | return nil, fmt.Errorf("listing Prometheus label names: %w", err)
352 | }
353 |
354 | // Apply limit
355 | if len(labelNames) > limit {
356 | labelNames = labelNames[:limit]
357 | }
358 |
359 | return labelNames, nil
360 | }
361 |
362 | var ListPrometheusLabelNames = mcpgrafana.MustTool(
363 | "list_prometheus_label_names",
364 | "List label names in a Prometheus datasource. Allows filtering by series selectors and time range.",
365 | listPrometheusLabelNames,
366 | mcp.WithTitleAnnotation("List Prometheus label names"),
367 | mcp.WithIdempotentHintAnnotation(true),
368 | mcp.WithReadOnlyHintAnnotation(true),
369 | )
370 |
371 | type ListPrometheusLabelValuesParams struct {
372 | DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
373 | LabelName string `json:"labelName" jsonschema:"required,description=The name of the label to query"`
374 | Matches []Selector `json:"matches,omitempty" jsonschema:"description=Optionally\\, a list of selectors to filter the results by"`
375 | StartRFC3339 string `json:"startRfc3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query"`
376 | EndRFC3339 string `json:"endRfc3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query"`
377 | Limit int `json:"limit,omitempty" jsonschema:"description=Optionally\\, the maximum number of results to return"`
378 | }
379 |
380 | func listPrometheusLabelValues(ctx context.Context, args ListPrometheusLabelValuesParams) (model.LabelValues, error) {
381 | promClient, err := promClientFromContext(ctx, args.DatasourceUID)
382 | if err != nil {
383 | return nil, fmt.Errorf("getting Prometheus client: %w", err)
384 | }
385 |
386 | limit := args.Limit
387 | if limit == 0 {
388 | limit = 100
389 | }
390 |
391 | var startTime, endTime time.Time
392 | if args.StartRFC3339 != "" {
393 | if startTime, err = time.Parse(time.RFC3339, args.StartRFC3339); err != nil {
394 | return nil, fmt.Errorf("parsing start time: %w", err)
395 | }
396 | }
397 | if args.EndRFC3339 != "" {
398 | if endTime, err = time.Parse(time.RFC3339, args.EndRFC3339); err != nil {
399 | return nil, fmt.Errorf("parsing end time: %w", err)
400 | }
401 | }
402 |
403 | var matchers []string
404 | for _, m := range args.Matches {
405 | matchers = append(matchers, m.String())
406 | }
407 |
408 | labelValues, _, err := promClient.LabelValues(ctx, args.LabelName, matchers, startTime, endTime)
409 | if err != nil {
410 | return nil, fmt.Errorf("listing Prometheus label values: %w", err)
411 | }
412 |
413 | // Apply limit
414 | if len(labelValues) > limit {
415 | labelValues = labelValues[:limit]
416 | }
417 |
418 | return labelValues, nil
419 | }
420 |
421 | var ListPrometheusLabelValues = mcpgrafana.MustTool(
422 | "list_prometheus_label_values",
423 | "Get the values for a specific label name in Prometheus. Allows filtering by series selectors and time range.",
424 | listPrometheusLabelValues,
425 | mcp.WithTitleAnnotation("List Prometheus label values"),
426 | mcp.WithIdempotentHintAnnotation(true),
427 | mcp.WithReadOnlyHintAnnotation(true),
428 | )
429 |
430 | func AddPrometheusTools(mcp *server.MCPServer) {
431 | ListPrometheusMetricMetadata.Register(mcp)
432 | QueryPrometheus.Register(mcp)
433 | ListPrometheusMetricNames.Register(mcp)
434 | ListPrometheusLabelNames.Register(mcp)
435 | ListPrometheusLabelValues.Register(mcp)
436 | }
437 |
```
--------------------------------------------------------------------------------
/tools/pyroscope.go:
--------------------------------------------------------------------------------
```go
1 | package tools
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "io"
7 | "net/http"
8 | "net/url"
9 | "regexp"
10 | "strings"
11 | "time"
12 |
13 | "connectrpc.com/connect"
14 | mcpgrafana "github.com/grafana/mcp-grafana"
15 | querierv1 "github.com/grafana/pyroscope/api/gen/proto/go/querier/v1"
16 | "github.com/grafana/pyroscope/api/gen/proto/go/querier/v1/querierv1connect"
17 | typesv1 "github.com/grafana/pyroscope/api/gen/proto/go/types/v1"
18 | "github.com/mark3labs/mcp-go/mcp"
19 | "github.com/mark3labs/mcp-go/server"
20 | )
21 |
22 | func AddPyroscopeTools(mcp *server.MCPServer) {
23 | ListPyroscopeLabelNames.Register(mcp)
24 | ListPyroscopeLabelValues.Register(mcp)
25 | ListPyroscopeProfileTypes.Register(mcp)
26 | FetchPyroscopeProfile.Register(mcp)
27 | }
28 |
29 | const listPyroscopeLabelNamesToolPrompt = `
30 | Lists all available label names (keys) found in profiles within a specified Pyroscope datasource, time range, and
31 | optional label matchers. Label matchers are typically used to qualify a service name ({service_name="foo"}). Returns a
32 | list of unique label strings (e.g., ["app", "env", "pod"]). Label names with double underscores (e.g. __name__) are
33 | internal and rarely useful to users. If the time range is not provided, it defaults to the last hour.
34 | `
35 |
36 | var ListPyroscopeLabelNames = mcpgrafana.MustTool(
37 | "list_pyroscope_label_names",
38 | listPyroscopeLabelNamesToolPrompt,
39 | listPyroscopeLabelNames,
40 | mcp.WithTitleAnnotation("List Pyroscope label names"),
41 | mcp.WithIdempotentHintAnnotation(true),
42 | mcp.WithReadOnlyHintAnnotation(true),
43 | )
44 |
45 | type ListPyroscopeLabelNamesParams struct {
46 | DataSourceUID string `json:"data_source_uid" jsonschema:"required,description=The UID of the datasource to query"`
47 | Matchers string `json:"matchers,omitempty" jsonschema:"Prometheus style matchers used t0 filter the result set (defaults to: {})"`
48 | StartRFC3339 string `json:"start_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query in RFC3339 format (defaults to 1 hour ago)"`
49 | EndRFC3339 string `json:"end_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query in RFC3339 format (defaults to now)"`
50 | }
51 |
52 | func listPyroscopeLabelNames(ctx context.Context, args ListPyroscopeLabelNamesParams) ([]string, error) {
53 | args.Matchers = stringOrDefault(args.Matchers, "{}")
54 |
55 | start, err := rfc3339OrDefault(args.StartRFC3339, time.Time{})
56 | if err != nil {
57 | return nil, fmt.Errorf("failed to parse start timestamp %q: %w", args.StartRFC3339, err)
58 | }
59 |
60 | end, err := rfc3339OrDefault(args.EndRFC3339, time.Time{})
61 | if err != nil {
62 | return nil, fmt.Errorf("failed to parse end timestamp %q: %w", args.EndRFC3339, err)
63 | }
64 |
65 | start, end, err = validateTimeRange(start, end)
66 | if err != nil {
67 | return nil, err
68 | }
69 |
70 | client, err := newPyroscopeClient(ctx, args.DataSourceUID)
71 | if err != nil {
72 | return nil, fmt.Errorf("failed to create Pyroscope client: %w", err)
73 | }
74 |
75 | req := &typesv1.LabelNamesRequest{
76 | Matchers: []string{args.Matchers},
77 | Start: start.UnixMilli(),
78 | End: end.UnixMilli(),
79 | }
80 | res, err := client.LabelNames(ctx, connect.NewRequest(req))
81 | if err != nil {
82 | return nil, fmt.Errorf("failed to call Pyroscope API: %w", err)
83 | }
84 |
85 | return res.Msg.Names, nil
86 | }
87 |
88 | const listPyroscopeLabelValuesToolPrompt = `
89 | Lists all available label values for a particular label name found in profiles within a specified Pyroscope datasource,
90 | time range, and optional label matchers. Label matchers are typically used to qualify a service name ({service_name="foo"}).
91 | Returns a list of unique label strings (e.g. for label name "env": ["dev", "staging", "prod"]). If the time range
92 | is not provided, it defaults to the last hour.
93 | `
94 |
95 | var ListPyroscopeLabelValues = mcpgrafana.MustTool(
96 | "list_pyroscope_label_values",
97 | listPyroscopeLabelValuesToolPrompt,
98 | listPyroscopeLabelValues,
99 | mcp.WithTitleAnnotation("List Pyroscope label values"),
100 | mcp.WithIdempotentHintAnnotation(true),
101 | mcp.WithReadOnlyHintAnnotation(true),
102 | )
103 |
104 | type ListPyroscopeLabelValuesParams struct {
105 | DataSourceUID string `json:"data_source_uid" jsonschema:"required,description=The UID of the datasource to query"`
106 | Name string `json:"name" jsonschema:"required,description=A label name"`
107 | Matchers string `json:"matchers,omitempty" jsonschema:"description=Optionally\\, Prometheus style matchers used to filter the result set (defaults to: {})"`
108 | StartRFC3339 string `json:"start_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query in RFC3339 format (defaults to 1 hour ago)"`
109 | EndRFC3339 string `json:"end_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query in RFC3339 format (defaults to now)"`
110 | }
111 |
112 | func listPyroscopeLabelValues(ctx context.Context, args ListPyroscopeLabelValuesParams) ([]string, error) {
113 | args.Name = strings.TrimSpace(args.Name)
114 | if args.Name == "" {
115 | return nil, fmt.Errorf("name is required")
116 | }
117 |
118 | args.Matchers = stringOrDefault(args.Matchers, "{}")
119 |
120 | start, err := rfc3339OrDefault(args.StartRFC3339, time.Time{})
121 | if err != nil {
122 | return nil, fmt.Errorf("failed to parse start timestamp %q: %w", args.StartRFC3339, err)
123 | }
124 |
125 | end, err := rfc3339OrDefault(args.EndRFC3339, time.Time{})
126 | if err != nil {
127 | return nil, fmt.Errorf("failed to parse end timestamp %q: %w", args.EndRFC3339, err)
128 | }
129 |
130 | start, end, err = validateTimeRange(start, end)
131 | if err != nil {
132 | return nil, err
133 | }
134 |
135 | client, err := newPyroscopeClient(ctx, args.DataSourceUID)
136 | if err != nil {
137 | return nil, fmt.Errorf("failed to create Pyroscope client: %w", err)
138 | }
139 |
140 | req := &typesv1.LabelValuesRequest{
141 | Name: args.Name,
142 | Matchers: []string{args.Matchers},
143 | Start: start.UnixMilli(),
144 | End: end.UnixMilli(),
145 | }
146 | res, err := client.LabelValues(ctx, connect.NewRequest(req))
147 | if err != nil {
148 | return nil, fmt.Errorf("failed to call Pyroscope API: %w", err)
149 | }
150 |
151 | return res.Msg.Names, nil
152 | }
153 |
154 | const listPyroscopeProfileTypesToolPrompt = `
155 | Lists all available profile types available in a specified Pyroscope datasource and time range. Returns a list of all
156 | available profile types (example profile type: "process_cpu:cpu:nanoseconds:cpu:nanoseconds"). A profile type has the
157 | following structure: <name>:<sample type>:<sample unit>:<period type>:<period unit>. Not all profile types are available
158 | for every service. If the time range is not provided, it defaults to the last hour.
159 | `
160 |
161 | var ListPyroscopeProfileTypes = mcpgrafana.MustTool(
162 | "list_pyroscope_profile_types",
163 | listPyroscopeProfileTypesToolPrompt,
164 | listPyroscopeProfileTypes,
165 | mcp.WithTitleAnnotation("List Pyroscope profile types"),
166 | mcp.WithIdempotentHintAnnotation(true),
167 | mcp.WithReadOnlyHintAnnotation(true),
168 | )
169 |
170 | type ListPyroscopeProfileTypesParams struct {
171 | DataSourceUID string `json:"data_source_uid" jsonschema:"required,description=The UID of the datasource to query"`
172 | StartRFC3339 string `json:"start_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query in RFC3339 format (defaults to 1 hour ago)"`
173 | EndRFC3339 string `json:"end_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query in RFC3339 format (defaults to now)"`
174 | }
175 |
176 | func listPyroscopeProfileTypes(ctx context.Context, args ListPyroscopeProfileTypesParams) ([]string, error) {
177 | start, err := rfc3339OrDefault(args.StartRFC3339, time.Time{})
178 | if err != nil {
179 | return nil, fmt.Errorf("failed to parse start timestamp %q: %w", args.StartRFC3339, err)
180 | }
181 |
182 | end, err := rfc3339OrDefault(args.EndRFC3339, time.Time{})
183 | if err != nil {
184 | return nil, fmt.Errorf("failed to parse end timestamp %q: %w", args.EndRFC3339, err)
185 | }
186 |
187 | start, end, err = validateTimeRange(start, end)
188 | if err != nil {
189 | return nil, err
190 | }
191 |
192 | client, err := newPyroscopeClient(ctx, args.DataSourceUID)
193 | if err != nil {
194 | return nil, fmt.Errorf("failed to create Pyroscope client: %w", err)
195 | }
196 |
197 | req := &querierv1.ProfileTypesRequest{
198 | Start: start.UnixMilli(),
199 | End: end.UnixMilli(),
200 | }
201 | res, err := client.ProfileTypes(ctx, connect.NewRequest(req))
202 | if err != nil {
203 | return nil, fmt.Errorf("failed to call Pyroscope API: %w", err)
204 | }
205 |
206 | profileTypes := make([]string, len(res.Msg.ProfileTypes))
207 | for i, typ := range res.Msg.ProfileTypes {
208 | profileTypes[i] = fmt.Sprintf("%s:%s:%s:%s:%s", typ.Name, typ.SampleType, typ.SampleUnit, typ.PeriodType, typ.PeriodUnit)
209 | }
210 | return profileTypes, nil
211 | }
212 |
213 | const fetchPyroscopeProfileToolPrompt = `
214 | Fetches a profile from a Pyroscope data source for a given time range. By default, the time range is tha past 1 hour.
215 | The profile type is required, available profile types can be fetched via the list_pyroscope_profile_types tool. Not all
216 | profile types are available for every service. Expect some queries to return empty result sets, this indicates the
217 | profile type does not exist for that query. In such a case, consider trying a related profile type or giving up.
218 | Matchers are not required, but highly recommended, they are generally used to select an application by the service_name
219 | label (e.g. {service_name="foo"}). Use the list_pyroscope_label_names tool to fetch available label names, and the
220 | list_pyroscope_label_values tool to fetch available label values. The returned profile is in DOT format.
221 | `
222 |
223 | var FetchPyroscopeProfile = mcpgrafana.MustTool(
224 | "fetch_pyroscope_profile",
225 | fetchPyroscopeProfileToolPrompt,
226 | fetchPyroscopeProfile,
227 | mcp.WithTitleAnnotation("Fetch Pyroscope profile"),
228 | mcp.WithIdempotentHintAnnotation(true),
229 | mcp.WithReadOnlyHintAnnotation(true),
230 | )
231 |
232 | type FetchPyroscopeProfileParams struct {
233 | DataSourceUID string `json:"data_source_uid" jsonschema:"required,description=The UID of the datasource to query"`
234 | ProfileType string `json:"profile_type" jsonschema:"required,description=Type profile type\\, use the list_pyroscope_profile_types tool to fetch available profile types"`
235 | Matchers string `json:"matchers,omitempty" jsonschema:"description=Optionally\\, Prometheus style matchers used to filter the result set (defaults to: {})"`
236 | MaxNodeDepth int `json:"max_node_depth,omitempty" jsonschema:"description=Optionally\\, the maximum depth of nodes in the resulting profile. Less depth results in smaller profiles that execute faster\\, more depth result in larger profiles that have more detail. A value of -1 indicates to use an unbounded node depth (default: 100). Reducing max node depth from the default will negatively impact the accuracy of the profile"`
237 | StartRFC3339 string `json:"start_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query in RFC3339 format (defaults to 1 hour ago)"`
238 | EndRFC3339 string `json:"end_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query in RFC3339 format (defaults to now)"`
239 | }
240 |
241 | func fetchPyroscopeProfile(ctx context.Context, args FetchPyroscopeProfileParams) (string, error) {
242 | args.Matchers = stringOrDefault(args.Matchers, "{}")
243 | matchersRegex := regexp.MustCompile(`^\{.*\}$`)
244 | if !matchersRegex.MatchString(args.Matchers) {
245 | args.Matchers = fmt.Sprintf("{%s}", args.Matchers)
246 | }
247 |
248 | args.MaxNodeDepth = intOrDefault(args.MaxNodeDepth, 100)
249 |
250 | start, err := rfc3339OrDefault(args.StartRFC3339, time.Time{})
251 | if err != nil {
252 | return "", fmt.Errorf("failed to parse start timestamp %q: %w", args.StartRFC3339, err)
253 | }
254 |
255 | end, err := rfc3339OrDefault(args.EndRFC3339, time.Time{})
256 | if err != nil {
257 | return "", fmt.Errorf("failed to parse end timestamp %q: %w", args.EndRFC3339, err)
258 | }
259 |
260 | start, end, err = validateTimeRange(start, end)
261 | if err != nil {
262 | return "", err
263 | }
264 |
265 | client, err := newPyroscopeClient(ctx, args.DataSourceUID)
266 | if err != nil {
267 | return "", fmt.Errorf("failed to create Pyroscope client: %w", err)
268 | }
269 |
270 | req := &renderRequest{
271 | ProfileType: args.ProfileType,
272 | Matcher: args.Matchers,
273 | Start: start,
274 | End: end,
275 | Format: "dot",
276 | MaxNodes: args.MaxNodeDepth,
277 | }
278 | res, err := client.Render(ctx, req)
279 | if err != nil {
280 | return "", fmt.Errorf("failed to call Pyroscope API: %w", err)
281 | }
282 |
283 | res = cleanupDotProfile(res)
284 | return res, nil
285 | }
286 |
287 | func newPyroscopeClient(ctx context.Context, uid string) (*pyroscopeClient, error) {
288 | cfg := mcpgrafana.GrafanaConfigFromContext(ctx)
289 |
290 | var transport http.RoundTripper = NewAuthRoundTripper(http.DefaultTransport, cfg.AccessToken, cfg.IDToken, cfg.APIKey, cfg.BasicAuth)
291 | transport = mcpgrafana.NewOrgIDRoundTripper(transport, cfg.OrgID)
292 |
293 | httpClient := &http.Client{
294 | Transport: mcpgrafana.NewUserAgentTransport(
295 | transport,
296 | ),
297 | Timeout: 10 * time.Second,
298 | }
299 |
300 | _, err := getDatasourceByUID(ctx, GetDatasourceByUIDParams{UID: uid})
301 | if err != nil {
302 | return nil, err
303 | }
304 |
305 | base, err := url.Parse(cfg.URL)
306 | if err != nil {
307 | return nil, fmt.Errorf("failed to parse base url: %w", err)
308 | }
309 | base = base.JoinPath("api", "datasources", "proxy", "uid", uid)
310 |
311 | querierClient := querierv1connect.NewQuerierServiceClient(httpClient, base.String())
312 |
313 | client := &pyroscopeClient{
314 | QuerierServiceClient: querierClient,
315 | http: httpClient,
316 | base: base,
317 | }
318 | return client, nil
319 | }
320 |
321 | type renderRequest struct {
322 | ProfileType string
323 | Matcher string
324 | Start time.Time
325 | End time.Time
326 | Format string
327 | MaxNodes int
328 | }
329 |
330 | type pyroscopeClient struct {
331 | querierv1connect.QuerierServiceClient
332 | http *http.Client
333 | base *url.URL
334 | }
335 |
336 | // Calls the /render endpoint for Pyroscope. This returns a rendered flame graph
337 | // (typically in Flamebearer or DOT formats).
338 | func (c *pyroscopeClient) Render(ctx context.Context, args *renderRequest) (string, error) {
339 | params := url.Values{}
340 | params.Add("query", fmt.Sprintf("%s%s", args.ProfileType, args.Matcher))
341 | params.Add("from", fmt.Sprintf("%d", args.Start.UnixMilli()))
342 | params.Add("until", fmt.Sprintf("%d", args.End.UnixMilli()))
343 | params.Add("format", args.Format)
344 | params.Add("max-nodes", fmt.Sprintf("%d", args.MaxNodes))
345 |
346 | res, err := c.get(ctx, "/pyroscope/render", params)
347 | if err != nil {
348 | return "", err
349 | }
350 |
351 | return string(res), nil
352 | }
353 |
354 | func (c *pyroscopeClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) {
355 | u := c.base.JoinPath(path)
356 |
357 | q := u.Query()
358 | for k, vs := range params {
359 | for _, v := range vs {
360 | q.Add(k, v)
361 | }
362 | }
363 | u.RawQuery = q.Encode()
364 |
365 | req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
366 | if err != nil {
367 | return nil, fmt.Errorf("failed to create GET request: %w", err)
368 | }
369 |
370 | res, err := c.http.Do(req)
371 | if err != nil {
372 | return nil, fmt.Errorf("failed to send request: %w", err)
373 | }
374 | defer func() {
375 | _ = res.Body.Close() //nolint:errcheck
376 | }()
377 |
378 | if res.StatusCode < 200 || res.StatusCode > 299 {
379 | body, err := io.ReadAll(res.Body)
380 | if err != nil {
381 | return nil, fmt.Errorf("pyroscope API failed with status code %d", res.StatusCode)
382 | }
383 | return nil, fmt.Errorf("pyroscope API failed with status code %d: %s", res.StatusCode, string(body))
384 | }
385 |
386 | const limit = 1 << 25 // 32 MiB
387 | body, err := io.ReadAll(io.LimitReader(res.Body, limit))
388 | if err != nil {
389 | return nil, fmt.Errorf("failed to read response body: %w", err)
390 | }
391 |
392 | if len(body) == 0 {
393 | return nil, fmt.Errorf("pyroscope API returned an empty response")
394 | }
395 |
396 | if strings.Contains(string(body), "Showing nodes accounting for 0, 0% of 0 total") {
397 | return nil, fmt.Errorf("pyroscope API returned a empty profile")
398 | }
399 | return body, nil
400 | }
401 |
402 | func intOrDefault(n int, def int) int {
403 | if n == 0 {
404 | return def
405 | }
406 | return n
407 | }
408 |
409 | func stringOrDefault(s string, def string) string {
410 | if strings.TrimSpace(s) == "" {
411 | return def
412 | }
413 | return s
414 | }
415 |
416 | func rfc3339OrDefault(s string, def time.Time) (time.Time, error) {
417 | s = strings.TrimSpace(s)
418 |
419 | var err error
420 | if s != "" {
421 | def, err = time.Parse(time.RFC3339, s)
422 | if err != nil {
423 | return time.Time{}, err
424 | }
425 | }
426 |
427 | return def, nil
428 | }
429 |
430 | func validateTimeRange(start time.Time, end time.Time) (time.Time, time.Time, error) {
431 | if end.IsZero() {
432 | end = time.Now()
433 | }
434 |
435 | if start.IsZero() {
436 | start = end.Add(-1 * time.Hour)
437 | }
438 |
439 | if start.After(end) || start.Equal(end) {
440 | return time.Time{}, time.Time{}, fmt.Errorf("start timestamp %q must be strictly before end timestamp %q", start.Format(time.RFC3339), end.Format(time.RFC3339))
441 | }
442 |
443 | return start, end, nil
444 | }
445 |
446 | var cleanupRegex = regexp.MustCompile(`(?m)(fontsize=\d+ )|(id="node\d+" )|(labeltooltip=".*?\)" )|(tooltip=".*?\)" )|(N\d+ -> N\d+).*|(N\d+ \[label="other.*\n)|(shape=box )|(fillcolor="#\w{6}")|(color="#\w{6}" )`)
447 |
448 | func cleanupDotProfile(profile string) string {
449 | return cleanupRegex.ReplaceAllStringFunc(profile, func(match string) string {
450 | // Preserve edge labels (e.g., "N1 -> N2")
451 | if m := regexp.MustCompile(`^N\d+ -> N\d+`).FindString(match); m != "" {
452 | return m
453 | }
454 | return ""
455 | })
456 | }
457 |
```