#
tokens: 44116/50000 9/96 files (page 3/5)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 3 of 5. Use http://codebase.md/grafana/mcp-grafana?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .dockerignore
├── .github
│   ├── dependabot.yml
│   └── workflows
│       ├── docker.yml
│       ├── e2e.yml
│       ├── integration.yml
│       ├── release.yml
│       └── unit.yml
├── .gitignore
├── .golangci.yaml
├── .goreleaser.yaml
├── cmd
│   ├── linters
│   │   └── jsonschema
│   │       └── main.go
│   └── mcp-grafana
│       └── main.go
├── CODEOWNERS
├── docker-compose.yaml
├── Dockerfile
├── examples
│   └── tls_example.go
├── gemini-extension.json
├── go.mod
├── go.sum
├── image-tag
├── internal
│   └── linter
│       └── jsonschema
│           ├── jsonschema_lint_test.go
│           ├── jsonschema_lint.go
│           └── README.md
├── LICENSE
├── Makefile
├── mcpgrafana_test.go
├── mcpgrafana.go
├── proxied_client.go
├── proxied_handler.go
├── proxied_tools_test.go
├── proxied_tools.go
├── README.md
├── renovate.json
├── server.json
├── session_test.go
├── session.go
├── testdata
│   ├── dashboards
│   │   └── demo.json
│   ├── loki-config.yml
│   ├── prometheus-entrypoint.sh
│   ├── prometheus-seed.yml
│   ├── prometheus.yml
│   ├── promtail-config.yml
│   ├── provisioning
│   │   ├── alerting
│   │   │   ├── alert_rules.yaml
│   │   │   └── contact_points.yaml
│   │   ├── dashboards
│   │   │   └── dashboards.yaml
│   │   └── datasources
│   │       └── datasources.yaml
│   ├── tempo-config-2.yaml
│   └── tempo-config.yaml
├── tests
│   ├── .gitignore
│   ├── .python-version
│   ├── admin_test.py
│   ├── conftest.py
│   ├── dashboards_test.py
│   ├── disable_write_test.py
│   ├── health_test.py
│   ├── loki_test.py
│   ├── navigation_test.py
│   ├── pyproject.toml
│   ├── README.md
│   ├── tempo_test.py
│   ├── utils.py
│   └── uv.lock
├── tls_test.go
├── tools
│   ├── admin_test.go
│   ├── admin.go
│   ├── alerting_client_test.go
│   ├── alerting_client.go
│   ├── alerting_test.go
│   ├── alerting_unit_test.go
│   ├── alerting.go
│   ├── annotations_integration_test.go
│   ├── annotations_unit_test.go
│   ├── annotations.go
│   ├── asserts_cloud_test.go
│   ├── asserts_test.go
│   ├── asserts.go
│   ├── cloud_testing_utils.go
│   ├── dashboard_test.go
│   ├── dashboard.go
│   ├── datasources_test.go
│   ├── datasources.go
│   ├── folder.go
│   ├── incident_integration_test.go
│   ├── incident_test.go
│   ├── incident.go
│   ├── loki_test.go
│   ├── loki.go
│   ├── navigation_test.go
│   ├── navigation.go
│   ├── oncall_cloud_test.go
│   ├── oncall.go
│   ├── prometheus_test.go
│   ├── prometheus_unit_test.go
│   ├── prometheus.go
│   ├── pyroscope_test.go
│   ├── pyroscope.go
│   ├── search_test.go
│   ├── search.go
│   ├── sift_cloud_test.go
│   ├── sift.go
│   └── testcontext_test.go
├── tools_test.go
└── tools.go
```

# Files

--------------------------------------------------------------------------------
/tools/oncall_cloud_test.go:
--------------------------------------------------------------------------------

```go
  1 | //go:build cloud
  2 | // +build cloud
  3 | 
  4 | // This file contains cloud integration tests that run against a dedicated test instance
  5 | // at mcptests.grafana-dev.net. This instance is configured with a minimal setup on the OnCall side:
  6 | //   - One team
  7 | //   - Two schedules (only one has a team assigned)
  8 | //   - One shift in the schedule with a team assigned
  9 | //   - One user
 10 | // These tests expect this configuration to exist and will skip if the required
 11 | // environment variables (GRAFANA_URL, GRAFANA_SERVICE_ACCOUNT_TOKEN or GRAFANA_API_KEY) are not set.
 12 | // The GRAFANA_API_KEY variable is deprecated.
 13 | 
 14 | package tools
 15 | 
 16 | import (
 17 | 	"testing"
 18 | 
 19 | 	"github.com/stretchr/testify/assert"
 20 | 	"github.com/stretchr/testify/require"
 21 | )
 22 | 
 23 | func TestCloudOnCallSchedules(t *testing.T) {
 24 | 	ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
 25 | 
 26 | 	// Test listing all schedules
 27 | 	t.Run("list all schedules", func(t *testing.T) {
 28 | 		result, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{})
 29 | 		require.NoError(t, err, "Should not error when listing schedules")
 30 | 		assert.NotNil(t, result, "Result should not be nil")
 31 | 	})
 32 | 
 33 | 	// Test pagination
 34 | 	t.Run("list schedules with pagination", func(t *testing.T) {
 35 | 		// Get first page
 36 | 		page1, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{Page: 1})
 37 | 		require.NoError(t, err, "Should not error when listing schedules page 1")
 38 | 		assert.NotNil(t, page1, "Page 1 should not be nil")
 39 | 
 40 | 		// Get second page
 41 | 		page2, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{Page: 2})
 42 | 		require.NoError(t, err, "Should not error when listing schedules page 2")
 43 | 		assert.NotNil(t, page2, "Page 2 should not be nil")
 44 | 	})
 45 | 
 46 | 	// Get a team ID from an existing schedule to test filtering
 47 | 	schedules, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{})
 48 | 	require.NoError(t, err, "Should not error when listing schedules")
 49 | 
 50 | 	if len(schedules) > 0 && schedules[0].TeamID != "" {
 51 | 		teamID := schedules[0].TeamID
 52 | 
 53 | 		// Test filtering by team ID
 54 | 		t.Run("list schedules by team ID", func(t *testing.T) {
 55 | 			result, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{
 56 | 				TeamID: teamID,
 57 | 			})
 58 | 			require.NoError(t, err, "Should not error when listing schedules by team")
 59 | 			assert.NotEmpty(t, result, "Should return at least one schedule")
 60 | 			for _, schedule := range result {
 61 | 				assert.Equal(t, teamID, schedule.TeamID, "All schedules should belong to the specified team")
 62 | 			}
 63 | 		})
 64 | 	}
 65 | 
 66 | 	// Test getting a specific schedule
 67 | 	if len(schedules) > 0 {
 68 | 		scheduleID := schedules[0].ID
 69 | 		t.Run("get specific schedule", func(t *testing.T) {
 70 | 			result, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{
 71 | 				ScheduleID: scheduleID,
 72 | 			})
 73 | 			require.NoError(t, err, "Should not error when getting specific schedule")
 74 | 			assert.Len(t, result, 1, "Should return exactly one schedule")
 75 | 			assert.Equal(t, scheduleID, result[0].ID, "Should return the correct schedule")
 76 | 
 77 | 			// Verify all summary fields are present
 78 | 			schedule := result[0]
 79 | 			assert.NotEmpty(t, schedule.Name, "Schedule should have a name")
 80 | 			assert.NotEmpty(t, schedule.Timezone, "Schedule should have a timezone")
 81 | 			assert.NotNil(t, schedule.Shifts, "Schedule should have a shifts field")
 82 | 		})
 83 | 	}
 84 | }
 85 | 
 86 | func TestCloudOnCallShift(t *testing.T) {
 87 | 	ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
 88 | 
 89 | 	// First get a schedule to find a valid shift
 90 | 	schedules, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{})
 91 | 	require.NoError(t, err, "Should not error when listing schedules")
 92 | 	require.NotEmpty(t, schedules, "Should have at least one schedule to test with")
 93 | 	require.NotEmpty(t, schedules[0].Shifts, "Schedule should have at least one shift")
 94 | 
 95 | 	shifts := schedules[0].Shifts
 96 | 	shiftID := shifts[0]
 97 | 
 98 | 	// Test getting shift details with valid ID
 99 | 	t.Run("get shift details", func(t *testing.T) {
100 | 		result, err := getOnCallShift(ctx, GetOnCallShiftParams{
101 | 			ShiftID: shiftID,
102 | 		})
103 | 		require.NoError(t, err, "Should not error when getting shift details")
104 | 		assert.NotNil(t, result, "Result should not be nil")
105 | 		assert.Equal(t, shiftID, result.ID, "Should return the correct shift")
106 | 	})
107 | 
108 | 	t.Run("get shift with invalid ID", func(t *testing.T) {
109 | 		_, err := getOnCallShift(ctx, GetOnCallShiftParams{
110 | 			ShiftID: "invalid-shift-id",
111 | 		})
112 | 		assert.Error(t, err, "Should error when getting shift with invalid ID")
113 | 	})
114 | }
115 | 
116 | func TestCloudGetCurrentOnCallUsers(t *testing.T) {
117 | 	ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
118 | 
119 | 	// First get a schedule to use for testing
120 | 	schedules, err := listOnCallSchedules(ctx, ListOnCallSchedulesParams{})
121 | 	require.NoError(t, err, "Should not error when listing schedules")
122 | 	require.NotEmpty(t, schedules, "Should have at least one schedule to test with")
123 | 
124 | 	scheduleID := schedules[0].ID
125 | 
126 | 	// Test getting current on-call users
127 | 	t.Run("get current on-call users", func(t *testing.T) {
128 | 		result, err := getCurrentOnCallUsers(ctx, GetCurrentOnCallUsersParams{
129 | 			ScheduleID: scheduleID,
130 | 		})
131 | 		require.NoError(t, err, "Should not error when getting current on-call users")
132 | 		assert.NotNil(t, result, "Result should not be nil")
133 | 		assert.Equal(t, scheduleID, result.ScheduleID, "Should return the correct schedule")
134 | 		assert.NotEmpty(t, result.ScheduleName, "Schedule should have a name")
135 | 		assert.NotNil(t, result.Users, "Users field should be present")
136 | 
137 | 		// Assert that Users is of type []*aapi.User
138 | 		if len(result.Users) > 0 {
139 | 			user := result.Users[0]
140 | 			assert.NotEmpty(t, user.ID, "User should have an ID")
141 | 			assert.NotEmpty(t, user.Username, "User should have a username")
142 | 		}
143 | 	})
144 | 
145 | 	t.Run("get current on-call users with invalid schedule ID", func(t *testing.T) {
146 | 		_, err := getCurrentOnCallUsers(ctx, GetCurrentOnCallUsersParams{
147 | 			ScheduleID: "invalid-schedule-id",
148 | 		})
149 | 		assert.Error(t, err, "Should error when getting current on-call users with invalid schedule ID")
150 | 	})
151 | }
152 | 
153 | func TestCloudOnCallTeams(t *testing.T) {
154 | 	ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
155 | 
156 | 	t.Run("list teams", func(t *testing.T) {
157 | 		result, err := listOnCallTeams(ctx, ListOnCallTeamsParams{})
158 | 		require.NoError(t, err, "Should not error when listing teams")
159 | 		assert.NotNil(t, result, "Result should not be nil")
160 | 
161 | 		if len(result) > 0 {
162 | 			team := result[0]
163 | 			assert.NotEmpty(t, team.ID, "Team should have an ID")
164 | 			assert.NotEmpty(t, team.Name, "Team should have a name")
165 | 		}
166 | 	})
167 | 
168 | 	// Test pagination
169 | 	t.Run("list teams with pagination", func(t *testing.T) {
170 | 		// Get first page
171 | 		page1, err := listOnCallTeams(ctx, ListOnCallTeamsParams{Page: 1})
172 | 		require.NoError(t, err, "Should not error when listing teams page 1")
173 | 		assert.NotNil(t, page1, "Page 1 should not be nil")
174 | 
175 | 		// Get second page
176 | 		page2, err := listOnCallTeams(ctx, ListOnCallTeamsParams{Page: 2})
177 | 		require.NoError(t, err, "Should not error when listing teams page 2")
178 | 		assert.NotNil(t, page2, "Page 2 should not be nil")
179 | 	})
180 | }
181 | 
182 | func TestCloudOnCallUsers(t *testing.T) {
183 | 	ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
184 | 
185 | 	t.Run("list all users", func(t *testing.T) {
186 | 		result, err := listOnCallUsers(ctx, ListOnCallUsersParams{})
187 | 		require.NoError(t, err, "Should not error when listing users")
188 | 		assert.NotNil(t, result, "Result should not be nil")
189 | 
190 | 		if len(result) > 0 {
191 | 			user := result[0]
192 | 			assert.NotEmpty(t, user.ID, "User should have an ID")
193 | 			assert.NotEmpty(t, user.Username, "User should have a username")
194 | 		}
195 | 	})
196 | 
197 | 	// Test pagination
198 | 	t.Run("list users with pagination", func(t *testing.T) {
199 | 		// Get first page
200 | 		page1, err := listOnCallUsers(ctx, ListOnCallUsersParams{Page: 1})
201 | 		require.NoError(t, err, "Should not error when listing users page 1")
202 | 		assert.NotNil(t, page1, "Page 1 should not be nil")
203 | 
204 | 		// Get second page
205 | 		page2, err := listOnCallUsers(ctx, ListOnCallUsersParams{Page: 2})
206 | 		require.NoError(t, err, "Should not error when listing users page 2")
207 | 		assert.NotNil(t, page2, "Page 2 should not be nil")
208 | 	})
209 | 
210 | 	// Get a user ID and username from the list to test filtering
211 | 	users, err := listOnCallUsers(ctx, ListOnCallUsersParams{})
212 | 	require.NoError(t, err, "Should not error when listing users")
213 | 	require.NotEmpty(t, users, "Should have at least one user to test with")
214 | 
215 | 	userID := users[0].ID
216 | 	username := users[0].Username
217 | 
218 | 	t.Run("get user by ID", func(t *testing.T) {
219 | 		result, err := listOnCallUsers(ctx, ListOnCallUsersParams{
220 | 			UserID: userID,
221 | 		})
222 | 		require.NoError(t, err, "Should not error when getting user by ID")
223 | 		assert.NotNil(t, result, "Result should not be nil")
224 | 		assert.Len(t, result, 1, "Should return exactly one user")
225 | 		assert.Equal(t, userID, result[0].ID, "Should return the correct user")
226 | 		assert.NotEmpty(t, result[0].Username, "User should have a username")
227 | 	})
228 | 
229 | 	t.Run("get user by username", func(t *testing.T) {
230 | 		result, err := listOnCallUsers(ctx, ListOnCallUsersParams{
231 | 			Username: username,
232 | 		})
233 | 		require.NoError(t, err, "Should not error when getting user by username")
234 | 		assert.NotNil(t, result, "Result should not be nil")
235 | 		assert.Len(t, result, 1, "Should return exactly one user")
236 | 		assert.Equal(t, username, result[0].Username, "Should return the correct user")
237 | 		assert.NotEmpty(t, result[0].ID, "User should have an ID")
238 | 	})
239 | 
240 | 	t.Run("get user with invalid ID", func(t *testing.T) {
241 | 		_, err := listOnCallUsers(ctx, ListOnCallUsersParams{
242 | 			UserID: "invalid-user-id",
243 | 		})
244 | 		assert.Error(t, err, "Should error when getting user with invalid ID")
245 | 	})
246 | 
247 | 	t.Run("get user with invalid username", func(t *testing.T) {
248 | 		result, err := listOnCallUsers(ctx, ListOnCallUsersParams{
249 | 			Username: "invalid-username",
250 | 		})
251 | 		require.NoError(t, err, "Should not error when getting user with invalid username")
252 | 		assert.Empty(t, result, "Should return empty result set for invalid username")
253 | 	})
254 | }
255 | 
256 | func TestCloudGetAlertGroup(t *testing.T) {
257 | 	ctx := createCloudTestContext(t, "OnCall", "GRAFANA_URL", "GRAFANA_API_KEY")
258 | 
259 | 	// First, get a list of alert groups to find a valid ID to test with
260 | 	alertGroups, err := listAlertGroups(ctx, ListAlertGroupsParams{})
261 | 	require.NoError(t, err, "Should not error when listing alert groups")
262 | 	require.NotEmpty(t, alertGroups, "Should have at least one alert group to test with")
263 | 
264 | 	alertGroupID := alertGroups[0].ID
265 | 
266 | 	t.Run("get alert group by ID", func(t *testing.T) {
267 | 		result, err := getAlertGroup(ctx, GetAlertGroupParams{
268 | 			AlertGroupID: alertGroupID,
269 | 		})
270 | 		require.NoError(t, err, "Should not error when getting alert group by ID")
271 | 		assert.NotNil(t, result, "Result should not be nil")
272 | 		assert.Equal(t, alertGroupID, result.ID, "Should return the correct alert group")
273 | 		assert.NotEmpty(t, result.Title, "Alert group should have a title")
274 | 		assert.NotEmpty(t, result.State, "Alert group should have a state")
275 | 	})
276 | 
277 | 	t.Run("get alert group with invalid ID", func(t *testing.T) {
278 | 		_, err := getAlertGroup(ctx, GetAlertGroupParams{
279 | 			AlertGroupID: "invalid-alert-group-id",
280 | 		})
281 | 		assert.Error(t, err, "Should error when getting alert group with invalid ID")
282 | 	})
283 | }
284 | 
```

--------------------------------------------------------------------------------
/proxied_tools.go:
--------------------------------------------------------------------------------

```go
  1 | package mcpgrafana
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"fmt"
  6 | 	"log/slog"
  7 | 	"net/http"
  8 | 	"strings"
  9 | 	"sync"
 10 | 
 11 | 	"github.com/go-openapi/runtime"
 12 | 	"github.com/mark3labs/mcp-go/mcp"
 13 | 	"github.com/mark3labs/mcp-go/server"
 14 | )
 15 | 
 16 | // MCPDatasourceConfig defines configuration for a datasource type that supports MCP
 17 | type MCPDatasourceConfig struct {
 18 | 	Type         string
 19 | 	EndpointPath string // e.g., "/api/mcp"
 20 | }
 21 | 
 22 | // mcpEnabledDatasources is a registry of datasource types that support MCP
 23 | var mcpEnabledDatasources = map[string]MCPDatasourceConfig{
 24 | 	"tempo": {Type: "tempo", EndpointPath: "/api/mcp"},
 25 | 	// Future: add other datasource types here
 26 | }
 27 | 
 28 | // DiscoveredDatasource represents a datasource that supports MCP
 29 | type DiscoveredDatasource struct {
 30 | 	UID    string
 31 | 	Name   string
 32 | 	Type   string
 33 | 	MCPURL string // The MCP endpoint URL
 34 | }
 35 | 
 36 | // discoverMCPDatasources discovers datasources that support MCP
 37 | // Returns a list of datasources with MCP endpoints
 38 | func discoverMCPDatasources(ctx context.Context) ([]DiscoveredDatasource, error) {
 39 | 	gc := GrafanaClientFromContext(ctx)
 40 | 	if gc == nil {
 41 | 		return nil, fmt.Errorf("grafana client not found in context")
 42 | 	}
 43 | 
 44 | 	var discovered []DiscoveredDatasource
 45 | 
 46 | 	// List all datasources
 47 | 	resp, err := gc.Datasources.GetDataSources()
 48 | 	if err != nil {
 49 | 		return nil, fmt.Errorf("failed to list datasources: %w", err)
 50 | 	}
 51 | 
 52 | 	// Get the Grafana base URL from context
 53 | 	config := GrafanaConfigFromContext(ctx)
 54 | 	if config.URL == "" {
 55 | 		return nil, fmt.Errorf("grafana url not found in context")
 56 | 	}
 57 | 	grafanaBaseURL := config.URL
 58 | 
 59 | 	// Filter for datasources that support MCP
 60 | 	for _, ds := range resp.Payload {
 61 | 		// Check if this datasource type supports MCP
 62 | 		dsConfig, supported := mcpEnabledDatasources[ds.Type]
 63 | 		if !supported {
 64 | 			continue
 65 | 		}
 66 | 
 67 | 		// Check if the datasource instance has MCP enabled
 68 | 		// We use a DELETE request to probe the MCP endpoint since:
 69 | 		// - GET would start an event stream and hang
 70 | 		// - POST doesn't work with the Grafana OpenAPI client
 71 | 		// - DELETE returns 200 if MCP is enabled, 404 if not
 72 | 		_, err := gc.Datasources.DatasourceProxyDELETEByUIDcalls(ds.UID, strings.TrimPrefix(dsConfig.EndpointPath, "/"))
 73 | 		if err == nil {
 74 | 			// Something strange happened - the server should never return a 202 for this really. Skip.
 75 | 			continue
 76 | 		}
 77 | 		if apiErr, ok := err.(*runtime.APIError); !ok || (ok && !apiErr.IsCode(http.StatusOK)) {
 78 | 			// Not a 200 response, MCP not enabled
 79 | 			continue
 80 | 		}
 81 | 
 82 | 		// Build the MCP endpoint URL using Grafana's datasource proxy API
 83 | 		// Format: <grafana URL>/api/datasources/proxy/uid/<uid><endpoint_path>
 84 | 		mcpURL := fmt.Sprintf("%s/api/datasources/proxy/uid/%s%s", grafanaBaseURL, ds.UID, dsConfig.EndpointPath)
 85 | 
 86 | 		discovered = append(discovered, DiscoveredDatasource{
 87 | 			UID:    ds.UID,
 88 | 			Name:   ds.Name,
 89 | 			Type:   ds.Type,
 90 | 			MCPURL: mcpURL,
 91 | 		})
 92 | 	}
 93 | 
 94 | 	slog.DebugContext(ctx, "discovered MCP datasources", "count", len(discovered))
 95 | 	return discovered, nil
 96 | }
 97 | 
 98 | // addDatasourceUidParameter adds a required datasourceUid parameter to a tool's input schema
 99 | func addDatasourceUidParameter(tool mcp.Tool, datasourceType string) mcp.Tool {
100 | 	modifiedTool := tool
101 | 	// Prefix tool name with datasource type (e.g., "tempo_traceql-search")
102 | 	modifiedTool.Name = datasourceType + "_" + tool.Name
103 | 
104 | 	// Add datasourceUid to the input schema
105 | 	if modifiedTool.InputSchema.Properties == nil {
106 | 		modifiedTool.InputSchema.Properties = make(map[string]any)
107 | 	}
108 | 
109 | 	modifiedTool.InputSchema.Properties["datasourceUid"] = map[string]any{
110 | 		"type":        "string",
111 | 		"description": "UID of the " + datasourceType + " datasource to query",
112 | 	}
113 | 
114 | 	// Add to required fields
115 | 	modifiedTool.InputSchema.Required = append(modifiedTool.InputSchema.Required, "datasourceUid")
116 | 
117 | 	return modifiedTool
118 | }
119 | 
120 | // parseProxiedToolName extracts datasource type and original tool name from a proxied tool name
121 | // Format: <datasource_type>_<original_tool_name>
122 | // Returns: datasourceType, originalToolName, error
123 | func parseProxiedToolName(toolName string) (string, string, error) {
124 | 	parts := strings.SplitN(toolName, "_", 2)
125 | 	if len(parts) != 2 {
126 | 		return "", "", fmt.Errorf("invalid proxied tool name format: %s", toolName)
127 | 	}
128 | 	return parts[0], parts[1], nil
129 | }
130 | 
131 | // ToolManager manages proxied tools (either per-session or server-wide)
132 | type ToolManager struct {
133 | 	sm     *SessionManager
134 | 	server *server.MCPServer
135 | 
136 | 	// Whether to enable proxied tools.
137 | 	enableProxiedTools bool
138 | 
139 | 	// For stdio transport: store clients at manager level (single-tenant).
140 | 	// These will be unused for HTTP/SSE transports.
141 | 	serverMode    bool // true if using server-wide tools (stdio), false for per-session (HTTP/SSE)
142 | 	serverClients map[string]*ProxiedClient
143 | 	clientsMutex  sync.RWMutex
144 | }
145 | 
146 | // NewToolManager creates a new ToolManager
147 | func NewToolManager(sm *SessionManager, mcpServer *server.MCPServer, opts ...toolManagerOption) *ToolManager {
148 | 	tm := &ToolManager{
149 | 		sm:            sm,
150 | 		server:        mcpServer,
151 | 		serverClients: make(map[string]*ProxiedClient),
152 | 	}
153 | 	for _, opt := range opts {
154 | 		opt(tm)
155 | 	}
156 | 	return tm
157 | }
158 | 
159 | type toolManagerOption func(*ToolManager)
160 | 
161 | // WithProxiedTools sets whether proxied tools are enabled
162 | func WithProxiedTools(enabled bool) toolManagerOption {
163 | 	return func(tm *ToolManager) {
164 | 		tm.enableProxiedTools = enabled
165 | 	}
166 | }
167 | 
168 | // InitializeAndRegisterServerTools discovers datasources and registers tools on the server (for stdio transport)
169 | // This should be called once at server startup for single-tenant stdio servers
170 | func (tm *ToolManager) InitializeAndRegisterServerTools(ctx context.Context) error {
171 | 	if !tm.enableProxiedTools {
172 | 		return nil
173 | 	}
174 | 
175 | 	// Mark as server mode (stdio transport)
176 | 	tm.serverMode = true
177 | 
178 | 	// Discover datasources with MCP support
179 | 	discovered, err := discoverMCPDatasources(ctx)
180 | 	if err != nil {
181 | 		return fmt.Errorf("failed to discover MCP datasources: %w", err)
182 | 	}
183 | 
184 | 	if len(discovered) == 0 {
185 | 		slog.Info("no MCP datasources discovered")
186 | 		return nil
187 | 	}
188 | 
189 | 	// Connect to each datasource and store in manager
190 | 	tm.clientsMutex.Lock()
191 | 	for _, ds := range discovered {
192 | 		client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
193 | 		if err != nil {
194 | 			slog.Error("failed to create proxied client", "datasource", ds.UID, "error", err)
195 | 			continue
196 | 		}
197 | 		key := ds.Type + "_" + ds.UID
198 | 		tm.serverClients[key] = client
199 | 	}
200 | 	clientCount := len(tm.serverClients)
201 | 	tm.clientsMutex.Unlock()
202 | 
203 | 	if clientCount == 0 {
204 | 		slog.Warn("no proxied clients created")
205 | 		return nil
206 | 	}
207 | 
208 | 	slog.Info("connected to proxied MCP servers", "datasources", clientCount)
209 | 
210 | 	// Collect and register all unique tools
211 | 	tm.clientsMutex.RLock()
212 | 	toolMap := make(map[string]mcp.Tool)
213 | 	for _, client := range tm.serverClients {
214 | 		for _, tool := range client.ListTools() {
215 | 			toolName := client.DatasourceType + "_" + tool.Name
216 | 			if _, exists := toolMap[toolName]; !exists {
217 | 				modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType)
218 | 				toolMap[toolName] = modifiedTool
219 | 			}
220 | 		}
221 | 	}
222 | 	tm.clientsMutex.RUnlock()
223 | 
224 | 	// Register tools on the server (not per-session)
225 | 	for toolName, tool := range toolMap {
226 | 		handler := NewProxiedToolHandler(tm.sm, tm, toolName)
227 | 		tm.server.AddTool(tool, handler.Handle)
228 | 	}
229 | 
230 | 	slog.Info("registered proxied tools on server", "tools", len(toolMap))
231 | 	return nil
232 | }
233 | 
234 | // InitializeAndRegisterProxiedTools discovers datasources, creates clients, and registers tools per-session
235 | // This should be called in OnBeforeListTools and OnBeforeCallTool hooks for HTTP/SSE transports
236 | func (tm *ToolManager) InitializeAndRegisterProxiedTools(ctx context.Context, session server.ClientSession) {
237 | 	if !tm.enableProxiedTools {
238 | 		return
239 | 	}
240 | 
241 | 	sessionID := session.SessionID()
242 | 	state, exists := tm.sm.GetSession(sessionID)
243 | 	if !exists {
244 | 		// Session exists in server context but not in our SessionManager yet
245 | 		tm.sm.CreateSession(ctx, session)
246 | 		state, exists = tm.sm.GetSession(sessionID)
247 | 		if !exists {
248 | 			slog.Error("failed to create session in SessionManager", "sessionID", sessionID)
249 | 			return
250 | 		}
251 | 	}
252 | 
253 | 	// Step 1: Discover and connect (guaranteed to run exactly once per session)
254 | 	state.initOnce.Do(func() {
255 | 		// Discover datasources with MCP support
256 | 		discovered, err := discoverMCPDatasources(ctx)
257 | 		if err != nil {
258 | 			slog.Error("failed to discover MCP datasources", "error", err)
259 | 			state.mutex.Lock()
260 | 			state.proxiedToolsInitialized = true
261 | 			state.mutex.Unlock()
262 | 			return
263 | 		}
264 | 
265 | 		state.mutex.Lock()
266 | 		// For each discovered datasource, create a proxied client
267 | 		for _, ds := range discovered {
268 | 			client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
269 | 			if err != nil {
270 | 				slog.Error("failed to create proxied client", "datasource", ds.UID, "error", err)
271 | 				continue
272 | 			}
273 | 
274 | 			// Store the client
275 | 			key := ds.Type + "_" + ds.UID
276 | 			state.proxiedClients[key] = client
277 | 		}
278 | 		state.proxiedToolsInitialized = true
279 | 		state.mutex.Unlock()
280 | 
281 | 		slog.Info("connected to proxied MCP servers", "session", sessionID, "datasources", len(state.proxiedClients))
282 | 	})
283 | 
284 | 	// Step 2: Register tools with the MCP server
285 | 	state.mutex.Lock()
286 | 	defer state.mutex.Unlock()
287 | 
288 | 	// Check if tools already registered
289 | 	if len(state.proxiedTools) > 0 {
290 | 		return
291 | 	}
292 | 
293 | 	// Check if we have any clients (discovery should have happened above)
294 | 	if len(state.proxiedClients) == 0 {
295 | 		return
296 | 	}
297 | 
298 | 	// First pass: collect all unique tools and track which datasources support them
299 | 	toolMap := make(map[string]mcp.Tool) // unique tools by name
300 | 
301 | 	for key, client := range state.proxiedClients {
302 | 		remoteTools := client.ListTools()
303 | 
304 | 		for _, tool := range remoteTools {
305 | 			// Tool name format: datasourceType_originalToolName (e.g., "tempo_traceql-search")
306 | 			toolName := client.DatasourceType + "_" + tool.Name
307 | 
308 | 			// Store the tool if we haven't seen it yet
309 | 			if _, exists := toolMap[toolName]; !exists {
310 | 				// Add datasourceUid parameter to the tool
311 | 				modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType)
312 | 				toolMap[toolName] = modifiedTool
313 | 			}
314 | 
315 | 			// Track which datasources support this tool
316 | 			state.toolToDatasources[toolName] = append(state.toolToDatasources[toolName], key)
317 | 		}
318 | 	}
319 | 
320 | 	// Second pass: register all unique tools at once (reduces listChanged notifications)
321 | 	var serverTools []server.ServerTool
322 | 	for toolName, tool := range toolMap {
323 | 		handler := NewProxiedToolHandler(tm.sm, tm, toolName)
324 | 		serverTools = append(serverTools, server.ServerTool{
325 | 			Tool:    tool,
326 | 			Handler: handler.Handle,
327 | 		})
328 | 		state.proxiedTools = append(state.proxiedTools, tool)
329 | 	}
330 | 
331 | 	if err := tm.server.AddSessionTools(sessionID, serverTools...); err != nil {
332 | 		slog.Warn("failed to add session tools", "session", sessionID, "error", err)
333 | 	} else {
334 | 		slog.Info("registered proxied tools", "session", sessionID, "tools", len(state.proxiedTools))
335 | 	}
336 | }
337 | 
338 | // GetServerClient retrieves a proxied client from server-level storage (for stdio transport)
339 | func (tm *ToolManager) GetServerClient(datasourceType, datasourceUID string) (*ProxiedClient, error) {
340 | 	tm.clientsMutex.RLock()
341 | 	defer tm.clientsMutex.RUnlock()
342 | 
343 | 	key := datasourceType + "_" + datasourceUID
344 | 	client, exists := tm.serverClients[key]
345 | 	if !exists {
346 | 		// List available datasources to help with debugging
347 | 		var availableUIDs []string
348 | 		for _, c := range tm.serverClients {
349 | 			if c.DatasourceType == datasourceType {
350 | 				availableUIDs = append(availableUIDs, c.DatasourceUID)
351 | 			}
352 | 		}
353 | 
354 | 		if len(availableUIDs) > 0 {
355 | 			return nil, fmt.Errorf("datasource '%s' not found. Available %s datasources: %v", datasourceUID, datasourceType, availableUIDs)
356 | 		}
357 | 		return nil, fmt.Errorf("datasource '%s' not found. No %s datasources with MCP support are configured", datasourceUID, datasourceType)
358 | 	}
359 | 
360 | 	return client, nil
361 | }
362 | 
```

--------------------------------------------------------------------------------
/proxied_tools_test.go:
--------------------------------------------------------------------------------

```go
  1 | package mcpgrafana
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"sync"
  6 | 	"sync/atomic"
  7 | 	"testing"
  8 | 	"time"
  9 | 
 10 | 	"github.com/mark3labs/mcp-go/mcp"
 11 | 	"github.com/stretchr/testify/assert"
 12 | )
 13 | 
 14 | func TestSessionStateRaceConditions(t *testing.T) {
 15 | 	t.Run("concurrent initialization with sync.Once is safe", func(t *testing.T) {
 16 | 		state := newSessionState()
 17 | 
 18 | 		var initCounter int32
 19 | 		var wg sync.WaitGroup
 20 | 
 21 | 		// Launch 100 goroutines that all try to initialize at once
 22 | 		const numGoroutines = 100
 23 | 		wg.Add(numGoroutines)
 24 | 
 25 | 		for i := 0; i < numGoroutines; i++ {
 26 | 			go func() {
 27 | 				defer wg.Done()
 28 | 				state.initOnce.Do(func() {
 29 | 					// Simulate initialization work
 30 | 					atomic.AddInt32(&initCounter, 1)
 31 | 					time.Sleep(10 * time.Millisecond) // Simulate some work
 32 | 					state.mutex.Lock()
 33 | 					state.proxiedToolsInitialized = true
 34 | 					state.mutex.Unlock()
 35 | 				})
 36 | 			}()
 37 | 		}
 38 | 
 39 | 		wg.Wait()
 40 | 
 41 | 		// Verify initialization happened exactly once
 42 | 		assert.Equal(t, int32(1), atomic.LoadInt32(&initCounter),
 43 | 			"Initialization should run exactly once despite 100 concurrent calls")
 44 | 		assert.True(t, state.proxiedToolsInitialized)
 45 | 	})
 46 | 
 47 | 	t.Run("concurrent reads and writes with mutex protection", func(t *testing.T) {
 48 | 		state := newSessionState()
 49 | 		var wg sync.WaitGroup
 50 | 
 51 | 		// Writer goroutines
 52 | 		for i := 0; i < 10; i++ {
 53 | 			wg.Add(1)
 54 | 			go func(id int) {
 55 | 				defer wg.Done()
 56 | 				state.mutex.Lock()
 57 | 				key := "tempo_" + string(rune('a'+id))
 58 | 				state.proxiedClients[key] = &ProxiedClient{
 59 | 					DatasourceUID:  key,
 60 | 					DatasourceName: "Test " + key,
 61 | 					DatasourceType: "tempo",
 62 | 				}
 63 | 				state.mutex.Unlock()
 64 | 			}(i)
 65 | 		}
 66 | 
 67 | 		// Reader goroutines
 68 | 		for i := 0; i < 10; i++ {
 69 | 			wg.Add(1)
 70 | 			go func() {
 71 | 				defer wg.Done()
 72 | 				state.mutex.RLock()
 73 | 				_ = len(state.proxiedClients)
 74 | 				state.mutex.RUnlock()
 75 | 			}()
 76 | 		}
 77 | 
 78 | 		wg.Wait()
 79 | 
 80 | 		// Verify all writes succeeded
 81 | 		state.mutex.RLock()
 82 | 		count := len(state.proxiedClients)
 83 | 		state.mutex.RUnlock()
 84 | 
 85 | 		assert.Equal(t, 10, count, "All 10 clients should be stored")
 86 | 	})
 87 | 
 88 | 	t.Run("concurrent tool registration is safe", func(t *testing.T) {
 89 | 		state := newSessionState()
 90 | 		var wg sync.WaitGroup
 91 | 
 92 | 		// Multiple goroutines trying to register tools
 93 | 		const numGoroutines = 50
 94 | 		wg.Add(numGoroutines)
 95 | 
 96 | 		for i := 0; i < numGoroutines; i++ {
 97 | 			go func(id int) {
 98 | 				defer wg.Done()
 99 | 				state.mutex.Lock()
100 | 				toolName := "tempo_tool-" + string(rune('a'+id%26))
101 | 				if state.toolToDatasources[toolName] == nil {
102 | 					state.toolToDatasources[toolName] = []string{}
103 | 				}
104 | 				state.toolToDatasources[toolName] = append(
105 | 					state.toolToDatasources[toolName],
106 | 					"datasource_"+string(rune('a'+id%26)),
107 | 				)
108 | 				state.mutex.Unlock()
109 | 			}(i)
110 | 		}
111 | 
112 | 		wg.Wait()
113 | 
114 | 		// Verify the tool mappings exist
115 | 		state.mutex.RLock()
116 | 		defer state.mutex.RUnlock()
117 | 		assert.Greater(t, len(state.toolToDatasources), 0, "Should have tool mappings")
118 | 	})
119 | }
120 | 
121 | func TestSessionManagerConcurrency(t *testing.T) {
122 | 	t.Run("concurrent session creation is safe", func(t *testing.T) {
123 | 		sm := NewSessionManager()
124 | 		var wg sync.WaitGroup
125 | 
126 | 		// Create many sessions concurrently
127 | 		const numSessions = 100
128 | 		wg.Add(numSessions)
129 | 
130 | 		for i := 0; i < numSessions; i++ {
131 | 			go func(id int) {
132 | 				defer wg.Done()
133 | 				sessionID := "session-" + string(rune('a'+id%26)) + "-" + string(rune('0'+id/26))
134 | 				mockSession := &mockClientSession{id: sessionID}
135 | 				sm.CreateSession(context.Background(), mockSession)
136 | 			}(i)
137 | 		}
138 | 
139 | 		wg.Wait()
140 | 
141 | 		// Verify all sessions were created
142 | 		sm.mutex.RLock()
143 | 		count := len(sm.sessions)
144 | 		sm.mutex.RUnlock()
145 | 
146 | 		assert.Equal(t, numSessions, count, "All sessions should be created")
147 | 	})
148 | 
149 | 	t.Run("concurrent get and remove is safe", func(t *testing.T) {
150 | 		sm := NewSessionManager()
151 | 
152 | 		// Pre-populate sessions
153 | 		for i := 0; i < 50; i++ {
154 | 			sessionID := "session-" + string(rune('a'+i%26))
155 | 			mockSession := &mockClientSession{id: sessionID}
156 | 			sm.CreateSession(context.Background(), mockSession)
157 | 		}
158 | 
159 | 		var wg sync.WaitGroup
160 | 
161 | 		// Readers
162 | 		for i := 0; i < 50; i++ {
163 | 			wg.Add(1)
164 | 			go func(id int) {
165 | 				defer wg.Done()
166 | 				sessionID := "session-" + string(rune('a'+id%26))
167 | 				_, _ = sm.GetSession(sessionID)
168 | 			}(i)
169 | 		}
170 | 
171 | 		// Writers (removers)
172 | 		for i := 0; i < 25; i++ {
173 | 			wg.Add(1)
174 | 			go func(id int) {
175 | 				defer wg.Done()
176 | 				sessionID := "session-" + string(rune('a'+id%26))
177 | 				mockSession := &mockClientSession{id: sessionID}
178 | 				sm.RemoveSession(context.Background(), mockSession)
179 | 			}(i)
180 | 		}
181 | 
182 | 		wg.Wait()
183 | 
184 | 		// Test passed if no race conditions occurred
185 | 	})
186 | }
187 | 
188 | func TestInitOncePattern(t *testing.T) {
189 | 	t.Run("verify sync.Once guarantees single execution", func(t *testing.T) {
190 | 		var once sync.Once
191 | 		var counter int32
192 | 		var wg sync.WaitGroup
193 | 
194 | 		// Simulate what happens in InitializeAndRegisterProxiedTools
195 | 		initFunc := func() {
196 | 			atomic.AddInt32(&counter, 1)
197 | 			// Simulate expensive initialization
198 | 			time.Sleep(50 * time.Millisecond)
199 | 		}
200 | 
201 | 		// Launch many concurrent calls
202 | 		for i := 0; i < 1000; i++ {
203 | 			wg.Add(1)
204 | 			go func() {
205 | 				defer wg.Done()
206 | 				once.Do(initFunc)
207 | 			}()
208 | 		}
209 | 
210 | 		wg.Wait()
211 | 
212 | 		assert.Equal(t, int32(1), atomic.LoadInt32(&counter),
213 | 			"sync.Once should guarantee function runs exactly once")
214 | 	})
215 | 
216 | 	t.Run("sync.Once with different functions only runs first", func(t *testing.T) {
217 | 		var once sync.Once
218 | 		var result string
219 | 		var mu sync.Mutex
220 | 
221 | 		once.Do(func() {
222 | 			mu.Lock()
223 | 			result = "first"
224 | 			mu.Unlock()
225 | 		})
226 | 
227 | 		once.Do(func() {
228 | 			mu.Lock()
229 | 			result = "second"
230 | 			mu.Unlock()
231 | 		})
232 | 
233 | 		mu.Lock()
234 | 		finalResult := result
235 | 		mu.Unlock()
236 | 
237 | 		assert.Equal(t, "first", finalResult, "Only first function should execute")
238 | 	})
239 | }
240 | 
241 | func TestProxiedToolsInitializationFlow(t *testing.T) {
242 | 	t.Run("initialization state transitions are correct", func(t *testing.T) {
243 | 		state := newSessionState()
244 | 
245 | 		// Initial state
246 | 		assert.False(t, state.proxiedToolsInitialized)
247 | 		assert.Empty(t, state.proxiedClients)
248 | 		assert.Empty(t, state.proxiedTools)
249 | 
250 | 		// Simulate initialization
251 | 		state.initOnce.Do(func() {
252 | 			state.mutex.Lock()
253 | 			state.proxiedToolsInitialized = true
254 | 			state.proxiedClients["tempo_test"] = &ProxiedClient{
255 | 				DatasourceUID:  "test",
256 | 				DatasourceName: "Test",
257 | 				DatasourceType: "tempo",
258 | 			}
259 | 			state.mutex.Unlock()
260 | 		})
261 | 
262 | 		// Verify state after initialization
263 | 		state.mutex.RLock()
264 | 		initialized := state.proxiedToolsInitialized
265 | 		clientCount := len(state.proxiedClients)
266 | 		state.mutex.RUnlock()
267 | 
268 | 		assert.True(t, initialized)
269 | 		assert.Equal(t, 1, clientCount)
270 | 	})
271 | 
272 | 	t.Run("multiple sessions maintain separate state", func(t *testing.T) {
273 | 		sm := NewSessionManager()
274 | 
275 | 		// Create two sessions
276 | 		session1 := &mockClientSession{id: "session-1"}
277 | 		session2 := &mockClientSession{id: "session-2"}
278 | 
279 | 		sm.CreateSession(context.Background(), session1)
280 | 		sm.CreateSession(context.Background(), session2)
281 | 
282 | 		state1, _ := sm.GetSession("session-1")
283 | 		state2, _ := sm.GetSession("session-2")
284 | 
285 | 		// Initialize only session1
286 | 		state1.initOnce.Do(func() {
287 | 			state1.mutex.Lock()
288 | 			state1.proxiedToolsInitialized = true
289 | 			state1.mutex.Unlock()
290 | 		})
291 | 
292 | 		// Verify states are independent
293 | 		assert.True(t, state1.proxiedToolsInitialized)
294 | 		assert.False(t, state2.proxiedToolsInitialized)
295 | 		assert.NotSame(t, state1, state2)
296 | 	})
297 | }
298 | 
299 | func TestRaceConditionDemonstration(t *testing.T) {
300 | 	t.Run("old pattern WITHOUT sync.Once would have race condition", func(t *testing.T) {
301 | 		// This test demonstrates what WOULD happen with the old mutex-check pattern
302 | 		state := newSessionState()
303 | 
304 | 		var discoveryCallCount int32
305 | 		var wg sync.WaitGroup
306 | 
307 | 		// Simulate the OLD pattern (mutex check, unlock, then do work)
308 | 		oldPatternInitialize := func() {
309 | 			state.mutex.Lock()
310 | 			// Check if already initialized
311 | 			if state.proxiedToolsInitialized {
312 | 				state.mutex.Unlock()
313 | 				return
314 | 			}
315 | 			alreadyDiscovered := state.proxiedToolsInitialized
316 | 			state.mutex.Unlock() // ❌ OLD PATTERN: Unlock before expensive work
317 | 
318 | 			if !alreadyDiscovered {
319 | 				// Simulate discovery work that should only happen once
320 | 				atomic.AddInt32(&discoveryCallCount, 1)
321 | 				time.Sleep(10 * time.Millisecond) // Simulate expensive operation
322 | 
323 | 				state.mutex.Lock()
324 | 				state.proxiedToolsInitialized = true
325 | 				state.mutex.Unlock()
326 | 			}
327 | 		}
328 | 
329 | 		// Launch concurrent initializations
330 | 		const numGoroutines = 10
331 | 		wg.Add(numGoroutines)
332 | 		for i := 0; i < numGoroutines; i++ {
333 | 			go func() {
334 | 				defer wg.Done()
335 | 				oldPatternInitialize()
336 | 			}()
337 | 		}
338 | 		wg.Wait()
339 | 
340 | 		// With the old pattern, multiple goroutines can get past the check
341 | 		// and call discovery multiple times
342 | 		count := atomic.LoadInt32(&discoveryCallCount)
343 | 		if count > 1 {
344 | 			t.Logf("OLD PATTERN: Discovery called %d times (race condition!)", count)
345 | 		}
346 | 		// We can't assert > 1 reliably because timing matters, but this demonstrates the problem
347 | 	})
348 | 
349 | 	t.Run("new pattern WITH sync.Once prevents race condition", func(t *testing.T) {
350 | 		// This test demonstrates the NEW pattern with sync.Once
351 | 		state := newSessionState()
352 | 
353 | 		var discoveryCallCount int32
354 | 		var wg sync.WaitGroup
355 | 
356 | 		// NEW pattern: sync.Once guarantees single execution
357 | 		newPatternInitialize := func() {
358 | 			state.initOnce.Do(func() {
359 | 				// Simulate discovery work that should only happen once
360 | 				atomic.AddInt32(&discoveryCallCount, 1)
361 | 				time.Sleep(10 * time.Millisecond) // Simulate expensive operation
362 | 
363 | 				state.mutex.Lock()
364 | 				state.proxiedToolsInitialized = true
365 | 				state.mutex.Unlock()
366 | 			})
367 | 		}
368 | 
369 | 		// Launch concurrent initializations
370 | 		const numGoroutines = 10
371 | 		wg.Add(numGoroutines)
372 | 		for i := 0; i < numGoroutines; i++ {
373 | 			go func() {
374 | 				defer wg.Done()
375 | 				newPatternInitialize()
376 | 			}()
377 | 		}
378 | 		wg.Wait()
379 | 
380 | 		// With sync.Once, discovery is guaranteed to run exactly once
381 | 		count := atomic.LoadInt32(&discoveryCallCount)
382 | 		assert.Equal(t, int32(1), count, "NEW PATTERN: Discovery must be called exactly once")
383 | 	})
384 | }
385 | 
386 | func TestRaceDetector(t *testing.T) {
387 | 	// This test is primarily valuable when run with -race flag
388 | 	t.Run("stress test with race detector", func(t *testing.T) {
389 | 
390 | 		sm := NewSessionManager()
391 | 		var wg sync.WaitGroup
392 | 
393 | 		// Create a mix of operations happening concurrently
394 | 		for i := 0; i < 20; i++ {
395 | 			sessionID := "stress-session-" + string(rune('a'+i%10))
396 | 
397 | 			// Create session
398 | 			wg.Add(1)
399 | 			go func(sid string) {
400 | 				defer wg.Done()
401 | 				mockSession := &mockClientSession{id: sid}
402 | 				sm.CreateSession(context.Background(), mockSession)
403 | 			}(sessionID)
404 | 
405 | 			// Initialize session state
406 | 			wg.Add(1)
407 | 			go func(sid string) {
408 | 				defer wg.Done()
409 | 				time.Sleep(time.Millisecond) // Let creation happen first
410 | 				state, exists := sm.GetSession(sid)
411 | 				if exists {
412 | 					state.initOnce.Do(func() {
413 | 						state.mutex.Lock()
414 | 						state.proxiedToolsInitialized = true
415 | 						state.mutex.Unlock()
416 | 					})
417 | 				}
418 | 			}(sessionID)
419 | 
420 | 			// Read session state
421 | 			wg.Add(1)
422 | 			go func(sid string) {
423 | 				defer wg.Done()
424 | 				time.Sleep(2 * time.Millisecond)
425 | 				state, exists := sm.GetSession(sid)
426 | 				if exists {
427 | 					state.mutex.RLock()
428 | 					_ = state.proxiedToolsInitialized
429 | 					state.mutex.RUnlock()
430 | 				}
431 | 			}(sessionID)
432 | 		}
433 | 
434 | 		wg.Wait()
435 | 
436 | 		// If we get here without race detector warnings, we're good
437 | 		t.Log("Stress test completed without race conditions")
438 | 	})
439 | }
440 | 
441 | // mockClientSession implements server.ClientSession for testing
442 | type mockClientSession struct {
443 | 	id            string
444 | 	notifChannel  chan mcp.JSONRPCNotification
445 | 	isInitialized bool
446 | }
447 | 
448 | func (m *mockClientSession) SessionID() string {
449 | 	return m.id
450 | }
451 | 
452 | func (m *mockClientSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
453 | 	if m.notifChannel == nil {
454 | 		m.notifChannel = make(chan mcp.JSONRPCNotification, 10)
455 | 	}
456 | 	return m.notifChannel
457 | }
458 | 
459 | func (m *mockClientSession) Initialize() {
460 | 	m.isInitialized = true
461 | }
462 | 
463 | func (m *mockClientSession) Initialized() bool {
464 | 	return m.isInitialized
465 | }
466 | 
```

--------------------------------------------------------------------------------
/tools/dashboard_test.go:
--------------------------------------------------------------------------------

```go
  1 | // Requires a Grafana instance running on localhost:3000,
  2 | // with a dashboard provisioned.
  3 | // Run with `go test -tags integration`.
  4 | //go:build integration
  5 | 
  6 | package tools
  7 | 
  8 | import (
  9 | 	"context"
 10 | 	"testing"
 11 | 
 12 | 	"github.com/grafana/grafana-openapi-client-go/models"
 13 | 	"github.com/stretchr/testify/assert"
 14 | 	"github.com/stretchr/testify/require"
 15 | )
 16 | 
 17 | const (
 18 | 	newTestDashboardName = "Integration Test"
 19 | )
 20 | 
 21 | // getExistingDashboardUID will fetch an existing dashboard for test purposes
 22 | // It will search for exisiting dashboards and return the first, otherwise
 23 | // will trigger a test error
 24 | func getExistingTestDashboard(t *testing.T, ctx context.Context, dashboardName string) *models.Hit {
 25 | 	// Make sure we query for the existing dashboard, not a folder
 26 | 	if dashboardName == "" {
 27 | 		dashboardName = "Demo"
 28 | 	}
 29 | 	searchResults, err := searchDashboards(ctx, SearchDashboardsParams{
 30 | 		Query: dashboardName,
 31 | 	})
 32 | 	require.NoError(t, err)
 33 | 	require.Greater(t, len(searchResults), 0, "No dashboards found")
 34 | 	return searchResults[0]
 35 | }
 36 | 
 37 | // getExistingTestDashboardJSON will fetch the JSON map for an existing
 38 | // dashboard in the test environment
 39 | func getTestDashboardJSON(t *testing.T, ctx context.Context, dashboard *models.Hit) map[string]interface{} {
 40 | 	result, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
 41 | 		UID: dashboard.UID,
 42 | 	})
 43 | 	require.NoError(t, err)
 44 | 	dashboardMap, ok := result.Dashboard.(map[string]interface{})
 45 | 	require.True(t, ok, "Dashboard should be a map")
 46 | 	return dashboardMap
 47 | }
 48 | 
 49 | func TestDashboardTools(t *testing.T) {
 50 | 	t.Run("get dashboard by uid", func(t *testing.T) {
 51 | 		ctx := newTestContext()
 52 | 
 53 | 		// First, let's search for a dashboard to get its UID
 54 | 		dashboard := getExistingTestDashboard(t, ctx, "")
 55 | 
 56 | 		// Now test the get dashboard by uid functionality
 57 | 		result, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
 58 | 			UID: dashboard.UID,
 59 | 		})
 60 | 		require.NoError(t, err)
 61 | 		dashboardMap, ok := result.Dashboard.(map[string]interface{})
 62 | 		require.True(t, ok, "Dashboard should be a map")
 63 | 		assert.Equal(t, dashboard.UID, dashboardMap["uid"])
 64 | 		assert.NotNil(t, result.Meta)
 65 | 	})
 66 | 
 67 | 	t.Run("get dashboard by uid - invalid uid", func(t *testing.T) {
 68 | 		ctx := newTestContext()
 69 | 
 70 | 		_, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
 71 | 			UID: "non-existent-uid",
 72 | 		})
 73 | 		require.Error(t, err)
 74 | 	})
 75 | 
 76 | 	t.Run("update dashboard - create new", func(t *testing.T) {
 77 | 		ctx := newTestContext()
 78 | 
 79 | 		// Get the dashboard JSON
 80 | 		// In this case, we will create a new dashboard with the same
 81 | 		// content but different Title, and disable "overwrite"
 82 | 		dashboard := getExistingTestDashboard(t, ctx, "")
 83 | 		dashboardMap := getTestDashboardJSON(t, ctx, dashboard)
 84 | 
 85 | 		// Avoid a clash by unsetting the existing IDs
 86 | 		delete(dashboardMap, "uid")
 87 | 		delete(dashboardMap, "id")
 88 | 
 89 | 		// Set a new title and tag
 90 | 		dashboardMap["title"] = newTestDashboardName
 91 | 		dashboardMap["tags"] = []string{"integration-test"}
 92 | 
 93 | 		params := UpdateDashboardParams{
 94 | 			Dashboard: dashboardMap,
 95 | 			Message:   "creating a new dashboard",
 96 | 			Overwrite: false,
 97 | 			UserID:    1,
 98 | 		}
 99 | 
100 | 		// Only pass in the Folder UID if it exists
101 | 		if dashboard.FolderUID != "" {
102 | 			params.FolderUID = dashboard.FolderUID
103 | 		}
104 | 
105 | 		// create the dashboard
106 | 		_, err := updateDashboard(ctx, params)
107 | 		require.NoError(t, err)
108 | 	})
109 | 
110 | 	t.Run("update dashboard - overwrite existing", func(t *testing.T) {
111 | 		ctx := newTestContext()
112 | 
113 | 		// Get the dashboard JSON for the non-provisioned dashboard we've created
114 | 		dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
115 | 		dashboardMap := getTestDashboardJSON(t, ctx, dashboard)
116 | 
117 | 		params := UpdateDashboardParams{
118 | 			Dashboard: dashboardMap,
119 | 			Message:   "updating existing dashboard",
120 | 			Overwrite: true,
121 | 			UserID:    1,
122 | 		}
123 | 
124 | 		// Only pass in the Folder UID if it exists
125 | 		if dashboard.FolderUID != "" {
126 | 			params.FolderUID = dashboard.FolderUID
127 | 		}
128 | 
129 | 		// update the dashboard
130 | 		_, err := updateDashboard(ctx, params)
131 | 		require.NoError(t, err)
132 | 	})
133 | 
134 | 	t.Run("get dashboard panel queries", func(t *testing.T) {
135 | 		ctx := newTestContext()
136 | 
137 | 		// Get the test dashboard
138 | 		dashboard := getExistingTestDashboard(t, ctx, "")
139 | 
140 | 		result, err := GetDashboardPanelQueriesTool(ctx, DashboardPanelQueriesParams{
141 | 			UID: dashboard.UID,
142 | 		})
143 | 		require.NoError(t, err)
144 | 		assert.Greater(t, len(result), 0, "Should return at least one panel query")
145 | 
146 | 		// The initial demo dashboard plus for all dashboards created by the integration tests,
147 | 		// every panel should have identical title and query values.
148 | 		// Datasource UID may differ. Datasource type can be an empty string as well but on the demo and test dashboards it should be "prometheus".
149 | 		for _, panelQuery := range result {
150 | 			assert.Equal(t, panelQuery.Title, "Node Load")
151 | 			assert.Equal(t, panelQuery.Query, "node_load1")
152 | 			assert.NotEmpty(t, panelQuery.Datasource.UID)
153 | 			assert.Equal(t, panelQuery.Datasource.Type, "prometheus")
154 | 		}
155 | 	})
156 | 
157 | 	// Tests for new Issue #101 context window management tools
158 | 	t.Run("get dashboard summary", func(t *testing.T) {
159 | 		ctx := newTestContext()
160 | 
161 | 		// Get the test dashboard
162 | 		dashboard := getExistingTestDashboard(t, ctx, "")
163 | 
164 | 		result, err := getDashboardSummary(ctx, GetDashboardSummaryParams{
165 | 			UID: dashboard.UID,
166 | 		})
167 | 		require.NoError(t, err)
168 | 
169 | 		assert.Equal(t, dashboard.UID, result.UID)
170 | 		assert.NotEmpty(t, result.Title)
171 | 		assert.Greater(t, result.PanelCount, 0, "Should have at least one panel")
172 | 		assert.Len(t, result.Panels, result.PanelCount, "Panel count should match panels array length")
173 | 		assert.NotNil(t, result.Meta)
174 | 
175 | 		// Check that panels have expected structure
176 | 		for _, panel := range result.Panels {
177 | 			assert.NotEmpty(t, panel.Title)
178 | 			assert.NotEmpty(t, panel.Type)
179 | 			assert.GreaterOrEqual(t, panel.QueryCount, 0)
180 | 		}
181 | 	})
182 | 
183 | 	t.Run("get dashboard property - title", func(t *testing.T) {
184 | 		ctx := newTestContext()
185 | 
186 | 		dashboard := getExistingTestDashboard(t, ctx, "")
187 | 
188 | 		result, err := getDashboardProperty(ctx, GetDashboardPropertyParams{
189 | 			UID:      dashboard.UID,
190 | 			JSONPath: "$.title",
191 | 		})
192 | 		require.NoError(t, err)
193 | 
194 | 		title, ok := result.(string)
195 | 		require.True(t, ok, "Title should be a string")
196 | 		assert.NotEmpty(t, title)
197 | 	})
198 | 
199 | 	t.Run("get dashboard property - panel titles", func(t *testing.T) {
200 | 		ctx := newTestContext()
201 | 
202 | 		dashboard := getExistingTestDashboard(t, ctx, "")
203 | 
204 | 		result, err := getDashboardProperty(ctx, GetDashboardPropertyParams{
205 | 			UID:      dashboard.UID,
206 | 			JSONPath: "$.panels[*].title",
207 | 		})
208 | 		require.NoError(t, err)
209 | 
210 | 		titles, ok := result.([]interface{})
211 | 		require.True(t, ok, "Panel titles should be an array")
212 | 		assert.Greater(t, len(titles), 0, "Should have at least one panel title")
213 | 
214 | 		for _, title := range titles {
215 | 			titleStr, ok := title.(string)
216 | 			require.True(t, ok, "Each title should be a string")
217 | 			assert.NotEmpty(t, titleStr)
218 | 		}
219 | 	})
220 | 
221 | 	t.Run("get dashboard property - invalid path", func(t *testing.T) {
222 | 		ctx := newTestContext()
223 | 
224 | 		dashboard := getExistingTestDashboard(t, ctx, "")
225 | 
226 | 		_, err := getDashboardProperty(ctx, GetDashboardPropertyParams{
227 | 			UID:      dashboard.UID,
228 | 			JSONPath: "$.nonexistent.path",
229 | 		})
230 | 		require.Error(t, err, "Should fail for non-existent path")
231 | 	})
232 | 
233 | 	t.Run("update dashboard - patch title", func(t *testing.T) {
234 | 		ctx := newTestContext()
235 | 
236 | 		// Get our test dashboard (not the provisioned one)
237 | 		dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
238 | 
239 | 		newTitle := "Updated Integration Test Dashboard"
240 | 
241 | 		result, err := updateDashboard(ctx, UpdateDashboardParams{
242 | 			UID: dashboard.UID,
243 | 			Operations: []PatchOperation{
244 | 				{
245 | 					Op:    "replace",
246 | 					Path:  "$.title",
247 | 					Value: newTitle,
248 | 				},
249 | 			},
250 | 			Message: "Updated title via patch",
251 | 		})
252 | 		require.NoError(t, err)
253 | 		assert.NotNil(t, result)
254 | 
255 | 		// Verify the change was applied
256 | 		updatedDashboard, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
257 | 			UID: dashboard.UID,
258 | 		})
259 | 		require.NoError(t, err)
260 | 
261 | 		dashboardMap, ok := updatedDashboard.Dashboard.(map[string]interface{})
262 | 		require.True(t, ok, "Dashboard should be a map")
263 | 		assert.Equal(t, newTitle, dashboardMap["title"])
264 | 	})
265 | 
266 | 	t.Run("update dashboard - patch add description", func(t *testing.T) {
267 | 		ctx := newTestContext()
268 | 
269 | 		dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
270 | 
271 | 		description := "This is a test description added via patch"
272 | 
273 | 		_, err := updateDashboard(ctx, UpdateDashboardParams{
274 | 			UID: dashboard.UID,
275 | 			Operations: []PatchOperation{
276 | 				{
277 | 					Op:    "add",
278 | 					Path:  "$.description",
279 | 					Value: description,
280 | 				},
281 | 			},
282 | 			Message: "Added description via patch",
283 | 		})
284 | 		require.NoError(t, err)
285 | 
286 | 		// Verify the description was added
287 | 		updatedDashboard, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
288 | 			UID: dashboard.UID,
289 | 		})
290 | 		require.NoError(t, err)
291 | 
292 | 		dashboardMap, ok := updatedDashboard.Dashboard.(map[string]interface{})
293 | 		require.True(t, ok, "Dashboard should be a map")
294 | 		assert.Equal(t, description, dashboardMap["description"])
295 | 	})
296 | 
297 | 	t.Run("update dashboard - patch remove description", func(t *testing.T) {
298 | 		ctx := newTestContext()
299 | 
300 | 		dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
301 | 
302 | 		_, err := updateDashboard(ctx, UpdateDashboardParams{
303 | 			UID: dashboard.UID,
304 | 			Operations: []PatchOperation{
305 | 				{
306 | 					Op:   "remove",
307 | 					Path: "$.description",
308 | 				},
309 | 			},
310 | 			Message: "Removed description via patch",
311 | 		})
312 | 		require.NoError(t, err)
313 | 
314 | 		// Verify the description was removed
315 | 		updatedDashboard, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
316 | 			UID: dashboard.UID,
317 | 		})
318 | 		require.NoError(t, err)
319 | 
320 | 		dashboardMap, ok := updatedDashboard.Dashboard.(map[string]interface{})
321 | 		require.True(t, ok, "Dashboard should be a map")
322 | 		_, hasDescription := dashboardMap["description"]
323 | 		assert.False(t, hasDescription, "Description should be removed")
324 | 	})
325 | 
326 | 	t.Run("update dashboard - unsupported operation", func(t *testing.T) {
327 | 		ctx := newTestContext()
328 | 
329 | 		dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
330 | 
331 | 		_, err := updateDashboard(ctx, UpdateDashboardParams{
332 | 			UID: dashboard.UID,
333 | 			Operations: []PatchOperation{
334 | 				{
335 | 					Op:    "copy", // Unsupported operation
336 | 					Path:  "$.title",
337 | 					Value: "New Title",
338 | 				},
339 | 			},
340 | 		})
341 | 		require.Error(t, err, "Should fail for unsupported operation")
342 | 	})
343 | 
344 | 	t.Run("update dashboard - invalid parameters", func(t *testing.T) {
345 | 		ctx := newTestContext()
346 | 
347 | 		_, err := updateDashboard(ctx, UpdateDashboardParams{
348 | 			// Neither dashboard nor (uid + operations) provided
349 | 		})
350 | 		require.Error(t, err, "Should fail when no valid parameters provided")
351 | 	})
352 | 
353 | 	t.Run("update dashboard - append to panels array", func(t *testing.T) {
354 | 		ctx := newTestContext()
355 | 
356 | 		// Get our test dashboard
357 | 		dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
358 | 
359 | 		// Create a new panel to append
360 | 		newPanel := map[string]interface{}{
361 | 			"id":    999,
362 | 			"title": "New Appended Panel",
363 | 			"type":  "stat",
364 | 			"targets": []interface{}{
365 | 				map[string]interface{}{
366 | 					"expr": "up",
367 | 				},
368 | 			},
369 | 			"gridPos": map[string]interface{}{
370 | 				"h": 8,
371 | 				"w": 12,
372 | 				"x": 0,
373 | 				"y": 8,
374 | 			},
375 | 		}
376 | 
377 | 		_, err := updateDashboard(ctx, UpdateDashboardParams{
378 | 			UID: dashboard.UID,
379 | 			Operations: []PatchOperation{
380 | 				{
381 | 					Op:    "add",
382 | 					Path:  "$.panels/-",
383 | 					Value: newPanel,
384 | 				},
385 | 			},
386 | 			Message: "Appended new panel via /- syntax",
387 | 		})
388 | 		require.NoError(t, err)
389 | 
390 | 		// Verify the panel was appended
391 | 		updatedDashboard, err := getDashboardByUID(ctx, GetDashboardByUIDParams{
392 | 			UID: dashboard.UID,
393 | 		})
394 | 		require.NoError(t, err)
395 | 
396 | 		dashboardMap, ok := updatedDashboard.Dashboard.(map[string]interface{})
397 | 		require.True(t, ok, "Dashboard should be a map")
398 | 
399 | 		panels, ok := dashboardMap["panels"].([]interface{})
400 | 		require.True(t, ok, "Panels should be an array")
401 | 
402 | 		// Check that the new panel was appended (should be the last panel)
403 | 		lastPanel, ok := panels[len(panels)-1].(map[string]interface{})
404 | 		require.True(t, ok, "Last panel should be an object")
405 | 		assert.Equal(t, "New Appended Panel", lastPanel["title"])
406 | 		assert.Equal(t, float64(999), lastPanel["id"]) // JSON unmarshaling converts to float64
407 | 	})
408 | 
409 | 	t.Run("update dashboard - remove with append syntax should fail", func(t *testing.T) {
410 | 		ctx := newTestContext()
411 | 
412 | 		dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
413 | 
414 | 		_, err := updateDashboard(ctx, UpdateDashboardParams{
415 | 			UID: dashboard.UID,
416 | 			Operations: []PatchOperation{
417 | 				{
418 | 					Op:   "remove",
419 | 					Path: "$.panels/-", // Invalid: remove with append syntax
420 | 				},
421 | 			},
422 | 		})
423 | 		require.Error(t, err, "Should fail when using remove operation with append syntax")
424 | 	})
425 | 
426 | 	t.Run("update dashboard - append to non-array should fail", func(t *testing.T) {
427 | 		ctx := newTestContext()
428 | 
429 | 		dashboard := getExistingTestDashboard(t, ctx, newTestDashboardName)
430 | 
431 | 		_, err := updateDashboard(ctx, UpdateDashboardParams{
432 | 			UID: dashboard.UID,
433 | 			Operations: []PatchOperation{
434 | 				{
435 | 					Op:    "add",
436 | 					Path:  "$.title/-", // Invalid: title is not an array
437 | 					Value: "Invalid",
438 | 				},
439 | 			},
440 | 		})
441 | 		require.Error(t, err, "Should fail when trying to append to non-array field")
442 | 	})
443 | }
444 | 
```

--------------------------------------------------------------------------------
/tools_test.go:
--------------------------------------------------------------------------------

```go
  1 | //go:build unit
  2 | // +build unit
  3 | 
  4 | package mcpgrafana
  5 | 
  6 | import (
  7 | 	"context"
  8 | 	"errors"
  9 | 	"testing"
 10 | 
 11 | 	"github.com/mark3labs/mcp-go/mcp"
 12 | 	"github.com/stretchr/testify/assert"
 13 | 	"github.com/stretchr/testify/require"
 14 | )
 15 | 
 16 | type testToolParams struct {
 17 | 	Name     string `json:"name" jsonschema:"required,description=The name parameter"`
 18 | 	Value    int    `json:"value" jsonschema:"required,description=The value parameter"`
 19 | 	Optional bool   `json:"optional,omitempty" jsonschema:"description=An optional parameter"`
 20 | }
 21 | 
 22 | func testToolHandler(ctx context.Context, params testToolParams) (*mcp.CallToolResult, error) {
 23 | 	if params.Name == "error" {
 24 | 		return nil, errors.New("test error")
 25 | 	}
 26 | 	return mcp.NewToolResultText(params.Name + ": " + string(rune(params.Value))), nil
 27 | }
 28 | 
 29 | type emptyToolParams struct{}
 30 | 
 31 | func emptyToolHandler(ctx context.Context, params emptyToolParams) (*mcp.CallToolResult, error) {
 32 | 	return mcp.NewToolResultText("empty"), nil
 33 | }
 34 | 
 35 | // New handlers for different return types
 36 | func stringToolHandler(ctx context.Context, params testToolParams) (string, error) {
 37 | 	if params.Name == "error" {
 38 | 		return "", errors.New("test error")
 39 | 	}
 40 | 	if params.Name == "empty" {
 41 | 		return "", nil
 42 | 	}
 43 | 	return params.Name + ": " + string(rune(params.Value)), nil
 44 | }
 45 | 
 46 | func stringPtrToolHandler(ctx context.Context, params testToolParams) (*string, error) {
 47 | 	if params.Name == "error" {
 48 | 		return nil, errors.New("test error")
 49 | 	}
 50 | 	if params.Name == "nil" {
 51 | 		return nil, nil
 52 | 	}
 53 | 	if params.Name == "empty" {
 54 | 		empty := ""
 55 | 		return &empty, nil
 56 | 	}
 57 | 	result := params.Name + ": " + string(rune(params.Value))
 58 | 	return &result, nil
 59 | }
 60 | 
 61 | type TestResult struct {
 62 | 	Name  string `json:"name"`
 63 | 	Value int    `json:"value"`
 64 | }
 65 | 
 66 | func structToolHandler(ctx context.Context, params testToolParams) (TestResult, error) {
 67 | 	if params.Name == "error" {
 68 | 		return TestResult{}, errors.New("test error")
 69 | 	}
 70 | 	return TestResult{
 71 | 		Name:  params.Name,
 72 | 		Value: params.Value,
 73 | 	}, nil
 74 | }
 75 | 
 76 | func structPtrToolHandler(ctx context.Context, params testToolParams) (*TestResult, error) {
 77 | 	if params.Name == "error" {
 78 | 		return nil, errors.New("test error")
 79 | 	}
 80 | 	if params.Name == "nil" {
 81 | 		return nil, nil
 82 | 	}
 83 | 	return &TestResult{
 84 | 		Name:  params.Name,
 85 | 		Value: params.Value,
 86 | 	}, nil
 87 | }
 88 | 
 89 | func TestConvertTool(t *testing.T) {
 90 | 	t.Run("valid handler conversion", func(t *testing.T) {
 91 | 		tool, handler, err := ConvertTool("test_tool", "A test tool", testToolHandler)
 92 | 
 93 | 		require.NoError(t, err)
 94 | 		require.NotNil(t, tool)
 95 | 		require.NotNil(t, handler)
 96 | 
 97 | 		// Check tool properties
 98 | 		assert.Equal(t, "test_tool", tool.Name)
 99 | 		assert.Equal(t, "A test tool", tool.Description)
100 | 
101 | 		// Check schema properties
102 | 		assert.Equal(t, "object", tool.InputSchema.Type)
103 | 		assert.Contains(t, tool.InputSchema.Properties, "name")
104 | 		assert.Contains(t, tool.InputSchema.Properties, "value")
105 | 		assert.Contains(t, tool.InputSchema.Properties, "optional")
106 | 
107 | 		// Test handler execution
108 | 		ctx := context.Background()
109 | 		request := mcp.CallToolRequest{
110 | 			Params: struct {
111 | 				Name      string    `json:"name"`
112 | 				Arguments any       `json:"arguments,omitempty"`
113 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
114 | 			}{
115 | 				Name: "test_tool",
116 | 				Arguments: map[string]any{
117 | 					"name":  "test",
118 | 					"value": 65, // ASCII 'A'
119 | 				},
120 | 			},
121 | 		}
122 | 
123 | 		result, err := handler(ctx, request)
124 | 		require.NoError(t, err)
125 | 		require.Len(t, result.Content, 1)
126 | 		resultString, ok := result.Content[0].(mcp.TextContent)
127 | 		require.True(t, ok)
128 | 		assert.Equal(t, "test: A", resultString.Text)
129 | 
130 | 		// Test error handling
131 | 		errorRequest := mcp.CallToolRequest{
132 | 			Params: struct {
133 | 				Name      string    `json:"name"`
134 | 				Arguments any       `json:"arguments,omitempty"`
135 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
136 | 			}{
137 | 				Name: "test_tool",
138 | 				Arguments: map[string]any{
139 | 					"name":  "error",
140 | 					"value": 66,
141 | 				},
142 | 			},
143 | 		}
144 | 
145 | 		_, err = handler(ctx, errorRequest)
146 | 		assert.Error(t, err)
147 | 		assert.Equal(t, "test error", err.Error())
148 | 	})
149 | 
150 | 	t.Run("empty handler params", func(t *testing.T) {
151 | 		tool, handler, err := ConvertTool("empty", "description", emptyToolHandler)
152 | 
153 | 		require.NoError(t, err)
154 | 		require.NotNil(t, tool)
155 | 		require.NotNil(t, handler)
156 | 
157 | 		// Check tool properties
158 | 		assert.Equal(t, "empty", tool.Name)
159 | 		assert.Equal(t, "description", tool.Description)
160 | 
161 | 		// Check schema properties
162 | 		assert.Equal(t, "object", tool.InputSchema.Type)
163 | 		assert.Len(t, tool.InputSchema.Properties, 0)
164 | 
165 | 		// Test handler execution
166 | 		ctx := context.Background()
167 | 		request := mcp.CallToolRequest{
168 | 			Params: struct {
169 | 				Name      string    `json:"name"`
170 | 				Arguments any       `json:"arguments,omitempty"`
171 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
172 | 			}{
173 | 				Name: "empty",
174 | 			},
175 | 		}
176 | 		result, err := handler(ctx, request)
177 | 		require.NoError(t, err)
178 | 		require.Len(t, result.Content, 1)
179 | 		resultString, ok := result.Content[0].(mcp.TextContent)
180 | 		require.True(t, ok)
181 | 		assert.Equal(t, "empty", resultString.Text)
182 | 	})
183 | 
184 | 	t.Run("string return type", func(t *testing.T) {
185 | 		_, handler, err := ConvertTool("string_tool", "A string tool", stringToolHandler)
186 | 		require.NoError(t, err)
187 | 
188 | 		// Test normal string return
189 | 		ctx := context.Background()
190 | 		request := mcp.CallToolRequest{
191 | 			Params: struct {
192 | 				Name      string    `json:"name"`
193 | 				Arguments any       `json:"arguments,omitempty"`
194 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
195 | 			}{
196 | 				Name: "string_tool",
197 | 				Arguments: map[string]any{
198 | 					"name":  "test",
199 | 					"value": 65, // ASCII 'A'
200 | 				},
201 | 			},
202 | 		}
203 | 
204 | 		result, err := handler(ctx, request)
205 | 		require.NoError(t, err)
206 | 		require.NotNil(t, result)
207 | 		require.Len(t, result.Content, 1)
208 | 		resultString, ok := result.Content[0].(mcp.TextContent)
209 | 		require.True(t, ok)
210 | 		assert.Equal(t, "test: A", resultString.Text)
211 | 
212 | 		// Test empty string return
213 | 		emptyRequest := mcp.CallToolRequest{
214 | 			Params: struct {
215 | 				Name      string    `json:"name"`
216 | 				Arguments any       `json:"arguments,omitempty"`
217 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
218 | 			}{
219 | 				Name: "string_tool",
220 | 				Arguments: map[string]any{
221 | 					"name":  "empty",
222 | 					"value": 65,
223 | 				},
224 | 			},
225 | 		}
226 | 
227 | 		result, err = handler(ctx, emptyRequest)
228 | 		require.NoError(t, err)
229 | 		assert.Nil(t, result)
230 | 
231 | 		// Test error return
232 | 		errorRequest := mcp.CallToolRequest{
233 | 			Params: struct {
234 | 				Name      string    `json:"name"`
235 | 				Arguments any       `json:"arguments,omitempty"`
236 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
237 | 			}{
238 | 				Name: "string_tool",
239 | 				Arguments: map[string]any{
240 | 					"name":  "error",
241 | 					"value": 65,
242 | 				},
243 | 			},
244 | 		}
245 | 
246 | 		_, err = handler(ctx, errorRequest)
247 | 		assert.Error(t, err)
248 | 		assert.Equal(t, "test error", err.Error())
249 | 	})
250 | 
251 | 	t.Run("string pointer return type", func(t *testing.T) {
252 | 		_, handler, err := ConvertTool("string_ptr_tool", "A string pointer tool", stringPtrToolHandler)
253 | 		require.NoError(t, err)
254 | 
255 | 		// Test normal string pointer return
256 | 		ctx := context.Background()
257 | 		request := mcp.CallToolRequest{
258 | 			Params: struct {
259 | 				Name      string    `json:"name"`
260 | 				Arguments any       `json:"arguments,omitempty"`
261 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
262 | 			}{
263 | 				Name: "string_ptr_tool",
264 | 				Arguments: map[string]any{
265 | 					"name":  "test",
266 | 					"value": 65, // ASCII 'A'
267 | 				},
268 | 			},
269 | 		}
270 | 
271 | 		result, err := handler(ctx, request)
272 | 		require.NoError(t, err)
273 | 		require.NotNil(t, result)
274 | 		require.Len(t, result.Content, 1)
275 | 		resultString, ok := result.Content[0].(mcp.TextContent)
276 | 		require.True(t, ok)
277 | 		assert.Equal(t, "test: A", resultString.Text)
278 | 
279 | 		// Test nil string pointer return
280 | 		nilRequest := mcp.CallToolRequest{
281 | 			Params: struct {
282 | 				Name      string    `json:"name"`
283 | 				Arguments any       `json:"arguments,omitempty"`
284 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
285 | 			}{
286 | 				Name: "string_ptr_tool",
287 | 				Arguments: map[string]any{
288 | 					"name":  "nil",
289 | 					"value": 65,
290 | 				},
291 | 			},
292 | 		}
293 | 
294 | 		result, err = handler(ctx, nilRequest)
295 | 		require.NoError(t, err)
296 | 		assert.Nil(t, result)
297 | 
298 | 		// Test empty string pointer return
299 | 		emptyRequest := mcp.CallToolRequest{
300 | 			Params: struct {
301 | 				Name      string    `json:"name"`
302 | 				Arguments any       `json:"arguments,omitempty"`
303 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
304 | 			}{
305 | 				Name: "string_ptr_tool",
306 | 				Arguments: map[string]any{
307 | 					"name":  "empty",
308 | 					"value": 65,
309 | 				},
310 | 			},
311 | 		}
312 | 
313 | 		result, err = handler(ctx, emptyRequest)
314 | 		require.NoError(t, err)
315 | 		assert.Nil(t, result)
316 | 
317 | 		// Test error return
318 | 		errorRequest := mcp.CallToolRequest{
319 | 			Params: struct {
320 | 				Name      string    `json:"name"`
321 | 				Arguments any       `json:"arguments,omitempty"`
322 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
323 | 			}{
324 | 				Name: "string_ptr_tool",
325 | 				Arguments: map[string]any{
326 | 					"name":  "error",
327 | 					"value": 65,
328 | 				},
329 | 			},
330 | 		}
331 | 
332 | 		_, err = handler(ctx, errorRequest)
333 | 		assert.Error(t, err)
334 | 		assert.Equal(t, "test error", err.Error())
335 | 	})
336 | 
337 | 	t.Run("struct return type", func(t *testing.T) {
338 | 		_, handler, err := ConvertTool("struct_tool", "A struct tool", structToolHandler)
339 | 		require.NoError(t, err)
340 | 
341 | 		// Test normal struct return
342 | 		ctx := context.Background()
343 | 		request := mcp.CallToolRequest{
344 | 			Params: struct {
345 | 				Name      string    `json:"name"`
346 | 				Arguments any       `json:"arguments,omitempty"`
347 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
348 | 			}{
349 | 				Name: "struct_tool",
350 | 				Arguments: map[string]any{
351 | 					"name":  "test",
352 | 					"value": 65, // ASCII 'A'
353 | 				},
354 | 			},
355 | 		}
356 | 
357 | 		result, err := handler(ctx, request)
358 | 		require.NoError(t, err)
359 | 		require.NotNil(t, result)
360 | 		require.Len(t, result.Content, 1)
361 | 		resultString, ok := result.Content[0].(mcp.TextContent)
362 | 		require.True(t, ok)
363 | 		assert.Contains(t, resultString.Text, `"name":"test"`)
364 | 		assert.Contains(t, resultString.Text, `"value":65`)
365 | 
366 | 		// Test error return
367 | 		errorRequest := mcp.CallToolRequest{
368 | 			Params: struct {
369 | 				Name      string    `json:"name"`
370 | 				Arguments any       `json:"arguments,omitempty"`
371 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
372 | 			}{
373 | 				Name: "struct_tool",
374 | 				Arguments: map[string]any{
375 | 					"name":  "error",
376 | 					"value": 65,
377 | 				},
378 | 			},
379 | 		}
380 | 
381 | 		_, err = handler(ctx, errorRequest)
382 | 		assert.Error(t, err)
383 | 		assert.Equal(t, "test error", err.Error())
384 | 	})
385 | 
386 | 	t.Run("struct pointer return type", func(t *testing.T) {
387 | 		_, handler, err := ConvertTool("struct_ptr_tool", "A struct pointer tool", structPtrToolHandler)
388 | 		require.NoError(t, err)
389 | 
390 | 		// Test normal struct pointer return
391 | 		ctx := context.Background()
392 | 		request := mcp.CallToolRequest{
393 | 			Params: struct {
394 | 				Name      string    `json:"name"`
395 | 				Arguments any       `json:"arguments,omitempty"`
396 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
397 | 			}{
398 | 				Name: "struct_ptr_tool",
399 | 				Arguments: map[string]any{
400 | 					"name":  "test",
401 | 					"value": 65, // ASCII 'A'
402 | 				},
403 | 			},
404 | 		}
405 | 
406 | 		result, err := handler(ctx, request)
407 | 		require.NoError(t, err)
408 | 		require.NotNil(t, result)
409 | 		require.Len(t, result.Content, 1)
410 | 		resultString, ok := result.Content[0].(mcp.TextContent)
411 | 		require.True(t, ok)
412 | 		assert.Contains(t, resultString.Text, `"name":"test"`)
413 | 		assert.Contains(t, resultString.Text, `"value":65`)
414 | 
415 | 		// Test nil struct pointer return
416 | 		nilRequest := mcp.CallToolRequest{
417 | 			Params: struct {
418 | 				Name      string    `json:"name"`
419 | 				Arguments any       `json:"arguments,omitempty"`
420 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
421 | 			}{
422 | 				Name: "struct_ptr_tool",
423 | 				Arguments: map[string]any{
424 | 					"name":  "nil",
425 | 					"value": 65,
426 | 				},
427 | 			},
428 | 		}
429 | 
430 | 		result, err = handler(ctx, nilRequest)
431 | 		require.NoError(t, err)
432 | 		assert.Nil(t, result)
433 | 
434 | 		// Test error return
435 | 		errorRequest := mcp.CallToolRequest{
436 | 			Params: struct {
437 | 				Name      string    `json:"name"`
438 | 				Arguments any       `json:"arguments,omitempty"`
439 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
440 | 			}{
441 | 				Name: "struct_ptr_tool",
442 | 				Arguments: map[string]any{
443 | 					"name":  "error",
444 | 					"value": 65,
445 | 				},
446 | 			},
447 | 		}
448 | 
449 | 		_, err = handler(ctx, errorRequest)
450 | 		assert.Error(t, err)
451 | 		assert.Equal(t, "test error", err.Error())
452 | 	})
453 | 
454 | 	t.Run("invalid handler types", func(t *testing.T) {
455 | 		// Test wrong second argument type (not a struct)
456 | 		wrongSecondArgFunc := func(ctx context.Context, s string) (*mcp.CallToolResult, error) {
457 | 			return nil, nil
458 | 		}
459 | 		_, _, err := ConvertTool("invalid", "description", wrongSecondArgFunc)
460 | 		assert.Error(t, err)
461 | 		assert.Contains(t, err.Error(), "second argument must be a struct")
462 | 	})
463 | 
464 | 	t.Run("handler execution with invalid arguments", func(t *testing.T) {
465 | 		_, handler, err := ConvertTool("test_tool", "A test tool", testToolHandler)
466 | 		require.NoError(t, err)
467 | 
468 | 		// Test with invalid JSON
469 | 		invalidRequest := mcp.CallToolRequest{
470 | 			Params: struct {
471 | 				Name      string    `json:"name"`
472 | 				Arguments any       `json:"arguments,omitempty"`
473 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
474 | 			}{
475 | 				Arguments: map[string]any{
476 | 					"name": make(chan int), // Channels can't be marshaled to JSON
477 | 				},
478 | 			},
479 | 		}
480 | 
481 | 		_, err = handler(context.Background(), invalidRequest)
482 | 		assert.Error(t, err)
483 | 		assert.Contains(t, err.Error(), "marshal args")
484 | 
485 | 		// Test with type mismatch
486 | 		mismatchRequest := mcp.CallToolRequest{
487 | 			Params: struct {
488 | 				Name      string    `json:"name"`
489 | 				Arguments any       `json:"arguments,omitempty"`
490 | 				Meta      *mcp.Meta `json:"_meta,omitempty"`
491 | 			}{
492 | 				Arguments: map[string]any{
493 | 					"name":  123, // Should be a string
494 | 					"value": "not an int",
495 | 				},
496 | 			},
497 | 		}
498 | 
499 | 		_, err = handler(context.Background(), mismatchRequest)
500 | 		assert.Error(t, err)
501 | 		assert.Contains(t, err.Error(), "unmarshal args")
502 | 	})
503 | }
504 | 
505 | func TestCreateJSONSchemaFromHandler(t *testing.T) {
506 | 	schema := createJSONSchemaFromHandler(testToolHandler)
507 | 
508 | 	assert.Equal(t, "object", schema.Type)
509 | 	assert.Len(t, schema.Required, 2) // name and value are required, optional is not
510 | 
511 | 	// Check properties
512 | 	nameProperty, ok := schema.Properties.Get("name")
513 | 	assert.True(t, ok)
514 | 	assert.Equal(t, "string", nameProperty.Type)
515 | 	assert.Equal(t, "The name parameter", nameProperty.Description)
516 | 
517 | 	valueProperty, ok := schema.Properties.Get("value")
518 | 	assert.True(t, ok)
519 | 	assert.Equal(t, "integer", valueProperty.Type)
520 | 	assert.Equal(t, "The value parameter", valueProperty.Description)
521 | 
522 | 	optionalProperty, ok := schema.Properties.Get("optional")
523 | 	assert.True(t, ok)
524 | 	assert.Equal(t, "boolean", optionalProperty.Type)
525 | 	assert.Equal(t, "An optional parameter", optionalProperty.Description)
526 | }
527 | 
```

--------------------------------------------------------------------------------
/cmd/mcp-grafana/main.go:
--------------------------------------------------------------------------------

```go
  1 | package main
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"errors"
  6 | 	"flag"
  7 | 	"fmt"
  8 | 	"log/slog"
  9 | 	"net/http"
 10 | 	"os"
 11 | 	"os/signal"
 12 | 	"slices"
 13 | 	"strings"
 14 | 	"syscall"
 15 | 	"time"
 16 | 
 17 | 	"github.com/mark3labs/mcp-go/mcp"
 18 | 	"github.com/mark3labs/mcp-go/server"
 19 | 
 20 | 	mcpgrafana "github.com/grafana/mcp-grafana"
 21 | 	"github.com/grafana/mcp-grafana/tools"
 22 | )
 23 | 
 24 | func maybeAddTools(s *server.MCPServer, tf func(*server.MCPServer), enabledTools []string, disable bool, category string) {
 25 | 	if !slices.Contains(enabledTools, category) {
 26 | 		slog.Debug("Not enabling tools", "category", category)
 27 | 		return
 28 | 	}
 29 | 	if disable {
 30 | 		slog.Info("Disabling tools", "category", category)
 31 | 		return
 32 | 	}
 33 | 	slog.Debug("Enabling tools", "category", category)
 34 | 	tf(s)
 35 | }
 36 | 
 37 | // disabledTools indicates whether each category of tools should be disabled.
 38 | type disabledTools struct {
 39 | 	enabledTools string
 40 | 
 41 | 	search, datasource, incident,
 42 | 	prometheus, loki, alerting,
 43 | 	dashboard, folder, oncall, asserts, sift, admin,
 44 | 	pyroscope, navigation, proxied, annotations, write bool
 45 | }
 46 | 
 47 | // Configuration for the Grafana client.
 48 | type grafanaConfig struct {
 49 | 	// Whether to enable debug mode for the Grafana transport.
 50 | 	debug bool
 51 | 
 52 | 	// TLS configuration
 53 | 	tlsCertFile   string
 54 | 	tlsKeyFile    string
 55 | 	tlsCAFile     string
 56 | 	tlsSkipVerify bool
 57 | }
 58 | 
 59 | func (dt *disabledTools) addFlags() {
 60 | 	flag.StringVar(&dt.enabledTools, "enabled-tools", "search,datasource,incident,prometheus,loki,alerting,dashboard,folder,oncall,asserts,sift,admin,pyroscope,navigation,proxied,annotations", "A comma separated list of tools enabled for this server. Can be overwritten entirely or by disabling specific components, e.g. --disable-search.")
 61 | 	flag.BoolVar(&dt.search, "disable-search", false, "Disable search tools")
 62 | 	flag.BoolVar(&dt.datasource, "disable-datasource", false, "Disable datasource tools")
 63 | 	flag.BoolVar(&dt.incident, "disable-incident", false, "Disable incident tools")
 64 | 	flag.BoolVar(&dt.prometheus, "disable-prometheus", false, "Disable prometheus tools")
 65 | 	flag.BoolVar(&dt.loki, "disable-loki", false, "Disable loki tools")
 66 | 	flag.BoolVar(&dt.alerting, "disable-alerting", false, "Disable alerting tools")
 67 | 	flag.BoolVar(&dt.dashboard, "disable-dashboard", false, "Disable dashboard tools")
 68 | 	flag.BoolVar(&dt.folder, "disable-folder", false, "Disable folder tools")
 69 | 	flag.BoolVar(&dt.oncall, "disable-oncall", false, "Disable oncall tools")
 70 | 	flag.BoolVar(&dt.asserts, "disable-asserts", false, "Disable asserts tools")
 71 | 	flag.BoolVar(&dt.sift, "disable-sift", false, "Disable sift tools")
 72 | 	flag.BoolVar(&dt.admin, "disable-admin", false, "Disable admin tools")
 73 | 	flag.BoolVar(&dt.pyroscope, "disable-pyroscope", false, "Disable pyroscope tools")
 74 | 	flag.BoolVar(&dt.navigation, "disable-navigation", false, "Disable navigation tools")
 75 | 	flag.BoolVar(&dt.proxied, "disable-proxied", false, "Disable proxied tools (tools from external MCP servers)")
 76 | 	flag.BoolVar(&dt.write, "disable-write", false, "Disable write tools (create/update operations)")
 77 | 	flag.BoolVar(&dt.annotations, "disable-annotations", false, "Disable annotation tools")
 78 | }
 79 | 
 80 | func (gc *grafanaConfig) addFlags() {
 81 | 	flag.BoolVar(&gc.debug, "debug", false, "Enable debug mode for the Grafana transport")
 82 | 
 83 | 	// TLS configuration flags
 84 | 	flag.StringVar(&gc.tlsCertFile, "tls-cert-file", "", "Path to TLS certificate file for client authentication")
 85 | 	flag.StringVar(&gc.tlsKeyFile, "tls-key-file", "", "Path to TLS private key file for client authentication")
 86 | 	flag.StringVar(&gc.tlsCAFile, "tls-ca-file", "", "Path to TLS CA certificate file for server verification")
 87 | 	flag.BoolVar(&gc.tlsSkipVerify, "tls-skip-verify", false, "Skip TLS certificate verification (insecure)")
 88 | }
 89 | 
 90 | func (dt *disabledTools) addTools(s *server.MCPServer) {
 91 | 	enabledTools := strings.Split(dt.enabledTools, ",")
 92 | 	enableWriteTools := !dt.write
 93 | 	maybeAddTools(s, tools.AddSearchTools, enabledTools, dt.search, "search")
 94 | 	maybeAddTools(s, tools.AddDatasourceTools, enabledTools, dt.datasource, "datasource")
 95 | 	maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddIncidentTools(mcp, enableWriteTools) }, enabledTools, dt.incident, "incident")
 96 | 	maybeAddTools(s, tools.AddPrometheusTools, enabledTools, dt.prometheus, "prometheus")
 97 | 	maybeAddTools(s, tools.AddLokiTools, enabledTools, dt.loki, "loki")
 98 | 	maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddAlertingTools(mcp, enableWriteTools) }, enabledTools, dt.alerting, "alerting")
 99 | 	maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddDashboardTools(mcp, enableWriteTools) }, enabledTools, dt.dashboard, "dashboard")
100 | 	maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddFolderTools(mcp, enableWriteTools) }, enabledTools, dt.folder, "folder")
101 | 	maybeAddTools(s, tools.AddOnCallTools, enabledTools, dt.oncall, "oncall")
102 | 	maybeAddTools(s, tools.AddAssertsTools, enabledTools, dt.asserts, "asserts")
103 | 	maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddSiftTools(mcp, enableWriteTools) }, enabledTools, dt.sift, "sift")
104 | 	maybeAddTools(s, tools.AddAdminTools, enabledTools, dt.admin, "admin")
105 | 	maybeAddTools(s, tools.AddPyroscopeTools, enabledTools, dt.pyroscope, "pyroscope")
106 | 	maybeAddTools(s, tools.AddNavigationTools, enabledTools, dt.navigation, "navigation")
107 | 	maybeAddTools(s, func(mcp *server.MCPServer) { tools.AddAnnotationTools(mcp, enableWriteTools) }, enabledTools, dt.annotations, "annotations")
108 | }
109 | 
110 | func newServer(transport string, dt disabledTools) (*server.MCPServer, *mcpgrafana.ToolManager) {
111 | 	sm := mcpgrafana.NewSessionManager()
112 | 
113 | 	// Declare variable for ToolManager that will be initialized after server creation
114 | 	var stm *mcpgrafana.ToolManager
115 | 
116 | 	// Create hooks
117 | 	hooks := &server.Hooks{
118 | 		OnRegisterSession:   []server.OnRegisterSessionHookFunc{sm.CreateSession},
119 | 		OnUnregisterSession: []server.OnUnregisterSessionHookFunc{sm.RemoveSession},
120 | 	}
121 | 
122 | 	// Add proxied tools hooks if enabled and we're not running in stdio mode.
123 | 	// (stdio mode is handled by InitializeAndRegisterServerTools; per-session tools
124 | 	// are not supported).
125 | 	if transport != "stdio" && !dt.proxied {
126 | 		// OnBeforeListTools: Discover, connect, and register tools
127 | 		hooks.OnBeforeListTools = []server.OnBeforeListToolsFunc{
128 | 			func(ctx context.Context, id any, request *mcp.ListToolsRequest) {
129 | 				if stm != nil {
130 | 					if session := server.ClientSessionFromContext(ctx); session != nil {
131 | 						stm.InitializeAndRegisterProxiedTools(ctx, session)
132 | 					}
133 | 				}
134 | 			},
135 | 		}
136 | 
137 | 		// OnBeforeCallTool: Fallback in case client calls tool without listing first
138 | 		hooks.OnBeforeCallTool = []server.OnBeforeCallToolFunc{
139 | 			func(ctx context.Context, id any, request *mcp.CallToolRequest) {
140 | 				if stm != nil {
141 | 					if session := server.ClientSessionFromContext(ctx); session != nil {
142 | 						stm.InitializeAndRegisterProxiedTools(ctx, session)
143 | 					}
144 | 				}
145 | 			},
146 | 		}
147 | 	}
148 | 	s := server.NewMCPServer("mcp-grafana", mcpgrafana.Version(),
149 | 		server.WithInstructions(`
150 | This server provides access to your Grafana instance and the surrounding ecosystem.
151 | 
152 | Available Capabilities:
153 | - Dashboards: Search, retrieve, update, and create dashboards. Extract panel queries and datasource information.
154 | - Datasources: List and fetch details for datasources.
155 | - Prometheus & Loki: Run PromQL and LogQL queries, retrieve metric/log metadata, and explore label names/values.
156 | - Incidents: Search, create, update, and resolve incidents in Grafana Incident.
157 | - Sift Investigations: Start and manage Sift investigations, analyze logs/traces, find error patterns, and detect slow requests.
158 | - Alerting: List and fetch alert rules and notification contact points.
159 | - OnCall: View and manage on-call schedules, shifts, teams, and users.
160 | - Admin: List teams and perform administrative tasks.
161 | - Pyroscope: Profile applications and fetch profiling data.
162 | - Navigation: Generate deeplink URLs for Grafana resources like dashboards, panels, and Explore queries.
163 | - Proxied Tools: Access tools from external MCP servers (like Tempo) through dynamic discovery.
164 | 
165 | Note that some of these capabilities may be disabled. Do not try to use features that are not available via tools.
166 | `),
167 | 		server.WithHooks(hooks),
168 | 	)
169 | 
170 | 	// Initialize ToolManager now that server is created
171 | 	stm = mcpgrafana.NewToolManager(sm, s, mcpgrafana.WithProxiedTools(!dt.proxied))
172 | 
173 | 	dt.addTools(s)
174 | 	return s, stm
175 | }
176 | 
177 | type tlsConfig struct {
178 | 	certFile, keyFile string
179 | }
180 | 
181 | func (tc *tlsConfig) addFlags() {
182 | 	flag.StringVar(&tc.certFile, "server.tls-cert-file", "", "Path to TLS certificate file for server HTTPS (required for TLS)")
183 | 	flag.StringVar(&tc.keyFile, "server.tls-key-file", "", "Path to TLS private key file for server HTTPS (required for TLS)")
184 | }
185 | 
186 | // httpServer represents a server with Start and Shutdown methods
187 | type httpServer interface {
188 | 	Start(addr string) error
189 | 	Shutdown(ctx context.Context) error
190 | }
191 | 
192 | // runHTTPServer handles the common logic for running HTTP-based servers
193 | func runHTTPServer(ctx context.Context, srv httpServer, addr, transportName string) error {
194 | 	// Start server in a goroutine
195 | 	serverErr := make(chan error, 1)
196 | 	go func() {
197 | 		if err := srv.Start(addr); err != nil {
198 | 			serverErr <- err
199 | 		}
200 | 		close(serverErr)
201 | 	}()
202 | 
203 | 	// Wait for either server error or shutdown signal
204 | 	select {
205 | 	case err := <-serverErr:
206 | 		return err
207 | 	case <-ctx.Done():
208 | 		slog.Info(fmt.Sprintf("%s server shutting down...", transportName))
209 | 
210 | 		// Create a timeout context for shutdown
211 | 		shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
212 | 		defer shutdownCancel()
213 | 
214 | 		if err := srv.Shutdown(shutdownCtx); err != nil {
215 | 			return fmt.Errorf("shutdown error: %v", err)
216 | 		}
217 | 		slog.Debug("Shutdown called, waiting for connections to close...")
218 | 
219 | 		// Wait for server to finish
220 | 		select {
221 | 		case err := <-serverErr:
222 | 			// http.ErrServerClosed is expected when shutting down
223 | 			if err != nil && !errors.Is(err, http.ErrServerClosed) {
224 | 				return fmt.Errorf("server error during shutdown: %v", err)
225 | 			}
226 | 		case <-shutdownCtx.Done():
227 | 			slog.Warn(fmt.Sprintf("%s server did not stop gracefully within timeout", transportName))
228 | 		}
229 | 	}
230 | 
231 | 	return nil
232 | }
233 | 
234 | func handleHealthz(w http.ResponseWriter, r *http.Request) {
235 | 	w.WriteHeader(http.StatusOK)
236 | 	_, _ = w.Write([]byte("ok"))
237 | }
238 | 
239 | func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt disabledTools, gc mcpgrafana.GrafanaConfig, tls tlsConfig) error {
240 | 	slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel})))
241 | 	s, tm := newServer(transport, dt)
242 | 
243 | 	// Create a context that will be cancelled on shutdown
244 | 	ctx, cancel := context.WithCancel(context.Background())
245 | 	defer cancel()
246 | 
247 | 	// Set up signal handling for graceful shutdown
248 | 	sigChan := make(chan os.Signal, 1)
249 | 	signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
250 | 	defer signal.Stop(sigChan)
251 | 
252 | 	// Handle shutdown signals
253 | 	go func() {
254 | 		<-sigChan
255 | 		slog.Info("Received shutdown signal")
256 | 		cancel()
257 | 
258 | 		// For stdio, close stdin to unblock the Listen call
259 | 		if transport == "stdio" {
260 | 			_ = os.Stdin.Close()
261 | 		}
262 | 	}()
263 | 
264 | 	// Start the appropriate server based on transport
265 | 	switch transport {
266 | 	case "stdio":
267 | 		srv := server.NewStdioServer(s)
268 | 		cf := mcpgrafana.ComposedStdioContextFunc(gc)
269 | 		srv.SetContextFunc(cf)
270 | 
271 | 		// For stdio (single-tenant), initialize proxied tools on the server directly
272 | 		if !dt.proxied {
273 | 			stdioCtx := cf(ctx)
274 | 			if err := tm.InitializeAndRegisterServerTools(stdioCtx); err != nil {
275 | 				slog.Error("failed to initialize proxied tools for stdio", "error", err)
276 | 			}
277 | 		}
278 | 
279 | 		slog.Info("Starting Grafana MCP server using stdio transport", "version", mcpgrafana.Version())
280 | 
281 | 		err := srv.Listen(ctx, os.Stdin, os.Stdout)
282 | 		if err != nil && err != context.Canceled {
283 | 			return fmt.Errorf("server error: %v", err)
284 | 		}
285 | 		return nil
286 | 
287 | 	case "sse":
288 | 		httpSrv := &http.Server{Addr: addr}
289 | 		srv := server.NewSSEServer(s,
290 | 			server.WithSSEContextFunc(mcpgrafana.ComposedSSEContextFunc(gc)),
291 | 			server.WithStaticBasePath(basePath),
292 | 			server.WithHTTPServer(httpSrv),
293 | 		)
294 | 		mux := http.NewServeMux()
295 | 		if basePath == "" {
296 | 			basePath = "/"
297 | 		}
298 | 		mux.Handle(basePath, srv)
299 | 		mux.HandleFunc("/healthz", handleHealthz)
300 | 		httpSrv.Handler = mux
301 | 		slog.Info("Starting Grafana MCP server using SSE transport",
302 | 			"version", mcpgrafana.Version(), "address", addr, "basePath", basePath)
303 | 		return runHTTPServer(ctx, srv, addr, "SSE")
304 | 	case "streamable-http":
305 | 		httpSrv := &http.Server{Addr: addr}
306 | 		opts := []server.StreamableHTTPOption{
307 | 			server.WithHTTPContextFunc(mcpgrafana.ComposedHTTPContextFunc(gc)),
308 | 			server.WithStateLess(dt.proxied), // Stateful when proxied tools enabled (requires sessions)
309 | 			server.WithEndpointPath(endpointPath),
310 | 			server.WithStreamableHTTPServer(httpSrv),
311 | 		}
312 | 		if tls.certFile != "" || tls.keyFile != "" {
313 | 			opts = append(opts, server.WithTLSCert(tls.certFile, tls.keyFile))
314 | 		}
315 | 		srv := server.NewStreamableHTTPServer(s, opts...)
316 | 		mux := http.NewServeMux()
317 | 		mux.Handle(endpointPath, srv)
318 | 		mux.HandleFunc("/healthz", handleHealthz)
319 | 		httpSrv.Handler = mux
320 | 		slog.Info("Starting Grafana MCP server using StreamableHTTP transport",
321 | 			"version", mcpgrafana.Version(), "address", addr, "endpointPath", endpointPath)
322 | 		return runHTTPServer(ctx, srv, addr, "StreamableHTTP")
323 | 	default:
324 | 		return fmt.Errorf("invalid transport type: %s. Must be 'stdio', 'sse' or 'streamable-http'", transport)
325 | 	}
326 | }
327 | 
328 | func main() {
329 | 	var transport string
330 | 	flag.StringVar(&transport, "t", "stdio", "Transport type (stdio, sse or streamable-http)")
331 | 	flag.StringVar(
332 | 		&transport,
333 | 		"transport",
334 | 		"stdio",
335 | 		"Transport type (stdio, sse or streamable-http)",
336 | 	)
337 | 	addr := flag.String("address", "localhost:8000", "The host and port to start the sse server on")
338 | 	basePath := flag.String("base-path", "", "Base path for the sse server")
339 | 	endpointPath := flag.String("endpoint-path", "/mcp", "Endpoint path for the streamable-http server")
340 | 	logLevel := flag.String("log-level", "info", "Log level (debug, info, warn, error)")
341 | 	showVersion := flag.Bool("version", false, "Print the version and exit")
342 | 	var dt disabledTools
343 | 	dt.addFlags()
344 | 	var gc grafanaConfig
345 | 	gc.addFlags()
346 | 	var tls tlsConfig
347 | 	tls.addFlags()
348 | 	flag.Parse()
349 | 
350 | 	if *showVersion {
351 | 		fmt.Println(mcpgrafana.Version())
352 | 		os.Exit(0)
353 | 	}
354 | 
355 | 	// Convert local grafanaConfig to mcpgrafana.GrafanaConfig
356 | 	grafanaConfig := mcpgrafana.GrafanaConfig{Debug: gc.debug}
357 | 	if gc.tlsCertFile != "" || gc.tlsKeyFile != "" || gc.tlsCAFile != "" || gc.tlsSkipVerify {
358 | 		grafanaConfig.TLSConfig = &mcpgrafana.TLSConfig{
359 | 			CertFile:   gc.tlsCertFile,
360 | 			KeyFile:    gc.tlsKeyFile,
361 | 			CAFile:     gc.tlsCAFile,
362 | 			SkipVerify: gc.tlsSkipVerify,
363 | 		}
364 | 	}
365 | 
366 | 	if err := run(transport, *addr, *basePath, *endpointPath, parseLevel(*logLevel), dt, grafanaConfig, tls); err != nil {
367 | 		panic(err)
368 | 	}
369 | }
370 | 
371 | func parseLevel(level string) slog.Level {
372 | 	var l slog.Level
373 | 	if err := l.UnmarshalText([]byte(level)); err != nil {
374 | 		return slog.LevelInfo
375 | 	}
376 | 	return l
377 | }
378 | 
```

--------------------------------------------------------------------------------
/session_test.go:
--------------------------------------------------------------------------------

```go
  1 | //go:build integration
  2 | 
  3 | // Integration tests for proxied MCP tools functionality.
  4 | // Requires docker-compose to be running with Grafana and Tempo instances.
  5 | // Run with: go test -tags integration -v ./...
  6 | 
  7 | package mcpgrafana
  8 | 
  9 | import (
 10 | 	"context"
 11 | 	"fmt"
 12 | 	"net/url"
 13 | 	"os"
 14 | 	"strings"
 15 | 	"sync"
 16 | 	"testing"
 17 | 
 18 | 	"github.com/go-openapi/strfmt"
 19 | 	grafana_client "github.com/grafana/grafana-openapi-client-go/client"
 20 | 	"github.com/mark3labs/mcp-go/mcp"
 21 | 	"github.com/stretchr/testify/assert"
 22 | 	"github.com/stretchr/testify/require"
 23 | )
 24 | 
 25 | // newProxiedToolsTestContext creates a test context with Grafana client and config
 26 | func newProxiedToolsTestContext(t *testing.T) context.Context {
 27 | 	cfg := grafana_client.DefaultTransportConfig()
 28 | 	cfg.Host = "localhost:3000"
 29 | 	cfg.Schemes = []string{"http"}
 30 | 
 31 | 	// Extract transport config from env vars, and set it on the context.
 32 | 	if u, ok := os.LookupEnv("GRAFANA_URL"); ok {
 33 | 		parsedURL, err := url.Parse(u)
 34 | 		require.NoError(t, err, "invalid GRAFANA_URL")
 35 | 		cfg.Host = parsedURL.Host
 36 | 		// The Grafana client will always prefer HTTPS even if the URL is HTTP,
 37 | 		// so we need to limit the schemes to HTTP if the URL is HTTP.
 38 | 		if parsedURL.Scheme == "http" {
 39 | 			cfg.Schemes = []string{"http"}
 40 | 		}
 41 | 	}
 42 | 
 43 | 	// Check for the new service account token environment variable first
 44 | 	if apiKey := os.Getenv("GRAFANA_SERVICE_ACCOUNT_TOKEN"); apiKey != "" {
 45 | 		cfg.APIKey = apiKey
 46 | 	} else if apiKey := os.Getenv("GRAFANA_API_KEY"); apiKey != "" {
 47 | 		// Fall back to the deprecated API key environment variable
 48 | 		cfg.APIKey = apiKey
 49 | 	} else {
 50 | 		cfg.BasicAuth = url.UserPassword("admin", "admin")
 51 | 	}
 52 | 
 53 | 	grafanaClient := grafana_client.NewHTTPClientWithConfig(strfmt.Default, cfg)
 54 | 
 55 | 	grafanaCfg := GrafanaConfig{
 56 | 		Debug:     true,
 57 | 		URL:       "http://localhost:3000",
 58 | 		APIKey:    cfg.APIKey,
 59 | 		BasicAuth: cfg.BasicAuth,
 60 | 	}
 61 | 
 62 | 	ctx := WithGrafanaConfig(context.Background(), grafanaCfg)
 63 | 	return WithGrafanaClient(ctx, grafanaClient)
 64 | }
 65 | 
 66 | func TestDiscoverMCPDatasources(t *testing.T) {
 67 | 	ctx := newProxiedToolsTestContext(t)
 68 | 
 69 | 	t.Run("discovers tempo datasources", func(t *testing.T) {
 70 | 		discovered, err := discoverMCPDatasources(ctx)
 71 | 		require.NoError(t, err)
 72 | 
 73 | 		// Should find two Tempo datasources from docker-compose
 74 | 		assert.GreaterOrEqual(t, len(discovered), 2, "Should discover at least 2 Tempo datasources")
 75 | 
 76 | 		// Check that we found the expected datasources
 77 | 		uids := make([]string, len(discovered))
 78 | 		for i, ds := range discovered {
 79 | 			uids[i] = ds.UID
 80 | 			assert.Equal(t, "tempo", ds.Type, "All discovered datasources should be tempo type")
 81 | 			assert.NotEmpty(t, ds.Name, "Datasource should have a name")
 82 | 			assert.NotEmpty(t, ds.MCPURL, "Datasource should have MCP URL")
 83 | 
 84 | 			// Verify URL format
 85 | 			expectedURLPattern := fmt.Sprintf("http://localhost:3000/api/datasources/proxy/uid/%s/api/mcp", ds.UID)
 86 | 			assert.Equal(t, expectedURLPattern, ds.MCPURL, "MCP URL should follow proxy pattern")
 87 | 		}
 88 | 
 89 | 		// Should contain our expected UIDs
 90 | 		assert.Contains(t, uids, "tempo", "Should discover 'tempo' datasource")
 91 | 		assert.Contains(t, uids, "tempo-secondary", "Should discover 'tempo-secondary' datasource")
 92 | 	})
 93 | 
 94 | 	t.Run("returns error when grafana client not in context", func(t *testing.T) {
 95 | 		emptyCtx := context.Background()
 96 | 		discovered, err := discoverMCPDatasources(emptyCtx)
 97 | 		assert.Error(t, err)
 98 | 		assert.Nil(t, discovered)
 99 | 		assert.Contains(t, err.Error(), "grafana client not found in context")
100 | 	})
101 | 
102 | 	t.Run("returns error when auth is missing", func(t *testing.T) {
103 | 		// Context with client but no auth credentials
104 | 		cfg := grafana_client.DefaultTransportConfig()
105 | 		cfg.Host = "localhost:3000"
106 | 		cfg.Schemes = []string{"http"}
107 | 		grafanaClient := grafana_client.NewHTTPClientWithConfig(strfmt.Default, cfg)
108 | 
109 | 		grafanaCfg := GrafanaConfig{
110 | 			URL: "http://localhost:3000",
111 | 			// No APIKey or BasicAuth set
112 | 		}
113 | 		ctx := WithGrafanaConfig(context.Background(), grafanaCfg)
114 | 		ctx = WithGrafanaClient(ctx, grafanaClient)
115 | 
116 | 		discovered, err := discoverMCPDatasources(ctx)
117 | 		assert.Error(t, err)
118 | 		assert.Nil(t, discovered)
119 | 		assert.Contains(t, err.Error(), "Unauthorized")
120 | 	})
121 | }
122 | 
123 | func TestToolNamespacing(t *testing.T) {
124 | 	t.Run("parse proxied tool name", func(t *testing.T) {
125 | 		datasourceType, toolName, err := parseProxiedToolName("tempo_traceql-search")
126 | 		require.NoError(t, err)
127 | 		assert.Equal(t, "tempo", datasourceType)
128 | 		assert.Equal(t, "traceql-search", toolName)
129 | 	})
130 | 
131 | 	t.Run("parse proxied tool name with multiple underscores", func(t *testing.T) {
132 | 		datasourceType, toolName, err := parseProxiedToolName("tempo_get-attribute-values")
133 | 		require.NoError(t, err)
134 | 		assert.Equal(t, "tempo", datasourceType)
135 | 		assert.Equal(t, "get-attribute-values", toolName)
136 | 	})
137 | 
138 | 	t.Run("parse proxied tool name with invalid format", func(t *testing.T) {
139 | 		_, _, err := parseProxiedToolName("invalid")
140 | 		assert.Error(t, err)
141 | 		assert.Contains(t, err.Error(), "invalid proxied tool name format")
142 | 	})
143 | 
144 | 	t.Run("add datasourceUid parameter to tool", func(t *testing.T) {
145 | 		originalTool := mcp.Tool{
146 | 			Name:        "query_traces",
147 | 			Description: "Query traces from Tempo",
148 | 			InputSchema: mcp.ToolInputSchema{
149 | 				Properties: map[string]any{
150 | 					"query": map[string]any{
151 | 						"type": "string",
152 | 					},
153 | 				},
154 | 				Required: []string{"query"},
155 | 			},
156 | 		}
157 | 
158 | 		modifiedTool := addDatasourceUidParameter(originalTool, "tempo")
159 | 
160 | 		assert.Equal(t, "tempo_query_traces", modifiedTool.Name)
161 | 		assert.Equal(t, "Query traces from Tempo", modifiedTool.Description)
162 | 		assert.NotNil(t, modifiedTool.InputSchema.Properties["datasourceUid"])
163 | 		assert.Contains(t, modifiedTool.InputSchema.Required, "datasourceUid")
164 | 		assert.Contains(t, modifiedTool.InputSchema.Required, "query")
165 | 	})
166 | 
167 | 	t.Run("add datasourceUid parameter with empty description", func(t *testing.T) {
168 | 		originalTool := mcp.Tool{
169 | 			Name:        "test_tool",
170 | 			Description: "",
171 | 			InputSchema: mcp.ToolInputSchema{
172 | 				Properties: make(map[string]any),
173 | 			},
174 | 		}
175 | 
176 | 		modifiedTool := addDatasourceUidParameter(originalTool, "tempo")
177 | 
178 | 		assert.Equal(t, "tempo_test_tool", modifiedTool.Name)
179 | 		assert.Equal(t, "", modifiedTool.Description, "Should not modify empty description")
180 | 		assert.NotNil(t, modifiedTool.InputSchema.Properties["datasourceUid"])
181 | 	})
182 | }
183 | 
184 | func TestSessionStateLifecycle(t *testing.T) {
185 | 	t.Run("create and get session", func(t *testing.T) {
186 | 		sm := NewSessionManager()
187 | 
188 | 		// Create mock session
189 | 		mockSession := &mockClientSession{id: "test-session-123"}
190 | 
191 | 		sm.CreateSession(context.Background(), mockSession)
192 | 
193 | 		state, exists := sm.GetSession("test-session-123")
194 | 		assert.True(t, exists)
195 | 		assert.NotNil(t, state)
196 | 		assert.NotNil(t, state.proxiedClients)
197 | 		assert.False(t, state.proxiedToolsInitialized)
198 | 	})
199 | 
200 | 	t.Run("remove session cleans up clients", func(t *testing.T) {
201 | 		sm := NewSessionManager()
202 | 
203 | 		mockSession := &mockClientSession{id: "test-session-456"}
204 | 		sm.CreateSession(context.Background(), mockSession)
205 | 
206 | 		state, _ := sm.GetSession("test-session-456")
207 | 
208 | 		// Add a mock proxied client
209 | 		mockClient := &ProxiedClient{
210 | 			DatasourceUID:  "test-uid",
211 | 			DatasourceName: "Test Datasource",
212 | 			DatasourceType: "tempo",
213 | 		}
214 | 		state.proxiedClients["tempo_test-uid"] = mockClient
215 | 
216 | 		// Remove session
217 | 		sm.RemoveSession(context.Background(), mockSession)
218 | 
219 | 		// Session should be gone
220 | 		_, exists := sm.GetSession("test-session-456")
221 | 		assert.False(t, exists)
222 | 	})
223 | 
224 | 	t.Run("get non-existent session", func(t *testing.T) {
225 | 		sm := NewSessionManager()
226 | 
227 | 		state, exists := sm.GetSession("non-existent")
228 | 		assert.False(t, exists)
229 | 		assert.Nil(t, state)
230 | 	})
231 | }
232 | 
233 | func TestConcurrentInitializationRaceCondition(t *testing.T) {
234 | 	t.Run("concurrent initialization calls should be safe", func(t *testing.T) {
235 | 		sm := NewSessionManager()
236 | 		mockSession := &mockClientSession{id: "race-test-session"}
237 | 		sm.CreateSession(context.Background(), mockSession)
238 | 
239 | 		state, exists := sm.GetSession("race-test-session")
240 | 		require.True(t, exists)
241 | 
242 | 		// Track how many times the initialization logic runs
243 | 		var initCount int
244 | 		var initCountMutex sync.Mutex
245 | 
246 | 		// Create a custom initOnce to track calls
247 | 		state.initOnce = sync.Once{}
248 | 
249 | 		// Simulate the initialization work that should run exactly once
250 | 		initWork := func() {
251 | 			initCountMutex.Lock()
252 | 			initCount++
253 | 			initCountMutex.Unlock()
254 | 			// Simulate some work
255 | 			state.mutex.Lock()
256 | 			state.proxiedToolsInitialized = true
257 | 			state.proxiedClients["tempo_test"] = &ProxiedClient{
258 | 				DatasourceUID:  "test",
259 | 				DatasourceName: "Test",
260 | 				DatasourceType: "tempo",
261 | 			}
262 | 			state.mutex.Unlock()
263 | 		}
264 | 
265 | 		// Launch multiple goroutines that all try to initialize concurrently
266 | 		const numGoroutines = 10
267 | 		var wg sync.WaitGroup
268 | 		wg.Add(numGoroutines)
269 | 
270 | 		for i := 0; i < numGoroutines; i++ {
271 | 			go func() {
272 | 				defer wg.Done()
273 | 				// This should be the pattern used in InitializeAndRegisterProxiedTools
274 | 				state.initOnce.Do(initWork)
275 | 			}()
276 | 		}
277 | 
278 | 		wg.Wait()
279 | 
280 | 		// Verify initialization ran exactly once
281 | 		assert.Equal(t, 1, initCount, "Initialization should run exactly once despite concurrent calls")
282 | 		assert.True(t, state.proxiedToolsInitialized, "State should be initialized")
283 | 		assert.Len(t, state.proxiedClients, 1, "Should have exactly one client")
284 | 	})
285 | 
286 | 	t.Run("sync.Once prevents double initialization", func(t *testing.T) {
287 | 		sm := NewSessionManager()
288 | 		mockSession := &mockClientSession{id: "double-init-test"}
289 | 		sm.CreateSession(context.Background(), mockSession)
290 | 
291 | 		state, _ := sm.GetSession("double-init-test")
292 | 
293 | 		callCount := 0
294 | 
295 | 		// First call
296 | 		state.initOnce.Do(func() {
297 | 			callCount++
298 | 		})
299 | 
300 | 		// Second call should not execute
301 | 		state.initOnce.Do(func() {
302 | 			callCount++
303 | 		})
304 | 
305 | 		// Third call should also not execute
306 | 		state.initOnce.Do(func() {
307 | 			callCount++
308 | 		})
309 | 
310 | 		assert.Equal(t, 1, callCount, "sync.Once should ensure function runs exactly once")
311 | 	})
312 | }
313 | 
314 | func TestProxiedClientLifecycle(t *testing.T) {
315 | 	ctx := newProxiedToolsTestContext(t)
316 | 
317 | 	t.Run("list tools returns copy", func(t *testing.T) {
318 | 		pc := &ProxiedClient{
319 | 			DatasourceUID:  "test-uid",
320 | 			DatasourceName: "Test",
321 | 			DatasourceType: "tempo",
322 | 			Tools: []mcp.Tool{
323 | 				{Name: "tool1", Description: "First tool"},
324 | 				{Name: "tool2", Description: "Second tool"},
325 | 			},
326 | 		}
327 | 
328 | 		tools1 := pc.ListTools()
329 | 		tools2 := pc.ListTools()
330 | 
331 | 		// Should return same content
332 | 		assert.Equal(t, tools1, tools2)
333 | 
334 | 		// But different slice instances (copy)
335 | 		assert.NotSame(t, &tools1[0], &tools2[0])
336 | 	})
337 | 
338 | 	t.Run("call tool validates tool exists", func(t *testing.T) {
339 | 		pc := &ProxiedClient{
340 | 			DatasourceUID:  "test-uid",
341 | 			DatasourceName: "Test",
342 | 			DatasourceType: "tempo",
343 | 			Tools: []mcp.Tool{
344 | 				{Name: "valid_tool", Description: "Valid tool"},
345 | 			},
346 | 		}
347 | 
348 | 		// Call non-existent tool
349 | 		result, err := pc.CallTool(ctx, "non_existent_tool", map[string]any{})
350 | 		assert.Error(t, err)
351 | 		assert.Nil(t, result)
352 | 		assert.Contains(t, err.Error(), "not found in remote MCP server")
353 | 	})
354 | }
355 | 
356 | func TestEndToEndProxiedToolsFlow(t *testing.T) {
357 | 	ctx := newProxiedToolsTestContext(t)
358 | 
359 | 	t.Run("full flow from discovery to tool call", func(t *testing.T) {
360 | 		// Step 1: Discover MCP datasources
361 | 		discovered, err := discoverMCPDatasources(ctx)
362 | 		require.NoError(t, err)
363 | 		require.GreaterOrEqual(t, len(discovered), 1, "Should discover at least one Tempo datasource")
364 | 
365 | 		// Use the first discovered datasource
366 | 		ds := discovered[0]
367 | 		t.Logf("Testing with datasource: %s (UID: %s, URL: %s)", ds.Name, ds.UID, ds.MCPURL)
368 | 
369 | 		// Step 2: Create a proxied client connection
370 | 		client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
371 | 		if err != nil {
372 | 			t.Skipf("Skipping end-to-end test: Tempo MCP endpoint not available: %v", err)
373 | 			return
374 | 		}
375 | 		defer func() {
376 | 			_ = client.Close()
377 | 		}()
378 | 
379 | 		// Step 3: Verify we got tools from the remote server
380 | 		tools := client.ListTools()
381 | 		require.Greater(t, len(tools), 0, "Should have at least one tool from Tempo MCP server")
382 | 		t.Logf("Discovered %d tools from Tempo MCP server", len(tools))
383 | 
384 | 		// Log the available tools
385 | 		for _, tool := range tools {
386 | 			t.Logf("  - Tool: %s - %s", tool.Name, tool.Description)
387 | 		}
388 | 
389 | 		// Step 4: Test tool modification with datasourceUid parameter
390 | 		firstTool := tools[0]
391 | 		modifiedTool := addDatasourceUidParameter(firstTool, ds.Type)
392 | 
393 | 		expectedName := ds.Type + "_" + firstTool.Name
394 | 		assert.Equal(t, expectedName, modifiedTool.Name, "Modified tool should have prefixed name")
395 | 		assert.Contains(t, modifiedTool.InputSchema.Required, "datasourceUid", "Modified tool should require datasourceUid")
396 | 
397 | 		// Step 5: Test session integration
398 | 		sm := NewSessionManager()
399 | 		mockSession := &mockClientSession{id: "e2e-test-session"}
400 | 		sm.CreateSession(ctx, mockSession)
401 | 
402 | 		state, exists := sm.GetSession("e2e-test-session")
403 | 		require.True(t, exists)
404 | 
405 | 		// Store the proxied client in session state
406 | 		key := ds.Type + "_" + ds.UID
407 | 		state.proxiedClients[key] = client
408 | 
409 | 		// Step 6: Verify client is stored correctly in session
410 | 		retrievedClient, exists := state.proxiedClients[key]
411 | 		require.True(t, exists, "Client should be stored in session state")
412 | 		assert.Equal(t, client, retrievedClient, "Should retrieve the same client from session")
413 | 
414 | 		// Step 7: Test ProxiedToolHandler flow
415 | 		handler := NewProxiedToolHandler(sm, nil, modifiedTool.Name)
416 | 		assert.NotNil(t, handler)
417 | 
418 | 		// Note: We can't actually call the tool without knowing what arguments it expects
419 | 		// and without the context having the proper session, but we've validated the setup
420 | 		t.Logf("Successfully validated end-to-end proxied tools flow")
421 | 	})
422 | 
423 | 	t.Run("multiple datasources in single session", func(t *testing.T) {
424 | 		discovered, err := discoverMCPDatasources(ctx)
425 | 		require.NoError(t, err)
426 | 
427 | 		if len(discovered) < 2 {
428 | 			t.Skip("Need at least 2 Tempo datasources for this test")
429 | 		}
430 | 
431 | 		sm := NewSessionManager()
432 | 		mockSession := &mockClientSession{id: "multi-ds-test-session"}
433 | 		sm.CreateSession(ctx, mockSession)
434 | 
435 | 		state, _ := sm.GetSession("multi-ds-test-session")
436 | 
437 | 		// Try to connect to multiple datasources
438 | 		connectedCount := 0
439 | 		for i, ds := range discovered {
440 | 			if i >= 2 {
441 | 				break // Test with first 2 datasources
442 | 			}
443 | 
444 | 			client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL)
445 | 			if err != nil {
446 | 				t.Logf("Could not connect to datasource %s: %v", ds.UID, err)
447 | 				continue
448 | 			}
449 | 			defer func() {
450 | 				_ = client.Close()
451 | 			}()
452 | 
453 | 			key := ds.Type + "_" + ds.UID
454 | 			state.proxiedClients[key] = client
455 | 			connectedCount++
456 | 
457 | 			t.Logf("Connected to datasource %s with %d tools", ds.UID, len(client.Tools))
458 | 		}
459 | 
460 | 		if connectedCount == 0 {
461 | 			t.Skip("Could not connect to any Tempo datasources")
462 | 		}
463 | 
464 | 		// Verify each client is stored correctly
465 | 		for key, client := range state.proxiedClients {
466 | 			parts := strings.Split(key, "_")
467 | 			require.Len(t, parts, 2, "Key should have format type_uid")
468 | 			assert.NotNil(t, client, "Client should not be nil")
469 | 			assert.Equal(t, parts[0], client.DatasourceType, "Client type should match key")
470 | 			assert.Equal(t, parts[1], client.DatasourceUID, "Client UID should match key")
471 | 		}
472 | 
473 | 		t.Logf("Successfully managed %d datasources in single session", connectedCount)
474 | 	})
475 | }
476 | 
```

--------------------------------------------------------------------------------
/tools/prometheus.go:
--------------------------------------------------------------------------------

```go
  1 | package tools
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"fmt"
  6 | 	"net/http"
  7 | 	"regexp"
  8 | 	"strings"
  9 | 	"time"
 10 | 
 11 | 	"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
 12 | 	mcpgrafana "github.com/grafana/mcp-grafana"
 13 | 	"github.com/mark3labs/mcp-go/mcp"
 14 | 	"github.com/mark3labs/mcp-go/server"
 15 | 	"github.com/prometheus/client_golang/api"
 16 | 	promv1 "github.com/prometheus/client_golang/api/prometheus/v1"
 17 | 	"github.com/prometheus/common/config"
 18 | 	"github.com/prometheus/common/model"
 19 | 	"github.com/prometheus/prometheus/model/labels"
 20 | )
 21 | 
 22 | var (
 23 | 	matchTypeMap = map[string]labels.MatchType{
 24 | 		"":   labels.MatchEqual,
 25 | 		"=":  labels.MatchEqual,
 26 | 		"!=": labels.MatchNotEqual,
 27 | 		"=~": labels.MatchRegexp,
 28 | 		"!~": labels.MatchNotRegexp,
 29 | 	}
 30 | )
 31 | 
 32 | func promClientFromContext(ctx context.Context, uid string) (promv1.API, error) {
 33 | 	// First check if the datasource exists
 34 | 	_, err := getDatasourceByUID(ctx, GetDatasourceByUIDParams{UID: uid})
 35 | 	if err != nil {
 36 | 		return nil, err
 37 | 	}
 38 | 
 39 | 	cfg := mcpgrafana.GrafanaConfigFromContext(ctx)
 40 | 	url := fmt.Sprintf("%s/api/datasources/proxy/uid/%s", strings.TrimRight(cfg.URL, "/"), uid)
 41 | 
 42 | 	// Create custom transport with TLS configuration if available
 43 | 	rt := api.DefaultRoundTripper
 44 | 	if tlsConfig := cfg.TLSConfig; tlsConfig != nil {
 45 | 		customTransport, err := tlsConfig.HTTPTransport(rt.(*http.Transport))
 46 | 		if err != nil {
 47 | 			return nil, fmt.Errorf("failed to create custom transport: %w", err)
 48 | 		}
 49 | 		rt = customTransport
 50 | 	}
 51 | 
 52 | 	if cfg.AccessToken != "" && cfg.IDToken != "" {
 53 | 		rt = config.NewHeadersRoundTripper(&config.Headers{
 54 | 			Headers: map[string]config.Header{
 55 | 				"X-Access-Token": {
 56 | 					Secrets: []config.Secret{config.Secret(cfg.AccessToken)},
 57 | 				},
 58 | 				"X-Grafana-Id": {
 59 | 					Secrets: []config.Secret{config.Secret(cfg.IDToken)},
 60 | 				},
 61 | 			},
 62 | 		}, rt)
 63 | 	} else if cfg.APIKey != "" {
 64 | 		rt = config.NewAuthorizationCredentialsRoundTripper(
 65 | 			"Bearer", config.NewInlineSecret(cfg.APIKey), rt,
 66 | 		)
 67 | 	} else if cfg.BasicAuth != nil {
 68 | 		password, _ := cfg.BasicAuth.Password()
 69 | 		rt = config.NewBasicAuthRoundTripper(config.NewInlineSecret(cfg.BasicAuth.Username()), config.NewInlineSecret(password), rt)
 70 | 	}
 71 | 
 72 | 	// Wrap with org ID support
 73 | 	rt = mcpgrafana.NewOrgIDRoundTripper(rt, cfg.OrgID)
 74 | 
 75 | 	c, err := api.NewClient(api.Config{
 76 | 		Address:      url,
 77 | 		RoundTripper: rt,
 78 | 	})
 79 | 	if err != nil {
 80 | 		return nil, fmt.Errorf("creating Prometheus client: %w", err)
 81 | 	}
 82 | 
 83 | 	return promv1.NewAPI(c), nil
 84 | }
 85 | 
 86 | type ListPrometheusMetricMetadataParams struct {
 87 | 	DatasourceUID  string `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
 88 | 	Limit          int    `json:"limit" jsonschema:"description=The maximum number of metrics to return"`
 89 | 	LimitPerMetric int    `json:"limitPerMetric" jsonschema:"description=The maximum number of metrics to return per metric"`
 90 | 	Metric         string `json:"metric" jsonschema:"description=The metric to query"`
 91 | }
 92 | 
 93 | func listPrometheusMetricMetadata(ctx context.Context, args ListPrometheusMetricMetadataParams) (map[string][]promv1.Metadata, error) {
 94 | 	promClient, err := promClientFromContext(ctx, args.DatasourceUID)
 95 | 	if err != nil {
 96 | 		return nil, fmt.Errorf("getting Prometheus client: %w", err)
 97 | 	}
 98 | 
 99 | 	limit := args.Limit
100 | 	if limit == 0 {
101 | 		limit = 10
102 | 	}
103 | 
104 | 	metadata, err := promClient.Metadata(ctx, args.Metric, fmt.Sprintf("%d", limit))
105 | 	if err != nil {
106 | 		return nil, fmt.Errorf("listing Prometheus metric metadata: %w", err)
107 | 	}
108 | 	return metadata, nil
109 | }
110 | 
111 | var ListPrometheusMetricMetadata = mcpgrafana.MustTool(
112 | 	"list_prometheus_metric_metadata",
113 | 	"List Prometheus metric metadata. Returns metadata about metrics currently scraped from targets. Note: This endpoint is experimental.",
114 | 	listPrometheusMetricMetadata,
115 | 	mcp.WithTitleAnnotation("List Prometheus metric metadata"),
116 | 	mcp.WithIdempotentHintAnnotation(true),
117 | 	mcp.WithReadOnlyHintAnnotation(true),
118 | )
119 | 
120 | type QueryPrometheusParams struct {
121 | 	DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
122 | 	Expr          string `json:"expr" jsonschema:"required,description=The PromQL expression to query"`
123 | 	StartTime     string `json:"startTime" jsonschema:"required,description=The start time. Supported formats are RFC3339 or relative to now (e.g. 'now'\\, 'now-1.5h'\\, 'now-2h45m'). Valid time units are 'ns'\\, 'us' (or 'µs')\\, 'ms'\\, 's'\\, 'm'\\, 'h'\\, 'd'."`
124 | 	EndTime       string `json:"endTime,omitempty" jsonschema:"description=The end time. Required if queryType is 'range'\\, ignored if queryType is 'instant' Supported formats are RFC3339 or relative to now (e.g. 'now'\\, 'now-1.5h'\\, 'now-2h45m'). Valid time units are 'ns'\\, 'us' (or 'µs')\\, 'ms'\\, 's'\\, 'm'\\, 'h'\\, 'd'."`
125 | 	StepSeconds   int    `json:"stepSeconds,omitempty" jsonschema:"description=The time series step size in seconds. Required if queryType is 'range'\\, ignored if queryType is 'instant'"`
126 | 	QueryType     string `json:"queryType,omitempty" jsonschema:"description=The type of query to use. Either 'range' or 'instant'"`
127 | }
128 | 
129 | func parseTime(timeStr string) (time.Time, error) {
130 | 	tr := gtime.TimeRange{
131 | 		From: timeStr,
132 | 		Now:  time.Now(),
133 | 	}
134 | 	return tr.ParseFrom()
135 | }
136 | 
137 | func queryPrometheus(ctx context.Context, args QueryPrometheusParams) (model.Value, error) {
138 | 	promClient, err := promClientFromContext(ctx, args.DatasourceUID)
139 | 	if err != nil {
140 | 		return nil, fmt.Errorf("getting Prometheus client: %w", err)
141 | 	}
142 | 
143 | 	queryType := args.QueryType
144 | 	if queryType == "" {
145 | 		queryType = "range"
146 | 	}
147 | 
148 | 	var startTime time.Time
149 | 	startTime, err = parseTime(args.StartTime)
150 | 	if err != nil {
151 | 		return nil, fmt.Errorf("parsing start time: %w", err)
152 | 	}
153 | 
154 | 	switch queryType {
155 | 	case "range":
156 | 		if args.StepSeconds == 0 {
157 | 			return nil, fmt.Errorf("stepSeconds must be provided when queryType is 'range'")
158 | 		}
159 | 
160 | 		var endTime time.Time
161 | 		endTime, err = parseTime(args.EndTime)
162 | 		if err != nil {
163 | 			return nil, fmt.Errorf("parsing end time: %w", err)
164 | 		}
165 | 
166 | 		step := time.Duration(args.StepSeconds) * time.Second
167 | 		result, _, err := promClient.QueryRange(ctx, args.Expr, promv1.Range{
168 | 			Start: startTime,
169 | 			End:   endTime,
170 | 			Step:  step,
171 | 		})
172 | 		if err != nil {
173 | 			return nil, fmt.Errorf("querying Prometheus range: %w", err)
174 | 		}
175 | 		return result, nil
176 | 	case "instant":
177 | 		result, _, err := promClient.Query(ctx, args.Expr, startTime)
178 | 		if err != nil {
179 | 			return nil, fmt.Errorf("querying Prometheus instant: %w", err)
180 | 		}
181 | 		return result, nil
182 | 	}
183 | 
184 | 	return nil, fmt.Errorf("invalid query type: %s", queryType)
185 | }
186 | 
187 | var QueryPrometheus = mcpgrafana.MustTool(
188 | 	"query_prometheus",
189 | 	"Query Prometheus using a PromQL expression. Supports both instant queries (at a single point in time) and range queries (over a time range). Time can be specified either in RFC3339 format or as relative time expressions like 'now', 'now-1h', 'now-30m', etc.",
190 | 	queryPrometheus,
191 | 	mcp.WithTitleAnnotation("Query Prometheus metrics"),
192 | 	mcp.WithIdempotentHintAnnotation(true),
193 | 	mcp.WithReadOnlyHintAnnotation(true),
194 | )
195 | 
196 | type ListPrometheusMetricNamesParams struct {
197 | 	DatasourceUID string `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
198 | 	Regex         string `json:"regex" jsonschema:"description=The regex to match against the metric names"`
199 | 	Limit         int    `json:"limit,omitempty" jsonschema:"description=The maximum number of results to return"`
200 | 	Page          int    `json:"page,omitempty" jsonschema:"description=The page number to return"`
201 | }
202 | 
203 | func listPrometheusMetricNames(ctx context.Context, args ListPrometheusMetricNamesParams) ([]string, error) {
204 | 	promClient, err := promClientFromContext(ctx, args.DatasourceUID)
205 | 	if err != nil {
206 | 		return nil, fmt.Errorf("getting Prometheus client: %w", err)
207 | 	}
208 | 
209 | 	limit := args.Limit
210 | 	if limit == 0 {
211 | 		limit = 10
212 | 	}
213 | 
214 | 	page := args.Page
215 | 	if page == 0 {
216 | 		page = 1
217 | 	}
218 | 
219 | 	// Get all metric names by querying for __name__ label values
220 | 	labelValues, _, err := promClient.LabelValues(ctx, "__name__", nil, time.Time{}, time.Time{})
221 | 	if err != nil {
222 | 		return nil, fmt.Errorf("listing Prometheus metric names: %w", err)
223 | 	}
224 | 
225 | 	// Filter by regex if provided
226 | 	matches := []string{}
227 | 	if args.Regex != "" {
228 | 		re, err := regexp.Compile(args.Regex)
229 | 		if err != nil {
230 | 			return nil, fmt.Errorf("compiling regex: %w", err)
231 | 		}
232 | 		for _, val := range labelValues {
233 | 			if re.MatchString(string(val)) {
234 | 				matches = append(matches, string(val))
235 | 			}
236 | 		}
237 | 	} else {
238 | 		for _, val := range labelValues {
239 | 			matches = append(matches, string(val))
240 | 		}
241 | 	}
242 | 
243 | 	// Apply pagination
244 | 	start := (page - 1) * limit
245 | 	end := start + limit
246 | 	if start >= len(matches) {
247 | 		matches = []string{}
248 | 	} else if end > len(matches) {
249 | 		matches = matches[start:]
250 | 	} else {
251 | 		matches = matches[start:end]
252 | 	}
253 | 
254 | 	return matches, nil
255 | }
256 | 
257 | var ListPrometheusMetricNames = mcpgrafana.MustTool(
258 | 	"list_prometheus_metric_names",
259 | 	"List metric names in a Prometheus datasource. Retrieves all metric names and then filters them locally using the provided regex. Supports pagination.",
260 | 	listPrometheusMetricNames,
261 | 	mcp.WithTitleAnnotation("List Prometheus metric names"),
262 | 	mcp.WithIdempotentHintAnnotation(true),
263 | 	mcp.WithReadOnlyHintAnnotation(true),
264 | )
265 | 
266 | type LabelMatcher struct {
267 | 	Name  string `json:"name" jsonschema:"required,description=The name of the label to match against"`
268 | 	Value string `json:"value" jsonschema:"required,description=The value to match against"`
269 | 	Type  string `json:"type" jsonschema:"required,description=One of the '=' or '!=' or '=~' or '!~'"`
270 | }
271 | 
272 | type Selector struct {
273 | 	Filters []LabelMatcher `json:"filters"`
274 | }
275 | 
276 | func (s Selector) String() string {
277 | 	b := strings.Builder{}
278 | 	b.WriteRune('{')
279 | 	for i, f := range s.Filters {
280 | 		if f.Type == "" {
281 | 			f.Type = "="
282 | 		}
283 | 		b.WriteString(fmt.Sprintf(`%s%s'%s'`, f.Name, f.Type, f.Value))
284 | 		if i < len(s.Filters)-1 {
285 | 			b.WriteString(", ")
286 | 		}
287 | 	}
288 | 	b.WriteRune('}')
289 | 	return b.String()
290 | }
291 | 
292 | // Matches runs the matchers against the given labels and returns whether they match the selector.
293 | func (s Selector) Matches(lbls labels.Labels) (bool, error) {
294 | 	matchers := make(labels.Selector, 0, len(s.Filters))
295 | 
296 | 	for _, filter := range s.Filters {
297 | 		matchType, ok := matchTypeMap[filter.Type]
298 | 		if !ok {
299 | 			return false, fmt.Errorf("invalid matcher type: %s", filter.Type)
300 | 		}
301 | 
302 | 		matcher, err := labels.NewMatcher(matchType, filter.Name, filter.Value)
303 | 		if err != nil {
304 | 			return false, fmt.Errorf("creating matcher: %w", err)
305 | 		}
306 | 
307 | 		matchers = append(matchers, matcher)
308 | 	}
309 | 
310 | 	return matchers.Matches(lbls), nil
311 | }
312 | 
313 | type ListPrometheusLabelNamesParams struct {
314 | 	DatasourceUID string     `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
315 | 	Matches       []Selector `json:"matches,omitempty" jsonschema:"description=Optionally\\, a list of label matchers to filter the results by"`
316 | 	StartRFC3339  string     `json:"startRfc3339,omitempty" jsonschema:"description=Optionally\\, the start time of the time range to filter the results by"`
317 | 	EndRFC3339    string     `json:"endRfc3339,omitempty" jsonschema:"description=Optionally\\, the end time of the time range to filter the results by"`
318 | 	Limit         int        `json:"limit,omitempty" jsonschema:"description=Optionally\\, the maximum number of results to return"`
319 | }
320 | 
321 | func listPrometheusLabelNames(ctx context.Context, args ListPrometheusLabelNamesParams) ([]string, error) {
322 | 	promClient, err := promClientFromContext(ctx, args.DatasourceUID)
323 | 	if err != nil {
324 | 		return nil, fmt.Errorf("getting Prometheus client: %w", err)
325 | 	}
326 | 
327 | 	limit := args.Limit
328 | 	if limit == 0 {
329 | 		limit = 100
330 | 	}
331 | 
332 | 	var startTime, endTime time.Time
333 | 	if args.StartRFC3339 != "" {
334 | 		if startTime, err = time.Parse(time.RFC3339, args.StartRFC3339); err != nil {
335 | 			return nil, fmt.Errorf("parsing start time: %w", err)
336 | 		}
337 | 	}
338 | 	if args.EndRFC3339 != "" {
339 | 		if endTime, err = time.Parse(time.RFC3339, args.EndRFC3339); err != nil {
340 | 			return nil, fmt.Errorf("parsing end time: %w", err)
341 | 		}
342 | 	}
343 | 
344 | 	var matchers []string
345 | 	for _, m := range args.Matches {
346 | 		matchers = append(matchers, m.String())
347 | 	}
348 | 
349 | 	labelNames, _, err := promClient.LabelNames(ctx, matchers, startTime, endTime)
350 | 	if err != nil {
351 | 		return nil, fmt.Errorf("listing Prometheus label names: %w", err)
352 | 	}
353 | 
354 | 	// Apply limit
355 | 	if len(labelNames) > limit {
356 | 		labelNames = labelNames[:limit]
357 | 	}
358 | 
359 | 	return labelNames, nil
360 | }
361 | 
362 | var ListPrometheusLabelNames = mcpgrafana.MustTool(
363 | 	"list_prometheus_label_names",
364 | 	"List label names in a Prometheus datasource. Allows filtering by series selectors and time range.",
365 | 	listPrometheusLabelNames,
366 | 	mcp.WithTitleAnnotation("List Prometheus label names"),
367 | 	mcp.WithIdempotentHintAnnotation(true),
368 | 	mcp.WithReadOnlyHintAnnotation(true),
369 | )
370 | 
371 | type ListPrometheusLabelValuesParams struct {
372 | 	DatasourceUID string     `json:"datasourceUid" jsonschema:"required,description=The UID of the datasource to query"`
373 | 	LabelName     string     `json:"labelName" jsonschema:"required,description=The name of the label to query"`
374 | 	Matches       []Selector `json:"matches,omitempty" jsonschema:"description=Optionally\\, a list of selectors to filter the results by"`
375 | 	StartRFC3339  string     `json:"startRfc3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query"`
376 | 	EndRFC3339    string     `json:"endRfc3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query"`
377 | 	Limit         int        `json:"limit,omitempty" jsonschema:"description=Optionally\\, the maximum number of results to return"`
378 | }
379 | 
380 | func listPrometheusLabelValues(ctx context.Context, args ListPrometheusLabelValuesParams) (model.LabelValues, error) {
381 | 	promClient, err := promClientFromContext(ctx, args.DatasourceUID)
382 | 	if err != nil {
383 | 		return nil, fmt.Errorf("getting Prometheus client: %w", err)
384 | 	}
385 | 
386 | 	limit := args.Limit
387 | 	if limit == 0 {
388 | 		limit = 100
389 | 	}
390 | 
391 | 	var startTime, endTime time.Time
392 | 	if args.StartRFC3339 != "" {
393 | 		if startTime, err = time.Parse(time.RFC3339, args.StartRFC3339); err != nil {
394 | 			return nil, fmt.Errorf("parsing start time: %w", err)
395 | 		}
396 | 	}
397 | 	if args.EndRFC3339 != "" {
398 | 		if endTime, err = time.Parse(time.RFC3339, args.EndRFC3339); err != nil {
399 | 			return nil, fmt.Errorf("parsing end time: %w", err)
400 | 		}
401 | 	}
402 | 
403 | 	var matchers []string
404 | 	for _, m := range args.Matches {
405 | 		matchers = append(matchers, m.String())
406 | 	}
407 | 
408 | 	labelValues, _, err := promClient.LabelValues(ctx, args.LabelName, matchers, startTime, endTime)
409 | 	if err != nil {
410 | 		return nil, fmt.Errorf("listing Prometheus label values: %w", err)
411 | 	}
412 | 
413 | 	// Apply limit
414 | 	if len(labelValues) > limit {
415 | 		labelValues = labelValues[:limit]
416 | 	}
417 | 
418 | 	return labelValues, nil
419 | }
420 | 
421 | var ListPrometheusLabelValues = mcpgrafana.MustTool(
422 | 	"list_prometheus_label_values",
423 | 	"Get the values for a specific label name in Prometheus. Allows filtering by series selectors and time range.",
424 | 	listPrometheusLabelValues,
425 | 	mcp.WithTitleAnnotation("List Prometheus label values"),
426 | 	mcp.WithIdempotentHintAnnotation(true),
427 | 	mcp.WithReadOnlyHintAnnotation(true),
428 | )
429 | 
430 | func AddPrometheusTools(mcp *server.MCPServer) {
431 | 	ListPrometheusMetricMetadata.Register(mcp)
432 | 	QueryPrometheus.Register(mcp)
433 | 	ListPrometheusMetricNames.Register(mcp)
434 | 	ListPrometheusLabelNames.Register(mcp)
435 | 	ListPrometheusLabelValues.Register(mcp)
436 | }
437 | 
```

--------------------------------------------------------------------------------
/tools/pyroscope.go:
--------------------------------------------------------------------------------

```go
  1 | package tools
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"fmt"
  6 | 	"io"
  7 | 	"net/http"
  8 | 	"net/url"
  9 | 	"regexp"
 10 | 	"strings"
 11 | 	"time"
 12 | 
 13 | 	"connectrpc.com/connect"
 14 | 	mcpgrafana "github.com/grafana/mcp-grafana"
 15 | 	querierv1 "github.com/grafana/pyroscope/api/gen/proto/go/querier/v1"
 16 | 	"github.com/grafana/pyroscope/api/gen/proto/go/querier/v1/querierv1connect"
 17 | 	typesv1 "github.com/grafana/pyroscope/api/gen/proto/go/types/v1"
 18 | 	"github.com/mark3labs/mcp-go/mcp"
 19 | 	"github.com/mark3labs/mcp-go/server"
 20 | )
 21 | 
 22 | func AddPyroscopeTools(mcp *server.MCPServer) {
 23 | 	ListPyroscopeLabelNames.Register(mcp)
 24 | 	ListPyroscopeLabelValues.Register(mcp)
 25 | 	ListPyroscopeProfileTypes.Register(mcp)
 26 | 	FetchPyroscopeProfile.Register(mcp)
 27 | }
 28 | 
 29 | const listPyroscopeLabelNamesToolPrompt = `
 30 | Lists all available label names (keys) found in profiles within a specified Pyroscope datasource, time range, and
 31 | optional label matchers. Label matchers are typically used to qualify a service name ({service_name="foo"}). Returns a
 32 | list of unique label strings (e.g., ["app", "env", "pod"]). Label names with double underscores (e.g. __name__) are
 33 | internal and rarely useful to users. If the time range is not provided, it defaults to the last hour.
 34 | `
 35 | 
 36 | var ListPyroscopeLabelNames = mcpgrafana.MustTool(
 37 | 	"list_pyroscope_label_names",
 38 | 	listPyroscopeLabelNamesToolPrompt,
 39 | 	listPyroscopeLabelNames,
 40 | 	mcp.WithTitleAnnotation("List Pyroscope label names"),
 41 | 	mcp.WithIdempotentHintAnnotation(true),
 42 | 	mcp.WithReadOnlyHintAnnotation(true),
 43 | )
 44 | 
 45 | type ListPyroscopeLabelNamesParams struct {
 46 | 	DataSourceUID string `json:"data_source_uid" jsonschema:"required,description=The UID of the datasource to query"`
 47 | 	Matchers      string `json:"matchers,omitempty" jsonschema:"Prometheus style matchers used t0 filter the result set (defaults to: {})"`
 48 | 	StartRFC3339  string `json:"start_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query in RFC3339 format (defaults to 1 hour ago)"`
 49 | 	EndRFC3339    string `json:"end_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query in RFC3339 format (defaults to now)"`
 50 | }
 51 | 
 52 | func listPyroscopeLabelNames(ctx context.Context, args ListPyroscopeLabelNamesParams) ([]string, error) {
 53 | 	args.Matchers = stringOrDefault(args.Matchers, "{}")
 54 | 
 55 | 	start, err := rfc3339OrDefault(args.StartRFC3339, time.Time{})
 56 | 	if err != nil {
 57 | 		return nil, fmt.Errorf("failed to parse start timestamp %q: %w", args.StartRFC3339, err)
 58 | 	}
 59 | 
 60 | 	end, err := rfc3339OrDefault(args.EndRFC3339, time.Time{})
 61 | 	if err != nil {
 62 | 		return nil, fmt.Errorf("failed to parse end timestamp %q: %w", args.EndRFC3339, err)
 63 | 	}
 64 | 
 65 | 	start, end, err = validateTimeRange(start, end)
 66 | 	if err != nil {
 67 | 		return nil, err
 68 | 	}
 69 | 
 70 | 	client, err := newPyroscopeClient(ctx, args.DataSourceUID)
 71 | 	if err != nil {
 72 | 		return nil, fmt.Errorf("failed to create Pyroscope client: %w", err)
 73 | 	}
 74 | 
 75 | 	req := &typesv1.LabelNamesRequest{
 76 | 		Matchers: []string{args.Matchers},
 77 | 		Start:    start.UnixMilli(),
 78 | 		End:      end.UnixMilli(),
 79 | 	}
 80 | 	res, err := client.LabelNames(ctx, connect.NewRequest(req))
 81 | 	if err != nil {
 82 | 		return nil, fmt.Errorf("failed to call Pyroscope API: %w", err)
 83 | 	}
 84 | 
 85 | 	return res.Msg.Names, nil
 86 | }
 87 | 
 88 | const listPyroscopeLabelValuesToolPrompt = `
 89 | Lists all available label values for a particular label name found in profiles within a specified Pyroscope datasource,
 90 | time range, and optional label matchers. Label matchers are typically used to qualify a service name ({service_name="foo"}).
 91 | Returns a list of unique label strings (e.g. for label name "env": ["dev", "staging", "prod"]). If the time range
 92 | is not provided, it defaults to the last hour.
 93 | `
 94 | 
 95 | var ListPyroscopeLabelValues = mcpgrafana.MustTool(
 96 | 	"list_pyroscope_label_values",
 97 | 	listPyroscopeLabelValuesToolPrompt,
 98 | 	listPyroscopeLabelValues,
 99 | 	mcp.WithTitleAnnotation("List Pyroscope label values"),
100 | 	mcp.WithIdempotentHintAnnotation(true),
101 | 	mcp.WithReadOnlyHintAnnotation(true),
102 | )
103 | 
104 | type ListPyroscopeLabelValuesParams struct {
105 | 	DataSourceUID string `json:"data_source_uid" jsonschema:"required,description=The UID of the datasource to query"`
106 | 	Name          string `json:"name" jsonschema:"required,description=A label name"`
107 | 	Matchers      string `json:"matchers,omitempty" jsonschema:"description=Optionally\\, Prometheus style matchers used to filter the result set (defaults to: {})"`
108 | 	StartRFC3339  string `json:"start_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query in RFC3339 format (defaults to 1 hour ago)"`
109 | 	EndRFC3339    string `json:"end_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query in RFC3339 format (defaults to now)"`
110 | }
111 | 
112 | func listPyroscopeLabelValues(ctx context.Context, args ListPyroscopeLabelValuesParams) ([]string, error) {
113 | 	args.Name = strings.TrimSpace(args.Name)
114 | 	if args.Name == "" {
115 | 		return nil, fmt.Errorf("name is required")
116 | 	}
117 | 
118 | 	args.Matchers = stringOrDefault(args.Matchers, "{}")
119 | 
120 | 	start, err := rfc3339OrDefault(args.StartRFC3339, time.Time{})
121 | 	if err != nil {
122 | 		return nil, fmt.Errorf("failed to parse start timestamp %q: %w", args.StartRFC3339, err)
123 | 	}
124 | 
125 | 	end, err := rfc3339OrDefault(args.EndRFC3339, time.Time{})
126 | 	if err != nil {
127 | 		return nil, fmt.Errorf("failed to parse end timestamp %q: %w", args.EndRFC3339, err)
128 | 	}
129 | 
130 | 	start, end, err = validateTimeRange(start, end)
131 | 	if err != nil {
132 | 		return nil, err
133 | 	}
134 | 
135 | 	client, err := newPyroscopeClient(ctx, args.DataSourceUID)
136 | 	if err != nil {
137 | 		return nil, fmt.Errorf("failed to create Pyroscope client: %w", err)
138 | 	}
139 | 
140 | 	req := &typesv1.LabelValuesRequest{
141 | 		Name:     args.Name,
142 | 		Matchers: []string{args.Matchers},
143 | 		Start:    start.UnixMilli(),
144 | 		End:      end.UnixMilli(),
145 | 	}
146 | 	res, err := client.LabelValues(ctx, connect.NewRequest(req))
147 | 	if err != nil {
148 | 		return nil, fmt.Errorf("failed to call Pyroscope API: %w", err)
149 | 	}
150 | 
151 | 	return res.Msg.Names, nil
152 | }
153 | 
154 | const listPyroscopeProfileTypesToolPrompt = `
155 | Lists all available profile types available in a specified Pyroscope datasource and time range. Returns a list of all
156 | available profile types (example profile type: "process_cpu:cpu:nanoseconds:cpu:nanoseconds"). A profile type has the
157 | following structure: <name>:<sample type>:<sample unit>:<period type>:<period unit>. Not all profile types are available
158 | for every service. If the time range is not provided, it defaults to the last hour.
159 | `
160 | 
161 | var ListPyroscopeProfileTypes = mcpgrafana.MustTool(
162 | 	"list_pyroscope_profile_types",
163 | 	listPyroscopeProfileTypesToolPrompt,
164 | 	listPyroscopeProfileTypes,
165 | 	mcp.WithTitleAnnotation("List Pyroscope profile types"),
166 | 	mcp.WithIdempotentHintAnnotation(true),
167 | 	mcp.WithReadOnlyHintAnnotation(true),
168 | )
169 | 
170 | type ListPyroscopeProfileTypesParams struct {
171 | 	DataSourceUID string `json:"data_source_uid" jsonschema:"required,description=The UID of the datasource to query"`
172 | 	StartRFC3339  string `json:"start_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query in RFC3339 format (defaults to 1 hour ago)"`
173 | 	EndRFC3339    string `json:"end_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query in RFC3339 format (defaults to now)"`
174 | }
175 | 
176 | func listPyroscopeProfileTypes(ctx context.Context, args ListPyroscopeProfileTypesParams) ([]string, error) {
177 | 	start, err := rfc3339OrDefault(args.StartRFC3339, time.Time{})
178 | 	if err != nil {
179 | 		return nil, fmt.Errorf("failed to parse start timestamp %q: %w", args.StartRFC3339, err)
180 | 	}
181 | 
182 | 	end, err := rfc3339OrDefault(args.EndRFC3339, time.Time{})
183 | 	if err != nil {
184 | 		return nil, fmt.Errorf("failed to parse end timestamp %q: %w", args.EndRFC3339, err)
185 | 	}
186 | 
187 | 	start, end, err = validateTimeRange(start, end)
188 | 	if err != nil {
189 | 		return nil, err
190 | 	}
191 | 
192 | 	client, err := newPyroscopeClient(ctx, args.DataSourceUID)
193 | 	if err != nil {
194 | 		return nil, fmt.Errorf("failed to create Pyroscope client: %w", err)
195 | 	}
196 | 
197 | 	req := &querierv1.ProfileTypesRequest{
198 | 		Start: start.UnixMilli(),
199 | 		End:   end.UnixMilli(),
200 | 	}
201 | 	res, err := client.ProfileTypes(ctx, connect.NewRequest(req))
202 | 	if err != nil {
203 | 		return nil, fmt.Errorf("failed to call Pyroscope API: %w", err)
204 | 	}
205 | 
206 | 	profileTypes := make([]string, len(res.Msg.ProfileTypes))
207 | 	for i, typ := range res.Msg.ProfileTypes {
208 | 		profileTypes[i] = fmt.Sprintf("%s:%s:%s:%s:%s", typ.Name, typ.SampleType, typ.SampleUnit, typ.PeriodType, typ.PeriodUnit)
209 | 	}
210 | 	return profileTypes, nil
211 | }
212 | 
213 | const fetchPyroscopeProfileToolPrompt = `
214 | Fetches a profile from a Pyroscope data source for a given time range. By default, the time range is tha past 1 hour.
215 | The profile type is required, available profile types can be fetched via the list_pyroscope_profile_types tool. Not all
216 | profile types are available for every service. Expect some queries to return empty result sets, this indicates the
217 | profile type does not exist for that query. In such a case, consider trying a related profile type or giving up.
218 | Matchers are not required, but highly recommended, they are generally used to select an application by the service_name
219 | label (e.g. {service_name="foo"}). Use the list_pyroscope_label_names tool to fetch available label names, and the
220 | list_pyroscope_label_values tool to fetch available label values. The returned profile is in DOT format.
221 | `
222 | 
223 | var FetchPyroscopeProfile = mcpgrafana.MustTool(
224 | 	"fetch_pyroscope_profile",
225 | 	fetchPyroscopeProfileToolPrompt,
226 | 	fetchPyroscopeProfile,
227 | 	mcp.WithTitleAnnotation("Fetch Pyroscope profile"),
228 | 	mcp.WithIdempotentHintAnnotation(true),
229 | 	mcp.WithReadOnlyHintAnnotation(true),
230 | )
231 | 
232 | type FetchPyroscopeProfileParams struct {
233 | 	DataSourceUID string `json:"data_source_uid" jsonschema:"required,description=The UID of the datasource to query"`
234 | 	ProfileType   string `json:"profile_type" jsonschema:"required,description=Type profile type\\, use the list_pyroscope_profile_types tool to fetch available profile types"`
235 | 	Matchers      string `json:"matchers,omitempty" jsonschema:"description=Optionally\\, Prometheus style matchers used to filter the result set (defaults to: {})"`
236 | 	MaxNodeDepth  int    `json:"max_node_depth,omitempty" jsonschema:"description=Optionally\\, the maximum depth of nodes in the resulting profile. Less depth results in smaller profiles that execute faster\\, more depth result in larger profiles that have more detail. A value of -1 indicates to use an unbounded node depth (default: 100). Reducing max node depth from the default will negatively impact the accuracy of the profile"`
237 | 	StartRFC3339  string `json:"start_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the start time of the query in RFC3339 format (defaults to 1 hour ago)"`
238 | 	EndRFC3339    string `json:"end_rfc_3339,omitempty" jsonschema:"description=Optionally\\, the end time of the query in RFC3339 format (defaults to now)"`
239 | }
240 | 
241 | func fetchPyroscopeProfile(ctx context.Context, args FetchPyroscopeProfileParams) (string, error) {
242 | 	args.Matchers = stringOrDefault(args.Matchers, "{}")
243 | 	matchersRegex := regexp.MustCompile(`^\{.*\}$`)
244 | 	if !matchersRegex.MatchString(args.Matchers) {
245 | 		args.Matchers = fmt.Sprintf("{%s}", args.Matchers)
246 | 	}
247 | 
248 | 	args.MaxNodeDepth = intOrDefault(args.MaxNodeDepth, 100)
249 | 
250 | 	start, err := rfc3339OrDefault(args.StartRFC3339, time.Time{})
251 | 	if err != nil {
252 | 		return "", fmt.Errorf("failed to parse start timestamp %q: %w", args.StartRFC3339, err)
253 | 	}
254 | 
255 | 	end, err := rfc3339OrDefault(args.EndRFC3339, time.Time{})
256 | 	if err != nil {
257 | 		return "", fmt.Errorf("failed to parse end timestamp %q: %w", args.EndRFC3339, err)
258 | 	}
259 | 
260 | 	start, end, err = validateTimeRange(start, end)
261 | 	if err != nil {
262 | 		return "", err
263 | 	}
264 | 
265 | 	client, err := newPyroscopeClient(ctx, args.DataSourceUID)
266 | 	if err != nil {
267 | 		return "", fmt.Errorf("failed to create Pyroscope client: %w", err)
268 | 	}
269 | 
270 | 	req := &renderRequest{
271 | 		ProfileType: args.ProfileType,
272 | 		Matcher:     args.Matchers,
273 | 		Start:       start,
274 | 		End:         end,
275 | 		Format:      "dot",
276 | 		MaxNodes:    args.MaxNodeDepth,
277 | 	}
278 | 	res, err := client.Render(ctx, req)
279 | 	if err != nil {
280 | 		return "", fmt.Errorf("failed to call Pyroscope API: %w", err)
281 | 	}
282 | 
283 | 	res = cleanupDotProfile(res)
284 | 	return res, nil
285 | }
286 | 
287 | func newPyroscopeClient(ctx context.Context, uid string) (*pyroscopeClient, error) {
288 | 	cfg := mcpgrafana.GrafanaConfigFromContext(ctx)
289 | 
290 | 	var transport http.RoundTripper = NewAuthRoundTripper(http.DefaultTransport, cfg.AccessToken, cfg.IDToken, cfg.APIKey, cfg.BasicAuth)
291 | 	transport = mcpgrafana.NewOrgIDRoundTripper(transport, cfg.OrgID)
292 | 
293 | 	httpClient := &http.Client{
294 | 		Transport: mcpgrafana.NewUserAgentTransport(
295 | 			transport,
296 | 		),
297 | 		Timeout: 10 * time.Second,
298 | 	}
299 | 
300 | 	_, err := getDatasourceByUID(ctx, GetDatasourceByUIDParams{UID: uid})
301 | 	if err != nil {
302 | 		return nil, err
303 | 	}
304 | 
305 | 	base, err := url.Parse(cfg.URL)
306 | 	if err != nil {
307 | 		return nil, fmt.Errorf("failed to parse base url: %w", err)
308 | 	}
309 | 	base = base.JoinPath("api", "datasources", "proxy", "uid", uid)
310 | 
311 | 	querierClient := querierv1connect.NewQuerierServiceClient(httpClient, base.String())
312 | 
313 | 	client := &pyroscopeClient{
314 | 		QuerierServiceClient: querierClient,
315 | 		http:                 httpClient,
316 | 		base:                 base,
317 | 	}
318 | 	return client, nil
319 | }
320 | 
321 | type renderRequest struct {
322 | 	ProfileType string
323 | 	Matcher     string
324 | 	Start       time.Time
325 | 	End         time.Time
326 | 	Format      string
327 | 	MaxNodes    int
328 | }
329 | 
330 | type pyroscopeClient struct {
331 | 	querierv1connect.QuerierServiceClient
332 | 	http *http.Client
333 | 	base *url.URL
334 | }
335 | 
336 | // Calls the /render endpoint for Pyroscope. This returns a rendered flame graph
337 | // (typically in Flamebearer or DOT formats).
338 | func (c *pyroscopeClient) Render(ctx context.Context, args *renderRequest) (string, error) {
339 | 	params := url.Values{}
340 | 	params.Add("query", fmt.Sprintf("%s%s", args.ProfileType, args.Matcher))
341 | 	params.Add("from", fmt.Sprintf("%d", args.Start.UnixMilli()))
342 | 	params.Add("until", fmt.Sprintf("%d", args.End.UnixMilli()))
343 | 	params.Add("format", args.Format)
344 | 	params.Add("max-nodes", fmt.Sprintf("%d", args.MaxNodes))
345 | 
346 | 	res, err := c.get(ctx, "/pyroscope/render", params)
347 | 	if err != nil {
348 | 		return "", err
349 | 	}
350 | 
351 | 	return string(res), nil
352 | }
353 | 
354 | func (c *pyroscopeClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) {
355 | 	u := c.base.JoinPath(path)
356 | 
357 | 	q := u.Query()
358 | 	for k, vs := range params {
359 | 		for _, v := range vs {
360 | 			q.Add(k, v)
361 | 		}
362 | 	}
363 | 	u.RawQuery = q.Encode()
364 | 
365 | 	req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
366 | 	if err != nil {
367 | 		return nil, fmt.Errorf("failed to create GET request: %w", err)
368 | 	}
369 | 
370 | 	res, err := c.http.Do(req)
371 | 	if err != nil {
372 | 		return nil, fmt.Errorf("failed to send request: %w", err)
373 | 	}
374 | 	defer func() {
375 | 		_ = res.Body.Close() //nolint:errcheck
376 | 	}()
377 | 
378 | 	if res.StatusCode < 200 || res.StatusCode > 299 {
379 | 		body, err := io.ReadAll(res.Body)
380 | 		if err != nil {
381 | 			return nil, fmt.Errorf("pyroscope API failed with status code %d", res.StatusCode)
382 | 		}
383 | 		return nil, fmt.Errorf("pyroscope API failed with status code %d: %s", res.StatusCode, string(body))
384 | 	}
385 | 
386 | 	const limit = 1 << 25 // 32 MiB
387 | 	body, err := io.ReadAll(io.LimitReader(res.Body, limit))
388 | 	if err != nil {
389 | 		return nil, fmt.Errorf("failed to read response body: %w", err)
390 | 	}
391 | 
392 | 	if len(body) == 0 {
393 | 		return nil, fmt.Errorf("pyroscope API returned an empty response")
394 | 	}
395 | 
396 | 	if strings.Contains(string(body), "Showing nodes accounting for 0, 0% of 0 total") {
397 | 		return nil, fmt.Errorf("pyroscope API returned a empty profile")
398 | 	}
399 | 	return body, nil
400 | }
401 | 
402 | func intOrDefault(n int, def int) int {
403 | 	if n == 0 {
404 | 		return def
405 | 	}
406 | 	return n
407 | }
408 | 
409 | func stringOrDefault(s string, def string) string {
410 | 	if strings.TrimSpace(s) == "" {
411 | 		return def
412 | 	}
413 | 	return s
414 | }
415 | 
416 | func rfc3339OrDefault(s string, def time.Time) (time.Time, error) {
417 | 	s = strings.TrimSpace(s)
418 | 
419 | 	var err error
420 | 	if s != "" {
421 | 		def, err = time.Parse(time.RFC3339, s)
422 | 		if err != nil {
423 | 			return time.Time{}, err
424 | 		}
425 | 	}
426 | 
427 | 	return def, nil
428 | }
429 | 
430 | func validateTimeRange(start time.Time, end time.Time) (time.Time, time.Time, error) {
431 | 	if end.IsZero() {
432 | 		end = time.Now()
433 | 	}
434 | 
435 | 	if start.IsZero() {
436 | 		start = end.Add(-1 * time.Hour)
437 | 	}
438 | 
439 | 	if start.After(end) || start.Equal(end) {
440 | 		return time.Time{}, time.Time{}, fmt.Errorf("start timestamp %q must be strictly before end timestamp %q", start.Format(time.RFC3339), end.Format(time.RFC3339))
441 | 	}
442 | 
443 | 	return start, end, nil
444 | }
445 | 
446 | var cleanupRegex = regexp.MustCompile(`(?m)(fontsize=\d+ )|(id="node\d+" )|(labeltooltip=".*?\)" )|(tooltip=".*?\)" )|(N\d+ -> N\d+).*|(N\d+ \[label="other.*\n)|(shape=box )|(fillcolor="#\w{6}")|(color="#\w{6}" )`)
447 | 
448 | func cleanupDotProfile(profile string) string {
449 | 	return cleanupRegex.ReplaceAllStringFunc(profile, func(match string) string {
450 | 		// Preserve edge labels (e.g., "N1 -> N2")
451 | 		if m := regexp.MustCompile(`^N\d+ -> N\d+`).FindString(match); m != "" {
452 | 			return m
453 | 		}
454 | 		return ""
455 | 	})
456 | }
457 | 
```
Page 3/5FirstPrevNextLast