This is page 3 of 7. Use http://codebase.md/freepeak/db-mcp-server?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .cm
│ └── gitstream.cm
├── .cursor
│ ├── mcp-example.json
│ ├── mcp.json
│ └── rules
│ └── global.mdc
├── .dockerignore
├── .DS_Store
├── .env.example
├── .github
│ ├── FUNDING.yml
│ └── workflows
│ └── go.yml
├── .gitignore
├── .golangci.yml
├── assets
│ └── logo.svg
├── CHANGELOG.md
├── cmd
│ └── server
│ └── main.go
├── commit-message.txt
├── config.json
├── config.timescaledb-test.json
├── docker-compose.mcp-test.yml
├── docker-compose.test.yml
├── docker-compose.timescaledb-test.yml
├── docker-compose.yml
├── docker-wrapper.sh
├── Dockerfile
├── docs
│ ├── REFACTORING.md
│ ├── TIMESCALEDB_FUNCTIONS.md
│ ├── TIMESCALEDB_IMPLEMENTATION.md
│ ├── TIMESCALEDB_PRD.md
│ └── TIMESCALEDB_TOOLS.md
├── examples
│ └── postgres_connection.go
├── glama.json
├── go.mod
├── go.sum
├── init-scripts
│ └── timescaledb
│ ├── 01-init.sql
│ ├── 02-sample-data.sql
│ ├── 03-continuous-aggregates.sql
│ └── README.md
├── internal
│ ├── config
│ │ ├── config_test.go
│ │ └── config.go
│ ├── delivery
│ │ └── mcp
│ │ ├── compression_policy_test.go
│ │ ├── context
│ │ │ ├── hypertable_schema_test.go
│ │ │ ├── timescale_completion_test.go
│ │ │ ├── timescale_context_test.go
│ │ │ └── timescale_query_suggestion_test.go
│ │ ├── mock_test.go
│ │ ├── response_test.go
│ │ ├── response.go
│ │ ├── retention_policy_test.go
│ │ ├── server_wrapper.go
│ │ ├── timescale_completion.go
│ │ ├── timescale_context.go
│ │ ├── timescale_schema.go
│ │ ├── timescale_tool_test.go
│ │ ├── timescale_tool.go
│ │ ├── timescale_tools_test.go
│ │ ├── tool_registry.go
│ │ └── tool_types.go
│ ├── domain
│ │ └── database.go
│ ├── logger
│ │ ├── logger_test.go
│ │ └── logger.go
│ ├── repository
│ │ └── database_repository.go
│ └── usecase
│ └── database_usecase.go
├── LICENSE
├── Makefile
├── pkg
│ ├── core
│ │ ├── core.go
│ │ └── logging.go
│ ├── db
│ │ ├── db_test.go
│ │ ├── db.go
│ │ ├── manager.go
│ │ ├── README.md
│ │ └── timescale
│ │ ├── config_test.go
│ │ ├── config.go
│ │ ├── connection_test.go
│ │ ├── connection.go
│ │ ├── continuous_aggregate_test.go
│ │ ├── continuous_aggregate.go
│ │ ├── hypertable_test.go
│ │ ├── hypertable.go
│ │ ├── metadata.go
│ │ ├── mocks_test.go
│ │ ├── policy_test.go
│ │ ├── policy.go
│ │ ├── query.go
│ │ ├── timeseries_test.go
│ │ └── timeseries.go
│ ├── dbtools
│ │ ├── db_helpers.go
│ │ ├── dbtools_test.go
│ │ ├── dbtools.go
│ │ ├── exec.go
│ │ ├── performance_test.go
│ │ ├── performance.go
│ │ ├── query.go
│ │ ├── querybuilder_test.go
│ │ ├── querybuilder.go
│ │ ├── README.md
│ │ ├── schema_test.go
│ │ ├── schema.go
│ │ ├── tx_test.go
│ │ └── tx.go
│ ├── internal
│ │ └── logger
│ │ └── logger.go
│ ├── jsonrpc
│ │ └── jsonrpc.go
│ ├── logger
│ │ └── logger.go
│ └── tools
│ └── tools.go
├── README-old.md
├── README.md
├── repomix-output.txt
├── request.json
├── start-mcp.sh
├── test.Dockerfile
├── timescaledb-test.sh
└── wait-for-it.sh
```
# Files
--------------------------------------------------------------------------------
/internal/logger/logger.go:
--------------------------------------------------------------------------------
```go
1 | package logger
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "os"
7 | "path/filepath"
8 | "runtime/debug"
9 | "strings"
10 | "sync"
11 | "time"
12 |
13 | "go.uber.org/zap"
14 | "go.uber.org/zap/zapcore"
15 | )
16 |
17 | // Level represents the severity of a log message
18 | type Level int
19 |
20 | const (
21 | // LevelDebug for detailed troubleshooting
22 | LevelDebug Level = iota
23 | // LevelInfo for general operational entries
24 | LevelInfo
25 | // LevelWarn for non-critical issues
26 | LevelWarn
27 | // LevelError for errors that should be addressed
28 | LevelError
29 | )
30 |
31 | var (
32 | // Default logger
33 | zapLogger *zap.Logger
34 | logLevel Level
35 | // Flag to indicate if we're in stdio mode
36 | isStdioMode bool
37 | // Log file for stdio mode
38 | stdioLogFile *os.File
39 | // Mutex to protect log file access
40 | logMutex sync.Mutex
41 | )
42 |
43 | // safeStdioWriter is a writer that ensures no output goes to stdout in stdio mode
44 | type safeStdioWriter struct {
45 | file *os.File
46 | }
47 |
48 | // Write implements io.Writer and filters all output in stdio mode
49 | func (w *safeStdioWriter) Write(p []byte) (n int, err error) {
50 | // In stdio mode, write to the log file instead of stdout
51 | logMutex.Lock()
52 | defer logMutex.Unlock()
53 |
54 | if stdioLogFile != nil {
55 | return stdioLogFile.Write(p)
56 | }
57 |
58 | // Last resort: write to stderr, never stdout
59 | return os.Stderr.Write(p)
60 | }
61 |
62 | // Sync implements zapcore.WriteSyncer
63 | func (w *safeStdioWriter) Sync() error {
64 | logMutex.Lock()
65 | defer logMutex.Unlock()
66 |
67 | if stdioLogFile != nil {
68 | return stdioLogFile.Sync()
69 | }
70 | return nil
71 | }
72 |
73 | // Initialize sets up the logger with the specified level
74 | func Initialize(level string) {
75 | setLogLevel(level)
76 |
77 | // Check if we're in stdio mode
78 | transportMode := os.Getenv("TRANSPORT_MODE")
79 | isStdioMode = transportMode == "stdio"
80 |
81 | if isStdioMode {
82 | // In stdio mode, we need to avoid ANY JSON output to stdout
83 |
84 | // Create a log file in logs directory
85 | logsDir := "logs"
86 | if _, err := os.Stat(logsDir); os.IsNotExist(err) {
87 | if err := os.Mkdir(logsDir, 0755); err != nil {
88 | fmt.Fprintf(os.Stderr, "Failed to create logs directory: %v\n", err)
89 | }
90 | }
91 |
92 | timestamp := time.Now().Format("20060102-150405")
93 | logFileName := filepath.Join(logsDir, fmt.Sprintf("mcp-logger-%s.log", timestamp))
94 |
95 | // Try to create the log file
96 | var err error
97 | stdioLogFile, err = os.OpenFile(logFileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
98 | if err != nil {
99 | // If we can't create a log file, we'll use a null logger
100 | fmt.Fprintf(os.Stderr, "Failed to create log file: %v - all logs will be suppressed\n", err)
101 | } else {
102 | // Write initial log message to stderr only (as a last message before full redirection)
103 | fmt.Fprintf(os.Stderr, "Stdio mode detected - all logs redirected to: %s\n", logFileName)
104 |
105 | // Create a custom writer that never writes to stdout
106 | safeWriter := &safeStdioWriter{file: stdioLogFile}
107 |
108 | // Create a development encoder for more readable logs
109 | encoderConfig := zap.NewDevelopmentEncoderConfig()
110 | encoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
111 | encoder := zapcore.NewConsoleEncoder(encoderConfig)
112 |
113 | // Create core that writes to our safe writer
114 | core := zapcore.NewCore(encoder, zapcore.AddSync(safeWriter), getZapLevel(logLevel))
115 |
116 | // Create the logger with the core
117 | zapLogger = zap.New(core)
118 | return
119 | }
120 | }
121 |
122 | // Standard logger initialization for non-stdio mode or fallback
123 | config := zap.NewProductionConfig()
124 | config.EncoderConfig.TimeKey = "time"
125 | config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
126 |
127 | // In stdio mode with no log file, use a no-op logger to avoid any stdout output
128 | if isStdioMode {
129 | zapLogger = zap.NewNop()
130 | return
131 | } else {
132 | config.OutputPaths = []string{"stdout"}
133 | }
134 |
135 | config.Level = getZapLevel(logLevel)
136 |
137 | var err error
138 | zapLogger, err = config.Build()
139 | if err != nil {
140 | // If Zap logger cannot be built, fall back to noop logger
141 | zapLogger = zap.NewNop()
142 | }
143 | }
144 |
145 | // InitializeWithWriter sets up the logger with the specified level and output writer
146 | func InitializeWithWriter(level string, writer *os.File) {
147 | setLogLevel(level)
148 |
149 | config := zap.NewProductionConfig()
150 | config.EncoderConfig.TimeKey = "time"
151 | config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
152 |
153 | // Create custom core with the provided writer
154 | encoderConfig := zap.NewProductionEncoderConfig()
155 | encoderConfig.TimeKey = "time"
156 | encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
157 |
158 | core := zapcore.NewCore(
159 | zapcore.NewJSONEncoder(encoderConfig),
160 | zapcore.AddSync(writer),
161 | getZapLevel(logLevel),
162 | )
163 |
164 | zapLogger = zap.New(core)
165 | }
166 |
167 | // setLogLevel sets the log level from a string
168 | func setLogLevel(level string) {
169 | switch strings.ToLower(level) {
170 | case "debug":
171 | logLevel = LevelDebug
172 | case "info":
173 | logLevel = LevelInfo
174 | case "warn":
175 | logLevel = LevelWarn
176 | case "error":
177 | logLevel = LevelError
178 | default:
179 | logLevel = LevelInfo
180 | }
181 | }
182 |
183 | // getZapLevel converts our level to zap.AtomicLevel
184 | func getZapLevel(level Level) zap.AtomicLevel {
185 | switch level {
186 | case LevelDebug:
187 | return zap.NewAtomicLevelAt(zapcore.DebugLevel)
188 | case LevelInfo:
189 | return zap.NewAtomicLevelAt(zapcore.InfoLevel)
190 | case LevelWarn:
191 | return zap.NewAtomicLevelAt(zapcore.WarnLevel)
192 | case LevelError:
193 | return zap.NewAtomicLevelAt(zapcore.ErrorLevel)
194 | default:
195 | return zap.NewAtomicLevelAt(zapcore.InfoLevel)
196 | }
197 | }
198 |
199 | // Debug logs a debug message
200 | func Debug(format string, v ...interface{}) {
201 | if logLevel > LevelDebug {
202 | return
203 | }
204 | msg := fmt.Sprintf(format, v...)
205 | zapLogger.Debug(msg)
206 | }
207 |
208 | // Info logs an info message
209 | func Info(format string, v ...interface{}) {
210 | if logLevel > LevelInfo {
211 | return
212 | }
213 | msg := fmt.Sprintf(format, v...)
214 | zapLogger.Info(msg)
215 | }
216 |
217 | // Warn logs a warning message
218 | func Warn(format string, v ...interface{}) {
219 | if logLevel > LevelWarn {
220 | return
221 | }
222 | msg := fmt.Sprintf(format, v...)
223 | zapLogger.Warn(msg)
224 | }
225 |
226 | // Error logs an error message
227 | func Error(format string, v ...interface{}) {
228 | if logLevel > LevelError {
229 | return
230 | }
231 | msg := fmt.Sprintf(format, v...)
232 | zapLogger.Error(msg)
233 | }
234 |
235 | // ErrorWithStack logs an error with a stack trace
236 | func ErrorWithStack(err error) {
237 | if err == nil {
238 | return
239 | }
240 | zapLogger.Error(
241 | err.Error(),
242 | zap.String("stack", string(debug.Stack())),
243 | )
244 | }
245 |
246 | // RequestLog logs details of an HTTP request
247 | func RequestLog(method, url, sessionID, body string) {
248 | if logLevel > LevelDebug {
249 | return
250 | }
251 | zapLogger.Debug("HTTP Request",
252 | zap.String("method", method),
253 | zap.String("url", url),
254 | zap.String("sessionID", sessionID),
255 | zap.String("body", body),
256 | )
257 | }
258 |
259 | // ResponseLog logs details of an HTTP response
260 | func ResponseLog(statusCode int, sessionID, body string) {
261 | if logLevel > LevelDebug {
262 | return
263 | }
264 | zapLogger.Debug("HTTP Response",
265 | zap.Int("statusCode", statusCode),
266 | zap.String("sessionID", sessionID),
267 | zap.String("body", body),
268 | )
269 | }
270 |
271 | // SSEEventLog logs details of an SSE event
272 | func SSEEventLog(eventType, sessionID, data string) {
273 | if logLevel > LevelDebug {
274 | return
275 | }
276 | zapLogger.Debug("SSE Event",
277 | zap.String("eventType", eventType),
278 | zap.String("sessionID", sessionID),
279 | zap.String("data", data),
280 | )
281 | }
282 |
283 | // RequestResponseLog logs a combined request and response log entry
284 | func RequestResponseLog(method, sessionID string, requestData, responseData string) {
285 | if logLevel > LevelDebug {
286 | return
287 | }
288 |
289 | // Format for more readable logs
290 | formattedRequest := requestData
291 | formattedResponse := responseData
292 |
293 | // Try to format JSON if it's valid
294 | if strings.HasPrefix(requestData, "{") || strings.HasPrefix(requestData, "[") {
295 | var obj interface{}
296 | if err := json.Unmarshal([]byte(requestData), &obj); err == nil {
297 | if formatted, err := json.MarshalIndent(obj, "", " "); err == nil {
298 | formattedRequest = string(formatted)
299 | }
300 | }
301 | }
302 |
303 | if strings.HasPrefix(responseData, "{") || strings.HasPrefix(responseData, "[") {
304 | var obj interface{}
305 | if err := json.Unmarshal([]byte(responseData), &obj); err == nil {
306 | if formatted, err := json.MarshalIndent(obj, "", " "); err == nil {
307 | formattedResponse = string(formatted)
308 | }
309 | }
310 | }
311 |
312 | zapLogger.Debug("Request/Response",
313 | zap.String("method", method),
314 | zap.String("sessionID", sessionID),
315 | zap.String("request", formattedRequest),
316 | zap.String("response", formattedResponse),
317 | )
318 | }
319 |
```
--------------------------------------------------------------------------------
/pkg/db/timescale/connection_test.go:
--------------------------------------------------------------------------------
```go
1 | package timescale
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "errors"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/assert"
10 |
11 | "github.com/FreePeak/db-mcp-server/pkg/db"
12 | )
13 |
14 | func TestNewTimescaleDB(t *testing.T) {
15 | // Create a config with test values
16 | pgConfig := db.Config{
17 | Type: "postgres",
18 | Host: "localhost",
19 | Port: 5432,
20 | User: "postgres",
21 | Password: "password",
22 | Name: "testdb",
23 | }
24 | config := DBConfig{
25 | PostgresConfig: pgConfig,
26 | UseTimescaleDB: true,
27 | }
28 |
29 | // Create a new DB instance
30 | tsdb, err := NewTimescaleDB(config)
31 | assert.NoError(t, err)
32 | assert.NotNil(t, tsdb)
33 | assert.Equal(t, pgConfig, tsdb.config.PostgresConfig)
34 | }
35 |
36 | func TestConnect(t *testing.T) {
37 | mockDB := NewMockDB()
38 | tsdb := &DB{
39 | Database: mockDB,
40 | config: DBConfig{UseTimescaleDB: true},
41 | isTimescaleDB: false,
42 | }
43 |
44 | // Mock the QueryRow method to simulate a successful TimescaleDB detection
45 | mockDB.RegisterQueryResult("SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'", "2.8.0", nil)
46 |
47 | // Connect to the database
48 | err := tsdb.Connect()
49 | if err != nil {
50 | t.Fatalf("Failed to connect: %v", err)
51 | }
52 |
53 | // Check that the TimescaleDB extension was detected
54 | if !tsdb.isTimescaleDB {
55 | t.Error("Expected isTimescaleDB to be true, got false")
56 | }
57 | if tsdb.extVersion != "2.8.0" {
58 | t.Errorf("Expected extVersion to be '2.8.0', got '%s'", tsdb.extVersion)
59 | }
60 |
61 | // Test error case when database connection fails
62 | mockDB = NewMockDB()
63 | mockDB.SetConnectError(errors.New("mocked connection error"))
64 | tsdb = &DB{
65 | Database: mockDB,
66 | config: DBConfig{UseTimescaleDB: true},
67 | isTimescaleDB: false,
68 | }
69 |
70 | err = tsdb.Connect()
71 | if err == nil {
72 | t.Error("Expected connection error, got nil")
73 | }
74 |
75 | // Test case when TimescaleDB extension is not available
76 | mockDB = NewMockDB()
77 | mockDB.RegisterQueryResult("SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'", nil, sql.ErrNoRows)
78 | tsdb = &DB{
79 | Database: mockDB,
80 | config: DBConfig{UseTimescaleDB: true},
81 | isTimescaleDB: false,
82 | }
83 |
84 | err = tsdb.Connect()
85 | if err != nil {
86 | t.Fatalf("Failed to connect: %v", err)
87 | }
88 |
89 | // Check that TimescaleDB features are disabled
90 | if tsdb.isTimescaleDB {
91 | t.Error("Expected isTimescaleDB to be false, got true")
92 | }
93 |
94 | // Test case when TimescaleDB check fails with an unexpected error
95 | mockDB = NewMockDB()
96 | mockDB.RegisterQueryResult("SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'", nil, errors.New("mocked query error"))
97 | tsdb = &DB{
98 | Database: mockDB,
99 | config: DBConfig{UseTimescaleDB: true},
100 | isTimescaleDB: false,
101 | }
102 |
103 | err = tsdb.Connect()
104 | if err == nil {
105 | t.Error("Expected query error, got nil")
106 | }
107 | }
108 |
109 | func TestClose(t *testing.T) {
110 | mockDB := NewMockDB()
111 | tsdb := &DB{
112 | Database: mockDB,
113 | }
114 |
115 | // Close should delegate to the underlying database
116 | err := tsdb.Close()
117 | if err != nil {
118 | t.Fatalf("Failed to close: %v", err)
119 | }
120 |
121 | // Test error case
122 | mockDB = NewMockDB()
123 | mockDB.SetCloseError(errors.New("mocked close error"))
124 | tsdb = &DB{
125 | Database: mockDB,
126 | }
127 |
128 | err = tsdb.Close()
129 | if err == nil {
130 | t.Error("Expected close error, got nil")
131 | }
132 | }
133 |
134 | func TestExtVersion(t *testing.T) {
135 | tsdb := &DB{
136 | extVersion: "2.8.0",
137 | }
138 |
139 | if tsdb.ExtVersion() != "2.8.0" {
140 | t.Errorf("Expected ExtVersion() to return '2.8.0', got '%s'", tsdb.ExtVersion())
141 | }
142 | }
143 |
144 | func TestTimescaleDBInstance(t *testing.T) {
145 | tsdb := &DB{
146 | isTimescaleDB: true,
147 | }
148 |
149 | if !tsdb.IsTimescaleDB() {
150 | t.Error("Expected IsTimescaleDB() to return true, got false")
151 | }
152 |
153 | tsdb.isTimescaleDB = false
154 | if tsdb.IsTimescaleDB() {
155 | t.Error("Expected IsTimescaleDB() to return false, got true")
156 | }
157 | }
158 |
159 | func TestApplyConfig(t *testing.T) {
160 | // Test when TimescaleDB is not available
161 | tsdb := &DB{
162 | isTimescaleDB: false,
163 | }
164 |
165 | err := tsdb.ApplyConfig()
166 | if err == nil {
167 | t.Error("Expected error when TimescaleDB is not available, got nil")
168 | }
169 |
170 | // Test when TimescaleDB is available
171 | tsdb = &DB{
172 | isTimescaleDB: true,
173 | }
174 |
175 | err = tsdb.ApplyConfig()
176 | if err != nil {
177 | t.Errorf("Expected no error, got %v", err)
178 | }
179 | }
180 |
181 | func TestExecuteSQLWithoutParams(t *testing.T) {
182 | mockDB := NewMockDB()
183 | tsdb := &DB{
184 | Database: mockDB,
185 | }
186 |
187 | ctx := context.Background()
188 |
189 | // Test SELECT query
190 | mockResult := []map[string]interface{}{
191 | {"id": 1, "name": "Test"},
192 | }
193 | mockDB.RegisterQueryResult("SELECT * FROM test", mockResult, nil)
194 |
195 | result, err := tsdb.ExecuteSQLWithoutParams(ctx, "SELECT * FROM test")
196 | if err != nil {
197 | t.Fatalf("Failed to execute query: %v", err)
198 | }
199 |
200 | // Verify the result is not nil
201 | if result == nil {
202 | t.Error("Expected non-nil result")
203 | }
204 |
205 | // Test non-SELECT query (e.g., INSERT)
206 | insertResult, err := tsdb.ExecuteSQLWithoutParams(ctx, "INSERT INTO test (id, name) VALUES (1, 'Test')")
207 | if err != nil {
208 | t.Fatalf("Failed to execute statement: %v", err)
209 | }
210 |
211 | // Since the mock doesn't do much, just verify it's a MockResult
212 | _, ok := insertResult.(*MockResult)
213 | if !ok {
214 | t.Error("Expected result to be a MockResult")
215 | }
216 |
217 | // Test query error
218 | mockDB.RegisterQueryResult("SELECT * FROM error_table", nil, errors.New("mocked query error"))
219 | _, err = tsdb.ExecuteSQLWithoutParams(ctx, "SELECT * FROM error_table")
220 | if err == nil {
221 | t.Error("Expected query error, got nil")
222 | }
223 | }
224 |
225 | func TestExecuteSQL(t *testing.T) {
226 | mockDB := NewMockDB()
227 | tsdb := &DB{
228 | Database: mockDB,
229 | }
230 |
231 | ctx := context.Background()
232 |
233 | // Test SELECT query with parameters
234 | mockResult := []map[string]interface{}{
235 | {"id": 1, "name": "Test"},
236 | }
237 | mockDB.RegisterQueryResult("SELECT * FROM test WHERE id = $1", mockResult, nil)
238 |
239 | result, err := tsdb.ExecuteSQL(ctx, "SELECT * FROM test WHERE id = $1", 1)
240 | if err != nil {
241 | t.Fatalf("Failed to execute query: %v", err)
242 | }
243 |
244 | // Verify the result is not nil
245 | if result == nil {
246 | t.Error("Expected non-nil result")
247 | }
248 |
249 | // Test non-SELECT query with parameters (e.g., INSERT)
250 | insertResult, err := tsdb.ExecuteSQL(ctx, "INSERT INTO test (id, name) VALUES ($1, $2)", 1, "Test")
251 | if err != nil {
252 | t.Fatalf("Failed to execute statement: %v", err)
253 | }
254 |
255 | // Since the mock doesn't do much, just verify it's not nil
256 | if insertResult == nil {
257 | t.Error("Expected non-nil result for INSERT")
258 | }
259 |
260 | // Test query error
261 | mockDB.RegisterQueryResult("SELECT * FROM error_table WHERE id = $1", nil, errors.New("mocked query error"))
262 | _, err = tsdb.ExecuteSQL(ctx, "SELECT * FROM error_table WHERE id = $1", 1)
263 | if err == nil {
264 | t.Error("Expected query error, got nil")
265 | }
266 | }
267 |
268 | func TestIsSelectQuery(t *testing.T) {
269 | testCases := []struct {
270 | query string
271 | expected bool
272 | }{
273 | {"SELECT * FROM test", true},
274 | {"select * from test", true},
275 | {" SELECT * FROM test", true},
276 | {"\tSELECT * FROM test", true},
277 | {"\nSELECT * FROM test", true},
278 | {"INSERT INTO test VALUES (1)", false},
279 | {"UPDATE test SET name = 'Test'", false},
280 | {"DELETE FROM test", false},
281 | {"CREATE TABLE test (id INT)", false},
282 | {"", false},
283 | }
284 |
285 | for _, tc := range testCases {
286 | result := isSelectQuery(tc.query)
287 | if result != tc.expected {
288 | t.Errorf("isSelectQuery(%q) = %v, expected %v", tc.query, result, tc.expected)
289 | }
290 | }
291 | }
292 |
293 | func TestTimescaleDB_Connect(t *testing.T) {
294 | mockDB := NewMockDB()
295 | tsdb := &DB{
296 | Database: mockDB,
297 | config: DBConfig{UseTimescaleDB: true},
298 | isTimescaleDB: false,
299 | }
300 |
301 | // Mock the QueryRow method to simulate a successful TimescaleDB detection
302 | mockDB.RegisterQueryResult("SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'", "2.8.0", nil)
303 |
304 | // Connect to the database
305 | err := tsdb.Connect()
306 | if err != nil {
307 | t.Fatalf("Failed to connect: %v", err)
308 | }
309 |
310 | // Check that the TimescaleDB extension was detected
311 | if !tsdb.isTimescaleDB {
312 | t.Error("Expected isTimescaleDB to be true, got false")
313 | }
314 | if tsdb.extVersion != "2.8.0" {
315 | t.Errorf("Expected extVersion to be '2.8.0', got '%s'", tsdb.extVersion)
316 | }
317 | }
318 |
319 | func TestTimescaleDB_ConnectNoExtension(t *testing.T) {
320 | mockDB := NewMockDB()
321 | tsdb := &DB{
322 | Database: mockDB,
323 | config: DBConfig{UseTimescaleDB: true},
324 | isTimescaleDB: false,
325 | }
326 |
327 | // Mock the QueryRow method to simulate no TimescaleDB extension
328 | mockDB.RegisterQueryResult("SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'", nil, sql.ErrNoRows)
329 |
330 | // Connect to the database
331 | err := tsdb.Connect()
332 | if err != nil {
333 | t.Fatalf("Failed to connect: %v", err)
334 | }
335 |
336 | // Check that TimescaleDB features are disabled
337 | if tsdb.isTimescaleDB {
338 | t.Error("Expected isTimescaleDB to be false, got true")
339 | }
340 | }
341 |
```
--------------------------------------------------------------------------------
/internal/delivery/mcp/context/timescale_completion_test.go:
--------------------------------------------------------------------------------
```go
1 | package context_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 | "github.com/stretchr/testify/mock"
9 |
10 | "github.com/FreePeak/db-mcp-server/internal/delivery/mcp"
11 | )
12 |
13 | func TestTimescaleDBCompletionProvider(t *testing.T) {
14 | // Create a mock use case provider
15 | mockUseCase := new(MockDatabaseUseCase)
16 |
17 | // Create a context for testing
18 | ctx := context.Background()
19 |
20 | t.Run("get_time_bucket_completions", func(t *testing.T) {
21 | // Set up expectations for the mock
22 | mockUseCase.On("GetDatabaseType", "timescale_db").Return("postgres", nil).Once()
23 | mockUseCase.On("ExecuteStatement", mock.Anything, "timescale_db", mock.MatchedBy(func(sql string) bool {
24 | return sql == "SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'"
25 | }), mock.Anything).Return(`[{"extversion": "2.8.0"}]`, nil).Once()
26 |
27 | // Create the completion provider
28 | provider := mcp.NewTimescaleDBCompletionProvider()
29 |
30 | // Call the method to get time bucket function completions
31 | completions, err := provider.GetTimeBucketCompletions(ctx, "timescale_db", mockUseCase)
32 |
33 | // Verify the result
34 | assert.NoError(t, err)
35 | assert.NotNil(t, completions)
36 | assert.NotEmpty(t, completions)
37 |
38 | // Check for essential time_bucket functions
39 | var foundBasicTimeBucket, foundGapfill, foundTzTimeBucket bool
40 | for _, completion := range completions {
41 | if completion.Name == "time_bucket" && completion.Type == "function" {
42 | foundBasicTimeBucket = true
43 | assert.Contains(t, completion.Documentation, "buckets")
44 | assert.Contains(t, completion.InsertText, "time_bucket")
45 | }
46 | if completion.Name == "time_bucket_gapfill" && completion.Type == "function" {
47 | foundGapfill = true
48 | assert.Contains(t, completion.Documentation, "gap")
49 | }
50 | if completion.Name == "time_bucket_ng" && completion.Type == "function" {
51 | foundTzTimeBucket = true
52 | assert.Contains(t, completion.Documentation, "timezone")
53 | }
54 | }
55 |
56 | assert.True(t, foundBasicTimeBucket, "time_bucket function completion not found")
57 | assert.True(t, foundGapfill, "time_bucket_gapfill function completion not found")
58 | assert.True(t, foundTzTimeBucket, "time_bucket_ng function completion not found")
59 |
60 | // Verify the mock expectations
61 | mockUseCase.AssertExpectations(t)
62 | })
63 |
64 | t.Run("get_hypertable_function_completions", func(t *testing.T) {
65 | // Set up expectations for the mock
66 | mockUseCase.On("GetDatabaseType", "timescale_db").Return("postgres", nil).Once()
67 | mockUseCase.On("ExecuteStatement", mock.Anything, "timescale_db", mock.MatchedBy(func(sql string) bool {
68 | return sql == "SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'"
69 | }), mock.Anything).Return(`[{"extversion": "2.8.0"}]`, nil).Once()
70 |
71 | // Create the completion provider
72 | provider := mcp.NewTimescaleDBCompletionProvider()
73 |
74 | // Call the method to get hypertable function completions
75 | completions, err := provider.GetHypertableFunctionCompletions(ctx, "timescale_db", mockUseCase)
76 |
77 | // Verify the result
78 | assert.NoError(t, err)
79 | assert.NotNil(t, completions)
80 | assert.NotEmpty(t, completions)
81 |
82 | // Check for essential hypertable functions
83 | var foundCreate, foundCompression, foundRetention bool
84 | for _, completion := range completions {
85 | if completion.Name == "create_hypertable" && completion.Type == "function" {
86 | foundCreate = true
87 | assert.Contains(t, completion.Documentation, "hypertable")
88 | assert.Contains(t, completion.InsertText, "create_hypertable")
89 | }
90 | if completion.Name == "add_compression_policy" && completion.Type == "function" {
91 | foundCompression = true
92 | assert.Contains(t, completion.Documentation, "compression")
93 | }
94 | if completion.Name == "add_retention_policy" && completion.Type == "function" {
95 | foundRetention = true
96 | assert.Contains(t, completion.Documentation, "retention")
97 | }
98 | }
99 |
100 | assert.True(t, foundCreate, "create_hypertable function completion not found")
101 | assert.True(t, foundCompression, "add_compression_policy function completion not found")
102 | assert.True(t, foundRetention, "add_retention_policy function completion not found")
103 |
104 | // Verify the mock expectations
105 | mockUseCase.AssertExpectations(t)
106 | })
107 |
108 | t.Run("get_all_function_completions", func(t *testing.T) {
109 | // Create a separate mock for this test to avoid issues with expectations
110 | localMock := new(MockDatabaseUseCase)
111 |
112 | // The new implementation makes fewer calls to GetDatabaseType
113 | localMock.On("GetDatabaseType", "timescale_db").Return("postgres", nil).Once()
114 |
115 | // It also calls ExecuteStatement once through DetectTimescaleDB
116 | localMock.On("ExecuteStatement", mock.Anything, "timescale_db", mock.MatchedBy(func(sql string) bool {
117 | return sql == "SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'"
118 | }), mock.Anything).Return(`[{"extversion": "2.8.0"}]`, nil).Once()
119 |
120 | // Create the completion provider
121 | provider := mcp.NewTimescaleDBCompletionProvider()
122 |
123 | // Call the method to get all function completions
124 | completions, err := provider.GetAllFunctionCompletions(ctx, "timescale_db", localMock)
125 |
126 | // Verify the result
127 | assert.NoError(t, err)
128 | assert.NotNil(t, completions)
129 | assert.NotEmpty(t, completions)
130 |
131 | // Check for categories of functions
132 | var foundTimeBucket, foundHypertable, foundContinuousAggregates, foundAnalytics bool
133 | for _, completion := range completions {
134 | if completion.Name == "time_bucket" && completion.Type == "function" {
135 | foundTimeBucket = true
136 | }
137 | if completion.Name == "create_hypertable" && completion.Type == "function" {
138 | foundHypertable = true
139 | }
140 | if completion.Name == "create_materialized_view" && completion.Type == "function" {
141 | foundContinuousAggregates = true
142 | // Special case - materialized view does not include parentheses
143 | assert.Contains(t, completion.InsertText, "CREATE MATERIALIZED VIEW")
144 | }
145 | if completion.Name == "first" || completion.Name == "last" || completion.Name == "time_weight" {
146 | foundAnalytics = true
147 | }
148 | }
149 |
150 | assert.True(t, foundTimeBucket, "time_bucket function completion not found")
151 | assert.True(t, foundHypertable, "hypertable function completion not found")
152 | assert.True(t, foundContinuousAggregates, "continuous aggregate function completion not found")
153 | assert.True(t, foundAnalytics, "analytics function completion not found")
154 |
155 | // Check that returned completions have properly formatted insert text
156 | for _, completion := range completions {
157 | if completion.Type == "function" && completion.Name != "create_materialized_view" {
158 | assert.Contains(t, completion.InsertText, completion.Name+"(")
159 | assert.Contains(t, completion.Documentation, "TimescaleDB")
160 | }
161 | }
162 |
163 | // Verify the mock expectations
164 | localMock.AssertExpectations(t)
165 | })
166 |
167 | t.Run("get_function_completions_with_non_timescaledb", func(t *testing.T) {
168 | // Create a separate mock for this test to avoid issues with expectations
169 | localMock := new(MockDatabaseUseCase)
170 |
171 | // With the new implementation, we only need one GetDatabaseType call
172 | localMock.On("GetDatabaseType", "postgres_db").Return("postgres", nil).Once()
173 |
174 | // It also calls ExecuteStatement through DetectTimescaleDB
175 | localMock.On("ExecuteStatement", mock.Anything, "postgres_db", mock.MatchedBy(func(sql string) bool {
176 | return sql == "SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'"
177 | }), mock.Anything).Return(`[]`, nil).Once()
178 |
179 | // Create the completion provider
180 | provider := mcp.NewTimescaleDBCompletionProvider()
181 |
182 | // Call the method to get function completions
183 | completions, err := provider.GetAllFunctionCompletions(ctx, "postgres_db", localMock)
184 |
185 | // Verify the result
186 | assert.Error(t, err)
187 | assert.Nil(t, completions)
188 | assert.Contains(t, err.Error(), "TimescaleDB is not available")
189 |
190 | // Verify the mock expectations
191 | localMock.AssertExpectations(t)
192 | })
193 |
194 | t.Run("get_function_completions_with_non_postgres", func(t *testing.T) {
195 | // Create a separate mock for this test
196 | localMock := new(MockDatabaseUseCase)
197 |
198 | // Set up expectations for the mock
199 | localMock.On("GetDatabaseType", "mysql_db").Return("mysql", nil).Once()
200 |
201 | // Create the completion provider
202 | provider := mcp.NewTimescaleDBCompletionProvider()
203 |
204 | // Call the method to get function completions
205 | completions, err := provider.GetAllFunctionCompletions(ctx, "mysql_db", localMock)
206 |
207 | // Verify the result
208 | assert.Error(t, err)
209 | assert.Nil(t, completions)
210 | // The error message is now "not available" instead of "not a PostgreSQL database"
211 | assert.Contains(t, err.Error(), "not available")
212 |
213 | // Verify the mock expectations
214 | localMock.AssertExpectations(t)
215 | })
216 | }
217 |
```
--------------------------------------------------------------------------------
/internal/delivery/mcp/compression_policy_test.go:
--------------------------------------------------------------------------------
```go
1 | package mcp
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/FreePeak/cortex/pkg/server"
8 | "github.com/stretchr/testify/assert"
9 | "github.com/stretchr/testify/mock"
10 | )
11 |
12 | func TestHandleEnableCompression(t *testing.T) {
13 | // Create a mock use case
14 | mockUseCase := new(MockDatabaseUseCase)
15 |
16 | // Set up expectations
17 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
18 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`{"message":"Compression enabled"}`, nil)
19 |
20 | // Create the tool
21 | tool := NewTimescaleDBTool()
22 |
23 | // Create a request
24 | request := server.ToolCallRequest{
25 | Parameters: map[string]interface{}{
26 | "operation": "enable_compression",
27 | "target_table": "test_table",
28 | },
29 | }
30 |
31 | // Call the handler
32 | result, err := tool.HandleRequest(context.Background(), request, "test_db", mockUseCase)
33 |
34 | // Assertions
35 | assert.NoError(t, err)
36 | assert.NotNil(t, result)
37 |
38 | // Check the result
39 | resultMap, ok := result.(map[string]interface{})
40 | assert.True(t, ok)
41 | assert.Contains(t, resultMap, "message")
42 |
43 | // Verify mock expectations
44 | mockUseCase.AssertExpectations(t)
45 | }
46 |
47 | func TestHandleEnableCompressionWithInterval(t *testing.T) {
48 | // Create a mock use case
49 | mockUseCase := new(MockDatabaseUseCase)
50 |
51 | // Set up expectations
52 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
53 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`{"message":"Compression enabled"}`, nil).Twice()
54 |
55 | // Create the tool
56 | tool := NewTimescaleDBTool()
57 |
58 | // Create a request
59 | request := server.ToolCallRequest{
60 | Parameters: map[string]interface{}{
61 | "operation": "enable_compression",
62 | "target_table": "test_table",
63 | "after": "7 days",
64 | },
65 | }
66 |
67 | // Call the handler
68 | result, err := tool.HandleRequest(context.Background(), request, "test_db", mockUseCase)
69 |
70 | // Assertions
71 | assert.NoError(t, err)
72 | assert.NotNil(t, result)
73 |
74 | // Check the result
75 | resultMap, ok := result.(map[string]interface{})
76 | assert.True(t, ok)
77 | assert.Contains(t, resultMap, "message")
78 |
79 | // Verify mock expectations
80 | mockUseCase.AssertExpectations(t)
81 | }
82 |
83 | func TestHandleDisableCompression(t *testing.T) {
84 | // Create a mock use case
85 | mockUseCase := new(MockDatabaseUseCase)
86 |
87 | // Set up expectations
88 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
89 |
90 | // First should try to remove any policy
91 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`[{"job_id": 123}]`, nil).Once()
92 |
93 | // Then remove the policy and disable compression
94 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`{"message":"Policy removed"}`, nil).Once()
95 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`{"message":"Compression disabled"}`, nil).Once()
96 |
97 | // Create the tool
98 | tool := NewTimescaleDBTool()
99 |
100 | // Create a request
101 | request := server.ToolCallRequest{
102 | Parameters: map[string]interface{}{
103 | "operation": "disable_compression",
104 | "target_table": "test_table",
105 | },
106 | }
107 |
108 | // Call the handler
109 | result, err := tool.HandleRequest(context.Background(), request, "test_db", mockUseCase)
110 |
111 | // Assertions
112 | assert.NoError(t, err)
113 | assert.NotNil(t, result)
114 |
115 | // Check the result
116 | resultMap, ok := result.(map[string]interface{})
117 | assert.True(t, ok)
118 | assert.Contains(t, resultMap, "message")
119 |
120 | // Verify mock expectations
121 | mockUseCase.AssertExpectations(t)
122 | }
123 |
124 | func TestHandleAddCompressionPolicy(t *testing.T) {
125 | // Create a mock use case
126 | mockUseCase := new(MockDatabaseUseCase)
127 |
128 | // Set up expectations
129 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
130 |
131 | // Check compression status
132 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`[{"compress": true}]`, nil).Once()
133 |
134 | // Add compression policy
135 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`{"message":"Compression policy added"}`, nil).Once()
136 |
137 | // Create the tool
138 | tool := NewTimescaleDBTool()
139 |
140 | // Create a request
141 | request := server.ToolCallRequest{
142 | Parameters: map[string]interface{}{
143 | "operation": "add_compression_policy",
144 | "target_table": "test_table",
145 | "interval": "30 days",
146 | },
147 | }
148 |
149 | // Call the handler
150 | result, err := tool.HandleRequest(context.Background(), request, "test_db", mockUseCase)
151 |
152 | // Assertions
153 | assert.NoError(t, err)
154 | assert.NotNil(t, result)
155 |
156 | // Check the result
157 | resultMap, ok := result.(map[string]interface{})
158 | assert.True(t, ok)
159 | assert.Contains(t, resultMap, "message")
160 |
161 | // Verify mock expectations
162 | mockUseCase.AssertExpectations(t)
163 | }
164 |
165 | func TestHandleAddCompressionPolicyWithOptions(t *testing.T) {
166 | // Create a mock use case
167 | mockUseCase := new(MockDatabaseUseCase)
168 |
169 | // Set up expectations
170 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
171 |
172 | // Check compression status
173 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`[{"compress": true}]`, nil).Once()
174 |
175 | // Add compression policy with options
176 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`{"message":"Compression policy added"}`, nil).Once()
177 |
178 | // Create the tool
179 | tool := NewTimescaleDBTool()
180 |
181 | // Create a request
182 | request := server.ToolCallRequest{
183 | Parameters: map[string]interface{}{
184 | "operation": "add_compression_policy",
185 | "target_table": "test_table",
186 | "interval": "30 days",
187 | "segment_by": "device_id",
188 | "order_by": "time DESC",
189 | },
190 | }
191 |
192 | // Call the handler
193 | result, err := tool.HandleRequest(context.Background(), request, "test_db", mockUseCase)
194 |
195 | // Assertions
196 | assert.NoError(t, err)
197 | assert.NotNil(t, result)
198 |
199 | // Check the result
200 | resultMap, ok := result.(map[string]interface{})
201 | assert.True(t, ok)
202 | assert.Contains(t, resultMap, "message")
203 |
204 | // Verify mock expectations
205 | mockUseCase.AssertExpectations(t)
206 | }
207 |
208 | func TestHandleRemoveCompressionPolicy(t *testing.T) {
209 | // Create a mock use case
210 | mockUseCase := new(MockDatabaseUseCase)
211 |
212 | // Set up expectations
213 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
214 |
215 | // Find policy ID
216 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`[{"job_id": 123}]`, nil).Once()
217 |
218 | // Remove policy
219 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`{"message":"Policy removed"}`, nil).Once()
220 |
221 | // Create the tool
222 | tool := NewTimescaleDBTool()
223 |
224 | // Create a request
225 | request := server.ToolCallRequest{
226 | Parameters: map[string]interface{}{
227 | "operation": "remove_compression_policy",
228 | "target_table": "test_table",
229 | },
230 | }
231 |
232 | // Call the handler
233 | result, err := tool.HandleRequest(context.Background(), request, "test_db", mockUseCase)
234 |
235 | // Assertions
236 | assert.NoError(t, err)
237 | assert.NotNil(t, result)
238 |
239 | // Check the result
240 | resultMap, ok := result.(map[string]interface{})
241 | assert.True(t, ok)
242 | assert.Contains(t, resultMap, "message")
243 |
244 | // Verify mock expectations
245 | mockUseCase.AssertExpectations(t)
246 | }
247 |
248 | func TestHandleGetCompressionSettings(t *testing.T) {
249 | // Create a mock use case
250 | mockUseCase := new(MockDatabaseUseCase)
251 |
252 | // Set up expectations
253 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
254 |
255 | // Check compression enabled
256 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`[{"compress": true}]`, nil).Once()
257 |
258 | // Get compression settings
259 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`[{"segmentby": "device_id", "orderby": "time DESC"}]`, nil).Once()
260 |
261 | // Get policy info
262 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.Anything, mock.Anything).Return(`[{"schedule_interval": "30 days", "chunk_time_interval": "1 day"}]`, nil).Once()
263 |
264 | // Create the tool
265 | tool := NewTimescaleDBTool()
266 |
267 | // Create a request
268 | request := server.ToolCallRequest{
269 | Parameters: map[string]interface{}{
270 | "operation": "get_compression_settings",
271 | "target_table": "test_table",
272 | },
273 | }
274 |
275 | // Call the handler
276 | result, err := tool.HandleRequest(context.Background(), request, "test_db", mockUseCase)
277 |
278 | // Assertions
279 | assert.NoError(t, err)
280 | assert.NotNil(t, result)
281 |
282 | // Check the result
283 | resultMap, ok := result.(map[string]interface{})
284 | assert.True(t, ok)
285 | assert.Contains(t, resultMap, "message")
286 | assert.Contains(t, resultMap, "settings")
287 |
288 | // Verify mock expectations
289 | mockUseCase.AssertExpectations(t)
290 | }
291 |
```
--------------------------------------------------------------------------------
/internal/delivery/mcp/tool_registry.go:
--------------------------------------------------------------------------------
```go
1 | package mcp
2 |
3 | // TODO: Refactor tool registration to reduce code duplication
4 | // TODO: Implement better error handling with error types instead of generic errors
5 | // TODO: Add metrics collection for tool usage and performance
6 | // TODO: Improve logging with structured logs and log levels
7 | // TODO: Consider implementing tool discovery mechanism to avoid hardcoded tool lists
8 |
9 | import (
10 | "context"
11 | "fmt"
12 |
13 | "github.com/FreePeak/cortex/pkg/server"
14 |
15 | "github.com/FreePeak/db-mcp-server/internal/logger"
16 | )
17 |
18 | // ToolRegistry structure to handle tool registration
19 | type ToolRegistry struct {
20 | server *ServerWrapper
21 | mcpServer *server.MCPServer
22 | databaseUseCase UseCaseProvider
23 | factory *ToolTypeFactory
24 | }
25 |
26 | // NewToolRegistry creates a new tool registry
27 | func NewToolRegistry(mcpServer *server.MCPServer) *ToolRegistry {
28 | factory := NewToolTypeFactory()
29 | return &ToolRegistry{
30 | server: NewServerWrapper(mcpServer),
31 | mcpServer: mcpServer,
32 | factory: factory,
33 | }
34 | }
35 |
36 | // RegisterAllTools registers all tools with the server
37 | func (tr *ToolRegistry) RegisterAllTools(ctx context.Context, useCase UseCaseProvider) error {
38 | tr.databaseUseCase = useCase
39 |
40 | // Get available databases
41 | dbList := useCase.ListDatabases()
42 | logger.Info("Found %d database connections for tool registration: %v", len(dbList), dbList)
43 |
44 | if len(dbList) == 0 {
45 | logger.Info("No databases available, registering mock tools")
46 | return tr.RegisterMockTools(ctx)
47 | }
48 |
49 | // Register database-specific tools
50 | registrationErrors := 0
51 | for _, dbID := range dbList {
52 | if err := tr.registerDatabaseTools(ctx, dbID); err != nil {
53 | logger.Error("Error registering tools for database %s: %v", dbID, err)
54 | registrationErrors++
55 | } else {
56 | logger.Info("Successfully registered tools for database %s", dbID)
57 | }
58 | }
59 |
60 | // Register common tools
61 | tr.registerCommonTools(ctx)
62 |
63 | if registrationErrors > 0 {
64 | return fmt.Errorf("errors occurred while registering tools for %d databases", registrationErrors)
65 | }
66 | return nil
67 | }
68 |
69 | // registerDatabaseTools registers all tools for a specific database
70 | func (tr *ToolRegistry) registerDatabaseTools(ctx context.Context, dbID string) error {
71 | // Get all tool types from the factory
72 | toolTypeNames := []string{
73 | "query", "execute", "transaction", "performance", "schema",
74 | }
75 |
76 | logger.Info("Registering tools for database %s", dbID)
77 |
78 | // Special case for postgres - skip the database info call that's failing
79 | dbType, err := tr.databaseUseCase.GetDatabaseType(dbID)
80 | if err == nil && dbType == "postgres" {
81 | // For PostgreSQL, we'll manually create a minimal info structure
82 | // rather than calling the problematic method
83 | logger.Info("Using special handling for PostgreSQL database: %s", dbID)
84 |
85 | // Create a mock database info for PostgreSQL
86 | dbInfo := map[string]interface{}{
87 | "database": dbID,
88 | "tables": []map[string]interface{}{},
89 | }
90 |
91 | logger.Info("Created mock database info for PostgreSQL database %s: %+v", dbID, dbInfo)
92 |
93 | // Register each tool type for this database
94 | registrationErrors := 0
95 | for _, typeName := range toolTypeNames {
96 | // Use simpler tool names: <tooltype>_<dbID>
97 | toolName := fmt.Sprintf("%s_%s", typeName, dbID)
98 | if err := tr.registerTool(ctx, typeName, toolName, dbID); err != nil {
99 | logger.Error("Error registering tool %s: %v", toolName, err)
100 | registrationErrors++
101 | } else {
102 | logger.Info("Successfully registered tool %s", toolName)
103 | }
104 | }
105 |
106 | // Check if TimescaleDB is available for this PostgreSQL database
107 | // by executing a simple check query
108 | checkQuery := "SELECT 1 FROM pg_extension WHERE extname = 'timescaledb'"
109 | result, err := tr.databaseUseCase.ExecuteQuery(ctx, dbID, checkQuery, nil)
110 | if err == nil && result != "[]" && result != "" {
111 | logger.Info("TimescaleDB extension detected for database %s, registering TimescaleDB tools", dbID)
112 |
113 | // Register TimescaleDB-specific tools
114 | timescaleTool := NewTimescaleDBTool()
115 |
116 | // Register time series query tool
117 | tsQueryToolName := fmt.Sprintf("timescaledb_timeseries_query_%s", dbID)
118 | tsQueryTool := timescaleTool.CreateTimeSeriesQueryTool(tsQueryToolName, dbID)
119 | if err := tr.server.AddTool(ctx, tsQueryTool, func(ctx context.Context, request server.ToolCallRequest) (interface{}, error) {
120 | response, err := timescaleTool.HandleRequest(ctx, request, dbID, tr.databaseUseCase)
121 | return FormatResponse(response, err)
122 | }); err != nil {
123 | logger.Error("Error registering TimescaleDB time series query tool: %v", err)
124 | registrationErrors++
125 | } else {
126 | logger.Info("Successfully registered TimescaleDB time series query tool: %s", tsQueryToolName)
127 | }
128 |
129 | // Register time series analyze tool
130 | tsAnalyzeToolName := fmt.Sprintf("timescaledb_analyze_timeseries_%s", dbID)
131 | tsAnalyzeTool := timescaleTool.CreateTimeSeriesAnalyzeTool(tsAnalyzeToolName, dbID)
132 | if err := tr.server.AddTool(ctx, tsAnalyzeTool, func(ctx context.Context, request server.ToolCallRequest) (interface{}, error) {
133 | response, err := timescaleTool.HandleRequest(ctx, request, dbID, tr.databaseUseCase)
134 | return FormatResponse(response, err)
135 | }); err != nil {
136 | logger.Error("Error registering TimescaleDB time series analyze tool: %v", err)
137 | registrationErrors++
138 | } else {
139 | logger.Info("Successfully registered TimescaleDB time series analyze tool: %s", tsAnalyzeToolName)
140 | }
141 | }
142 |
143 | if registrationErrors > 0 {
144 | return fmt.Errorf("errors occurred while registering %d tools", registrationErrors)
145 | }
146 |
147 | logger.Info("Completed registering tools for database %s", dbID)
148 | return nil
149 | }
150 |
151 | // For other database types, continue with the normal approach
152 | // Check if this database actually exists
153 | dbInfo, err := tr.databaseUseCase.GetDatabaseInfo(dbID)
154 | if err != nil {
155 | return fmt.Errorf("failed to get database info for %s: %w", dbID, err)
156 | }
157 |
158 | logger.Info("Database %s info: %+v", dbID, dbInfo)
159 |
160 | // Register each tool type for this database
161 | registrationErrors := 0
162 | for _, typeName := range toolTypeNames {
163 | // Use simpler tool names: <tooltype>_<dbID>
164 | toolName := fmt.Sprintf("%s_%s", typeName, dbID)
165 | if err := tr.registerTool(ctx, typeName, toolName, dbID); err != nil {
166 | logger.Error("Error registering tool %s: %v", toolName, err)
167 | registrationErrors++
168 | } else {
169 | logger.Info("Successfully registered tool %s", toolName)
170 | }
171 | }
172 |
173 | if registrationErrors > 0 {
174 | return fmt.Errorf("errors occurred while registering %d tools", registrationErrors)
175 | }
176 |
177 | logger.Info("Completed registering tools for database %s", dbID)
178 | return nil
179 | }
180 |
181 | // registerTool registers a tool with the server
182 | func (tr *ToolRegistry) registerTool(ctx context.Context, toolTypeName string, name string, dbID string) error {
183 | logger.Info("Registering tool '%s' of type '%s' (database: %s)", name, toolTypeName, dbID)
184 |
185 | toolTypeImpl, ok := tr.factory.GetToolType(toolTypeName)
186 | if !ok {
187 | return fmt.Errorf("failed to get tool type for '%s'", toolTypeName)
188 | }
189 |
190 | tool := toolTypeImpl.CreateTool(name, dbID)
191 |
192 | return tr.server.AddTool(ctx, tool, func(ctx context.Context, request server.ToolCallRequest) (interface{}, error) {
193 | response, err := toolTypeImpl.HandleRequest(ctx, request, dbID, tr.databaseUseCase)
194 | return FormatResponse(response, err)
195 | })
196 | }
197 |
198 | // registerCommonTools registers tools that are not specific to a database
199 | func (tr *ToolRegistry) registerCommonTools(ctx context.Context) {
200 | // Register the list_databases tool with simple name
201 | _, ok := tr.factory.GetToolType("list_databases")
202 | if ok {
203 | // Use simple name for list_databases tool
204 | listDbName := "list_databases"
205 | if err := tr.registerTool(ctx, "list_databases", listDbName, ""); err != nil {
206 | logger.Error("Error registering %s tool: %v", listDbName, err)
207 | } else {
208 | logger.Info("Successfully registered tool %s", listDbName)
209 | }
210 | }
211 | }
212 |
213 | // RegisterMockTools registers mock tools with the server when no db connections available
214 | func (tr *ToolRegistry) RegisterMockTools(ctx context.Context) error {
215 | logger.Info("Registering mock tools")
216 |
217 | // For each tool type, register a simplified mock tool
218 | for toolTypeName := range tr.factory.toolTypes {
219 | // Format: mock_<tooltype>
220 | mockToolName := fmt.Sprintf("mock_%s", toolTypeName)
221 |
222 | toolTypeImpl, ok := tr.factory.GetToolType(toolTypeName)
223 | if !ok {
224 | logger.Warn("Failed to get tool type for '%s'", toolTypeName)
225 | continue
226 | }
227 |
228 | tool := toolTypeImpl.CreateTool(mockToolName, "mock")
229 |
230 | err := tr.server.AddTool(ctx, tool, func(ctx context.Context, request server.ToolCallRequest) (interface{}, error) {
231 | response, err := toolTypeImpl.HandleRequest(ctx, request, "mock", tr.databaseUseCase)
232 | return FormatResponse(response, err)
233 | })
234 |
235 | if err != nil {
236 | logger.Error("Failed to register mock tool '%s': %v", mockToolName, err)
237 | continue
238 | }
239 | }
240 |
241 | return nil
242 | }
243 |
244 | // RegisterCursorCompatibleTools is kept for backward compatibility but does nothing
245 | // as we now register tools with simple names directly
246 | func (tr *ToolRegistry) RegisterCursorCompatibleTools(ctx context.Context) error {
247 | // This function is intentionally empty as we now register tools with simple names directly
248 | return nil
249 | }
250 |
```
--------------------------------------------------------------------------------
/pkg/db/timescale/continuous_aggregate.go:
--------------------------------------------------------------------------------
```go
1 | package timescale
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strings"
7 | )
8 |
9 | // ContinuousAggregateOptions encapsulates options for creating a continuous aggregate
10 | type ContinuousAggregateOptions struct {
11 | // Required parameters
12 | ViewName string // Name of the continuous aggregate view to create
13 | SourceTable string // Source table with raw data
14 | TimeColumn string // Time column to bucket
15 | BucketInterval string // Time bucket interval (e.g., '1 hour', '1 day')
16 |
17 | // Optional parameters
18 | Aggregations []ColumnAggregation // Aggregations to include in the view
19 | WhereCondition string // WHERE condition to filter source data
20 | WithData bool // Whether to materialize data immediately (WITH DATA)
21 | RefreshPolicy bool // Whether to add a refresh policy
22 | RefreshInterval string // Refresh interval (default: '1 day')
23 | RefreshLookback string // How far back to look when refreshing (default: '1 week')
24 | MaterializedOnly bool // Whether to materialize only (no real-time)
25 | CreateIfNotExists bool // Whether to use IF NOT EXISTS
26 | }
27 |
28 | // ContinuousAggregatePolicyOptions encapsulates options for refresh policies
29 | type ContinuousAggregatePolicyOptions struct {
30 | ViewName string // Name of the continuous aggregate view
31 | Start string // Start offset (e.g., '-2 days')
32 | End string // End offset (e.g., 'now()')
33 | ScheduleInterval string // Execution interval (e.g., '1 hour')
34 | }
35 |
36 | // CreateContinuousAggregate creates a new continuous aggregate view
37 | func (t *DB) CreateContinuousAggregate(ctx context.Context, options ContinuousAggregateOptions) error {
38 | if !t.isTimescaleDB {
39 | return fmt.Errorf("TimescaleDB extension not available")
40 | }
41 |
42 | var builder strings.Builder
43 |
44 | // Build CREATE MATERIALIZED VIEW statement
45 | builder.WriteString("CREATE MATERIALIZED VIEW ")
46 |
47 | // Add IF NOT EXISTS clause if requested
48 | if options.CreateIfNotExists {
49 | builder.WriteString("IF NOT EXISTS ")
50 | }
51 |
52 | // Add view name
53 | builder.WriteString(options.ViewName)
54 | builder.WriteString("\n")
55 |
56 | // Add WITH clause for materialized_only if requested
57 | if options.MaterializedOnly {
58 | builder.WriteString("WITH (timescaledb.materialized_only=true)\n")
59 | }
60 |
61 | // Start SELECT statement
62 | builder.WriteString("AS SELECT\n ")
63 |
64 | // Add time bucket
65 | builder.WriteString(fmt.Sprintf("time_bucket('%s', %s) as time_bucket",
66 | options.BucketInterval, options.TimeColumn))
67 |
68 | // Add aggregations
69 | if len(options.Aggregations) > 0 {
70 | for _, agg := range options.Aggregations {
71 | colName := agg.Alias
72 | if colName == "" {
73 | colName = strings.ToLower(string(agg.Function)) + "_" + agg.Column
74 | }
75 |
76 | builder.WriteString(fmt.Sprintf(",\n %s(%s) as %s",
77 | agg.Function, agg.Column, colName))
78 | }
79 | } else {
80 | // Default to count(*) if no aggregations specified
81 | builder.WriteString(",\n COUNT(*) as count")
82 | }
83 |
84 | // Add FROM clause
85 | builder.WriteString(fmt.Sprintf("\nFROM %s\n", options.SourceTable))
86 |
87 | // Add WHERE clause if specified
88 | if options.WhereCondition != "" {
89 | builder.WriteString(fmt.Sprintf("WHERE %s\n", options.WhereCondition))
90 | }
91 |
92 | // Add GROUP BY clause
93 | builder.WriteString("GROUP BY time_bucket\n")
94 |
95 | // Add WITH DATA or WITH NO DATA
96 | if options.WithData {
97 | builder.WriteString("WITH DATA")
98 | } else {
99 | builder.WriteString("WITH NO DATA")
100 | }
101 |
102 | // Execute the statement
103 | _, err := t.ExecuteSQLWithoutParams(ctx, builder.String())
104 | if err != nil {
105 | return fmt.Errorf("failed to create continuous aggregate: %w", err)
106 | }
107 |
108 | // Add refresh policy if requested
109 | if options.RefreshPolicy {
110 | refreshInterval := options.RefreshInterval
111 | if refreshInterval == "" {
112 | refreshInterval = "1 day"
113 | }
114 |
115 | refreshLookback := options.RefreshLookback
116 | if refreshLookback == "" {
117 | refreshLookback = "1 week"
118 | }
119 |
120 | err = t.AddContinuousAggregatePolicy(ctx, ContinuousAggregatePolicyOptions{
121 | ViewName: options.ViewName,
122 | Start: fmt.Sprintf("-%s", refreshLookback),
123 | End: "now()",
124 | ScheduleInterval: refreshInterval,
125 | })
126 |
127 | if err != nil {
128 | return fmt.Errorf("created continuous aggregate but failed to add refresh policy: %w", err)
129 | }
130 | }
131 |
132 | return nil
133 | }
134 |
135 | // RefreshContinuousAggregate refreshes a continuous aggregate for a specific time range
136 | func (t *DB) RefreshContinuousAggregate(ctx context.Context, viewName, startTime, endTime string) error {
137 | if !t.isTimescaleDB {
138 | return fmt.Errorf("TimescaleDB extension not available")
139 | }
140 |
141 | var builder strings.Builder
142 |
143 | // Build CALL statement
144 | builder.WriteString("CALL refresh_continuous_aggregate(")
145 |
146 | // Add view name
147 | builder.WriteString(fmt.Sprintf("'%s'", viewName))
148 |
149 | // Add time range if specified
150 | if startTime != "" && endTime != "" {
151 | builder.WriteString(fmt.Sprintf(", '%s'::timestamptz, '%s'::timestamptz",
152 | startTime, endTime))
153 | } else {
154 | builder.WriteString(", NULL, NULL")
155 | }
156 |
157 | builder.WriteString(")")
158 |
159 | // Execute the statement
160 | _, err := t.ExecuteSQLWithoutParams(ctx, builder.String())
161 | if err != nil {
162 | return fmt.Errorf("failed to refresh continuous aggregate: %w", err)
163 | }
164 |
165 | return nil
166 | }
167 |
168 | // AddContinuousAggregatePolicy adds a refresh policy to a continuous aggregate
169 | func (t *DB) AddContinuousAggregatePolicy(ctx context.Context, options ContinuousAggregatePolicyOptions) error {
170 | if !t.isTimescaleDB {
171 | return fmt.Errorf("TimescaleDB extension not available")
172 | }
173 |
174 | // Build policy creation SQL
175 | sql := fmt.Sprintf(
176 | "SELECT add_continuous_aggregate_policy('%s', start_offset => INTERVAL '%s', "+
177 | "end_offset => INTERVAL '%s', schedule_interval => INTERVAL '%s')",
178 | options.ViewName,
179 | options.Start,
180 | options.End,
181 | options.ScheduleInterval,
182 | )
183 |
184 | // Execute the statement
185 | _, err := t.ExecuteSQLWithoutParams(ctx, sql)
186 | if err != nil {
187 | return fmt.Errorf("failed to add continuous aggregate policy: %w", err)
188 | }
189 |
190 | return nil
191 | }
192 |
193 | // RemoveContinuousAggregatePolicy removes a refresh policy from a continuous aggregate
194 | func (t *DB) RemoveContinuousAggregatePolicy(ctx context.Context, viewName string) error {
195 | if !t.isTimescaleDB {
196 | return fmt.Errorf("TimescaleDB extension not available")
197 | }
198 |
199 | // Build policy removal SQL
200 | sql := fmt.Sprintf(
201 | "SELECT remove_continuous_aggregate_policy('%s')",
202 | viewName,
203 | )
204 |
205 | // Execute the statement
206 | _, err := t.ExecuteSQLWithoutParams(ctx, sql)
207 | if err != nil {
208 | return fmt.Errorf("failed to remove continuous aggregate policy: %w", err)
209 | }
210 |
211 | return nil
212 | }
213 |
214 | // DropContinuousAggregate drops a continuous aggregate
215 | func (t *DB) DropContinuousAggregate(ctx context.Context, viewName string, cascade bool) error {
216 | if !t.isTimescaleDB {
217 | return fmt.Errorf("TimescaleDB extension not available")
218 | }
219 |
220 | var builder strings.Builder
221 |
222 | // Build DROP statement
223 | builder.WriteString(fmt.Sprintf("DROP MATERIALIZED VIEW %s", viewName))
224 |
225 | // Add CASCADE if requested
226 | if cascade {
227 | builder.WriteString(" CASCADE")
228 | }
229 |
230 | // Execute the statement
231 | _, err := t.ExecuteSQLWithoutParams(ctx, builder.String())
232 | if err != nil {
233 | return fmt.Errorf("failed to drop continuous aggregate: %w", err)
234 | }
235 |
236 | return nil
237 | }
238 |
239 | // GetContinuousAggregateInfo gets detailed information about a continuous aggregate
240 | func (t *DB) GetContinuousAggregateInfo(ctx context.Context, viewName string) (map[string]interface{}, error) {
241 | if !t.isTimescaleDB {
242 | return nil, fmt.Errorf("TimescaleDB extension not available")
243 | }
244 |
245 | // Query for continuous aggregate information
246 | query := fmt.Sprintf(`
247 | WITH policy_info AS (
248 | SELECT
249 | ca.user_view_name,
250 | p.schedule_interval,
251 | p.start_offset,
252 | p.end_offset
253 | FROM timescaledb_information.continuous_aggregates ca
254 | LEFT JOIN timescaledb_information.jobs j ON j.hypertable_name = ca.user_view_name
255 | LEFT JOIN timescaledb_information.policies p ON p.job_id = j.job_id
256 | WHERE p.proc_name = 'policy_refresh_continuous_aggregate'
257 | AND ca.view_name = '%s'
258 | ),
259 | size_info AS (
260 | SELECT
261 | pg_size_pretty(pg_total_relation_size(format('%%I.%%I', schemaname, tablename)))
262 | as view_size
263 | FROM pg_tables
264 | WHERE tablename = '%s'
265 | )
266 | SELECT
267 | ca.view_name,
268 | ca.view_schema,
269 | ca.materialized_only,
270 | ca.view_definition,
271 | ca.refresh_lag,
272 | ca.refresh_interval,
273 | ca.hypertable_name,
274 | ca.hypertable_schema,
275 | pi.schedule_interval,
276 | pi.start_offset,
277 | pi.end_offset,
278 | si.view_size,
279 | (
280 | SELECT min(time_bucket)
281 | FROM %s
282 | ) as min_time,
283 | (
284 | SELECT max(time_bucket)
285 | FROM %s
286 | ) as max_time
287 | FROM timescaledb_information.continuous_aggregates ca
288 | LEFT JOIN policy_info pi ON pi.user_view_name = ca.user_view_name
289 | CROSS JOIN size_info si
290 | WHERE ca.view_name = '%s'
291 | `, viewName, viewName, viewName, viewName, viewName)
292 |
293 | // Execute query
294 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
295 | if err != nil {
296 | return nil, fmt.Errorf("failed to get continuous aggregate info: %w", err)
297 | }
298 |
299 | // Convert result to map
300 | rows, ok := result.([]map[string]interface{})
301 | if !ok || len(rows) == 0 {
302 | return nil, fmt.Errorf("continuous aggregate '%s' not found", viewName)
303 | }
304 |
305 | // Extract the first row
306 | info := rows[0]
307 |
308 | // Add computed fields
309 | info["has_policy"] = info["schedule_interval"] != nil
310 |
311 | return info, nil
312 | }
313 |
```
--------------------------------------------------------------------------------
/pkg/db/db.go:
--------------------------------------------------------------------------------
```go
1 | package db
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "errors"
7 | "fmt"
8 | "net/url"
9 | "strings"
10 | "time"
11 |
12 | "github.com/FreePeak/db-mcp-server/pkg/logger"
13 | // Import database drivers
14 | _ "github.com/go-sql-driver/mysql"
15 | _ "github.com/lib/pq"
16 | )
17 |
18 | // Common database errors
19 | var (
20 | ErrNotFound = errors.New("record not found")
21 | ErrAlreadyExists = errors.New("record already exists")
22 | ErrInvalidInput = errors.New("invalid input")
23 | ErrNotImplemented = errors.New("not implemented")
24 | ErrNoDatabase = errors.New("no database connection")
25 | )
26 |
27 | // PostgresSSLMode defines the SSL mode for PostgreSQL connections
28 | type PostgresSSLMode string
29 |
30 | // SSLMode constants for PostgreSQL
31 | const (
32 | SSLDisable PostgresSSLMode = "disable"
33 | SSLRequire PostgresSSLMode = "require"
34 | SSLVerifyCA PostgresSSLMode = "verify-ca"
35 | SSLVerifyFull PostgresSSLMode = "verify-full"
36 | SSLPrefer PostgresSSLMode = "prefer"
37 | )
38 |
39 | // Config represents database connection configuration
40 | type Config struct {
41 | Type string
42 | Host string
43 | Port int
44 | User string
45 | Password string
46 | Name string
47 |
48 | // Additional PostgreSQL specific options
49 | SSLMode PostgresSSLMode
50 | SSLCert string
51 | SSLKey string
52 | SSLRootCert string
53 | ApplicationName string
54 | ConnectTimeout int // in seconds
55 | QueryTimeout int // in seconds, default is 30 seconds
56 | TargetSessionAttrs string // for PostgreSQL 10+
57 | Options map[string]string // Extra connection options
58 |
59 | // Connection pool settings
60 | MaxOpenConns int
61 | MaxIdleConns int
62 | ConnMaxLifetime time.Duration
63 | ConnMaxIdleTime time.Duration
64 | }
65 |
66 | // SetDefaults sets default values for the configuration if they are not set
67 | func (c *Config) SetDefaults() {
68 | if c.MaxOpenConns == 0 {
69 | c.MaxOpenConns = 25
70 | }
71 | if c.MaxIdleConns == 0 {
72 | c.MaxIdleConns = 5
73 | }
74 | if c.ConnMaxLifetime == 0 {
75 | c.ConnMaxLifetime = 5 * time.Minute
76 | }
77 | if c.ConnMaxIdleTime == 0 {
78 | c.ConnMaxIdleTime = 5 * time.Minute
79 | }
80 | if c.Type == "postgres" && c.SSLMode == "" {
81 | c.SSLMode = SSLDisable
82 | }
83 | if c.ConnectTimeout == 0 {
84 | c.ConnectTimeout = 10 // Default 10 seconds
85 | }
86 | if c.QueryTimeout == 0 {
87 | c.QueryTimeout = 30 // Default 30 seconds
88 | }
89 | }
90 |
91 | // Database represents a generic database interface
92 | type Database interface {
93 | // Core database operations
94 | Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
95 | QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row
96 | Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
97 |
98 | // Transaction support
99 | BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
100 |
101 | // Connection management
102 | Connect() error
103 | Close() error
104 | Ping(ctx context.Context) error
105 |
106 | // Metadata
107 | DriverName() string
108 | ConnectionString() string
109 | QueryTimeout() int
110 |
111 | // DB object access (for specific DB operations)
112 | DB() *sql.DB
113 | }
114 |
115 | // database is the concrete implementation of the Database interface
116 | type database struct {
117 | config Config
118 | db *sql.DB
119 | driverName string
120 | dsn string
121 | }
122 |
123 | // buildPostgresConnStr builds a PostgreSQL connection string with all options
124 | func buildPostgresConnStr(config Config) string {
125 | params := make([]string, 0)
126 |
127 | // Required parameters
128 | params = append(params, fmt.Sprintf("host=%s", config.Host))
129 | params = append(params, fmt.Sprintf("port=%d", config.Port))
130 | params = append(params, fmt.Sprintf("user=%s", config.User))
131 |
132 | if config.Password != "" {
133 | params = append(params, fmt.Sprintf("password=%s", config.Password))
134 | }
135 |
136 | if config.Name != "" {
137 | params = append(params, fmt.Sprintf("dbname=%s", config.Name))
138 | }
139 |
140 | // SSL configuration
141 | params = append(params, fmt.Sprintf("sslmode=%s", config.SSLMode))
142 |
143 | if config.SSLCert != "" {
144 | params = append(params, fmt.Sprintf("sslcert=%s", config.SSLCert))
145 | }
146 |
147 | if config.SSLKey != "" {
148 | params = append(params, fmt.Sprintf("sslkey=%s", config.SSLKey))
149 | }
150 |
151 | if config.SSLRootCert != "" {
152 | params = append(params, fmt.Sprintf("sslrootcert=%s", config.SSLRootCert))
153 | }
154 |
155 | // Connection timeout
156 | if config.ConnectTimeout > 0 {
157 | params = append(params, fmt.Sprintf("connect_timeout=%d", config.ConnectTimeout))
158 | }
159 |
160 | // Application name for better identification in pg_stat_activity
161 | if config.ApplicationName != "" {
162 | params = append(params, fmt.Sprintf("application_name=%s", url.QueryEscape(config.ApplicationName)))
163 | }
164 |
165 | // Target session attributes for load balancing and failover (PostgreSQL 10+)
166 | if config.TargetSessionAttrs != "" {
167 | params = append(params, fmt.Sprintf("target_session_attrs=%s", config.TargetSessionAttrs))
168 | }
169 |
170 | // Add any additional options from the map
171 | if config.Options != nil {
172 | for key, value := range config.Options {
173 | params = append(params, fmt.Sprintf("%s=%s", key, url.QueryEscape(value)))
174 | }
175 | }
176 |
177 | return strings.Join(params, " ")
178 | }
179 |
180 | // NewDatabase creates a new database connection based on the provided configuration
181 | func NewDatabase(config Config) (Database, error) {
182 | // Set default values for the configuration
183 | config.SetDefaults()
184 |
185 | var dsn string
186 | var driverName string
187 |
188 | // Create DSN string based on database type
189 | switch config.Type {
190 | case "mysql":
191 | driverName = "mysql"
192 | dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
193 | config.User, config.Password, config.Host, config.Port, config.Name)
194 | case "postgres":
195 | driverName = "postgres"
196 | dsn = buildPostgresConnStr(config)
197 | default:
198 | return nil, fmt.Errorf("unsupported database type: %s", config.Type)
199 | }
200 |
201 | return &database{
202 | config: config,
203 | driverName: driverName,
204 | dsn: dsn,
205 | }, nil
206 | }
207 |
208 | // Connect establishes a connection to the database
209 | func (d *database) Connect() error {
210 | db, err := sql.Open(d.driverName, d.dsn)
211 | if err != nil {
212 | return fmt.Errorf("failed to open database connection: %w", err)
213 | }
214 |
215 | // Configure connection pool
216 | db.SetMaxOpenConns(d.config.MaxOpenConns)
217 | db.SetMaxIdleConns(d.config.MaxIdleConns)
218 | db.SetConnMaxLifetime(d.config.ConnMaxLifetime)
219 | db.SetConnMaxIdleTime(d.config.ConnMaxIdleTime)
220 |
221 | // Verify connection is working
222 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
223 | defer cancel()
224 |
225 | if err := db.PingContext(ctx); err != nil {
226 | closeErr := db.Close()
227 | if closeErr != nil {
228 | logger.Error("Error closing database connection: %v", closeErr)
229 | }
230 | return fmt.Errorf("failed to ping database: %w", err)
231 | }
232 |
233 | d.db = db
234 | logger.Info("Connected to %s database at %s:%d/%s", d.config.Type, d.config.Host, d.config.Port, d.config.Name)
235 |
236 | return nil
237 | }
238 |
239 | // Close closes the database connection
240 | func (d *database) Close() error {
241 | if d.db == nil {
242 | return nil
243 | }
244 | if err := d.db.Close(); err != nil {
245 | logger.Error("Error closing database connection: %v", err)
246 | return err
247 | }
248 | return nil
249 | }
250 |
251 | // Ping checks if the database connection is still alive
252 | func (d *database) Ping(ctx context.Context) error {
253 | if d.db == nil {
254 | return ErrNoDatabase
255 | }
256 | return d.db.PingContext(ctx)
257 | }
258 |
259 | // Query executes a query that returns rows
260 | func (d *database) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
261 | if d.db == nil {
262 | return nil, ErrNoDatabase
263 | }
264 | return d.db.QueryContext(ctx, query, args...)
265 | }
266 |
267 | // QueryRow executes a query that is expected to return at most one row
268 | func (d *database) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row {
269 | if d.db == nil {
270 | return nil
271 | }
272 | return d.db.QueryRowContext(ctx, query, args...)
273 | }
274 |
275 | // Exec executes a query without returning any rows
276 | func (d *database) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
277 | if d.db == nil {
278 | return nil, ErrNoDatabase
279 | }
280 | return d.db.ExecContext(ctx, query, args...)
281 | }
282 |
283 | // BeginTx starts a transaction
284 | func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
285 | if d.db == nil {
286 | return nil, ErrNoDatabase
287 | }
288 | return d.db.BeginTx(ctx, opts)
289 | }
290 |
291 | // DB returns the underlying database connection
292 | func (d *database) DB() *sql.DB {
293 | return d.db
294 | }
295 |
296 | // DriverName returns the name of the database driver
297 | func (d *database) DriverName() string {
298 | return d.driverName
299 | }
300 |
301 | // ConnectionString returns the database connection string with password masked
302 | func (d *database) ConnectionString() string {
303 | // Return masked DSN (hide password)
304 | switch d.config.Type {
305 | case "mysql":
306 | return fmt.Sprintf("%s:***@tcp(%s:%d)/%s",
307 | d.config.User, d.config.Host, d.config.Port, d.config.Name)
308 | case "postgres":
309 | // Create a sanitized version of the connection string
310 | params := make([]string, 0)
311 |
312 | params = append(params, fmt.Sprintf("host=%s", d.config.Host))
313 | params = append(params, fmt.Sprintf("port=%d", d.config.Port))
314 | params = append(params, fmt.Sprintf("user=%s", d.config.User))
315 | params = append(params, "password=***")
316 | params = append(params, fmt.Sprintf("dbname=%s", d.config.Name))
317 |
318 | if string(d.config.SSLMode) != "" {
319 | params = append(params, fmt.Sprintf("sslmode=%s", d.config.SSLMode))
320 | }
321 |
322 | if d.config.ApplicationName != "" {
323 | params = append(params, fmt.Sprintf("application_name=%s", d.config.ApplicationName))
324 | }
325 |
326 | return strings.Join(params, " ")
327 | default:
328 | return "unknown"
329 | }
330 | }
331 |
332 | // QueryTimeout returns the configured query timeout in seconds
333 | func (d *database) QueryTimeout() int {
334 | return d.config.QueryTimeout
335 | }
336 |
```
--------------------------------------------------------------------------------
/internal/delivery/mcp/timescale_schema.go:
--------------------------------------------------------------------------------
```go
1 | package mcp
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "strconv"
8 | )
9 |
10 | // UseCaseProvider interface defined in the package
11 |
12 | // HypertableSchemaInfo represents schema information for a TimescaleDB hypertable
13 | type HypertableSchemaInfo struct {
14 | TableName string `json:"tableName"`
15 | SchemaName string `json:"schemaName"`
16 | TimeColumn string `json:"timeColumn"`
17 | ChunkTimeInterval string `json:"chunkTimeInterval"`
18 | Size string `json:"size"`
19 | ChunkCount int `json:"chunkCount"`
20 | RowCount int64 `json:"rowCount"`
21 | SpacePartitioning []string `json:"spacePartitioning,omitempty"`
22 | CompressionEnabled bool `json:"compressionEnabled"`
23 | CompressionConfig CompressionConfig `json:"compressionConfig,omitempty"`
24 | RetentionEnabled bool `json:"retentionEnabled"`
25 | RetentionInterval string `json:"retentionInterval,omitempty"`
26 | Columns []HypertableColumnInfo `json:"columns"`
27 | }
28 |
29 | // HypertableColumnInfo represents column information for a hypertable
30 | type HypertableColumnInfo struct {
31 | Name string `json:"name"`
32 | Type string `json:"type"`
33 | Nullable bool `json:"nullable"`
34 | PrimaryKey bool `json:"primaryKey"`
35 | Indexed bool `json:"indexed"`
36 | Description string `json:"description,omitempty"`
37 | }
38 |
39 | // CompressionConfig represents compression configuration for a hypertable
40 | type CompressionConfig struct {
41 | SegmentBy string `json:"segmentBy,omitempty"`
42 | OrderBy string `json:"orderBy,omitempty"`
43 | Interval string `json:"interval,omitempty"`
44 | }
45 |
46 | // HypertableSchemaProvider provides schema information for hypertables
47 | type HypertableSchemaProvider struct {
48 | // We use the TimescaleDBContextProvider from timescale_context.go
49 | contextProvider *TimescaleDBContextProvider
50 | }
51 |
52 | // NewHypertableSchemaProvider creates a new hypertable schema provider
53 | func NewHypertableSchemaProvider() *HypertableSchemaProvider {
54 | return &HypertableSchemaProvider{
55 | contextProvider: NewTimescaleDBContextProvider(),
56 | }
57 | }
58 |
59 | // GetHypertableSchema gets schema information for a specific hypertable
60 | func (p *HypertableSchemaProvider) GetHypertableSchema(
61 | ctx context.Context,
62 | dbID string,
63 | tableName string,
64 | useCase UseCaseProvider,
65 | ) (*HypertableSchemaInfo, error) {
66 | // First check if TimescaleDB is available
67 | tsdbContext, err := p.contextProvider.DetectTimescaleDB(ctx, dbID, useCase)
68 | if err != nil {
69 | return nil, fmt.Errorf("failed to detect TimescaleDB: %w", err)
70 | }
71 |
72 | if !tsdbContext.IsTimescaleDB {
73 | return nil, fmt.Errorf("TimescaleDB is not available in the database %s", dbID)
74 | }
75 |
76 | // Get hypertable metadata
77 | query := fmt.Sprintf(`
78 | SELECT
79 | h.table_name,
80 | h.schema_name,
81 | t.tableowner as owner,
82 | h.num_dimensions,
83 | dc.column_name as time_dimension,
84 | dc.column_type as time_dimension_type,
85 | dc.time_interval as chunk_time_interval,
86 | h.compression_enabled,
87 | pg_size_pretty(pg_total_relation_size(format('%%I.%%I', h.schema_name, h.table_name))) as total_size,
88 | (SELECT count(*) FROM timescaledb_information.chunks WHERE hypertable_name = h.table_name) as chunks,
89 | (SELECT count(*) FROM %s.%s) as total_rows
90 | FROM timescaledb_information.hypertables h
91 | JOIN pg_tables t ON h.table_name = t.tablename AND h.schema_name = t.schemaname
92 | JOIN timescaledb_information.dimensions dc ON h.hypertable_name = dc.hypertable_name
93 | WHERE h.table_name = '%s' AND dc.dimension_number = 1
94 | `, tableName, tableName, tableName)
95 |
96 | metadataResult, err := useCase.ExecuteStatement(ctx, dbID, query, nil)
97 | if err != nil {
98 | return nil, fmt.Errorf("failed to get hypertable metadata: %w", err)
99 | }
100 |
101 | // Parse the result to determine if the table is a hypertable
102 | var metadata []map[string]interface{}
103 | if err := json.Unmarshal([]byte(metadataResult), &metadata); err != nil {
104 | return nil, fmt.Errorf("failed to parse metadata result: %w", err)
105 | }
106 |
107 | // If no results, the table is not a hypertable
108 | if len(metadata) == 0 {
109 | return nil, fmt.Errorf("table '%s' is not a hypertable", tableName)
110 | }
111 |
112 | // Create schema info from metadata
113 | schemaInfo := &HypertableSchemaInfo{
114 | TableName: tableName,
115 | Columns: []HypertableColumnInfo{},
116 | }
117 |
118 | // Extract metadata fields
119 | row := metadata[0]
120 |
121 | if schemaName, ok := row["schema_name"].(string); ok {
122 | schemaInfo.SchemaName = schemaName
123 | }
124 |
125 | if timeDimension, ok := row["time_dimension"].(string); ok {
126 | schemaInfo.TimeColumn = timeDimension
127 | }
128 |
129 | if chunkInterval, ok := row["chunk_time_interval"].(string); ok {
130 | schemaInfo.ChunkTimeInterval = chunkInterval
131 | }
132 |
133 | if size, ok := row["total_size"].(string); ok {
134 | schemaInfo.Size = size
135 | }
136 |
137 | // Convert numeric fields
138 | if chunks, ok := row["chunks"].(float64); ok {
139 | schemaInfo.ChunkCount = int(chunks)
140 | } else if chunks, ok := row["chunks"].(int); ok {
141 | schemaInfo.ChunkCount = chunks
142 | } else if chunksStr, ok := row["chunks"].(string); ok {
143 | if chunks, err := strconv.Atoi(chunksStr); err == nil {
144 | schemaInfo.ChunkCount = chunks
145 | }
146 | }
147 |
148 | if rows, ok := row["total_rows"].(float64); ok {
149 | schemaInfo.RowCount = int64(rows)
150 | } else if rows, ok := row["total_rows"].(int64); ok {
151 | schemaInfo.RowCount = rows
152 | } else if rowsStr, ok := row["total_rows"].(string); ok {
153 | if rows, err := strconv.ParseInt(rowsStr, 10, 64); err == nil {
154 | schemaInfo.RowCount = rows
155 | }
156 | }
157 |
158 | // Handle boolean fields
159 | if compression, ok := row["compression_enabled"].(bool); ok {
160 | schemaInfo.CompressionEnabled = compression
161 | } else if compressionStr, ok := row["compression_enabled"].(string); ok {
162 | schemaInfo.CompressionEnabled = compressionStr == "t" || compressionStr == "true" || compressionStr == "1"
163 | }
164 |
165 | // Get compression settings if compression is enabled
166 | if schemaInfo.CompressionEnabled {
167 | compressionQuery := fmt.Sprintf(`
168 | SELECT segmentby, orderby, compression_interval
169 | FROM (
170 | SELECT
171 | cs.segmentby,
172 | cs.orderby,
173 | (SELECT schedule_interval FROM timescaledb_information.job_stats js
174 | JOIN timescaledb_information.jobs j ON js.job_id = j.job_id
175 | WHERE j.hypertable_name = '%s' AND j.proc_name = 'policy_compression'
176 | LIMIT 1) as compression_interval
177 | FROM timescaledb_information.compression_settings cs
178 | WHERE cs.hypertable_name = '%s'
179 | ) t
180 | `, tableName, tableName)
181 |
182 | compressionResult, err := useCase.ExecuteStatement(ctx, dbID, compressionQuery, nil)
183 | if err == nil {
184 | var compressionSettings []map[string]interface{}
185 | if err := json.Unmarshal([]byte(compressionResult), &compressionSettings); err == nil && len(compressionSettings) > 0 {
186 | settings := compressionSettings[0]
187 |
188 | if segmentBy, ok := settings["segmentby"].(string); ok {
189 | schemaInfo.CompressionConfig.SegmentBy = segmentBy
190 | }
191 |
192 | if orderBy, ok := settings["orderby"].(string); ok {
193 | schemaInfo.CompressionConfig.OrderBy = orderBy
194 | }
195 |
196 | if interval, ok := settings["compression_interval"].(string); ok {
197 | schemaInfo.CompressionConfig.Interval = interval
198 | }
199 | }
200 | }
201 | }
202 |
203 | // Get retention settings
204 | retentionQuery := fmt.Sprintf(`
205 | SELECT
206 | hypertable_name,
207 | schedule_interval as retention_interval,
208 | TRUE as retention_enabled
209 | FROM
210 | timescaledb_information.jobs j
211 | JOIN
212 | timescaledb_information.job_stats js ON j.job_id = js.job_id
213 | WHERE
214 | j.hypertable_name = '%s' AND j.proc_name = 'policy_retention'
215 | `, tableName)
216 |
217 | retentionResult, err := useCase.ExecuteStatement(ctx, dbID, retentionQuery, nil)
218 | if err == nil {
219 | var retentionSettings []map[string]interface{}
220 | if err := json.Unmarshal([]byte(retentionResult), &retentionSettings); err == nil && len(retentionSettings) > 0 {
221 | settings := retentionSettings[0]
222 |
223 | schemaInfo.RetentionEnabled = true
224 |
225 | if interval, ok := settings["retention_interval"].(string); ok {
226 | schemaInfo.RetentionInterval = interval
227 | }
228 | }
229 | }
230 |
231 | // Get column information
232 | columnsQuery := fmt.Sprintf(`
233 | SELECT
234 | c.column_name,
235 | c.data_type,
236 | c.is_nullable = 'YES' as is_nullable,
237 | (
238 | SELECT COUNT(*) > 0
239 | FROM pg_index i
240 | JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
241 | WHERE i.indrelid = format('%%I.%%I', c.table_schema, c.table_name)::regclass
242 | AND i.indisprimary
243 | AND a.attname = c.column_name
244 | ) as is_primary_key,
245 | (
246 | SELECT COUNT(*) > 0
247 | FROM pg_index i
248 | JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
249 | WHERE i.indrelid = format('%%I.%%I', c.table_schema, c.table_name)::regclass
250 | AND a.attname = c.column_name
251 | ) as is_indexed,
252 | col_description(format('%%I.%%I', c.table_schema, c.table_name)::regclass::oid,
253 | ordinal_position) as description
254 | FROM information_schema.columns c
255 | WHERE c.table_name = '%s'
256 | ORDER BY c.ordinal_position
257 | `, tableName)
258 |
259 | columnsResult, err := useCase.ExecuteStatement(ctx, dbID, columnsQuery, nil)
260 | if err == nil {
261 | var columns []map[string]interface{}
262 | if err := json.Unmarshal([]byte(columnsResult), &columns); err == nil {
263 | for _, column := range columns {
264 | columnInfo := HypertableColumnInfo{}
265 |
266 | if name, ok := column["column_name"].(string); ok {
267 | columnInfo.Name = name
268 | }
269 |
270 | if dataType, ok := column["data_type"].(string); ok {
271 | columnInfo.Type = dataType
272 | }
273 |
274 | if nullable, ok := column["is_nullable"].(bool); ok {
275 | columnInfo.Nullable = nullable
276 | }
277 |
278 | if primaryKey, ok := column["is_primary_key"].(bool); ok {
279 | columnInfo.PrimaryKey = primaryKey
280 | }
281 |
282 | if indexed, ok := column["is_indexed"].(bool); ok {
283 | columnInfo.Indexed = indexed
284 | }
285 |
286 | if description, ok := column["description"].(string); ok {
287 | columnInfo.Description = description
288 | }
289 |
290 | schemaInfo.Columns = append(schemaInfo.Columns, columnInfo)
291 | }
292 | }
293 | }
294 |
295 | return schemaInfo, nil
296 | }
297 |
```
--------------------------------------------------------------------------------
/internal/usecase/database_usecase.go:
--------------------------------------------------------------------------------
```go
1 | package usecase
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strings"
7 | "time"
8 |
9 | "github.com/FreePeak/db-mcp-server/internal/domain"
10 | "github.com/FreePeak/db-mcp-server/internal/logger"
11 | )
12 |
13 | // TODO: Improve error handling with custom error types and better error messages
14 | // TODO: Add extensive unit tests for all business logic
15 | // TODO: Consider implementing domain events for better decoupling
16 | // TODO: Add request validation layer before processing in usecases
17 | // TODO: Implement proper context propagation and timeout handling
18 |
19 | // QueryFactory provides database-specific queries
20 | type QueryFactory interface {
21 | GetTablesQueries() []string
22 | }
23 |
24 | // PostgresQueryFactory creates queries for PostgreSQL
25 | type PostgresQueryFactory struct{}
26 |
27 | func (f *PostgresQueryFactory) GetTablesQueries() []string {
28 | return []string{
29 | // Primary PostgreSQL query using pg_catalog (most reliable)
30 | "SELECT tablename AS table_name FROM pg_catalog.pg_tables WHERE schemaname = 'public'",
31 | // Fallback 1: Using information_schema
32 | "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'",
33 | // Fallback 2: Using pg_class for relations
34 | "SELECT relname AS table_name FROM pg_catalog.pg_class WHERE relkind = 'r' AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')",
35 | }
36 | }
37 |
38 | // MySQLQueryFactory creates queries for MySQL
39 | type MySQLQueryFactory struct{}
40 |
41 | func (f *MySQLQueryFactory) GetTablesQueries() []string {
42 | return []string{
43 | // Primary MySQL query
44 | "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
45 | // Fallback MySQL query
46 | "SHOW TABLES",
47 | }
48 | }
49 |
50 | // GenericQueryFactory creates generic queries for unknown database types
51 | type GenericQueryFactory struct{}
52 |
53 | func (f *GenericQueryFactory) GetTablesQueries() []string {
54 | return []string{
55 | "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'",
56 | "SELECT table_name FROM information_schema.tables",
57 | }
58 | }
59 |
60 | // NewQueryFactory creates the appropriate query factory for the database type
61 | func NewQueryFactory(dbType string) QueryFactory {
62 | switch dbType {
63 | case "postgres":
64 | return &PostgresQueryFactory{}
65 | case "mysql":
66 | return &MySQLQueryFactory{}
67 | default:
68 | logger.Warn("Unknown database type: %s, will use generic query factory", dbType)
69 | return &GenericQueryFactory{}
70 | }
71 | }
72 |
73 | // executeQueriesWithFallback tries multiple queries until one succeeds
74 | func executeQueriesWithFallback(ctx context.Context, db domain.Database, queries []string) (domain.Rows, error) {
75 | var lastErr error
76 | var rows domain.Rows
77 |
78 | for i, query := range queries {
79 | var err error
80 | rows, err = db.Query(ctx, query)
81 | if err == nil {
82 | return rows, nil // Query succeeded
83 | }
84 | lastErr = err
85 | logger.Warn("Query %d failed: %s - Error: %v", i+1, query, err)
86 | }
87 |
88 | // All queries failed
89 | return nil, fmt.Errorf("all queries failed: %w", lastErr)
90 | }
91 |
92 | // DatabaseUseCase defines operations for managing database functionality
93 | type DatabaseUseCase struct {
94 | repo domain.DatabaseRepository
95 | }
96 |
97 | // NewDatabaseUseCase creates a new database use case
98 | func NewDatabaseUseCase(repo domain.DatabaseRepository) *DatabaseUseCase {
99 | return &DatabaseUseCase{
100 | repo: repo,
101 | }
102 | }
103 |
104 | // ListDatabases returns a list of available databases
105 | func (uc *DatabaseUseCase) ListDatabases() []string {
106 | return uc.repo.ListDatabases()
107 | }
108 |
109 | // GetDatabaseInfo returns information about a database
110 | func (uc *DatabaseUseCase) GetDatabaseInfo(dbID string) (map[string]interface{}, error) {
111 | // Get database connection
112 | db, err := uc.repo.GetDatabase(dbID)
113 | if err != nil {
114 | return nil, fmt.Errorf("failed to get database: %w", err)
115 | }
116 |
117 | // Get the database type
118 | dbType, err := uc.repo.GetDatabaseType(dbID)
119 | if err != nil {
120 | return nil, fmt.Errorf("failed to get database type: %w", err)
121 | }
122 |
123 | // Create appropriate query factory based on database type
124 | factory := NewQueryFactory(dbType)
125 |
126 | // Get queries for tables
127 | tableQueries := factory.GetTablesQueries()
128 |
129 | // Execute queries with fallback
130 | ctx := context.Background()
131 | rows, err := executeQueriesWithFallback(ctx, db, tableQueries)
132 | if err != nil {
133 | return nil, fmt.Errorf("failed to get schema information: %w", err)
134 | }
135 |
136 | defer func() {
137 | if closeErr := rows.Close(); closeErr != nil {
138 | logger.Error("error closing rows: %v", closeErr)
139 | }
140 | }()
141 |
142 | // Process results
143 | tables := []map[string]interface{}{}
144 | columns, err := rows.Columns()
145 | if err != nil {
146 | return nil, fmt.Errorf("failed to get column names: %w", err)
147 | }
148 |
149 | // Prepare for scanning
150 | values := make([]interface{}, len(columns))
151 | valuePtrs := make([]interface{}, len(columns))
152 | for i := range columns {
153 | valuePtrs[i] = &values[i]
154 | }
155 |
156 | // Process each row
157 | for rows.Next() {
158 | if err := rows.Scan(valuePtrs...); err != nil {
159 | continue
160 | }
161 |
162 | // Convert to map
163 | tableInfo := make(map[string]interface{})
164 | for i, colName := range columns {
165 | val := values[i]
166 | if val == nil {
167 | tableInfo[colName] = nil
168 | } else {
169 | switch v := val.(type) {
170 | case []byte:
171 | tableInfo[colName] = string(v)
172 | default:
173 | tableInfo[colName] = v
174 | }
175 | }
176 | }
177 | tables = append(tables, tableInfo)
178 | }
179 |
180 | // Create result
181 | result := map[string]interface{}{
182 | "database": dbID,
183 | "dbType": dbType,
184 | "tables": tables,
185 | }
186 |
187 | return result, nil
188 | }
189 |
190 | // ExecuteQuery executes a SQL query and returns the formatted results
191 | func (uc *DatabaseUseCase) ExecuteQuery(ctx context.Context, dbID, query string, params []interface{}) (string, error) {
192 | db, err := uc.repo.GetDatabase(dbID)
193 | if err != nil {
194 | return "", fmt.Errorf("failed to get database: %w", err)
195 | }
196 |
197 | // Execute query
198 | rows, err := db.Query(ctx, query, params...)
199 | if err != nil {
200 | return "", fmt.Errorf("query execution failed: %w", err)
201 | }
202 | defer func() {
203 | if closeErr := rows.Close(); closeErr != nil {
204 | err = fmt.Errorf("error closing rows: %w", closeErr)
205 | }
206 | }()
207 |
208 | // Process results into a readable format
209 | columns, err := rows.Columns()
210 | if err != nil {
211 | return "", fmt.Errorf("failed to get column names: %w", err)
212 | }
213 |
214 | // Format results as text
215 | var resultText strings.Builder
216 | resultText.WriteString("Results:\n\n")
217 | resultText.WriteString(strings.Join(columns, "\t") + "\n")
218 | resultText.WriteString(strings.Repeat("-", 80) + "\n")
219 |
220 | // Prepare for scanning
221 | values := make([]interface{}, len(columns))
222 | valuePtrs := make([]interface{}, len(columns))
223 | for i := range columns {
224 | valuePtrs[i] = &values[i]
225 | }
226 |
227 | // Process rows
228 | rowCount := 0
229 | for rows.Next() {
230 | rowCount++
231 | scanErr := rows.Scan(valuePtrs...)
232 | if scanErr != nil {
233 | return "", fmt.Errorf("failed to scan row: %w", scanErr)
234 | }
235 |
236 | // Convert to strings and print
237 | var rowText []string
238 | for i := range columns {
239 | val := values[i]
240 | if val == nil {
241 | rowText = append(rowText, "NULL")
242 | } else {
243 | switch v := val.(type) {
244 | case []byte:
245 | rowText = append(rowText, string(v))
246 | default:
247 | rowText = append(rowText, fmt.Sprintf("%v", v))
248 | }
249 | }
250 | }
251 | resultText.WriteString(strings.Join(rowText, "\t") + "\n")
252 | }
253 |
254 | if err = rows.Err(); err != nil {
255 | return "", fmt.Errorf("error reading rows: %w", err)
256 | }
257 |
258 | resultText.WriteString(fmt.Sprintf("\nTotal rows: %d", rowCount))
259 | return resultText.String(), nil
260 | }
261 |
262 | // ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE)
263 | func (uc *DatabaseUseCase) ExecuteStatement(ctx context.Context, dbID, statement string, params []interface{}) (string, error) {
264 | db, err := uc.repo.GetDatabase(dbID)
265 | if err != nil {
266 | return "", fmt.Errorf("failed to get database: %w", err)
267 | }
268 |
269 | // Execute statement
270 | result, err := db.Exec(ctx, statement, params...)
271 | if err != nil {
272 | return "", fmt.Errorf("statement execution failed: %w", err)
273 | }
274 |
275 | // Get rows affected
276 | rowsAffected, err := result.RowsAffected()
277 | if err != nil {
278 | rowsAffected = 0
279 | }
280 |
281 | // Get last insert ID (if applicable)
282 | lastInsertID, err := result.LastInsertId()
283 | if err != nil {
284 | lastInsertID = 0
285 | }
286 |
287 | return fmt.Sprintf("Statement executed successfully.\nRows affected: %d\nLast insert ID: %d", rowsAffected, lastInsertID), nil
288 | }
289 |
290 | // ExecuteTransaction executes operations in a transaction
291 | func (uc *DatabaseUseCase) ExecuteTransaction(ctx context.Context, dbID, action string, txID string,
292 | statement string, params []interface{}, readOnly bool) (string, map[string]interface{}, error) {
293 |
294 | switch action {
295 | case "begin":
296 | db, err := uc.repo.GetDatabase(dbID)
297 | if err != nil {
298 | return "", nil, fmt.Errorf("failed to get database: %w", err)
299 | }
300 |
301 | // Start a new transaction
302 | txOpts := &domain.TxOptions{ReadOnly: readOnly}
303 | tx, err := db.Begin(ctx, txOpts)
304 | if err != nil {
305 | return "", nil, fmt.Errorf("failed to start transaction: %w", err)
306 | }
307 |
308 | // In a real implementation, we would store the transaction for later use
309 | // For now, we just commit right away to avoid the unused variable warning
310 | if err := tx.Commit(); err != nil {
311 | return "", nil, fmt.Errorf("failed to commit transaction: %w", err)
312 | }
313 |
314 | // Generate transaction ID
315 | newTxID := fmt.Sprintf("tx_%s_%d", dbID, timeNowUnix())
316 |
317 | return "Transaction started", map[string]interface{}{"transactionId": newTxID}, nil
318 |
319 | case "commit":
320 | // Implement commit logic (would need access to stored transaction)
321 | return "Transaction committed", nil, nil
322 |
323 | case "rollback":
324 | // Implement rollback logic (would need access to stored transaction)
325 | return "Transaction rolled back", nil, nil
326 |
327 | case "execute":
328 | // Implement execute within transaction logic (would need access to stored transaction)
329 | return "Statement executed in transaction", nil, nil
330 |
331 | default:
332 | return "", nil, fmt.Errorf("invalid transaction action: %s", action)
333 | }
334 | }
335 |
336 | // Helper function to get current Unix timestamp
337 | func timeNowUnix() int64 {
338 | return time.Now().Unix()
339 | }
340 |
341 | // GetDatabaseType returns the type of a database by ID
342 | func (uc *DatabaseUseCase) GetDatabaseType(dbID string) (string, error) {
343 | return uc.repo.GetDatabaseType(dbID)
344 | }
345 |
```
--------------------------------------------------------------------------------
/cmd/server/main.go:
--------------------------------------------------------------------------------
```go
1 | package main
2 |
3 | // TODO: Refactor main.go to separate server initialization logic from configuration loading
4 | // TODO: Create dedicated server setup package for better separation of concerns
5 | // TODO: Implement structured logging instead of using standard log package
6 | // TODO: Consider using a configuration management library like Viper for better config handling
7 |
8 | import (
9 | "context"
10 | "flag"
11 | "fmt"
12 | "log"
13 | "os"
14 | "os/signal"
15 | "path/filepath"
16 | "syscall"
17 | "time"
18 |
19 | "github.com/FreePeak/cortex/pkg/server"
20 |
21 | "github.com/FreePeak/db-mcp-server/internal/config"
22 | "github.com/FreePeak/db-mcp-server/internal/delivery/mcp"
23 | "github.com/FreePeak/db-mcp-server/internal/logger"
24 | "github.com/FreePeak/db-mcp-server/internal/repository"
25 | "github.com/FreePeak/db-mcp-server/internal/usecase"
26 | "github.com/FreePeak/db-mcp-server/pkg/dbtools"
27 | pkgLogger "github.com/FreePeak/db-mcp-server/pkg/logger"
28 | )
29 |
30 | // findConfigFile attempts to find config.json in the current directory or parent directories
31 | func findConfigFile() string {
32 | // Default config file name
33 | const defaultConfigFile = "config.json"
34 |
35 | // Check if the file exists in current directory
36 | if _, err := os.Stat(defaultConfigFile); err == nil {
37 | return defaultConfigFile
38 | }
39 |
40 | // Get current working directory
41 | cwd, err := os.Getwd()
42 | if err != nil {
43 | logger.Error("Error getting current directory: %v", err)
44 | return defaultConfigFile
45 | }
46 |
47 | // Try up to 3 parent directories
48 | for i := 0; i < 3; i++ {
49 | cwd = filepath.Dir(cwd)
50 | configPath := filepath.Join(cwd, defaultConfigFile)
51 | if _, err := os.Stat(configPath); err == nil {
52 | return configPath
53 | }
54 | }
55 |
56 | // Fall back to default if not found
57 | return defaultConfigFile
58 | }
59 |
60 | func main() {
61 | // Parse command-line arguments
62 | configFile := flag.String("c", "config.json", "Database configuration file")
63 | configPath := flag.String("config", "config.json", "Database configuration file (alternative)")
64 | transportMode := flag.String("t", "sse", "Transport mode (stdio or sse)")
65 | serverPort := flag.Int("p", 9092, "Server port for SSE transport")
66 | serverHost := flag.String("h", "localhost", "Server host for SSE transport")
67 | dbConfigJSON := flag.String("db-config", "", "JSON string with database configuration")
68 | logLevel := flag.String("log-level", "info", "Log level (debug, info, warn, error)")
69 | flag.Parse()
70 |
71 | // Initialize logger
72 | logger.Initialize(*logLevel)
73 | pkgLogger.Initialize(*logLevel)
74 |
75 | // Prioritize flags with actual values
76 | finalConfigPath := *configFile
77 | if finalConfigPath == "config.json" && *configPath != "config.json" {
78 | finalConfigPath = *configPath
79 | }
80 |
81 | // If no specific config path was provided, try to find a config file
82 | if finalConfigPath == "config.json" {
83 | possibleConfigPath := findConfigFile()
84 | if possibleConfigPath != "config.json" {
85 | logger.Info("Found config file at: %s", possibleConfigPath)
86 | finalConfigPath = possibleConfigPath
87 | }
88 | }
89 |
90 | finalServerPort := *serverPort
91 | // Set environment variables from command line arguments if provided
92 | if finalConfigPath != "config.json" {
93 | if err := os.Setenv("CONFIG_PATH", finalConfigPath); err != nil {
94 | logger.Warn("Warning: failed to set CONFIG_PATH env: %v", err)
95 | }
96 | }
97 | if *transportMode != "sse" {
98 | if err := os.Setenv("TRANSPORT_MODE", *transportMode); err != nil {
99 | logger.Warn("Warning: failed to set TRANSPORT_MODE env: %v", err)
100 | }
101 | }
102 | if finalServerPort != 9092 {
103 | if err := os.Setenv("SERVER_PORT", fmt.Sprintf("%d", finalServerPort)); err != nil {
104 | logger.Warn("Warning: failed to set SERVER_PORT env: %v", err)
105 | }
106 | }
107 | // Set DB_CONFIG environment variable if provided via flag
108 | if *dbConfigJSON != "" {
109 | if err := os.Setenv("DB_CONFIG", *dbConfigJSON); err != nil {
110 | logger.Warn("Warning: failed to set DB_CONFIG env: %v", err)
111 | }
112 | }
113 |
114 | // Load configuration after environment variables are set
115 | cfg, err := config.LoadConfig()
116 | if err != nil {
117 | logger.Warn("Warning: Failed to load configuration: %v", err)
118 | // Create a default config if loading fails
119 | cfg = &config.Config{
120 | ServerPort: finalServerPort,
121 | TransportMode: *transportMode,
122 | ConfigPath: finalConfigPath,
123 | }
124 | }
125 |
126 | // Initialize database connection from config
127 | dbConfig := &dbtools.Config{
128 | ConfigFile: cfg.ConfigPath,
129 | }
130 |
131 | // Ensure database configuration exists
132 | logger.Info("Using database configuration from: %s", cfg.ConfigPath)
133 |
134 | // Try to initialize database from config
135 | if err := dbtools.InitDatabase(dbConfig); err != nil {
136 | logger.Warn("Warning: Failed to initialize database: %v", err)
137 | }
138 |
139 | // Set up signal handling for clean shutdown
140 | stop := make(chan os.Signal, 1)
141 | signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
142 |
143 | // Create mcp-go server with our logger's standard logger (compatibility layer)
144 | mcpServer := server.NewMCPServer(
145 | "DB MCP Server", // Server name
146 | "1.0.0", // Server version
147 | nil, // Use default logger
148 | )
149 |
150 | // Set up Clean Architecture layers
151 | dbRepo := repository.NewDatabaseRepository()
152 | dbUseCase := usecase.NewDatabaseUseCase(dbRepo)
153 | toolRegistry := mcp.NewToolRegistry(mcpServer)
154 |
155 | // Set the database use case in the tool registry
156 | ctx := context.Background()
157 |
158 | // Debug log: Check database connections before registering tools
159 | dbIDs := dbUseCase.ListDatabases()
160 | if len(dbIDs) > 0 {
161 | logger.Info("Detected %d database connections: %v", len(dbIDs), dbIDs)
162 | logger.Info("Will dynamically generate database tools for each connection")
163 | } else {
164 | logger.Info("No database connections detected")
165 | }
166 |
167 | // Register tools
168 | if err := toolRegistry.RegisterAllTools(ctx, dbUseCase); err != nil {
169 | logger.Warn("Warning: error registering tools: %v", err)
170 | // If there was an error registering tools, register mock tools as fallback
171 | logger.Info("Registering mock tools as fallback due to error...")
172 | if err := toolRegistry.RegisterMockTools(ctx); err != nil {
173 | logger.Warn("Warning: error registering mock tools: %v", err)
174 | }
175 | }
176 | logger.Info("Finished registering tools")
177 |
178 | // If we have databases, display the available tools
179 | if len(dbIDs) > 0 {
180 | logger.Info("Available database tools:")
181 | for _, dbID := range dbIDs {
182 | logger.Info(" Database %s:", dbID)
183 | logger.Info(" - query_%s: Execute SQL queries", dbID)
184 | logger.Info(" - execute_%s: Execute SQL statements", dbID)
185 | logger.Info(" - transaction_%s: Manage transactions", dbID)
186 | logger.Info(" - performance_%s: Analyze query performance", dbID)
187 | logger.Info(" - schema_%s: Get database schema", dbID)
188 | }
189 | logger.Info(" Common tools:")
190 | logger.Info(" - list_databases: List all available databases")
191 | }
192 |
193 | // If no database connections, register mock tools to ensure at least some tools are available
194 | if len(dbIDs) == 0 {
195 | logger.Info("No database connections available. Adding mock tools...")
196 | if err := toolRegistry.RegisterMockTools(ctx); err != nil {
197 | logger.Warn("Warning: error registering mock tools: %v", err)
198 | }
199 | }
200 |
201 | // Create a session store to track valid sessions
202 | sessions := make(map[string]bool)
203 |
204 | // Create a default session for easier testing
205 | defaultSessionID := "default-session"
206 | sessions[defaultSessionID] = true
207 | logger.Info("Created default session: %s", defaultSessionID)
208 |
209 | // Handle transport mode
210 | switch cfg.TransportMode {
211 | case "sse":
212 | logger.Info("Starting SSE server on port %d", cfg.ServerPort)
213 |
214 | // Configure base URL with explicit protocol
215 | baseURL := fmt.Sprintf("http://%s:%d", *serverHost, cfg.ServerPort)
216 | logger.Info("Using base URL: %s", baseURL)
217 |
218 | // Set logging mode based on configuration
219 | if cfg.DisableLogging {
220 | logger.Info("Logging in SSE transport is disabled")
221 | // Redirect standard output to null device if logging is disabled
222 | // This only works on Unix-like systems
223 | if err := os.Setenv("MCP_DISABLE_LOGGING", "true"); err != nil {
224 | logger.Warn("Warning: failed to set MCP_DISABLE_LOGGING env: %v", err)
225 | }
226 | }
227 | // Set the server address
228 | mcpServer.SetAddress(fmt.Sprintf(":%d", cfg.ServerPort))
229 |
230 | // Start the server
231 | errCh := make(chan error, 1)
232 | go func() {
233 | logger.Info("Starting server...")
234 | errCh <- mcpServer.ServeHTTP()
235 | }()
236 |
237 | // Wait for interrupt or error
238 | select {
239 | case err := <-errCh:
240 | logger.Error("Server error: %v", err)
241 | os.Exit(1)
242 | case <-stop:
243 | logger.Info("Shutting down server...")
244 |
245 | // Create shutdown context
246 | shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
247 | defer shutdownCancel()
248 |
249 | // Shutdown the server
250 | if err := mcpServer.Shutdown(shutdownCtx); err != nil {
251 | logger.Error("Error during server shutdown: %v", err)
252 | }
253 |
254 | // Close database connections
255 | if err := dbtools.CloseDatabase(); err != nil {
256 | logger.Error("Error closing database connections: %v", err)
257 | }
258 | }
259 |
260 | case "stdio":
261 | // We can only log to stderr in stdio mode - NEVER stdout
262 | fmt.Fprintln(os.Stderr, "Starting STDIO server - all logging redirected to log files")
263 |
264 | // Create logs directory if not exists
265 | logsDir := "logs"
266 | if err := os.MkdirAll(logsDir, 0755); err != nil {
267 | // Can't use logger.Warn as it might go to stdout
268 | fmt.Fprintf(os.Stderr, "Failed to create logs directory: %v\n", err)
269 | }
270 |
271 | // Export environment variables for the stdio server
272 | os.Setenv("MCP_DISABLE_LOGGING", "true")
273 | os.Setenv("DISABLE_LOGGING", "true")
274 |
275 | // Ensure standard logger doesn't output to stdout for any imported libraries
276 | // that use the standard log package
277 | logFile, err := os.OpenFile(filepath.Join(logsDir, "stdio-server.log"), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
278 | if err != nil {
279 | fmt.Fprintf(os.Stderr, "Failed to create stdio server log file: %v\n", err)
280 | } else {
281 | // Redirect all standard logging to the file
282 | log.SetOutput(logFile)
283 | }
284 |
285 | // Critical: Use ServeStdio WITHOUT any console output to stdout
286 | if err := mcpServer.ServeStdio(); err != nil {
287 | // Log error to stderr only - never stdout
288 | fmt.Fprintf(os.Stderr, "STDIO server error: %v\n", err)
289 | os.Exit(1)
290 | }
291 |
292 | default:
293 | logger.Error("Invalid transport mode: %s", cfg.TransportMode)
294 | }
295 |
296 | logger.Info("Server shutdown complete")
297 | }
298 |
```
--------------------------------------------------------------------------------
/docs/TIMESCALEDB_IMPLEMENTATION.md:
--------------------------------------------------------------------------------
```markdown
1 | # TimescaleDB Integration: Engineering Implementation Document
2 |
3 | ## 1. Introduction
4 |
5 | This document provides detailed technical specifications and implementation guidance for integrating TimescaleDB with the DB-MCP-Server. It outlines the architecture, code structures, and specific tasks required to implement the features described in the PRD document.
6 |
7 | ## 2. Technical Background
8 |
9 | ### 2.1 TimescaleDB Overview
10 |
11 | TimescaleDB is an open-source time-series database built as an extension to PostgreSQL. It provides:
12 |
13 | - Automatic partitioning of time-series data ("chunks") for better query performance
14 | - Retention policies for automatic data management
15 | - Compression features for efficient storage
16 | - Continuous aggregates for optimized analytics
17 | - Advanced time-series functions and operators
18 | - Full SQL compatibility with PostgreSQL
19 |
20 | TimescaleDB operates as a transparent extension to PostgreSQL, meaning existing PostgreSQL applications can use TimescaleDB with minimal modifications.
21 |
22 | ### 2.2 Current Architecture
23 |
24 | The DB-MCP-Server currently supports multiple database types through a common interface in the `pkg/db` package. PostgreSQL support is already implemented, which provides a foundation for TimescaleDB integration (as TimescaleDB is a PostgreSQL extension).
25 |
26 | Key components in the existing architecture:
27 |
28 | - `pkg/db/db.go`: Core database interface and implementations
29 | - `Config` struct: Database configuration parameters
30 | - Database connection management
31 | - Query execution functions
32 | - Multi-database support through configuration
33 |
34 | ## 3. Architecture Changes
35 |
36 | ### 3.1 Component Additions
37 |
38 | New components to be added:
39 |
40 | 1. **TimescaleDB Connection Manager**
41 | - Extended PostgreSQL connection with TimescaleDB-specific configuration options
42 | - Support for hypertable management and time-series operations
43 |
44 | 2. **Hypertable Management Tools**
45 | - Tools for creating and managing hypertables
46 | - Functions for configuring chunks, dimensions, and compression
47 |
48 | 3. **Time-Series Query Utilities**
49 | - Functions for building and executing time-series queries
50 | - Support for time bucket operations and continuous aggregates
51 |
52 | 4. **Context Provider**
53 | - Enhanced information about TimescaleDB objects for user code context
54 | - Schema awareness for hypertables
55 |
56 | ### 3.2 Integration Points
57 |
58 | The TimescaleDB integration will hook into the existing system at these points:
59 |
60 | 1. **Configuration System**
61 | - Extend the database configuration to include TimescaleDB-specific options
62 | - Add support for chunk time intervals, retention policies, and compression settings
63 |
64 | 2. **Database Connection Management**
65 | - Extend the PostgreSQL connection to detect and utilize TimescaleDB features
66 | - Register TimescaleDB-specific connection parameters
67 |
68 | 3. **Tool Registry**
69 | - Register new tools for TimescaleDB operations
70 | - Add TimescaleDB-specific functionality to existing PostgreSQL tools
71 |
72 | 4. **Context Engine**
73 | - Add TimescaleDB-specific context information to editor context
74 | - Provide hypertable schema information
75 |
76 | ## 4. Implementation Details
77 |
78 | ### 4.1 Configuration Extensions
79 |
80 | Extend the existing `Config` struct in `pkg/db/db.go` to include TimescaleDB-specific options:
81 |
82 | ### 4.2 Connection Management
83 |
84 | Create a new package `pkg/db/timescale` with TimescaleDB-specific connection management:
85 |
86 | ### 4.3 Hypertable Management
87 |
88 | Create a new file `pkg/db/timescale/hypertable.go` for hypertable management:
89 |
90 | ### 4.4 Time-Series Query Functions
91 |
92 | Create a new file `pkg/db/timescale/query.go` for time-series query utilities:
93 |
94 | ### 4.5 Tool Registration
95 |
96 | Extend the tool registry in `internal/delivery/mcp` to add TimescaleDB-specific tools:
97 |
98 | ### 4.6 Editor Context Integration
99 |
100 | Extend the editor context provider to include TimescaleDB-specific information:
101 |
102 | ## 5. Implementation Tasks
103 |
104 | ### 5.1 Core Infrastructure Tasks
105 |
106 | | Task ID | Description | Estimated Effort | Dependencies | Status |
107 | |---------|-------------|------------------|--------------|--------|
108 | | INFRA-1 | Update database configuration structures for TimescaleDB | 2 days | None | Completed |
109 | | INFRA-2 | Create TimescaleDB connection manager package | 3 days | INFRA-1 | Completed |
110 | | INFRA-3 | Implement hypertable management functions | 3 days | INFRA-2 | Completed |
111 | | INFRA-4 | Implement time-series query builder | 4 days | INFRA-2 | Completed |
112 | | INFRA-5 | Add compression and retention policy management | 2 days | INFRA-3 | Completed |
113 | | INFRA-6 | Create schema detection and metadata functions | 2 days | INFRA-3 | Completed |
114 |
115 | ### 5.2 Tool Integration Tasks
116 |
117 | | Task ID | Description | Estimated Effort | Dependencies | Status |
118 | |---------|-------------|------------------|--------------|--------|
119 | | TOOL-1 | Register TimescaleDB tool category | 1 day | INFRA-2 | Completed |
120 | | TOOL-2 | Implement hypertable creation tool | 2 days | INFRA-3, TOOL-1 | Completed |
121 | | TOOL-3 | Implement hypertable listing tool | 1 day | INFRA-3, TOOL-1 | Completed |
122 | | TOOL-4 | Implement compression policy tools | 2 days | INFRA-5, TOOL-1 | Completed |
123 | | TOOL-5 | Implement retention policy tools | 2 days | INFRA-5, TOOL-1 | Completed |
124 | | TOOL-6 | Implement time-series query tools | 3 days | INFRA-4, TOOL-1 | Completed |
125 | | TOOL-7 | Implement continuous aggregate tools | 3 days | INFRA-3, TOOL-1 | Completed |
126 |
127 | ### 5.3 Context Integration Tasks
128 |
129 | | Task ID | Description | Estimated Effort | Dependencies | Status |
130 | |---------|-------------|------------------|--------------|--------|
131 | | CTX-1 | Add TimescaleDB detection to editor context | 2 days | INFRA-2 | Completed |
132 | | CTX-2 | Add hypertable schema information to context | 2 days | INFRA-3, CTX-1 | Completed |
133 | | CTX-3 | Implement code completion for TimescaleDB functions | 3 days | CTX-1 | Completed |
134 | | CTX-4 | Create documentation for TimescaleDB functions | 3 days | None | Completed |
135 | | CTX-5 | Implement query suggestion features | 4 days | INFRA-4, CTX-2 | Completed |
136 |
137 | ### 5.4 Testing and Documentation Tasks
138 |
139 | | Task ID | Description | Estimated Effort | Dependencies | Status |
140 | |---------|-------------|------------------|--------------|--------|
141 | | TEST-1 | Create TimescaleDB Docker setup for testing | 1 day | None | Completed |
142 | | TEST-2 | Write unit tests for TimescaleDB connection | 2 days | INFRA-2, TEST-1 | Completed |
143 | | TEST-3 | Write integration tests for hypertable management | 2 days | INFRA-3, TEST-1 | Completed |
144 | | TEST-4 | Write tests for time-series query functions | 2 days | INFRA-4, TEST-1 | Completed |
145 | | TEST-5 | Write tests for compression and retention | 2 days | INFRA-5, TEST-1 | Completed |
146 | | TEST-6 | Write end-to-end tests for all tools | 3 days | All TOOL tasks, TEST-1 | Pending |
147 | | DOC-1 | Update configuration documentation | 1 day | INFRA-1 | Pending |
148 | | DOC-2 | Create user guide for TimescaleDB features | 2 days | All TOOL tasks | Pending |
149 | | DOC-3 | Document TimescaleDB best practices | 2 days | All implementation | Pending |
150 | | DOC-4 | Create code samples and tutorials | 3 days | All implementation | Pending |
151 |
152 | ### 5.5 Deployment and Release Tasks
153 |
154 | | Task ID | Description | Estimated Effort | Dependencies | Status |
155 | |---------|-------------|------------------|--------------|--------|
156 | | REL-1 | Create TimescaleDB Docker Compose example | 1 day | TEST-1 | Completed |
157 | | REL-2 | Update CI/CD pipeline for TimescaleDB testing | 1 day | TEST-1 | Pending |
158 | | REL-3 | Create release notes and migration guide | 1 day | All implementation | Pending |
159 | | REL-4 | Performance testing and optimization | 3 days | All implementation | Pending |
160 |
161 | ## 5.6 Implementation Progress Summary
162 |
163 | As of the current codebase status:
164 |
165 | - **Core Infrastructure (100% Complete)**: All core TimescaleDB infrastructure components have been implemented, including configuration structures, connection management, hypertable management, time-series query builder, and policy management.
166 |
167 | - **Tool Integration (100% Complete)**: All TimescaleDB tool types have been registered and implemented. This includes hypertable creation and listing tools, compression and retention policy tools, time-series query tools, and continuous aggregate tools. All tools have comprehensive test coverage.
168 |
169 | - **Context Integration (100% Complete)**: All the context integration features have been implemented, including TimescaleDB detection, hypertable schema information, code completion for TimescaleDB functions, documentation for TimescaleDB functions, and query suggestion features.
170 |
171 | - **Testing (90% Complete)**: Unit tests for connection, hypertable management, policy features, compression and retention policy tools, time-series query tools, continuous aggregate tools, and context features have been implemented. The TimescaleDB Docker setup for testing has been completed. End-to-end tool tests are still pending.
172 |
173 | - **Documentation (25% Complete)**: Documentation for TimescaleDB functions has been created, but documentation for other features, best practices, and usage examples is still pending.
174 |
175 | - **Deployment (25% Complete)**: TimescaleDB Docker setup has been completed and a Docker Compose example is provided. CI/CD integration and performance testing are still pending.
176 |
177 | **Overall Progress**: Approximately 92% of the planned work has been completed, focusing on the core infrastructure layer, tool integration, context integration features, and testing infrastructure. The remaining work is primarily related to comprehensive documentation, end-to-end testing, and CI/CD integration.
178 |
179 | ## 6. Timeline
180 |
181 | Estimated total effort: 65 person-days
182 |
183 | Minimum viable implementation (Phase 1 - Core Features):
184 | - INFRA-1, INFRA-2, INFRA-3, TOOL-1, TOOL-2, TOOL-3, TEST-1, TEST-2, DOC-1
185 | - Timeline: 2-3 weeks
186 |
187 | Complete implementation (All Phases):
188 | - All tasks
189 | - Timeline: 8-10 weeks
190 |
191 | ## 7. Risk Assessment
192 |
193 | | Risk | Impact | Likelihood | Mitigation |
194 | |------|--------|------------|------------|
195 | | TimescaleDB version compatibility issues | High | Medium | Test with multiple versions, clear version requirements |
196 | | Performance impacts with large datasets | High | Medium | Performance testing with representative datasets |
197 | | Complex query builder challenges | Medium | Medium | Start with core functions, expand iteratively |
198 | | Integration with existing PostgreSQL tools | Medium | Low | Clear separation of concerns, thorough testing |
199 | | Security concerns with new database features | High | Low | Security review of all new code, follow established patterns |
```
--------------------------------------------------------------------------------
/pkg/db/timescale/hypertable.go:
--------------------------------------------------------------------------------
```go
1 | package timescale
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strings"
7 | )
8 |
9 | // HypertableConfig defines configuration for creating a hypertable
10 | type HypertableConfig struct {
11 | TableName string
12 | TimeColumn string
13 | ChunkTimeInterval string
14 | PartitioningColumn string
15 | CreateIfNotExists bool
16 | SpacePartitions int // Number of space partitions (for multi-dimensional partitioning)
17 | IfNotExists bool // If true, don't error if table is already a hypertable
18 | MigrateData bool // If true, migrate existing data to chunks
19 | }
20 |
21 | // Hypertable represents a TimescaleDB hypertable
22 | type Hypertable struct {
23 | TableName string
24 | SchemaName string
25 | TimeColumn string
26 | SpaceColumn string
27 | NumDimensions int
28 | CompressionEnabled bool
29 | RetentionEnabled bool
30 | }
31 |
32 | // CreateHypertable converts a regular PostgreSQL table to a TimescaleDB hypertable
33 | func (t *DB) CreateHypertable(ctx context.Context, config HypertableConfig) error {
34 | if !t.isTimescaleDB {
35 | return fmt.Errorf("TimescaleDB extension not available")
36 | }
37 |
38 | // Construct the create_hypertable call
39 | var queryBuilder strings.Builder
40 | queryBuilder.WriteString("SELECT create_hypertable(")
41 |
42 | // Table name and time column are required
43 | queryBuilder.WriteString(fmt.Sprintf("'%s', '%s'", config.TableName, config.TimeColumn))
44 |
45 | // Optional parameters
46 | if config.PartitioningColumn != "" {
47 | queryBuilder.WriteString(fmt.Sprintf(", partition_column => '%s'", config.PartitioningColumn))
48 | }
49 |
50 | if config.ChunkTimeInterval != "" {
51 | queryBuilder.WriteString(fmt.Sprintf(", chunk_time_interval => INTERVAL '%s'", config.ChunkTimeInterval))
52 | }
53 |
54 | if config.SpacePartitions > 0 {
55 | queryBuilder.WriteString(fmt.Sprintf(", number_partitions => %d", config.SpacePartitions))
56 | }
57 |
58 | if config.IfNotExists {
59 | queryBuilder.WriteString(", if_not_exists => TRUE")
60 | }
61 |
62 | if config.MigrateData {
63 | queryBuilder.WriteString(", migrate_data => TRUE")
64 | }
65 |
66 | queryBuilder.WriteString(")")
67 |
68 | // Execute the query
69 | _, err := t.ExecuteSQLWithoutParams(ctx, queryBuilder.String())
70 | if err != nil {
71 | return fmt.Errorf("failed to create hypertable: %w", err)
72 | }
73 |
74 | return nil
75 | }
76 |
77 | // AddDimension adds a new dimension (partitioning key) to a hypertable
78 | func (t *DB) AddDimension(ctx context.Context, tableName, columnName string, numPartitions int) error {
79 | if !t.isTimescaleDB {
80 | return fmt.Errorf("TimescaleDB extension not available")
81 | }
82 |
83 | query := fmt.Sprintf("SELECT add_dimension('%s', '%s', number_partitions => %d)",
84 | tableName, columnName, numPartitions)
85 |
86 | _, err := t.ExecuteSQLWithoutParams(ctx, query)
87 | if err != nil {
88 | return fmt.Errorf("failed to add dimension: %w", err)
89 | }
90 |
91 | return nil
92 | }
93 |
94 | // ListHypertables returns a list of all hypertables in the database
95 | func (t *DB) ListHypertables(ctx context.Context) ([]Hypertable, error) {
96 | if !t.isTimescaleDB {
97 | return nil, fmt.Errorf("TimescaleDB extension not available")
98 | }
99 |
100 | query := `
101 | SELECT h.table_name, h.schema_name, d.column_name as time_column,
102 | count(d.id) as num_dimensions,
103 | (
104 | SELECT column_name FROM _timescaledb_catalog.dimension
105 | WHERE hypertable_id = h.id AND column_type != 'TIMESTAMP'
106 | AND column_type != 'TIMESTAMPTZ'
107 | LIMIT 1
108 | ) as space_column
109 | FROM _timescaledb_catalog.hypertable h
110 | JOIN _timescaledb_catalog.dimension d ON h.id = d.hypertable_id
111 | GROUP BY h.id, h.table_name, h.schema_name
112 | `
113 |
114 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
115 | if err != nil {
116 | return nil, fmt.Errorf("failed to list hypertables: %w", err)
117 | }
118 |
119 | rows, ok := result.([]map[string]interface{})
120 | if !ok {
121 | return nil, fmt.Errorf("unexpected result type from database query")
122 | }
123 |
124 | var hypertables []Hypertable
125 | for _, row := range rows {
126 | ht := Hypertable{
127 | TableName: fmt.Sprintf("%v", row["table_name"]),
128 | SchemaName: fmt.Sprintf("%v", row["schema_name"]),
129 | TimeColumn: fmt.Sprintf("%v", row["time_column"]),
130 | }
131 |
132 | // Handle nullable columns
133 | if row["space_column"] != nil {
134 | ht.SpaceColumn = fmt.Sprintf("%v", row["space_column"])
135 | }
136 |
137 | // Convert numeric dimensions
138 | if dimensions, ok := row["num_dimensions"].(int64); ok {
139 | ht.NumDimensions = int(dimensions)
140 | } else if dimensions, ok := row["num_dimensions"].(int); ok {
141 | ht.NumDimensions = dimensions
142 | }
143 |
144 | // Check if compression is enabled
145 | compQuery := fmt.Sprintf(
146 | "SELECT count(*) > 0 as is_compressed FROM timescaledb_information.compression_settings WHERE hypertable_name = '%s'",
147 | ht.TableName,
148 | )
149 | compResult, err := t.ExecuteSQLWithoutParams(ctx, compQuery)
150 | if err == nil {
151 | if compRows, ok := compResult.([]map[string]interface{}); ok && len(compRows) > 0 {
152 | if isCompressed, ok := compRows[0]["is_compressed"].(bool); ok {
153 | ht.CompressionEnabled = isCompressed
154 | }
155 | }
156 | }
157 |
158 | // Check if retention policy is enabled
159 | retQuery := fmt.Sprintf(
160 | "SELECT count(*) > 0 as has_retention FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'",
161 | ht.TableName,
162 | )
163 | retResult, err := t.ExecuteSQLWithoutParams(ctx, retQuery)
164 | if err == nil {
165 | if retRows, ok := retResult.([]map[string]interface{}); ok && len(retRows) > 0 {
166 | if hasRetention, ok := retRows[0]["has_retention"].(bool); ok {
167 | ht.RetentionEnabled = hasRetention
168 | }
169 | }
170 | }
171 |
172 | hypertables = append(hypertables, ht)
173 | }
174 |
175 | return hypertables, nil
176 | }
177 |
178 | // GetHypertable gets information about a specific hypertable
179 | func (t *DB) GetHypertable(ctx context.Context, tableName string) (*Hypertable, error) {
180 | if !t.isTimescaleDB {
181 | return nil, fmt.Errorf("TimescaleDB extension not available")
182 | }
183 |
184 | query := fmt.Sprintf(`
185 | SELECT h.table_name, h.schema_name, d.column_name as time_column,
186 | count(d.id) as num_dimensions,
187 | (
188 | SELECT column_name FROM _timescaledb_catalog.dimension
189 | WHERE hypertable_id = h.id AND column_type != 'TIMESTAMP'
190 | AND column_type != 'TIMESTAMPTZ'
191 | LIMIT 1
192 | ) as space_column
193 | FROM _timescaledb_catalog.hypertable h
194 | JOIN _timescaledb_catalog.dimension d ON h.id = d.hypertable_id
195 | WHERE h.table_name = '%s'
196 | GROUP BY h.id, h.table_name, h.schema_name
197 | `, tableName)
198 |
199 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
200 | if err != nil {
201 | return nil, fmt.Errorf("failed to get hypertable information: %w", err)
202 | }
203 |
204 | rows, ok := result.([]map[string]interface{})
205 | if !ok || len(rows) == 0 {
206 | return nil, fmt.Errorf("table '%s' is not a hypertable", tableName)
207 | }
208 |
209 | row := rows[0]
210 | ht := &Hypertable{
211 | TableName: fmt.Sprintf("%v", row["table_name"]),
212 | SchemaName: fmt.Sprintf("%v", row["schema_name"]),
213 | TimeColumn: fmt.Sprintf("%v", row["time_column"]),
214 | }
215 |
216 | // Handle nullable columns
217 | if row["space_column"] != nil {
218 | ht.SpaceColumn = fmt.Sprintf("%v", row["space_column"])
219 | }
220 |
221 | // Convert numeric dimensions
222 | if dimensions, ok := row["num_dimensions"].(int64); ok {
223 | ht.NumDimensions = int(dimensions)
224 | } else if dimensions, ok := row["num_dimensions"].(int); ok {
225 | ht.NumDimensions = dimensions
226 | }
227 |
228 | // Check if compression is enabled
229 | compQuery := fmt.Sprintf(
230 | "SELECT count(*) > 0 as is_compressed FROM timescaledb_information.compression_settings WHERE hypertable_name = '%s'",
231 | ht.TableName,
232 | )
233 | compResult, err := t.ExecuteSQLWithoutParams(ctx, compQuery)
234 | if err == nil {
235 | if compRows, ok := compResult.([]map[string]interface{}); ok && len(compRows) > 0 {
236 | if isCompressed, ok := compRows[0]["is_compressed"].(bool); ok {
237 | ht.CompressionEnabled = isCompressed
238 | }
239 | }
240 | }
241 |
242 | // Check if retention policy is enabled
243 | retQuery := fmt.Sprintf(
244 | "SELECT count(*) > 0 as has_retention FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'",
245 | ht.TableName,
246 | )
247 | retResult, err := t.ExecuteSQLWithoutParams(ctx, retQuery)
248 | if err == nil {
249 | if retRows, ok := retResult.([]map[string]interface{}); ok && len(retRows) > 0 {
250 | if hasRetention, ok := retRows[0]["has_retention"].(bool); ok {
251 | ht.RetentionEnabled = hasRetention
252 | }
253 | }
254 | }
255 |
256 | return ht, nil
257 | }
258 |
259 | // DropHypertable drops a hypertable
260 | func (t *DB) DropHypertable(ctx context.Context, tableName string, cascade bool) error {
261 | if !t.isTimescaleDB {
262 | return fmt.Errorf("TimescaleDB extension not available")
263 | }
264 |
265 | // Build the DROP TABLE query
266 | query := fmt.Sprintf("DROP TABLE %s", tableName)
267 | if cascade {
268 | query += " CASCADE"
269 | }
270 |
271 | _, err := t.ExecuteSQLWithoutParams(ctx, query)
272 | if err != nil {
273 | return fmt.Errorf("failed to drop hypertable: %w", err)
274 | }
275 |
276 | return nil
277 | }
278 |
279 | // CheckIfHypertable checks if a table is a hypertable
280 | func (t *DB) CheckIfHypertable(ctx context.Context, tableName string) (bool, error) {
281 | if !t.isTimescaleDB {
282 | return false, fmt.Errorf("TimescaleDB extension not available")
283 | }
284 |
285 | query := fmt.Sprintf(`
286 | SELECT EXISTS (
287 | SELECT 1
288 | FROM _timescaledb_catalog.hypertable
289 | WHERE table_name = '%s'
290 | ) as is_hypertable
291 | `, tableName)
292 |
293 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
294 | if err != nil {
295 | return false, fmt.Errorf("failed to check if table is a hypertable: %w", err)
296 | }
297 |
298 | rows, ok := result.([]map[string]interface{})
299 | if !ok || len(rows) == 0 {
300 | return false, fmt.Errorf("unexpected result from database query")
301 | }
302 |
303 | isHypertable, ok := rows[0]["is_hypertable"].(bool)
304 | if !ok {
305 | return false, fmt.Errorf("unexpected result type for is_hypertable")
306 | }
307 |
308 | return isHypertable, nil
309 | }
310 |
311 | // RecentChunks returns information about the most recent chunks
312 | func (t *DB) RecentChunks(ctx context.Context, tableName string, limit int) (interface{}, error) {
313 | if !t.isTimescaleDB {
314 | return nil, fmt.Errorf("TimescaleDB extension not available")
315 | }
316 |
317 | // Use default limit of 10 if not specified
318 | if limit <= 0 {
319 | limit = 10
320 | }
321 |
322 | query := fmt.Sprintf(`
323 | SELECT
324 | chunk_name,
325 | range_start,
326 | range_end
327 | FROM timescaledb_information.chunks
328 | WHERE hypertable_name = '%s'
329 | ORDER BY range_end DESC
330 | LIMIT %d
331 | `, tableName, limit)
332 |
333 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
334 | if err != nil {
335 | return nil, fmt.Errorf("failed to get recent chunks: %w", err)
336 | }
337 |
338 | return result, nil
339 | }
340 |
341 | // CreateHypertable is a helper function to create a hypertable with the given configuration options.
342 | // This is exported for use by other packages.
343 | func CreateHypertable(ctx context.Context, db *DB, table, timeColumn string, opts ...HypertableOption) error {
344 | config := HypertableConfig{
345 | TableName: table,
346 | TimeColumn: timeColumn,
347 | }
348 |
349 | // Apply all options
350 | for _, opt := range opts {
351 | opt(&config)
352 | }
353 |
354 | return db.CreateHypertable(ctx, config)
355 | }
356 |
357 | // HypertableOption is a functional option for CreateHypertable
358 | type HypertableOption func(*HypertableConfig)
359 |
360 | // WithChunkInterval sets the chunk time interval
361 | func WithChunkInterval(interval string) HypertableOption {
362 | return func(config *HypertableConfig) {
363 | config.ChunkTimeInterval = interval
364 | }
365 | }
366 |
367 | // WithPartitioningColumn sets the space partitioning column
368 | func WithPartitioningColumn(column string) HypertableOption {
369 | return func(config *HypertableConfig) {
370 | config.PartitioningColumn = column
371 | }
372 | }
373 |
374 | // WithIfNotExists sets the if_not_exists flag
375 | func WithIfNotExists(ifNotExists bool) HypertableOption {
376 | return func(config *HypertableConfig) {
377 | config.IfNotExists = ifNotExists
378 | }
379 | }
380 |
381 | // WithMigrateData sets the migrate_data flag
382 | func WithMigrateData(migrateData bool) HypertableOption {
383 | return func(config *HypertableConfig) {
384 | config.MigrateData = migrateData
385 | }
386 | }
387 |
```
--------------------------------------------------------------------------------
/pkg/dbtools/performance.go:
--------------------------------------------------------------------------------
```go
1 | package dbtools
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "regexp"
7 | "strings"
8 | "time"
9 |
10 | "github.com/FreePeak/db-mcp-server/pkg/logger"
11 | )
12 |
13 | // QueryMetrics stores performance metrics for a database query
14 | type QueryMetrics struct {
15 | Query string // SQL query text
16 | Count int // Number of times the query was executed
17 | TotalDuration time.Duration // Total execution time
18 | MinDuration time.Duration // Minimum execution time
19 | MaxDuration time.Duration // Maximum execution time
20 | AvgDuration time.Duration // Average execution time
21 | LastExecuted time.Time // When the query was last executed
22 | }
23 |
24 | // PerformanceAnalyzer tracks query performance and provides optimization suggestions
25 | type PerformanceAnalyzer struct {
26 | slowThreshold time.Duration
27 | queryHistory []QueryRecord
28 | maxHistory int
29 | }
30 |
31 | // QueryRecord stores information about a query execution
32 | type QueryRecord struct {
33 | Query string `json:"query"`
34 | Params []interface{} `json:"params"`
35 | Duration time.Duration `json:"duration"`
36 | StartTime time.Time `json:"startTime"`
37 | Error string `json:"error,omitempty"`
38 | Optimized bool `json:"optimized"`
39 | Suggestion string `json:"suggestion,omitempty"`
40 | }
41 |
42 | // SQLIssueDetector detects potential issues in SQL queries
43 | type SQLIssueDetector struct {
44 | patterns map[string]*regexp.Regexp
45 | }
46 |
47 | // singleton instance
48 | var performanceAnalyzer *PerformanceAnalyzer
49 |
50 | // GetPerformanceAnalyzer returns the singleton performance analyzer
51 | func GetPerformanceAnalyzer() *PerformanceAnalyzer {
52 | if performanceAnalyzer == nil {
53 | performanceAnalyzer = NewPerformanceAnalyzer()
54 | }
55 | return performanceAnalyzer
56 | }
57 |
58 | // NewPerformanceAnalyzer creates a new performance analyzer
59 | func NewPerformanceAnalyzer() *PerformanceAnalyzer {
60 | return &PerformanceAnalyzer{
61 | slowThreshold: 500 * time.Millisecond, // Default: 500ms
62 | queryHistory: make([]QueryRecord, 0),
63 | maxHistory: 100, // Default: store last 100 queries
64 | }
65 | }
66 |
67 | // LogSlowQuery logs a warning if a query takes longer than the slow query threshold
68 | func (pa *PerformanceAnalyzer) LogSlowQuery(query string, params []interface{}, duration time.Duration) {
69 | if duration >= pa.slowThreshold {
70 | paramStr := formatParams(params)
71 | logger.Warn("Slow query detected (%.2fms): %s [params: %s]",
72 | float64(duration.Microseconds())/1000.0,
73 | query,
74 | paramStr)
75 | }
76 | }
77 |
78 | // TrackQuery tracks the execution of a query and logs slow queries
79 | func (pa *PerformanceAnalyzer) TrackQuery(ctx context.Context, query string, params []interface{}, exec func() (interface{}, error)) (interface{}, error) {
80 | startTime := time.Now()
81 | result, err := exec()
82 | duration := time.Since(startTime)
83 |
84 | // Create query record
85 | record := QueryRecord{
86 | Query: query,
87 | Params: params,
88 | Duration: duration,
89 | StartTime: startTime,
90 | }
91 |
92 | // Check if query is slow
93 | if duration >= pa.slowThreshold {
94 | pa.LogSlowQuery(query, params, duration)
95 | record.Suggestion = "Query execution time exceeds threshold"
96 | }
97 |
98 | // Record error if any
99 | if err != nil {
100 | record.Error = err.Error()
101 | }
102 |
103 | // Add to history (keeping max size)
104 | pa.queryHistory = append(pa.queryHistory, record)
105 | if len(pa.queryHistory) > pa.maxHistory {
106 | pa.queryHistory = pa.queryHistory[1:]
107 | }
108 |
109 | return result, err
110 | }
111 |
112 | // SQLIssueDetector methods
113 |
114 | // NewSQLIssueDetector creates a new SQL issue detector
115 | func NewSQLIssueDetector() *SQLIssueDetector {
116 | detector := &SQLIssueDetector{
117 | patterns: make(map[string]*regexp.Regexp),
118 | }
119 |
120 | // Add known issue patterns
121 | detector.AddPattern("cartesian-join", `SELECT.*FROM\s+(\w+)\s*,\s*(\w+)`)
122 | detector.AddPattern("select-star", `SELECT\s+\*\s+FROM`)
123 | detector.AddPattern("missing-where", `(DELETE\s+FROM|UPDATE)\s+\w+\s+(?:SET\s+(?:\w+\s*=\s*[^,]+)(?:\s*,\s*\w+\s*=\s*[^,]+)*\s*)*(;|\z)`)
124 | detector.AddPattern("or-in-where", `WHERE.*\s+OR\s+`)
125 | detector.AddPattern("in-with-many-items", `IN\s*\(\s*(?:'[^']*'\s*,\s*){10,}`)
126 | detector.AddPattern("not-in", `NOT\s+IN\s*\(`)
127 | detector.AddPattern("is-null", `IS\s+NULL`)
128 | detector.AddPattern("function-on-column", `WHERE\s+\w+\s*\(\s*\w+\s*\)`)
129 | detector.AddPattern("order-by-rand", `ORDER\s+BY\s+RAND\(\)`)
130 | detector.AddPattern("group-by-number", `GROUP\s+BY\s+\d+`)
131 | detector.AddPattern("having-without-group", `HAVING.*(?:(?:GROUP\s+BY.*$)|(?:$))`)
132 |
133 | return detector
134 | }
135 |
136 | // AddPattern adds a pattern for detecting SQL issues
137 | func (d *SQLIssueDetector) AddPattern(name, pattern string) {
138 | re, err := regexp.Compile("(?i)" + pattern)
139 | if err != nil {
140 | logger.Error("Error compiling regex pattern '%s': %v", pattern, err)
141 | return
142 | }
143 | d.patterns[name] = re
144 | }
145 |
146 | // DetectIssues detects issues in a SQL query
147 | func (d *SQLIssueDetector) DetectIssues(query string) map[string]string {
148 | issues := make(map[string]string)
149 |
150 | for name, pattern := range d.patterns {
151 | if pattern.MatchString(query) {
152 | issues[name] = d.getSuggestionForIssue(name)
153 | }
154 | }
155 |
156 | return issues
157 | }
158 |
159 | // getSuggestionForIssue returns a suggestion for a detected issue
160 | func (d *SQLIssueDetector) getSuggestionForIssue(issue string) string {
161 | suggestions := map[string]string{
162 | "cartesian-join": "Use explicit JOIN statements instead of comma-syntax joins to avoid unintended Cartesian products.",
163 | "select-star": "Specify exact columns needed instead of SELECT * to reduce network traffic and improve query execution.",
164 | "missing-where": "Add a WHERE clause to avoid affecting all rows in the table.",
165 | "or-in-where": "Consider using IN instead of multiple OR conditions for better performance.",
166 | "in-with-many-items": "Too many items in IN clause; consider a temporary table or a JOIN instead.",
167 | "not-in": "NOT IN with subqueries can be slow. Consider using NOT EXISTS or LEFT JOIN/IS NULL pattern.",
168 | "is-null": "IS NULL conditions prevent index usage. Consider redesigning to avoid NULL values if possible.",
169 | "function-on-column": "Applying functions to columns in WHERE clauses prevents index usage. Restructure if possible.",
170 | "order-by-rand": "ORDER BY RAND() causes full table scan and sort. Consider alternative randomization methods.",
171 | "group-by-number": "Using column position numbers in GROUP BY can be error-prone. Use explicit column names.",
172 | "having-without-group": "HAVING without GROUP BY may indicate a logical error in query structure.",
173 | }
174 |
175 | if suggestion, ok := suggestions[issue]; ok {
176 | return suggestion
177 | }
178 |
179 | return "Potential issue detected; review query for optimization opportunities."
180 | }
181 |
182 | // Helper functions
183 |
184 | // formatParams converts query parameters to a readable string format
185 | func formatParams(params []interface{}) string {
186 | if len(params) == 0 {
187 | return "none"
188 | }
189 |
190 | paramStrings := make([]string, len(params))
191 | for i, param := range params {
192 | if param == nil {
193 | paramStrings[i] = "NULL"
194 | } else {
195 | paramStrings[i] = fmt.Sprintf("%v", param)
196 | }
197 | }
198 |
199 | return "[" + strings.Join(paramStrings, ", ") + "]"
200 | }
201 |
202 | // Placeholder for future implementation
203 | // Currently not used but kept for reference
204 | /*
205 | func handlePerformanceAnalyzer(ctx context.Context, params map[string]interface{}) (interface{}, error) {
206 | // This function will be implemented in the future
207 | return nil, fmt.Errorf("not implemented")
208 | }
209 | */
210 |
211 | // StripComments removes SQL comments from a query string
212 | func StripComments(input string) string {
213 | // Strip /* ... */ comments
214 | multiLineRe, err := regexp.Compile(`/\*[\s\S]*?\*/`)
215 | if err != nil {
216 | // If there's an error with the regex, just return the input
217 | logger.Error("Error compiling regex pattern '%s': %v", `\/\*[\s\S]*?\*\/`, err)
218 | return input
219 | }
220 | withoutMultiLine := multiLineRe.ReplaceAllString(input, "")
221 |
222 | // Strip -- comments
223 | singleLineRe, err := regexp.Compile(`--.*$`)
224 | if err != nil {
225 | return withoutMultiLine
226 | }
227 | return singleLineRe.ReplaceAllString(withoutMultiLine, "")
228 | }
229 |
230 | // GetAllMetrics returns all collected metrics
231 | func (pa *PerformanceAnalyzer) GetAllMetrics() []*QueryMetrics {
232 | // Group query history by normalized query text
233 | queryMap := make(map[string]*QueryMetrics)
234 |
235 | for _, record := range pa.queryHistory {
236 | normalizedQuery := normalizeQuery(record.Query)
237 |
238 | metrics, exists := queryMap[normalizedQuery]
239 | if !exists {
240 | metrics = &QueryMetrics{
241 | Query: record.Query,
242 | MinDuration: record.Duration,
243 | MaxDuration: record.Duration,
244 | LastExecuted: record.StartTime,
245 | }
246 | queryMap[normalizedQuery] = metrics
247 | }
248 |
249 | // Update metrics
250 | metrics.Count++
251 | metrics.TotalDuration += record.Duration
252 |
253 | if record.Duration < metrics.MinDuration {
254 | metrics.MinDuration = record.Duration
255 | }
256 |
257 | if record.Duration > metrics.MaxDuration {
258 | metrics.MaxDuration = record.Duration
259 | }
260 |
261 | if record.StartTime.After(metrics.LastExecuted) {
262 | metrics.LastExecuted = record.StartTime
263 | }
264 | }
265 |
266 | // Calculate averages and convert to slice
267 | metrics := make([]*QueryMetrics, 0, len(queryMap))
268 | for _, m := range queryMap {
269 | m.AvgDuration = time.Duration(int64(m.TotalDuration) / int64(m.Count))
270 | metrics = append(metrics, m)
271 | }
272 |
273 | return metrics
274 | }
275 |
276 | // Reset clears all collected metrics
277 | func (pa *PerformanceAnalyzer) Reset() {
278 | pa.queryHistory = make([]QueryRecord, 0)
279 | }
280 |
281 | // GetSlowThreshold returns the current slow query threshold
282 | func (pa *PerformanceAnalyzer) GetSlowThreshold() time.Duration {
283 | return pa.slowThreshold
284 | }
285 |
286 | // SetSlowThreshold sets the slow query threshold
287 | func (pa *PerformanceAnalyzer) SetSlowThreshold(threshold time.Duration) {
288 | pa.slowThreshold = threshold
289 | }
290 |
291 | // AnalyzeQuery analyzes a SQL query and returns optimization suggestions
292 | func AnalyzeQuery(query string) []string {
293 | // Create detector and get suggestions
294 | detector := NewSQLIssueDetector()
295 | issues := detector.DetectIssues(query)
296 |
297 | suggestions := make([]string, 0, len(issues))
298 | for _, suggestion := range issues {
299 | suggestions = append(suggestions, suggestion)
300 | }
301 |
302 | // Add default suggestions for query patterns
303 | if strings.Contains(strings.ToUpper(query), "SELECT *") {
304 | suggestions = append(suggestions, "Avoid using SELECT * - specify only the columns you need")
305 | }
306 |
307 | if !strings.Contains(strings.ToUpper(query), "WHERE") &&
308 | !strings.Contains(strings.ToUpper(query), "JOIN") {
309 | suggestions = append(suggestions, "Consider adding a WHERE clause to limit the result set")
310 | }
311 |
312 | if strings.Contains(strings.ToUpper(query), "JOIN") &&
313 | !strings.Contains(strings.ToUpper(query), "ON") {
314 | suggestions = append(suggestions, "Ensure all JOINs have proper conditions")
315 | }
316 |
317 | if strings.Contains(strings.ToUpper(query), "ORDER BY") {
318 | suggestions = append(suggestions, "Verify that ORDER BY columns are properly indexed")
319 | }
320 |
321 | if strings.Contains(query, "(SELECT") {
322 | suggestions = append(suggestions, "Consider replacing subqueries with JOINs where possible")
323 | }
324 |
325 | return suggestions
326 | }
327 |
328 | // normalizeQuery standardizes SQL queries for comparison by replacing literals
329 | func normalizeQuery(query string) string {
330 | // Trim and normalize whitespace
331 | query = strings.TrimSpace(query)
332 | wsRegex := regexp.MustCompile(`\s+`)
333 | query = wsRegex.ReplaceAllString(query, " ")
334 |
335 | // Replace numeric literals
336 | numRegex := regexp.MustCompile(`\b\d+\b`)
337 | query = numRegex.ReplaceAllString(query, "?")
338 |
339 | // Replace string literals in single quotes
340 | strRegex := regexp.MustCompile(`'[^']*'`)
341 | query = strRegex.ReplaceAllString(query, "'?'")
342 |
343 | // Replace string literals in double quotes
344 | dblQuoteRegex := regexp.MustCompile(`"[^"]*"`)
345 | query = dblQuoteRegex.ReplaceAllString(query, "\"?\"")
346 |
347 | return query
348 | }
349 |
350 | // TODO: Implement more sophisticated performance metrics and query analysis
351 | // TODO: Add support for query plan visualization
352 | // TODO: Consider using time-series storage for long-term performance tracking
353 | // TODO: Implement anomaly detection for query performance
354 | // TODO: Add integration with external monitoring systems
355 | // TODO: Implement periodic background performance analysis
356 |
```
--------------------------------------------------------------------------------
/pkg/db/timescale/mocks_test.go:
--------------------------------------------------------------------------------
```go
1 | package timescale
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "database/sql/driver"
7 | "fmt"
8 | "io"
9 | "strings"
10 | "testing"
11 | )
12 |
13 | // MockDB simulates a database for testing purposes
14 | type MockDB struct {
15 | mockResults map[string]MockExecuteResult
16 | lastQuery string
17 | lastQueryArgs []interface{}
18 | queryHistory []string
19 | connectCalled bool
20 | connectError error
21 | closeCalled bool
22 | closeError error
23 | isTimescaleDB bool
24 | // Store the expected scan value for QueryRow
25 | queryScanValues map[string]interface{}
26 | // Added for additional mock methods
27 | queryResult []map[string]interface{}
28 | err error
29 | }
30 |
31 | // MockExecuteResult is used for mocking query results
32 | type MockExecuteResult struct {
33 | Result interface{}
34 | Error error
35 | }
36 |
37 | // NewMockDB creates a new mock database for testing
38 | func NewMockDB() *MockDB {
39 | return &MockDB{
40 | mockResults: make(map[string]MockExecuteResult),
41 | queryScanValues: make(map[string]interface{}),
42 | queryHistory: make([]string, 0),
43 | isTimescaleDB: true, // Default is true
44 | }
45 | }
46 |
47 | // RegisterQueryResult registers a result to be returned for a specific query
48 | func (m *MockDB) RegisterQueryResult(query string, result interface{}, err error) {
49 | // Store the result for exact matching
50 | m.mockResults[query] = MockExecuteResult{
51 | Result: result,
52 | Error: err,
53 | }
54 |
55 | // Also store the result for partial matching
56 | if result != nil || err != nil {
57 | m.mockResults["partial:"+query] = MockExecuteResult{
58 | Result: result,
59 | Error: err,
60 | }
61 | }
62 |
63 | // Also store the result as a scan value for QueryRow
64 | m.queryScanValues[query] = result
65 | }
66 |
67 | // getQueryResult tries to find a matching result for a query
68 | // First tries exact match, then partial match
69 | func (m *MockDB) getQueryResult(query string) (MockExecuteResult, bool) {
70 | // Try exact match first
71 | if result, ok := m.mockResults[query]; ok {
72 | return result, true
73 | }
74 |
75 | // Try partial match
76 | for k, v := range m.mockResults {
77 | if strings.HasPrefix(k, "partial:") && strings.Contains(query, k[8:]) {
78 | return v, true
79 | }
80 | }
81 |
82 | return MockExecuteResult{}, false
83 | }
84 |
85 | // GetLastQuery returns the last executed query and args
86 | func (m *MockDB) GetLastQuery() (string, []interface{}) {
87 | return m.lastQuery, m.lastQueryArgs
88 | }
89 |
90 | // Connect implements db.Database.Connect
91 | func (m *MockDB) Connect() error {
92 | m.connectCalled = true
93 | return m.connectError
94 | }
95 |
96 | // SetConnectError sets an error to be returned from Connect()
97 | func (m *MockDB) SetConnectError(err error) {
98 | m.connectError = err
99 | }
100 |
101 | // Close implements db.Database.Close
102 | func (m *MockDB) Close() error {
103 | m.closeCalled = true
104 | return m.closeError
105 | }
106 |
107 | // SetCloseError sets an error to be returned from Close()
108 | func (m *MockDB) SetCloseError(err error) {
109 | m.closeError = err
110 | }
111 |
112 | // SetTimescaleAvailable sets whether TimescaleDB is available for this mock
113 | func (m *MockDB) SetTimescaleAvailable(available bool) {
114 | m.isTimescaleDB = available
115 | }
116 |
117 | // Exec implements db.Database.Exec
118 | func (m *MockDB) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
119 | m.lastQuery = query
120 | m.lastQueryArgs = args
121 |
122 | if result, found := m.getQueryResult(query); found {
123 | if result.Error != nil {
124 | return nil, result.Error
125 | }
126 | if sqlResult, ok := result.Result.(sql.Result); ok {
127 | return sqlResult, nil
128 | }
129 | }
130 |
131 | return &MockResult{}, nil
132 | }
133 |
134 | // Query implements db.Database.Query
135 | func (m *MockDB) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
136 | m.lastQuery = query
137 | m.lastQueryArgs = args
138 |
139 | if result, found := m.getQueryResult(query); found {
140 | if result.Error != nil {
141 | return nil, result.Error
142 | }
143 | if rows, ok := result.Result.(*sql.Rows); ok {
144 | return rows, nil
145 | }
146 | }
147 |
148 | // Create a MockRows for the test
149 | return sql.OpenDB(&MockConnector{mockDB: m, query: query}).Query(query)
150 | }
151 |
152 | // QueryRow implements db.Database.QueryRow
153 | func (m *MockDB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row {
154 | m.lastQuery = query
155 | m.lastQueryArgs = args
156 |
157 | // Use a custom connector to create a sql.DB that's backed by our mock
158 | db := sql.OpenDB(&MockConnector{mockDB: m, query: query})
159 | return db.QueryRow(query)
160 | }
161 |
162 | // BeginTx implements db.Database.BeginTx
163 | func (m *MockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
164 | return nil, nil
165 | }
166 |
167 | // Prepare implements db.Database.Prepare
168 | func (m *MockDB) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
169 | return nil, nil
170 | }
171 |
172 | // Ping implements db.Database.Ping
173 | func (m *MockDB) Ping(ctx context.Context) error {
174 | return nil
175 | }
176 |
177 | // DB implements db.Database.DB
178 | func (m *MockDB) DB() *sql.DB {
179 | return nil
180 | }
181 |
182 | // ConnectionString implements db.Database.ConnectionString
183 | func (m *MockDB) ConnectionString() string {
184 | return "mock://localhost/testdb"
185 | }
186 |
187 | // DriverName implements db.Database.DriverName
188 | func (m *MockDB) DriverName() string {
189 | return "postgres"
190 | }
191 |
192 | // QueryTimeout implements db.Database.QueryTimeout
193 | func (m *MockDB) QueryTimeout() int {
194 | return 30
195 | }
196 |
197 | // MockResult implements sql.Result
198 | type MockResult struct{}
199 |
200 | // LastInsertId implements sql.Result.LastInsertId
201 | func (r *MockResult) LastInsertId() (int64, error) {
202 | return 1, nil
203 | }
204 |
205 | // RowsAffected implements sql.Result.RowsAffected
206 | func (r *MockResult) RowsAffected() (int64, error) {
207 | return 1, nil
208 | }
209 |
210 | // MockTimescaleDB creates a TimescaleDB instance with mocked database for testing
211 | func MockTimescaleDB(t testing.TB) (*DB, *MockDB) {
212 | mockDB := NewMockDB()
213 | mockDB.SetTimescaleAvailable(true)
214 |
215 | tsdb := &DB{
216 | Database: mockDB,
217 | isTimescaleDB: true,
218 | extVersion: "2.8.0",
219 | config: DBConfig{
220 | UseTimescaleDB: true,
221 | },
222 | }
223 | return tsdb, mockDB
224 | }
225 |
226 | // AssertQueryContains checks if the query contains the expected substrings
227 | func AssertQueryContains(t testing.TB, query string, substrings ...string) {
228 | for _, substring := range substrings {
229 | if !contains(query, substring) {
230 | t.Errorf("Expected query to contain '%s', but got: %s", substring, query)
231 | }
232 | }
233 | }
234 |
235 | // Helper function to check if a string contains another string
236 | func contains(s, substr string) bool {
237 | return strings.Contains(s, substr)
238 | }
239 |
240 | // Mock driver implementation to support sql.Row
241 |
242 | // MockConnector implements driver.Connector
243 | type MockConnector struct {
244 | mockDB *MockDB
245 | query string
246 | }
247 |
248 | // Connect implements driver.Connector
249 | func (c *MockConnector) Connect(_ context.Context) (driver.Conn, error) {
250 | return &MockConn{mockDB: c.mockDB, query: c.query}, nil
251 | }
252 |
253 | // Driver implements driver.Connector
254 | func (c *MockConnector) Driver() driver.Driver {
255 | return &MockDriver{}
256 | }
257 |
258 | // MockDriver implements driver.Driver
259 | type MockDriver struct{}
260 |
261 | // Open implements driver.Driver
262 | func (d *MockDriver) Open(name string) (driver.Conn, error) {
263 | return &MockConn{}, nil
264 | }
265 |
266 | // MockConn implements driver.Conn
267 | type MockConn struct {
268 | mockDB *MockDB
269 | query string
270 | }
271 |
272 | // Prepare implements driver.Conn
273 | func (c *MockConn) Prepare(query string) (driver.Stmt, error) {
274 | return &MockStmt{mockDB: c.mockDB, query: query}, nil
275 | }
276 |
277 | // Close implements driver.Conn
278 | func (c *MockConn) Close() error {
279 | return nil
280 | }
281 |
282 | // Begin implements driver.Conn
283 | func (c *MockConn) Begin() (driver.Tx, error) {
284 | return nil, nil
285 | }
286 |
287 | // MockStmt implements driver.Stmt
288 | type MockStmt struct {
289 | mockDB *MockDB
290 | query string
291 | }
292 |
293 | // Close implements driver.Stmt
294 | func (s *MockStmt) Close() error {
295 | return nil
296 | }
297 |
298 | // NumInput implements driver.Stmt
299 | func (s *MockStmt) NumInput() int {
300 | return 0
301 | }
302 |
303 | // Exec implements driver.Stmt
304 | func (s *MockStmt) Exec(args []driver.Value) (driver.Result, error) {
305 | return nil, nil
306 | }
307 |
308 | // Query implements driver.Stmt
309 | func (s *MockStmt) Query(args []driver.Value) (driver.Rows, error) {
310 | // Return the registered result for this query
311 | if s.mockDB != nil {
312 | if result, found := s.mockDB.getQueryResult(s.query); found {
313 | if result.Error != nil {
314 | return nil, result.Error
315 | }
316 | return &MockRows{value: result.Result}, nil
317 | }
318 | }
319 | return &MockRows{}, nil
320 | }
321 |
322 | // MockRows implements driver.Rows
323 | type MockRows struct {
324 | value interface{}
325 | currentRow int
326 | columnNames []string
327 | }
328 |
329 | // Columns implements driver.Rows
330 | func (r *MockRows) Columns() []string {
331 | // If we have a slice of maps, extract column names from the first map
332 | if rows, ok := r.value.([]map[string]interface{}); ok && len(rows) > 0 {
333 | if r.columnNames == nil {
334 | r.columnNames = make([]string, 0, len(rows[0]))
335 | for k := range rows[0] {
336 | r.columnNames = append(r.columnNames, k)
337 | }
338 | }
339 | return r.columnNames
340 | }
341 | return []string{"value"}
342 | }
343 |
344 | // Close implements driver.Rows
345 | func (r *MockRows) Close() error {
346 | return nil
347 | }
348 |
349 | // Next implements driver.Rows
350 | func (r *MockRows) Next(dest []driver.Value) error {
351 | // Handle slice of maps (multiple rows of data)
352 | if rows, ok := r.value.([]map[string]interface{}); ok {
353 | if r.currentRow < len(rows) {
354 | row := rows[r.currentRow]
355 | r.currentRow++
356 |
357 | // Find column values in the expected order
358 | columns := r.Columns()
359 | for i, col := range columns {
360 | if i < len(dest) {
361 | dest[i] = row[col]
362 | }
363 | }
364 | return nil
365 | }
366 | return io.EOF
367 | }
368 |
369 | // Handle simple string value
370 | if r.currentRow == 0 && r.value != nil {
371 | r.currentRow++
372 | dest[0] = r.value
373 | return nil
374 | }
375 | return io.EOF
376 | }
377 |
378 | // RunQueryTest executes a mock query test against the DB
379 | func RunQueryTest(t *testing.T, testFunc func(*DB) error) {
380 | tsdb, _ := MockTimescaleDB(t)
381 | err := testFunc(tsdb)
382 | if err != nil {
383 | t.Errorf("Test failed: %v", err)
384 | }
385 | }
386 |
387 | // SetQueryResult sets the mock result for queries
388 | func (m *MockDB) SetQueryResult(result []map[string]interface{}) {
389 | m.queryResult = result
390 | }
391 |
392 | // SetError sets the mock error
393 | func (m *MockDB) SetError(errMsg string) {
394 | m.err = fmt.Errorf("%s", errMsg)
395 | }
396 |
397 | // LastQuery returns the last executed query
398 | func (m *MockDB) LastQuery() string {
399 | return m.lastQuery
400 | }
401 |
402 | // QueryContains checks if the last query contains a substring
403 | func (m *MockDB) QueryContains(substring string) bool {
404 | return strings.Contains(m.lastQuery, substring)
405 | }
406 |
407 | // ExecuteSQL implements db.Database.ExecuteSQL
408 | func (m *MockDB) ExecuteSQL(ctx context.Context, query string, args ...interface{}) (interface{}, error) {
409 | m.lastQuery = query
410 | m.lastQueryArgs = args
411 | m.queryHistory = append(m.queryHistory, query)
412 |
413 | if m.err != nil {
414 | return nil, m.err
415 | }
416 |
417 | // If TimescaleDB is not available and the query is for TimescaleDB specific features
418 | if !m.isTimescaleDB && (strings.Contains(query, "time_bucket") ||
419 | strings.Contains(query, "hypertable") ||
420 | strings.Contains(query, "continuous_aggregate") ||
421 | strings.Contains(query, "timescaledb_information")) {
422 | return nil, fmt.Errorf("TimescaleDB extension not available")
423 | }
424 |
425 | if m.queryResult != nil {
426 | return m.queryResult, nil
427 | }
428 |
429 | // Return an empty result set by default
430 | return []map[string]interface{}{}, nil
431 | }
432 |
433 | // ExecuteSQLWithoutParams implements db.Database.ExecuteSQLWithoutParams
434 | func (m *MockDB) ExecuteSQLWithoutParams(ctx context.Context, query string) (interface{}, error) {
435 | m.lastQuery = query
436 | m.lastQueryArgs = nil
437 | m.queryHistory = append(m.queryHistory, query)
438 |
439 | if m.err != nil {
440 | return nil, m.err
441 | }
442 |
443 | // If TimescaleDB is not available and the query is for TimescaleDB specific features
444 | if !m.isTimescaleDB && (strings.Contains(query, "time_bucket") ||
445 | strings.Contains(query, "hypertable") ||
446 | strings.Contains(query, "continuous_aggregate") ||
447 | strings.Contains(query, "timescaledb_information")) {
448 | return nil, fmt.Errorf("TimescaleDB extension not available")
449 | }
450 |
451 | if m.queryResult != nil {
452 | return m.queryResult, nil
453 | }
454 |
455 | // Return an empty result set by default
456 | return []map[string]interface{}{}, nil
457 | }
458 |
459 | // QueryHistory returns the complete history of queries executed
460 | func (m *MockDB) QueryHistory() []string {
461 | return m.queryHistory
462 | }
463 |
464 | // AnyQueryContains checks if any query in the history contains the given substring
465 | func (m *MockDB) AnyQueryContains(substring string) bool {
466 | for _, query := range m.queryHistory {
467 | if strings.Contains(query, substring) {
468 | return true
469 | }
470 | }
471 | return false
472 | }
473 |
```