This is page 4 of 7. Use http://codebase.md/freepeak/db-mcp-server?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .cm
│ └── gitstream.cm
├── .cursor
│ ├── mcp-example.json
│ ├── mcp.json
│ └── rules
│ └── global.mdc
├── .dockerignore
├── .DS_Store
├── .env.example
├── .github
│ ├── FUNDING.yml
│ └── workflows
│ └── go.yml
├── .gitignore
├── .golangci.yml
├── assets
│ └── logo.svg
├── CHANGELOG.md
├── cmd
│ └── server
│ └── main.go
├── commit-message.txt
├── config.json
├── config.timescaledb-test.json
├── docker-compose.mcp-test.yml
├── docker-compose.test.yml
├── docker-compose.timescaledb-test.yml
├── docker-compose.yml
├── docker-wrapper.sh
├── Dockerfile
├── docs
│ ├── REFACTORING.md
│ ├── TIMESCALEDB_FUNCTIONS.md
│ ├── TIMESCALEDB_IMPLEMENTATION.md
│ ├── TIMESCALEDB_PRD.md
│ └── TIMESCALEDB_TOOLS.md
├── examples
│ └── postgres_connection.go
├── glama.json
├── go.mod
├── go.sum
├── init-scripts
│ └── timescaledb
│ ├── 01-init.sql
│ ├── 02-sample-data.sql
│ ├── 03-continuous-aggregates.sql
│ └── README.md
├── internal
│ ├── config
│ │ ├── config_test.go
│ │ └── config.go
│ ├── delivery
│ │ └── mcp
│ │ ├── compression_policy_test.go
│ │ ├── context
│ │ │ ├── hypertable_schema_test.go
│ │ │ ├── timescale_completion_test.go
│ │ │ ├── timescale_context_test.go
│ │ │ └── timescale_query_suggestion_test.go
│ │ ├── mock_test.go
│ │ ├── response_test.go
│ │ ├── response.go
│ │ ├── retention_policy_test.go
│ │ ├── server_wrapper.go
│ │ ├── timescale_completion.go
│ │ ├── timescale_context.go
│ │ ├── timescale_schema.go
│ │ ├── timescale_tool_test.go
│ │ ├── timescale_tool.go
│ │ ├── timescale_tools_test.go
│ │ ├── tool_registry.go
│ │ └── tool_types.go
│ ├── domain
│ │ └── database.go
│ ├── logger
│ │ ├── logger_test.go
│ │ └── logger.go
│ ├── repository
│ │ └── database_repository.go
│ └── usecase
│ └── database_usecase.go
├── LICENSE
├── Makefile
├── pkg
│ ├── core
│ │ ├── core.go
│ │ └── logging.go
│ ├── db
│ │ ├── db_test.go
│ │ ├── db.go
│ │ ├── manager.go
│ │ ├── README.md
│ │ └── timescale
│ │ ├── config_test.go
│ │ ├── config.go
│ │ ├── connection_test.go
│ │ ├── connection.go
│ │ ├── continuous_aggregate_test.go
│ │ ├── continuous_aggregate.go
│ │ ├── hypertable_test.go
│ │ ├── hypertable.go
│ │ ├── metadata.go
│ │ ├── mocks_test.go
│ │ ├── policy_test.go
│ │ ├── policy.go
│ │ ├── query.go
│ │ ├── timeseries_test.go
│ │ └── timeseries.go
│ ├── dbtools
│ │ ├── db_helpers.go
│ │ ├── dbtools_test.go
│ │ ├── dbtools.go
│ │ ├── exec.go
│ │ ├── performance_test.go
│ │ ├── performance.go
│ │ ├── query.go
│ │ ├── querybuilder_test.go
│ │ ├── querybuilder.go
│ │ ├── README.md
│ │ ├── schema_test.go
│ │ ├── schema.go
│ │ ├── tx_test.go
│ │ └── tx.go
│ ├── internal
│ │ └── logger
│ │ └── logger.go
│ ├── jsonrpc
│ │ └── jsonrpc.go
│ ├── logger
│ │ └── logger.go
│ └── tools
│ └── tools.go
├── README-old.md
├── README.md
├── repomix-output.txt
├── request.json
├── start-mcp.sh
├── test.Dockerfile
├── timescaledb-test.sh
└── wait-for-it.sh
```
# Files
--------------------------------------------------------------------------------
/pkg/db/timescale/policy.go:
--------------------------------------------------------------------------------
```go
1 | package timescale
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strings"
7 | )
8 |
9 | // CompressionSettings represents TimescaleDB compression settings
10 | type CompressionSettings struct {
11 | HypertableName string
12 | SegmentBy string
13 | OrderBy string
14 | ChunkTimeInterval string
15 | CompressionInterval string
16 | CompressionEnabled bool
17 | }
18 |
19 | // RetentionSettings represents TimescaleDB retention settings
20 | type RetentionSettings struct {
21 | HypertableName string
22 | RetentionInterval string
23 | RetentionEnabled bool
24 | }
25 |
26 | // EnableCompression enables compression on a hypertable
27 | func (t *DB) EnableCompression(ctx context.Context, tableName string, afterInterval string) error {
28 | if !t.isTimescaleDB {
29 | return fmt.Errorf("TimescaleDB extension not available")
30 | }
31 |
32 | query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", tableName)
33 | _, err := t.ExecuteSQLWithoutParams(ctx, query)
34 | if err != nil {
35 | return fmt.Errorf("failed to enable compression: %w", err)
36 | }
37 |
38 | // Set compression policy if interval is specified
39 | if afterInterval != "" {
40 | err = t.AddCompressionPolicy(ctx, tableName, afterInterval, "", "")
41 | if err != nil {
42 | return fmt.Errorf("failed to add compression policy: %w", err)
43 | }
44 | }
45 |
46 | return nil
47 | }
48 |
49 | // DisableCompression disables compression on a hypertable
50 | func (t *DB) DisableCompression(ctx context.Context, tableName string) error {
51 | if !t.isTimescaleDB {
52 | return fmt.Errorf("TimescaleDB extension not available")
53 | }
54 |
55 | // First, remove any compression policies
56 | err := t.RemoveCompressionPolicy(ctx, tableName)
57 | if err != nil {
58 | return fmt.Errorf("failed to remove compression policy: %w", err)
59 | }
60 |
61 | // Then disable compression
62 | query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = false)", tableName)
63 | _, err = t.ExecuteSQLWithoutParams(ctx, query)
64 | if err != nil {
65 | return fmt.Errorf("failed to disable compression: %w", err)
66 | }
67 |
68 | return nil
69 | }
70 |
71 | // AddCompressionPolicy adds a compression policy to a hypertable
72 | func (t *DB) AddCompressionPolicy(ctx context.Context, tableName, interval, segmentBy, orderBy string) error {
73 | if !t.isTimescaleDB {
74 | return fmt.Errorf("TimescaleDB extension not available")
75 | }
76 |
77 | // First, check if the table has compression enabled
78 | query := fmt.Sprintf("SELECT compress FROM timescaledb_information.hypertables WHERE hypertable_name = '%s'", tableName)
79 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
80 | if err != nil {
81 | return fmt.Errorf("failed to check compression status: %w", err)
82 | }
83 |
84 | rows, ok := result.([]map[string]interface{})
85 | if !ok || len(rows) == 0 {
86 | return fmt.Errorf("table '%s' is not a hypertable", tableName)
87 | }
88 |
89 | isCompressed := rows[0]["compress"]
90 | if isCompressed == nil || fmt.Sprintf("%v", isCompressed) == "false" {
91 | // Enable compression
92 | enableQuery := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", tableName)
93 | _, err := t.ExecuteSQLWithoutParams(ctx, enableQuery)
94 | if err != nil {
95 | return fmt.Errorf("failed to enable compression: %w", err)
96 | }
97 | }
98 |
99 | // Build the compression policy query
100 | var policyQuery strings.Builder
101 | policyQuery.WriteString(fmt.Sprintf("SELECT add_compression_policy('%s', INTERVAL '%s'", tableName, interval))
102 |
103 | if segmentBy != "" {
104 | policyQuery.WriteString(fmt.Sprintf(", segmentby => '%s'", segmentBy))
105 | }
106 |
107 | if orderBy != "" {
108 | policyQuery.WriteString(fmt.Sprintf(", orderby => '%s'", orderBy))
109 | }
110 |
111 | policyQuery.WriteString(")")
112 |
113 | // Add the compression policy
114 | _, err = t.ExecuteSQLWithoutParams(ctx, policyQuery.String())
115 | if err != nil {
116 | return fmt.Errorf("failed to add compression policy: %w", err)
117 | }
118 |
119 | return nil
120 | }
121 |
122 | // RemoveCompressionPolicy removes a compression policy from a hypertable
123 | func (t *DB) RemoveCompressionPolicy(ctx context.Context, tableName string) error {
124 | if !t.isTimescaleDB {
125 | return fmt.Errorf("TimescaleDB extension not available")
126 | }
127 |
128 | // Find the policy ID
129 | query := fmt.Sprintf(
130 | "SELECT job_id FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_compression'",
131 | tableName,
132 | )
133 |
134 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
135 | if err != nil {
136 | return fmt.Errorf("failed to find compression policy: %w", err)
137 | }
138 |
139 | rows, ok := result.([]map[string]interface{})
140 | if !ok || len(rows) == 0 {
141 | // No policy exists, so nothing to remove
142 | return nil
143 | }
144 |
145 | // Get the job ID
146 | jobID := rows[0]["job_id"]
147 | if jobID == nil {
148 | return fmt.Errorf("invalid job ID for compression policy")
149 | }
150 |
151 | // Remove the policy
152 | removeQuery := fmt.Sprintf("SELECT remove_compression_policy(%v)", jobID)
153 | _, err = t.ExecuteSQLWithoutParams(ctx, removeQuery)
154 | if err != nil {
155 | return fmt.Errorf("failed to remove compression policy: %w", err)
156 | }
157 |
158 | return nil
159 | }
160 |
161 | // GetCompressionSettings gets the compression settings for a hypertable
162 | func (t *DB) GetCompressionSettings(ctx context.Context, tableName string) (*CompressionSettings, error) {
163 | if !t.isTimescaleDB {
164 | return nil, fmt.Errorf("TimescaleDB extension not available")
165 | }
166 |
167 | // Check if the table has compression enabled
168 | query := fmt.Sprintf(
169 | "SELECT compress FROM timescaledb_information.hypertables WHERE hypertable_name = '%s'",
170 | tableName,
171 | )
172 |
173 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
174 | if err != nil {
175 | return nil, fmt.Errorf("failed to check compression status: %w", err)
176 | }
177 |
178 | rows, ok := result.([]map[string]interface{})
179 | if !ok || len(rows) == 0 {
180 | return nil, fmt.Errorf("table '%s' is not a hypertable", tableName)
181 | }
182 |
183 | settings := &CompressionSettings{
184 | HypertableName: tableName,
185 | }
186 |
187 | isCompressed := rows[0]["compress"]
188 | if isCompressed != nil && fmt.Sprintf("%v", isCompressed) == "true" {
189 | settings.CompressionEnabled = true
190 |
191 | // Get compression-specific settings
192 | settingsQuery := fmt.Sprintf(
193 | "SELECT segmentby, orderby FROM timescaledb_information.compression_settings WHERE hypertable_name = '%s'",
194 | tableName,
195 | )
196 |
197 | settingsResult, err := t.ExecuteSQLWithoutParams(ctx, settingsQuery)
198 | if err != nil {
199 | return nil, fmt.Errorf("failed to get compression settings: %w", err)
200 | }
201 |
202 | settingsRows, ok := settingsResult.([]map[string]interface{})
203 | if ok && len(settingsRows) > 0 {
204 | if segmentBy, ok := settingsRows[0]["segmentby"]; ok && segmentBy != nil {
205 | settings.SegmentBy = fmt.Sprintf("%v", segmentBy)
206 | }
207 |
208 | if orderBy, ok := settingsRows[0]["orderby"]; ok && orderBy != nil {
209 | settings.OrderBy = fmt.Sprintf("%v", orderBy)
210 | }
211 | }
212 |
213 | // Check if a compression policy exists
214 | policyQuery := fmt.Sprintf(
215 | "SELECT s.schedule_interval, h.chunk_time_interval FROM timescaledb_information.jobs j "+
216 | "JOIN timescaledb_information.job_stats s ON j.job_id = s.job_id "+
217 | "JOIN timescaledb_information.hypertables h ON j.hypertable_name = h.hypertable_name "+
218 | "WHERE j.hypertable_name = '%s' AND j.proc_name = 'policy_compression'",
219 | tableName,
220 | )
221 |
222 | policyResult, err := t.ExecuteSQLWithoutParams(ctx, policyQuery)
223 | if err == nil {
224 | policyRows, ok := policyResult.([]map[string]interface{})
225 | if ok && len(policyRows) > 0 {
226 | if interval, ok := policyRows[0]["schedule_interval"]; ok && interval != nil {
227 | settings.CompressionInterval = fmt.Sprintf("%v", interval)
228 | }
229 |
230 | if chunkInterval, ok := policyRows[0]["chunk_time_interval"]; ok && chunkInterval != nil {
231 | settings.ChunkTimeInterval = fmt.Sprintf("%v", chunkInterval)
232 | }
233 | }
234 | }
235 | }
236 |
237 | return settings, nil
238 | }
239 |
240 | // AddRetentionPolicy adds a data retention policy to a hypertable
241 | func (t *DB) AddRetentionPolicy(ctx context.Context, tableName, interval string) error {
242 | if !t.isTimescaleDB {
243 | return fmt.Errorf("TimescaleDB extension not available")
244 | }
245 |
246 | query := fmt.Sprintf("SELECT add_retention_policy('%s', INTERVAL '%s')", tableName, interval)
247 | _, err := t.ExecuteSQLWithoutParams(ctx, query)
248 | if err != nil {
249 | return fmt.Errorf("failed to add retention policy: %w", err)
250 | }
251 |
252 | return nil
253 | }
254 |
255 | // RemoveRetentionPolicy removes a data retention policy from a hypertable
256 | func (t *DB) RemoveRetentionPolicy(ctx context.Context, tableName string) error {
257 | if !t.isTimescaleDB {
258 | return fmt.Errorf("TimescaleDB extension not available")
259 | }
260 |
261 | // Find the policy ID
262 | query := fmt.Sprintf(
263 | "SELECT job_id FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'",
264 | tableName,
265 | )
266 |
267 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
268 | if err != nil {
269 | return fmt.Errorf("failed to find retention policy: %w", err)
270 | }
271 |
272 | rows, ok := result.([]map[string]interface{})
273 | if !ok || len(rows) == 0 {
274 | // No policy exists, so nothing to remove
275 | return nil
276 | }
277 |
278 | // Get the job ID
279 | jobID := rows[0]["job_id"]
280 | if jobID == nil {
281 | return fmt.Errorf("invalid job ID for retention policy")
282 | }
283 |
284 | // Remove the policy
285 | removeQuery := fmt.Sprintf("SELECT remove_retention_policy(%v)", jobID)
286 | _, err = t.ExecuteSQLWithoutParams(ctx, removeQuery)
287 | if err != nil {
288 | return fmt.Errorf("failed to remove retention policy: %w", err)
289 | }
290 |
291 | return nil
292 | }
293 |
294 | // GetRetentionSettings gets the retention settings for a hypertable
295 | func (t *DB) GetRetentionSettings(ctx context.Context, tableName string) (*RetentionSettings, error) {
296 | if !t.isTimescaleDB {
297 | return nil, fmt.Errorf("TimescaleDB extension not available")
298 | }
299 |
300 | settings := &RetentionSettings{
301 | HypertableName: tableName,
302 | }
303 |
304 | // Check if a retention policy exists
305 | query := fmt.Sprintf(
306 | "SELECT s.schedule_interval FROM timescaledb_information.jobs j "+
307 | "JOIN timescaledb_information.job_stats s ON j.job_id = s.job_id "+
308 | "WHERE j.hypertable_name = '%s' AND j.proc_name = 'policy_retention'",
309 | tableName,
310 | )
311 |
312 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
313 | if err != nil {
314 | return settings, nil // Return empty settings with no error
315 | }
316 |
317 | rows, ok := result.([]map[string]interface{})
318 | if ok && len(rows) > 0 {
319 | settings.RetentionEnabled = true
320 | if interval, ok := rows[0]["schedule_interval"]; ok && interval != nil {
321 | settings.RetentionInterval = fmt.Sprintf("%v", interval)
322 | }
323 | }
324 |
325 | return settings, nil
326 | }
327 |
328 | // CompressChunks compresses chunks for a hypertable
329 | func (t *DB) CompressChunks(ctx context.Context, tableName, olderThan string) error {
330 | if !t.isTimescaleDB {
331 | return fmt.Errorf("TimescaleDB extension not available")
332 | }
333 |
334 | var query string
335 | if olderThan == "" {
336 | // Compress all chunks
337 | query = fmt.Sprintf("SELECT compress_chunks(hypertable => '%s')", tableName)
338 | } else {
339 | // Compress chunks older than the specified interval
340 | query = fmt.Sprintf("SELECT compress_chunks(hypertable => '%s', older_than => INTERVAL '%s')",
341 | tableName, olderThan)
342 | }
343 |
344 | _, err := t.ExecuteSQLWithoutParams(ctx, query)
345 | if err != nil {
346 | return fmt.Errorf("failed to compress chunks: %w", err)
347 | }
348 |
349 | return nil
350 | }
351 |
352 | // DecompressChunks decompresses chunks for a hypertable
353 | func (t *DB) DecompressChunks(ctx context.Context, tableName, newerThan string) error {
354 | if !t.isTimescaleDB {
355 | return fmt.Errorf("TimescaleDB extension not available")
356 | }
357 |
358 | var query string
359 | if newerThan == "" {
360 | // Decompress all chunks
361 | query = fmt.Sprintf("SELECT decompress_chunks(hypertable => '%s')", tableName)
362 | } else {
363 | // Decompress chunks newer than the specified interval
364 | query = fmt.Sprintf("SELECT decompress_chunks(hypertable => '%s', newer_than => INTERVAL '%s')",
365 | tableName, newerThan)
366 | }
367 |
368 | _, err := t.ExecuteSQLWithoutParams(ctx, query)
369 | if err != nil {
370 | return fmt.Errorf("failed to decompress chunks: %w", err)
371 | }
372 |
373 | return nil
374 | }
375 |
376 | // GetChunkCompressionStats gets compression statistics for a hypertable
377 | func (t *DB) GetChunkCompressionStats(ctx context.Context, tableName string) (interface{}, error) {
378 | if !t.isTimescaleDB {
379 | return nil, fmt.Errorf("TimescaleDB extension not available")
380 | }
381 |
382 | query := fmt.Sprintf(`
383 | SELECT
384 | chunk_name,
385 | range_start,
386 | range_end,
387 | is_compressed,
388 | before_compression_total_bytes,
389 | after_compression_total_bytes,
390 | CASE
391 | WHEN before_compression_total_bytes = 0 THEN 0
392 | ELSE (1 - (after_compression_total_bytes::float / before_compression_total_bytes::float)) * 100
393 | END AS compression_ratio
394 | FROM timescaledb_information.chunks
395 | WHERE hypertable_name = '%s'
396 | ORDER BY range_end DESC
397 | `, tableName)
398 |
399 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
400 | if err != nil {
401 | return nil, fmt.Errorf("failed to get chunk compression statistics: %w", err)
402 | }
403 |
404 | return result, nil
405 | }
406 |
```
--------------------------------------------------------------------------------
/pkg/db/timescale/hypertable_test.go:
--------------------------------------------------------------------------------
```go
1 | package timescale
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "testing"
7 | )
8 |
9 | func TestCreateHypertable(t *testing.T) {
10 | mockDB := NewMockDB()
11 | tsdb := &DB{
12 | Database: mockDB,
13 | isTimescaleDB: true,
14 | }
15 |
16 | ctx := context.Background()
17 |
18 | // Test basic hypertable creation
19 | config := HypertableConfig{
20 | TableName: "test_table",
21 | TimeColumn: "time",
22 | ChunkTimeInterval: "1 day",
23 | CreateIfNotExists: true,
24 | }
25 |
26 | err := tsdb.CreateHypertable(ctx, config)
27 | if err != nil {
28 | t.Fatalf("Failed to create hypertable: %v", err)
29 | }
30 |
31 | // Check that the correct query was executed
32 | query, _ := mockDB.GetLastQuery()
33 | AssertQueryContains(t, query, "create_hypertable")
34 | AssertQueryContains(t, query, "test_table")
35 | AssertQueryContains(t, query, "time")
36 | AssertQueryContains(t, query, "chunk_time_interval")
37 | AssertQueryContains(t, query, "1 day")
38 |
39 | // Test with partitioning
40 | config = HypertableConfig{
41 | TableName: "test_table",
42 | TimeColumn: "time",
43 | ChunkTimeInterval: "1 day",
44 | PartitioningColumn: "device_id",
45 | SpacePartitions: 4,
46 | }
47 |
48 | err = tsdb.CreateHypertable(ctx, config)
49 | if err != nil {
50 | t.Fatalf("Failed to create hypertable with partitioning: %v", err)
51 | }
52 |
53 | // Check that the correct query was executed
54 | query, _ = mockDB.GetLastQuery()
55 | AssertQueryContains(t, query, "create_hypertable")
56 | AssertQueryContains(t, query, "partition_column")
57 | AssertQueryContains(t, query, "device_id")
58 | AssertQueryContains(t, query, "number_partitions")
59 |
60 | // Test with if_not_exists and migrate_data
61 | config = HypertableConfig{
62 | TableName: "test_table",
63 | TimeColumn: "time",
64 | IfNotExists: true,
65 | MigrateData: true,
66 | }
67 |
68 | err = tsdb.CreateHypertable(ctx, config)
69 | if err != nil {
70 | t.Fatalf("Failed to create hypertable with extra options: %v", err)
71 | }
72 |
73 | // Check that the correct query was executed
74 | query, _ = mockDB.GetLastQuery()
75 | AssertQueryContains(t, query, "if_not_exists => TRUE")
76 | AssertQueryContains(t, query, "migrate_data => TRUE")
77 |
78 | // Test when TimescaleDB is not available
79 | tsdb.isTimescaleDB = false
80 | err = tsdb.CreateHypertable(ctx, config)
81 | if err == nil {
82 | t.Error("Expected error when TimescaleDB is not available, got nil")
83 | }
84 |
85 | // Test execution error
86 | tsdb.isTimescaleDB = true
87 | mockDB.RegisterQueryResult("SELECT create_hypertable(", nil, errors.New("mocked error"))
88 | err = tsdb.CreateHypertable(ctx, config)
89 | if err == nil {
90 | t.Error("Expected query error, got nil")
91 | }
92 | }
93 |
94 | func TestAddDimension(t *testing.T) {
95 | mockDB := NewMockDB()
96 | tsdb := &DB{
97 | Database: mockDB,
98 | isTimescaleDB: true,
99 | }
100 |
101 | ctx := context.Background()
102 |
103 | // Test adding a dimension
104 | err := tsdb.AddDimension(ctx, "test_table", "device_id", 4)
105 | if err != nil {
106 | t.Fatalf("Failed to add dimension: %v", err)
107 | }
108 |
109 | // Check that the correct query was executed
110 | query, _ := mockDB.GetLastQuery()
111 | AssertQueryContains(t, query, "add_dimension")
112 | AssertQueryContains(t, query, "test_table")
113 | AssertQueryContains(t, query, "device_id")
114 | AssertQueryContains(t, query, "number_partitions => 4")
115 |
116 | // Test when TimescaleDB is not available
117 | tsdb.isTimescaleDB = false
118 | err = tsdb.AddDimension(ctx, "test_table", "device_id", 4)
119 | if err == nil {
120 | t.Error("Expected error when TimescaleDB is not available, got nil")
121 | }
122 |
123 | // Test execution error
124 | tsdb.isTimescaleDB = true
125 | mockDB.RegisterQueryResult("SELECT add_dimension(", nil, errors.New("mocked error"))
126 | err = tsdb.AddDimension(ctx, "test_table", "device_id", 4)
127 | if err == nil {
128 | t.Error("Expected query error, got nil")
129 | }
130 | }
131 |
132 | func TestListHypertables(t *testing.T) {
133 | mockDB := NewMockDB()
134 | tsdb := &DB{
135 | Database: mockDB,
136 | isTimescaleDB: true,
137 | }
138 |
139 | ctx := context.Background()
140 |
141 | // Prepare mock data
142 | mockResult := []map[string]interface{}{
143 | {
144 | "table_name": "test_table",
145 | "schema_name": "public",
146 | "time_column": "time",
147 | "num_dimensions": 2,
148 | "space_column": "device_id",
149 | },
150 | {
151 | "table_name": "test_table2",
152 | "schema_name": "public",
153 | "time_column": "timestamp",
154 | "num_dimensions": 1,
155 | "space_column": nil,
156 | },
157 | }
158 |
159 | // Register different result patterns for different queries
160 | mockDB.RegisterQueryResult("FROM _timescaledb_catalog.hypertable h", mockResult, nil)
161 | mockDB.RegisterQueryResult("FROM timescaledb_information.compression_settings", []map[string]interface{}{
162 | {"is_compressed": true},
163 | }, nil)
164 | mockDB.RegisterQueryResult("FROM timescaledb_information.jobs", []map[string]interface{}{
165 | {"has_retention": true},
166 | }, nil)
167 |
168 | // Test listing hypertables
169 | hypertables, err := tsdb.ListHypertables(ctx)
170 | if err != nil {
171 | t.Fatalf("Failed to list hypertables: %v", err)
172 | }
173 |
174 | // Check the results
175 | if len(hypertables) != 2 {
176 | t.Errorf("Expected 2 hypertables, got %d", len(hypertables))
177 | }
178 |
179 | if hypertables[0].TableName != "test_table" {
180 | t.Errorf("Expected TableName to be 'test_table', got '%s'", hypertables[0].TableName)
181 | }
182 |
183 | if hypertables[0].TimeColumn != "time" {
184 | t.Errorf("Expected TimeColumn to be 'time', got '%s'", hypertables[0].TimeColumn)
185 | }
186 |
187 | if hypertables[0].SpaceColumn != "device_id" {
188 | t.Errorf("Expected SpaceColumn to be 'device_id', got '%s'", hypertables[0].SpaceColumn)
189 | }
190 |
191 | if hypertables[0].NumDimensions != 2 {
192 | t.Errorf("Expected NumDimensions to be 2, got %d", hypertables[0].NumDimensions)
193 | }
194 |
195 | if !hypertables[0].CompressionEnabled {
196 | t.Error("Expected CompressionEnabled to be true, got false")
197 | }
198 |
199 | if !hypertables[0].RetentionEnabled {
200 | t.Error("Expected RetentionEnabled to be true, got false")
201 | }
202 |
203 | // Test when TimescaleDB is not available
204 | tsdb.isTimescaleDB = false
205 | _, err = tsdb.ListHypertables(ctx)
206 | if err == nil {
207 | t.Error("Expected error when TimescaleDB is not available, got nil")
208 | }
209 |
210 | // Test execution error
211 | tsdb.isTimescaleDB = true
212 | mockDB.RegisterQueryResult("FROM _timescaledb_catalog.hypertable h", nil, errors.New("mocked error"))
213 | _, err = tsdb.ListHypertables(ctx)
214 | if err == nil {
215 | t.Error("Expected query error, got nil")
216 | }
217 | }
218 |
219 | func TestGetHypertable(t *testing.T) {
220 | mockDB := NewMockDB()
221 | tsdb := &DB{
222 | Database: mockDB,
223 | isTimescaleDB: true,
224 | }
225 |
226 | ctx := context.Background()
227 |
228 | // Prepare mock data - Set up the correct result by using RegisterQueryResult
229 | mockResult := []map[string]interface{}{
230 | {
231 | "table_name": "test_table",
232 | "schema_name": "public",
233 | "time_column": "time",
234 | "num_dimensions": int64(2),
235 | "space_column": "device_id",
236 | },
237 | }
238 |
239 | // Register the query result pattern for the main query
240 | mockDB.RegisterQueryResult("WHERE h.table_name = 'test_table'", mockResult, nil)
241 |
242 | // Register results for the compression check
243 | mockDB.RegisterQueryResult("FROM timescaledb_information.compression_settings", []map[string]interface{}{
244 | {"is_compressed": true},
245 | }, nil)
246 |
247 | // Register results for the retention policy check
248 | mockDB.RegisterQueryResult("FROM timescaledb_information.jobs", []map[string]interface{}{
249 | {"has_retention": true},
250 | }, nil)
251 |
252 | // Test getting a hypertable
253 | hypertable, err := tsdb.GetHypertable(ctx, "test_table")
254 | if err != nil {
255 | t.Fatalf("Failed to get hypertable: %v", err)
256 | }
257 |
258 | // Check the results
259 | if hypertable.TableName != "test_table" {
260 | t.Errorf("Expected TableName to be 'test_table', got '%s'", hypertable.TableName)
261 | }
262 |
263 | if hypertable.TimeColumn != "time" {
264 | t.Errorf("Expected TimeColumn to be 'time', got '%s'", hypertable.TimeColumn)
265 | }
266 |
267 | if hypertable.SpaceColumn != "device_id" {
268 | t.Errorf("Expected SpaceColumn to be 'device_id', got '%s'", hypertable.SpaceColumn)
269 | }
270 |
271 | if hypertable.NumDimensions != 2 {
272 | t.Errorf("Expected NumDimensions to be 2, got %d", hypertable.NumDimensions)
273 | }
274 |
275 | if !hypertable.CompressionEnabled {
276 | t.Error("Expected CompressionEnabled to be true, got false")
277 | }
278 |
279 | if !hypertable.RetentionEnabled {
280 | t.Error("Expected RetentionEnabled to be true, got false")
281 | }
282 |
283 | // Test when TimescaleDB is not available
284 | tsdb.isTimescaleDB = false
285 | _, err = tsdb.GetHypertable(ctx, "test_table")
286 | if err == nil {
287 | t.Error("Expected error when TimescaleDB is not available, got nil")
288 | }
289 |
290 | // Test execution error
291 | tsdb.isTimescaleDB = true
292 | mockDB.RegisterQueryResult("WHERE h.table_name = 'test_table'", nil, errors.New("mocked error"))
293 | _, err = tsdb.GetHypertable(ctx, "test_table")
294 | if err == nil {
295 | t.Error("Expected query error, got nil")
296 | }
297 |
298 | // Test table not found - Create a new mock to avoid interference
299 | newMockDB := NewMockDB()
300 | newMockDB.SetTimescaleAvailable(true)
301 | tsdb.Database = newMockDB
302 |
303 | // Register an empty result for the "not_found" table
304 | newMockDB.RegisterQueryResult("WHERE h.table_name = 'not_found'", []map[string]interface{}{}, nil)
305 | _, err = tsdb.GetHypertable(ctx, "not_found")
306 | if err == nil {
307 | t.Error("Expected error for non-existent table, got nil")
308 | }
309 | }
310 |
311 | func TestDropHypertable(t *testing.T) {
312 | mockDB := NewMockDB()
313 | tsdb := &DB{
314 | Database: mockDB,
315 | isTimescaleDB: true,
316 | }
317 |
318 | ctx := context.Background()
319 |
320 | // Test dropping a hypertable
321 | err := tsdb.DropHypertable(ctx, "test_table", false)
322 | if err != nil {
323 | t.Fatalf("Failed to drop hypertable: %v", err)
324 | }
325 |
326 | // Check that the correct query was executed
327 | query, _ := mockDB.GetLastQuery()
328 | AssertQueryContains(t, query, "DROP TABLE test_table")
329 |
330 | // Test dropping with CASCADE
331 | err = tsdb.DropHypertable(ctx, "test_table", true)
332 | if err != nil {
333 | t.Fatalf("Failed to drop hypertable with CASCADE: %v", err)
334 | }
335 |
336 | // Check that the correct query was executed
337 | query, _ = mockDB.GetLastQuery()
338 | AssertQueryContains(t, query, "DROP TABLE test_table CASCADE")
339 |
340 | // Test when TimescaleDB is not available
341 | tsdb.isTimescaleDB = false
342 | err = tsdb.DropHypertable(ctx, "test_table", false)
343 | if err == nil {
344 | t.Error("Expected error when TimescaleDB is not available, got nil")
345 | }
346 |
347 | // Test execution error
348 | tsdb.isTimescaleDB = true
349 | mockDB.RegisterQueryResult("DROP TABLE", nil, errors.New("mocked error"))
350 | err = tsdb.DropHypertable(ctx, "test_table", false)
351 | if err == nil {
352 | t.Error("Expected query error, got nil")
353 | }
354 | }
355 |
356 | func TestCheckIfHypertable(t *testing.T) {
357 | mockDB := NewMockDB()
358 | tsdb := &DB{
359 | Database: mockDB,
360 | isTimescaleDB: true,
361 | }
362 |
363 | ctx := context.Background()
364 |
365 | // Prepare mock data
366 | mockResultTrue := []map[string]interface{}{
367 | {"is_hypertable": true},
368 | }
369 |
370 | mockResultFalse := []map[string]interface{}{
371 | {"is_hypertable": false},
372 | }
373 |
374 | // Test table is a hypertable
375 | mockDB.RegisterQueryResult("WHERE table_name = 'test_table'", mockResultTrue, nil)
376 | isHypertable, err := tsdb.CheckIfHypertable(ctx, "test_table")
377 | if err != nil {
378 | t.Fatalf("Failed to check if hypertable: %v", err)
379 | }
380 |
381 | if !isHypertable {
382 | t.Error("Expected table to be a hypertable, got false")
383 | }
384 |
385 | // Test table is not a hypertable
386 | mockDB.RegisterQueryResult("WHERE table_name = 'regular_table'", mockResultFalse, nil)
387 | isHypertable, err = tsdb.CheckIfHypertable(ctx, "regular_table")
388 | if err != nil {
389 | t.Fatalf("Failed to check if hypertable: %v", err)
390 | }
391 |
392 | if isHypertable {
393 | t.Error("Expected table not to be a hypertable, got true")
394 | }
395 |
396 | // Test when TimescaleDB is not available
397 | tsdb.isTimescaleDB = false
398 | _, err = tsdb.CheckIfHypertable(ctx, "test_table")
399 | if err == nil {
400 | t.Error("Expected error when TimescaleDB is not available, got nil")
401 | }
402 |
403 | // Test execution error
404 | tsdb.isTimescaleDB = true
405 | mockDB.RegisterQueryResult("WHERE table_name = 'error_table'", nil, errors.New("mocked error"))
406 | _, err = tsdb.CheckIfHypertable(ctx, "error_table")
407 | if err == nil {
408 | t.Error("Expected query error, got nil")
409 | }
410 |
411 | // Test unexpected result structure
412 | mockDB.RegisterQueryResult("WHERE table_name = 'bad_structure'", []map[string]interface{}{}, nil)
413 | _, err = tsdb.CheckIfHypertable(ctx, "bad_structure")
414 | if err == nil {
415 | t.Error("Expected error for bad result structure, got nil")
416 | }
417 | }
418 |
419 | func TestRecentChunks(t *testing.T) {
420 | mockDB := NewMockDB()
421 | tsdb := &DB{
422 | Database: mockDB,
423 | isTimescaleDB: true,
424 | }
425 |
426 | ctx := context.Background()
427 |
428 | // Prepare mock data
429 | mockResult := []map[string]interface{}{
430 | {
431 | "chunk_name": "_hyper_1_1_chunk",
432 | "range_start": "2023-01-01 00:00:00",
433 | "range_end": "2023-01-02 00:00:00",
434 | "is_compressed": false,
435 | },
436 | {
437 | "chunk_name": "_hyper_1_2_chunk",
438 | "range_start": "2023-01-02 00:00:00",
439 | "range_end": "2023-01-03 00:00:00",
440 | "is_compressed": true,
441 | },
442 | }
443 |
444 | // Register mock result
445 | mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", mockResult, nil)
446 |
447 | // Test getting recent chunks
448 | _, err := tsdb.RecentChunks(ctx, "test_table", 2)
449 | if err != nil {
450 | t.Fatalf("Failed to get recent chunks: %v", err)
451 | }
452 |
453 | // Check that a query with the right table name and limit was executed
454 | query, _ := mockDB.GetLastQuery()
455 | AssertQueryContains(t, query, "hypertable_name = 'test_table'")
456 | AssertQueryContains(t, query, "LIMIT 2")
457 |
458 | // Test with default limit
459 | _, err = tsdb.RecentChunks(ctx, "test_table", 0)
460 | if err != nil {
461 | t.Fatalf("Failed to get recent chunks with default limit: %v", err)
462 | }
463 |
464 | query, _ = mockDB.GetLastQuery()
465 | AssertQueryContains(t, query, "LIMIT 10")
466 |
467 | // Test when TimescaleDB is not available
468 | tsdb.isTimescaleDB = false
469 | _, err = tsdb.RecentChunks(ctx, "test_table", 2)
470 | if err == nil {
471 | t.Error("Expected error when TimescaleDB is not available, got nil")
472 | }
473 |
474 | // Test execution error
475 | tsdb.isTimescaleDB = true
476 | mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", nil, errors.New("mocked error"))
477 | _, err = tsdb.RecentChunks(ctx, "test_table", 2)
478 | if err == nil {
479 | t.Error("Expected query error, got nil")
480 | }
481 | }
482 |
```
--------------------------------------------------------------------------------
/pkg/db/timescale/query.go:
--------------------------------------------------------------------------------
```go
1 | package timescale
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strings"
7 | "time"
8 | )
9 |
10 | // TimeBucket represents a time bucket for time-series aggregation
11 | type TimeBucket struct {
12 | Interval string // e.g., '1 hour', '1 day', '1 month'
13 | Column string // Time column to bucket
14 | Alias string // Optional alias for the bucket column
15 | }
16 |
17 | // AggregateFunction represents a common aggregate function
18 | type AggregateFunction string
19 |
20 | const (
21 | // AggrAvg calculates the average value of a column
22 | AggrAvg AggregateFunction = "AVG"
23 | // AggrSum calculates the sum of values in a column
24 | AggrSum AggregateFunction = "SUM"
25 | // AggrMin finds the minimum value in a column
26 | AggrMin AggregateFunction = "MIN"
27 | // AggrMax finds the maximum value in a column
28 | AggrMax AggregateFunction = "MAX"
29 | // AggrCount counts the number of rows
30 | AggrCount AggregateFunction = "COUNT"
31 | // AggrFirst takes the first value in a window
32 | AggrFirst AggregateFunction = "FIRST"
33 | // AggrLast takes the last value in a window
34 | AggrLast AggregateFunction = "LAST"
35 | )
36 |
37 | // ColumnAggregation represents an aggregation operation on a column
38 | type ColumnAggregation struct {
39 | Function AggregateFunction
40 | Column string
41 | Alias string
42 | }
43 |
44 | // TimeseriesQueryBuilder helps build optimized time-series queries
45 | type TimeseriesQueryBuilder struct {
46 | table string
47 | timeBucket *TimeBucket
48 | selectCols []string
49 | aggregations []ColumnAggregation
50 | whereClauses []string
51 | whereArgs []interface{}
52 | groupByCols []string
53 | orderByCols []string
54 | limit int
55 | offset int
56 | }
57 |
58 | // NewTimeseriesQueryBuilder creates a new builder for a specific table
59 | func NewTimeseriesQueryBuilder(table string) *TimeseriesQueryBuilder {
60 | return &TimeseriesQueryBuilder{
61 | table: table,
62 | selectCols: make([]string, 0),
63 | aggregations: make([]ColumnAggregation, 0),
64 | whereClauses: make([]string, 0),
65 | whereArgs: make([]interface{}, 0),
66 | groupByCols: make([]string, 0),
67 | orderByCols: make([]string, 0),
68 | }
69 | }
70 |
71 | // WithTimeBucket adds a time bucket to the query
72 | func (b *TimeseriesQueryBuilder) WithTimeBucket(interval, column, alias string) *TimeseriesQueryBuilder {
73 | b.timeBucket = &TimeBucket{
74 | Interval: interval,
75 | Column: column,
76 | Alias: alias,
77 | }
78 | return b
79 | }
80 |
81 | // Select adds columns to the SELECT clause
82 | func (b *TimeseriesQueryBuilder) Select(cols ...string) *TimeseriesQueryBuilder {
83 | b.selectCols = append(b.selectCols, cols...)
84 | return b
85 | }
86 |
87 | // Aggregate adds an aggregation function to a column
88 | func (b *TimeseriesQueryBuilder) Aggregate(function AggregateFunction, column, alias string) *TimeseriesQueryBuilder {
89 | b.aggregations = append(b.aggregations, ColumnAggregation{
90 | Function: function,
91 | Column: column,
92 | Alias: alias,
93 | })
94 | return b
95 | }
96 |
97 | // WhereTimeRange adds a time range condition
98 | func (b *TimeseriesQueryBuilder) WhereTimeRange(column string, start, end time.Time) *TimeseriesQueryBuilder {
99 | clause := fmt.Sprintf("%s BETWEEN $%d AND $%d", column, len(b.whereArgs)+1, len(b.whereArgs)+2)
100 | b.whereClauses = append(b.whereClauses, clause)
101 | b.whereArgs = append(b.whereArgs, start, end)
102 | return b
103 | }
104 |
105 | // Where adds a WHERE condition
106 | func (b *TimeseriesQueryBuilder) Where(clause string, args ...interface{}) *TimeseriesQueryBuilder {
107 | // Adjust the parameter indices in the clause
108 | paramCount := len(b.whereArgs)
109 | for i := 1; i <= len(args); i++ {
110 | oldParam := fmt.Sprintf("$%d", i)
111 | newParam := fmt.Sprintf("$%d", i+paramCount)
112 | clause = strings.Replace(clause, oldParam, newParam, -1)
113 | }
114 |
115 | b.whereClauses = append(b.whereClauses, clause)
116 | b.whereArgs = append(b.whereArgs, args...)
117 | return b
118 | }
119 |
120 | // GroupBy adds columns to the GROUP BY clause
121 | func (b *TimeseriesQueryBuilder) GroupBy(cols ...string) *TimeseriesQueryBuilder {
122 | b.groupByCols = append(b.groupByCols, cols...)
123 | return b
124 | }
125 |
126 | // OrderBy adds columns to the ORDER BY clause
127 | func (b *TimeseriesQueryBuilder) OrderBy(cols ...string) *TimeseriesQueryBuilder {
128 | b.orderByCols = append(b.orderByCols, cols...)
129 | return b
130 | }
131 |
132 | // Limit sets the LIMIT clause
133 | func (b *TimeseriesQueryBuilder) Limit(limit int) *TimeseriesQueryBuilder {
134 | b.limit = limit
135 | return b
136 | }
137 |
138 | // Offset sets the OFFSET clause
139 | func (b *TimeseriesQueryBuilder) Offset(offset int) *TimeseriesQueryBuilder {
140 | b.offset = offset
141 | return b
142 | }
143 |
144 | // Build constructs the SQL query and args
145 | func (b *TimeseriesQueryBuilder) Build() (string, []interface{}) {
146 | var selectClause strings.Builder
147 | selectClause.WriteString("SELECT ")
148 |
149 | var selects []string
150 |
151 | // Add time bucket if specified
152 | if b.timeBucket != nil {
153 | alias := b.timeBucket.Alias
154 | if alias == "" {
155 | alias = "time_bucket"
156 | }
157 |
158 | bucketStr := fmt.Sprintf(
159 | "time_bucket('%s', %s) AS %s",
160 | b.timeBucket.Interval,
161 | b.timeBucket.Column,
162 | alias,
163 | )
164 | selects = append(selects, bucketStr)
165 |
166 | // Add time bucket to group by if not already included
167 | bucketFound := false
168 | for _, col := range b.groupByCols {
169 | if col == alias {
170 | bucketFound = true
171 | break
172 | }
173 | }
174 |
175 | if !bucketFound {
176 | b.groupByCols = append([]string{alias}, b.groupByCols...)
177 | }
178 | }
179 |
180 | // Add selected columns
181 | selects = append(selects, b.selectCols...)
182 |
183 | // Add aggregations
184 | for _, agg := range b.aggregations {
185 | alias := agg.Alias
186 | if alias == "" {
187 | alias = strings.ToLower(string(agg.Function)) + "_" + agg.Column
188 | }
189 |
190 | aggStr := fmt.Sprintf(
191 | "%s(%s) AS %s",
192 | agg.Function,
193 | agg.Column,
194 | alias,
195 | )
196 | selects = append(selects, aggStr)
197 | }
198 |
199 | // If no columns or aggregations selected, use *
200 | if len(selects) == 0 {
201 | selectClause.WriteString("*")
202 | } else {
203 | selectClause.WriteString(strings.Join(selects, ", "))
204 | }
205 |
206 | // Build query
207 | query := fmt.Sprintf("%s FROM %s", selectClause.String(), b.table)
208 |
209 | // Add WHERE clause
210 | if len(b.whereClauses) > 0 {
211 | query += " WHERE " + strings.Join(b.whereClauses, " AND ")
212 | }
213 |
214 | // Add GROUP BY clause
215 | if len(b.groupByCols) > 0 {
216 | query += " GROUP BY " + strings.Join(b.groupByCols, ", ")
217 | }
218 |
219 | // Add ORDER BY clause
220 | if len(b.orderByCols) > 0 {
221 | query += " ORDER BY " + strings.Join(b.orderByCols, ", ")
222 | }
223 |
224 | // Add LIMIT clause
225 | if b.limit > 0 {
226 | query += fmt.Sprintf(" LIMIT %d", b.limit)
227 | }
228 |
229 | // Add OFFSET clause
230 | if b.offset > 0 {
231 | query += fmt.Sprintf(" OFFSET %d", b.offset)
232 | }
233 |
234 | return query, b.whereArgs
235 | }
236 |
237 | // Execute runs the query against the database
238 | func (b *TimeseriesQueryBuilder) Execute(ctx context.Context, db *DB) ([]map[string]interface{}, error) {
239 | query, args := b.Build()
240 | result, err := db.ExecuteSQL(ctx, query, args...)
241 | if err != nil {
242 | return nil, fmt.Errorf("failed to execute time-series query: %w", err)
243 | }
244 |
245 | rows, ok := result.([]map[string]interface{})
246 | if !ok {
247 | return nil, fmt.Errorf("unexpected result type from database query")
248 | }
249 |
250 | return rows, nil
251 | }
252 |
253 | // DownsampleOptions describes options for downsampling time-series data
254 | type DownsampleOptions struct {
255 | SourceTable string
256 | DestTable string
257 | TimeColumn string
258 | BucketInterval string
259 | Aggregations []ColumnAggregation
260 | WhereCondition string
261 | CreateTable bool
262 | ChunkTimeInterval string
263 | }
264 |
265 | // DownsampleTimeSeries creates downsampled time-series data
266 | func (t *DB) DownsampleTimeSeries(ctx context.Context, options DownsampleOptions) error {
267 | if !t.isTimescaleDB {
268 | return fmt.Errorf("TimescaleDB extension not available")
269 | }
270 |
271 | // Create the destination table if requested
272 | if options.CreateTable {
273 | // Get source table columns
274 | schemaQuery := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '%s'", options.SourceTable)
275 | result, err := t.ExecuteSQLWithoutParams(ctx, schemaQuery)
276 | if err != nil {
277 | return fmt.Errorf("failed to get source table schema: %w", err)
278 | }
279 |
280 | columns, ok := result.([]map[string]interface{})
281 | if !ok {
282 | return fmt.Errorf("unexpected result from schema query")
283 | }
284 |
285 | // Build CREATE TABLE statement
286 | var createStmt strings.Builder
287 | createStmt.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", options.DestTable))
288 |
289 | // Add time bucket column
290 | createStmt.WriteString("time_bucket timestamptz, ")
291 |
292 | // Add aggregation columns
293 | for i, agg := range options.Aggregations {
294 | colName := agg.Alias
295 | if colName == "" {
296 | colName = strings.ToLower(string(agg.Function)) + "_" + agg.Column
297 | }
298 |
299 | // Find the data type of the source column
300 | var dataType string
301 | for _, col := range columns {
302 | if fmt.Sprintf("%v", col["column_name"]) == agg.Column {
303 | dataType = fmt.Sprintf("%v", col["data_type"])
304 | break
305 | }
306 | }
307 |
308 | if dataType == "" {
309 | dataType = "double precision" // Default for numeric aggregations
310 | }
311 |
312 | createStmt.WriteString(fmt.Sprintf("%s %s", colName, dataType))
313 |
314 | if i < len(options.Aggregations)-1 {
315 | createStmt.WriteString(", ")
316 | }
317 | }
318 |
319 | createStmt.WriteString(", PRIMARY KEY (time_bucket)")
320 | createStmt.WriteString(")")
321 |
322 | // Create the table
323 | _, err = t.ExecuteSQLWithoutParams(ctx, createStmt.String())
324 | if err != nil {
325 | return fmt.Errorf("failed to create destination table: %w", err)
326 | }
327 |
328 | // Make it a hypertable
329 | if options.ChunkTimeInterval == "" {
330 | options.ChunkTimeInterval = options.BucketInterval
331 | }
332 |
333 | err = t.CreateHypertable(ctx, HypertableConfig{
334 | TableName: options.DestTable,
335 | TimeColumn: "time_bucket",
336 | ChunkTimeInterval: options.ChunkTimeInterval,
337 | IfNotExists: true,
338 | })
339 | if err != nil {
340 | return fmt.Errorf("failed to create hypertable: %w", err)
341 | }
342 | }
343 |
344 | // Build the INSERT statement with aggregations
345 | var insertStmt strings.Builder
346 | insertStmt.WriteString(fmt.Sprintf("INSERT INTO %s (time_bucket, ", options.DestTable))
347 |
348 | // Add aggregation column names
349 | for i, agg := range options.Aggregations {
350 | colName := agg.Alias
351 | if colName == "" {
352 | colName = strings.ToLower(string(agg.Function)) + "_" + agg.Column
353 | }
354 |
355 | insertStmt.WriteString(colName)
356 |
357 | if i < len(options.Aggregations)-1 {
358 | insertStmt.WriteString(", ")
359 | }
360 | }
361 |
362 | insertStmt.WriteString(") SELECT time_bucket('")
363 | insertStmt.WriteString(options.BucketInterval)
364 | insertStmt.WriteString("', ")
365 | insertStmt.WriteString(options.TimeColumn)
366 | insertStmt.WriteString(") AS time_bucket, ")
367 |
368 | // Add aggregation functions
369 | for i, agg := range options.Aggregations {
370 | insertStmt.WriteString(fmt.Sprintf("%s(%s)", agg.Function, agg.Column))
371 |
372 | if i < len(options.Aggregations)-1 {
373 | insertStmt.WriteString(", ")
374 | }
375 | }
376 |
377 | insertStmt.WriteString(fmt.Sprintf(" FROM %s", options.SourceTable))
378 |
379 | // Add WHERE clause if specified
380 | if options.WhereCondition != "" {
381 | insertStmt.WriteString(" WHERE ")
382 | insertStmt.WriteString(options.WhereCondition)
383 | }
384 |
385 | // Group by time bucket
386 | insertStmt.WriteString(" GROUP BY time_bucket")
387 |
388 | // Order by time bucket
389 | insertStmt.WriteString(" ORDER BY time_bucket")
390 |
391 | // Execute the INSERT statement
392 | _, err := t.ExecuteSQLWithoutParams(ctx, insertStmt.String())
393 | if err != nil {
394 | return fmt.Errorf("failed to downsample data: %w", err)
395 | }
396 |
397 | return nil
398 | }
399 |
400 | // TimeRange represents a common time range for queries
401 | type TimeRange struct {
402 | Start time.Time
403 | End time.Time
404 | }
405 |
406 | // PredefinedTimeRange returns a TimeRange for common time ranges
407 | func PredefinedTimeRange(name string) (*TimeRange, error) {
408 | now := time.Now()
409 |
410 | switch strings.ToLower(name) {
411 | case "today":
412 | start := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
413 | return &TimeRange{Start: start, End: now}, nil
414 |
415 | case "yesterday":
416 | end := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
417 | start := end.Add(-24 * time.Hour)
418 | return &TimeRange{Start: start, End: end}, nil
419 |
420 | case "last24hours", "last_24_hours":
421 | start := now.Add(-24 * time.Hour)
422 | return &TimeRange{Start: start, End: now}, nil
423 |
424 | case "thisweek", "this_week":
425 | // Calculate the beginning of the week (Sunday/Monday depending on locale, using Sunday here)
426 | weekday := int(now.Weekday())
427 | start := now.Add(-time.Duration(weekday) * 24 * time.Hour)
428 | start = time.Date(start.Year(), start.Month(), start.Day(), 0, 0, 0, 0, now.Location())
429 | return &TimeRange{Start: start, End: now}, nil
430 |
431 | case "lastweek", "last_week":
432 | // Calculate the beginning of this week
433 | weekday := int(now.Weekday())
434 | thisWeekStart := now.Add(-time.Duration(weekday) * 24 * time.Hour)
435 | thisWeekStart = time.Date(thisWeekStart.Year(), thisWeekStart.Month(), thisWeekStart.Day(), 0, 0, 0, 0, now.Location())
436 |
437 | // Last week is 7 days before the beginning of this week
438 | lastWeekStart := thisWeekStart.Add(-7 * 24 * time.Hour)
439 | lastWeekEnd := thisWeekStart
440 |
441 | return &TimeRange{Start: lastWeekStart, End: lastWeekEnd}, nil
442 |
443 | case "last7days", "last_7_days":
444 | start := now.Add(-7 * 24 * time.Hour)
445 | return &TimeRange{Start: start, End: now}, nil
446 |
447 | case "thismonth", "this_month":
448 | start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
449 | return &TimeRange{Start: start, End: now}, nil
450 |
451 | case "lastmonth", "last_month":
452 | thisMonthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
453 |
454 | var lastMonthStart time.Time
455 | if now.Month() == 1 {
456 | // January, so last month is December of previous year
457 | lastMonthStart = time.Date(now.Year()-1, 12, 1, 0, 0, 0, 0, now.Location())
458 | } else {
459 | // Any other month
460 | lastMonthStart = time.Date(now.Year(), now.Month()-1, 1, 0, 0, 0, 0, now.Location())
461 | }
462 |
463 | return &TimeRange{Start: lastMonthStart, End: thisMonthStart}, nil
464 |
465 | case "last30days", "last_30_days":
466 | start := now.Add(-30 * 24 * time.Hour)
467 | return &TimeRange{Start: start, End: now}, nil
468 |
469 | case "thisyear", "this_year":
470 | start := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
471 | return &TimeRange{Start: start, End: now}, nil
472 |
473 | case "lastyear", "last_year":
474 | thisYearStart := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
475 | lastYearStart := time.Date(now.Year()-1, 1, 1, 0, 0, 0, 0, now.Location())
476 |
477 | return &TimeRange{Start: lastYearStart, End: thisYearStart}, nil
478 |
479 | case "last365days", "last_365_days":
480 | start := now.Add(-365 * 24 * time.Hour)
481 | return &TimeRange{Start: start, End: now}, nil
482 |
483 | default:
484 | return nil, fmt.Errorf("unknown time range: %s", name)
485 | }
486 | }
487 |
```
--------------------------------------------------------------------------------
/internal/delivery/mcp/timescale_tools_test.go:
--------------------------------------------------------------------------------
```go
1 | package mcp_test
2 |
3 | import (
4 | "context"
5 | "strings"
6 | "testing"
7 |
8 | "github.com/FreePeak/cortex/pkg/server"
9 | "github.com/stretchr/testify/assert"
10 | "github.com/stretchr/testify/mock"
11 |
12 | "github.com/FreePeak/db-mcp-server/internal/delivery/mcp"
13 | )
14 |
15 | // MockDatabaseUseCase is a mock implementation of the UseCaseProvider interface
16 | type MockDatabaseUseCase struct {
17 | mock.Mock
18 | }
19 |
20 | // ExecuteStatement mocks the ExecuteStatement method
21 | func (m *MockDatabaseUseCase) ExecuteStatement(ctx context.Context, dbID, statement string, params []interface{}) (string, error) {
22 | args := m.Called(ctx, dbID, statement, params)
23 | return args.String(0), args.Error(1)
24 | }
25 |
26 | // GetDatabaseType mocks the GetDatabaseType method
27 | func (m *MockDatabaseUseCase) GetDatabaseType(dbID string) (string, error) {
28 | args := m.Called(dbID)
29 | return args.String(0), args.Error(1)
30 | }
31 |
32 | // ExecuteQuery mocks the ExecuteQuery method
33 | func (m *MockDatabaseUseCase) ExecuteQuery(ctx context.Context, dbID, query string, params []interface{}) (string, error) {
34 | args := m.Called(ctx, dbID, query, params)
35 | return args.String(0), args.Error(1)
36 | }
37 |
38 | // ExecuteTransaction mocks the ExecuteTransaction method
39 | func (m *MockDatabaseUseCase) ExecuteTransaction(ctx context.Context, dbID, action string, txID string, statement string, params []interface{}, readOnly bool) (string, map[string]interface{}, error) {
40 | args := m.Called(ctx, dbID, action, txID, statement, params, readOnly)
41 | return args.String(0), args.Get(1).(map[string]interface{}), args.Error(2)
42 | }
43 |
44 | // GetDatabaseInfo mocks the GetDatabaseInfo method
45 | func (m *MockDatabaseUseCase) GetDatabaseInfo(dbID string) (map[string]interface{}, error) {
46 | args := m.Called(dbID)
47 | return args.Get(0).(map[string]interface{}), args.Error(1)
48 | }
49 |
50 | // ListDatabases mocks the ListDatabases method
51 | func (m *MockDatabaseUseCase) ListDatabases() []string {
52 | args := m.Called()
53 | return args.Get(0).([]string)
54 | }
55 |
56 | func TestTimescaleDBTool(t *testing.T) {
57 | tool := mcp.NewTimescaleDBTool()
58 | assert.Equal(t, "timescaledb", tool.GetName())
59 | }
60 |
61 | func TestTimeSeriesQueryTool(t *testing.T) {
62 | // Create a mock use case provider
63 | mockUseCase := new(MockDatabaseUseCase)
64 |
65 | // Set up the TimescaleDB tool
66 | tool := mcp.NewTimescaleDBTool()
67 |
68 | // Create a context for testing
69 | ctx := context.Background()
70 |
71 | // Test case for time_series_query operation
72 | t.Run("time_series_query with basic parameters", func(t *testing.T) {
73 | // Sample result that would be returned by the database
74 | sampleResult := `[
75 | {"time_bucket": "2023-01-01T00:00:00Z", "avg_temp": 22.5, "count": 10},
76 | {"time_bucket": "2023-01-02T00:00:00Z", "avg_temp": 23.1, "count": 12}
77 | ]`
78 |
79 | // Set up expectations for the mock
80 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.AnythingOfType("string"), mock.Anything).
81 | Return(sampleResult, nil).Once()
82 |
83 | // Create a request with time_series_query operation
84 | request := server.ToolCallRequest{
85 | Name: "timescaledb_timeseries_query_test_db",
86 | Parameters: map[string]interface{}{
87 | "operation": "time_series_query",
88 | "target_table": "sensor_data",
89 | "time_column": "timestamp",
90 | "bucket_interval": "1 day",
91 | "start_time": "2023-01-01",
92 | "end_time": "2023-01-31",
93 | "aggregations": "AVG(temperature) as avg_temp, COUNT(*) as count",
94 | },
95 | }
96 |
97 | // Call the handler
98 | result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
99 |
100 | // Verify the result
101 | assert.NoError(t, err)
102 | assert.NotNil(t, result)
103 |
104 | // Check the result contains expected fields
105 | resultMap, ok := result.(map[string]interface{})
106 | assert.True(t, ok)
107 | assert.Contains(t, resultMap, "message")
108 | assert.Contains(t, resultMap, "details")
109 | assert.Equal(t, sampleResult, resultMap["details"])
110 |
111 | // Verify the mock expectations
112 | mockUseCase.AssertExpectations(t)
113 | })
114 |
115 | t.Run("time_series_query with window functions", func(t *testing.T) {
116 | // Sample result that would be returned by the database
117 | sampleResult := `[
118 | {"time_bucket": "2023-01-01T00:00:00Z", "avg_temp": 22.5, "prev_avg": null},
119 | {"time_bucket": "2023-01-02T00:00:00Z", "avg_temp": 23.1, "prev_avg": 22.5}
120 | ]`
121 |
122 | // Set up expectations for the mock
123 | mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.AnythingOfType("string"), mock.Anything).
124 | Return(sampleResult, nil).Once()
125 |
126 | // Create a request with time_series_query operation
127 | request := server.ToolCallRequest{
128 | Name: "timescaledb_timeseries_query_test_db",
129 | Parameters: map[string]interface{}{
130 | "operation": "time_series_query",
131 | "target_table": "sensor_data",
132 | "time_column": "timestamp",
133 | "bucket_interval": "1 day",
134 | "aggregations": "AVG(temperature) as avg_temp",
135 | "window_functions": "LAG(avg_temp) OVER (ORDER BY time_bucket) AS prev_avg",
136 | "format_pretty": true,
137 | },
138 | }
139 |
140 | // Call the handler
141 | result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
142 |
143 | // Verify the result
144 | assert.NoError(t, err)
145 | assert.NotNil(t, result)
146 |
147 | // Check the result contains expected fields
148 | resultMap, ok := result.(map[string]interface{})
149 | assert.True(t, ok)
150 | assert.Contains(t, resultMap, "message")
151 | assert.Contains(t, resultMap, "details")
152 | assert.Contains(t, resultMap, "metadata")
153 |
154 | // Check metadata contains expected fields for pretty formatting
155 | metadata, ok := resultMap["metadata"].(map[string]interface{})
156 | assert.True(t, ok)
157 | assert.Contains(t, metadata, "num_rows")
158 | assert.Contains(t, metadata, "time_bucket_interval")
159 |
160 | // Verify the mock expectations
161 | mockUseCase.AssertExpectations(t)
162 | })
163 | }
164 |
165 | // TestContinuousAggregateTool tests the continuous aggregate operations
166 | func TestContinuousAggregateTool(t *testing.T) {
167 | // Create a context for testing
168 | ctx := context.Background()
169 |
170 | // Test case for create_continuous_aggregate operation
171 | t.Run("create_continuous_aggregate", func(t *testing.T) {
172 | // Create a new mock for this test case
173 | mockUseCase := new(MockDatabaseUseCase)
174 |
175 | // Set up the TimescaleDB tool
176 | tool := mcp.NewTimescaleDBTool()
177 |
178 | // Set up expectations
179 | // Removed GetDatabaseType expectation as it's not called in this handler
180 |
181 | // Add mock expectation for the SQL containing CREATE MATERIALIZED VIEW
182 | mockUseCase.On("ExecuteStatement",
183 | mock.Anything,
184 | "test_db",
185 | mock.MatchedBy(func(sql string) bool {
186 | return strings.Contains(sql, "CREATE MATERIALIZED VIEW")
187 | }),
188 | mock.Anything).Return(`{"result": "success"}`, nil)
189 |
190 | // Add separate mock expectation for policy SQL if needed
191 | mockUseCase.On("ExecuteStatement",
192 | mock.Anything,
193 | "test_db",
194 | mock.MatchedBy(func(sql string) bool {
195 | return strings.Contains(sql, "add_continuous_aggregate_policy")
196 | }),
197 | mock.Anything).Return(`{"result": "success"}`, nil)
198 |
199 | // Create a request
200 | request := server.ToolCallRequest{
201 | Name: "timescaledb_create_continuous_aggregate_test_db",
202 | Parameters: map[string]interface{}{
203 | "operation": "create_continuous_aggregate",
204 | "view_name": "daily_metrics",
205 | "source_table": "sensor_data",
206 | "time_column": "timestamp",
207 | "bucket_interval": "1 day",
208 | "aggregations": "AVG(temperature) as avg_temp, MIN(temperature) as min_temp, MAX(temperature) as max_temp",
209 | "with_data": true,
210 | "refresh_policy": true,
211 | "refresh_interval": "1 hour",
212 | },
213 | }
214 |
215 | // Call the handler
216 | result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
217 |
218 | // Verify the result
219 | assert.NoError(t, err)
220 | assert.NotNil(t, result)
221 |
222 | // Check the result contains expected fields
223 | resultMap, ok := result.(map[string]interface{})
224 | assert.True(t, ok)
225 | assert.Contains(t, resultMap, "message")
226 | assert.Contains(t, resultMap, "sql")
227 |
228 | // Verify the mock expectations
229 | mockUseCase.AssertExpectations(t)
230 | })
231 |
232 | // Test case for refresh_continuous_aggregate operation
233 | t.Run("refresh_continuous_aggregate", func(t *testing.T) {
234 | // Create a new mock for this test case
235 | mockUseCase := new(MockDatabaseUseCase)
236 |
237 | // Set up the TimescaleDB tool
238 | tool := mcp.NewTimescaleDBTool()
239 |
240 | // Set up expectations
241 | // Removed GetDatabaseType expectation as it's not called in this handler
242 | mockUseCase.On("ExecuteStatement",
243 | mock.Anything,
244 | "test_db",
245 | mock.MatchedBy(func(sql string) bool {
246 | return strings.Contains(sql, "CALL refresh_continuous_aggregate")
247 | }),
248 | mock.Anything).Return(`{"result": "success"}`, nil)
249 |
250 | // Create a request
251 | request := server.ToolCallRequest{
252 | Name: "timescaledb_refresh_continuous_aggregate_test_db",
253 | Parameters: map[string]interface{}{
254 | "operation": "refresh_continuous_aggregate",
255 | "view_name": "daily_metrics",
256 | "start_time": "2023-01-01",
257 | "end_time": "2023-01-31",
258 | },
259 | }
260 |
261 | // Call the handler
262 | result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
263 |
264 | // Verify the result
265 | assert.NoError(t, err)
266 | assert.NotNil(t, result)
267 |
268 | // Check the result contains expected fields
269 | resultMap, ok := result.(map[string]interface{})
270 | assert.True(t, ok)
271 | assert.Contains(t, resultMap, "message")
272 |
273 | // Verify the mock expectations
274 | mockUseCase.AssertExpectations(t)
275 | })
276 |
277 | // Test case for drop_continuous_aggregate operation
278 | t.Run("drop_continuous_aggregate", func(t *testing.T) {
279 | // Create a new mock for this test case
280 | mockUseCase := new(MockDatabaseUseCase)
281 |
282 | // Set up the TimescaleDB tool
283 | tool := mcp.NewTimescaleDBTool()
284 |
285 | // Set up expectations
286 | // Removed GetDatabaseType expectation as it's not called in this handler
287 | mockUseCase.On("ExecuteStatement",
288 | mock.Anything,
289 | "test_db",
290 | mock.MatchedBy(func(sql string) bool {
291 | return strings.Contains(sql, "DROP MATERIALIZED VIEW")
292 | }),
293 | mock.Anything).Return(`{"result": "success"}`, nil)
294 |
295 | // Create a request
296 | request := server.ToolCallRequest{
297 | Name: "timescaledb_drop_continuous_aggregate_test_db",
298 | Parameters: map[string]interface{}{
299 | "operation": "drop_continuous_aggregate",
300 | "view_name": "daily_metrics",
301 | "cascade": true,
302 | },
303 | }
304 |
305 | // Call the handler
306 | result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
307 |
308 | // Verify the result
309 | assert.NoError(t, err)
310 | assert.NotNil(t, result)
311 |
312 | // Check the result contains expected fields
313 | resultMap, ok := result.(map[string]interface{})
314 | assert.True(t, ok)
315 | assert.Contains(t, resultMap, "message")
316 |
317 | // Verify the mock expectations
318 | mockUseCase.AssertExpectations(t)
319 | })
320 |
321 | // Test case for list_continuous_aggregates operation
322 | t.Run("list_continuous_aggregates", func(t *testing.T) {
323 | // Create a new mock for this test case
324 | mockUseCase := new(MockDatabaseUseCase)
325 |
326 | // Set up the TimescaleDB tool
327 | tool := mcp.NewTimescaleDBTool()
328 |
329 | // Set up expectations
330 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
331 | mockUseCase.On("ExecuteStatement",
332 | mock.Anything,
333 | "test_db",
334 | mock.MatchedBy(func(sql string) bool {
335 | return strings.Contains(sql, "SELECT") && strings.Contains(sql, "continuous_aggregates")
336 | }),
337 | mock.Anything).Return(`[{"view_name": "daily_metrics", "source_table": "sensor_data"}]`, nil)
338 |
339 | // Create a request
340 | request := server.ToolCallRequest{
341 | Name: "timescaledb_list_continuous_aggregates_test_db",
342 | Parameters: map[string]interface{}{
343 | "operation": "list_continuous_aggregates",
344 | },
345 | }
346 |
347 | // Call the handler
348 | result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
349 |
350 | // Verify the result
351 | assert.NoError(t, err)
352 | assert.NotNil(t, result)
353 |
354 | // Check the result contains expected fields
355 | resultMap, ok := result.(map[string]interface{})
356 | assert.True(t, ok)
357 | assert.Contains(t, resultMap, "message")
358 | assert.Contains(t, resultMap, "details")
359 |
360 | // Verify the mock expectations
361 | mockUseCase.AssertExpectations(t)
362 | })
363 |
364 | // Test case for get_continuous_aggregate_info operation
365 | t.Run("get_continuous_aggregate_info", func(t *testing.T) {
366 | // Create a new mock for this test case
367 | mockUseCase := new(MockDatabaseUseCase)
368 |
369 | // Set up the TimescaleDB tool
370 | tool := mcp.NewTimescaleDBTool()
371 |
372 | // Set up expectations
373 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
374 | mockUseCase.On("ExecuteStatement",
375 | mock.Anything,
376 | "test_db",
377 | mock.MatchedBy(func(sql string) bool {
378 | return strings.Contains(sql, "SELECT") && strings.Contains(sql, "continuous_aggregates") && strings.Contains(sql, "WHERE")
379 | }),
380 | mock.Anything).Return(`[{"view_name": "daily_metrics", "source_table": "sensor_data", "bucket_interval": "1 day"}]`, nil)
381 |
382 | // Create a request
383 | request := server.ToolCallRequest{
384 | Name: "timescaledb_get_continuous_aggregate_info_test_db",
385 | Parameters: map[string]interface{}{
386 | "operation": "get_continuous_aggregate_info",
387 | "view_name": "daily_metrics",
388 | },
389 | }
390 |
391 | // Call the handler
392 | result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
393 |
394 | // Verify the result
395 | assert.NoError(t, err)
396 | assert.NotNil(t, result)
397 |
398 | // Check the result contains expected fields
399 | resultMap, ok := result.(map[string]interface{})
400 | assert.True(t, ok)
401 | assert.Contains(t, resultMap, "message")
402 | assert.Contains(t, resultMap, "details")
403 |
404 | // Verify the mock expectations
405 | mockUseCase.AssertExpectations(t)
406 | })
407 |
408 | // Test case for add_continuous_aggregate_policy operation
409 | t.Run("add_continuous_aggregate_policy", func(t *testing.T) {
410 | // Create a new mock for this test case
411 | mockUseCase := new(MockDatabaseUseCase)
412 |
413 | // Set up the TimescaleDB tool
414 | tool := mcp.NewTimescaleDBTool()
415 |
416 | // Set up expectations
417 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
418 | mockUseCase.On("ExecuteStatement",
419 | mock.Anything,
420 | "test_db",
421 | mock.MatchedBy(func(sql string) bool {
422 | return strings.Contains(sql, "add_continuous_aggregate_policy")
423 | }),
424 | mock.Anything).Return(`{"result": "success"}`, nil)
425 |
426 | // Create a request
427 | request := server.ToolCallRequest{
428 | Name: "timescaledb_add_continuous_aggregate_policy_test_db",
429 | Parameters: map[string]interface{}{
430 | "operation": "add_continuous_aggregate_policy",
431 | "view_name": "daily_metrics",
432 | "start_offset": "1 month",
433 | "end_offset": "2 hours",
434 | "schedule_interval": "6 hours",
435 | },
436 | }
437 |
438 | // Call the handler
439 | result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
440 |
441 | // Verify the result
442 | assert.NoError(t, err)
443 | assert.NotNil(t, result)
444 |
445 | // Check the result contains expected fields
446 | resultMap, ok := result.(map[string]interface{})
447 | assert.True(t, ok)
448 | assert.Contains(t, resultMap, "message")
449 |
450 | // Verify the mock expectations
451 | mockUseCase.AssertExpectations(t)
452 | })
453 |
454 | // Test case for remove_continuous_aggregate_policy operation
455 | t.Run("remove_continuous_aggregate_policy", func(t *testing.T) {
456 | // Create a new mock for this test case
457 | mockUseCase := new(MockDatabaseUseCase)
458 |
459 | // Set up the TimescaleDB tool
460 | tool := mcp.NewTimescaleDBTool()
461 |
462 | // Set up expectations
463 | mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
464 | mockUseCase.On("ExecuteStatement",
465 | mock.Anything,
466 | "test_db",
467 | mock.MatchedBy(func(sql string) bool {
468 | return strings.Contains(sql, "remove_continuous_aggregate_policy")
469 | }),
470 | mock.Anything).Return(`{"result": "success"}`, nil)
471 |
472 | // Create a request
473 | request := server.ToolCallRequest{
474 | Name: "timescaledb_remove_continuous_aggregate_policy_test_db",
475 | Parameters: map[string]interface{}{
476 | "operation": "remove_continuous_aggregate_policy",
477 | "view_name": "daily_metrics",
478 | },
479 | }
480 |
481 | // Call the handler
482 | result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
483 |
484 | // Verify the result
485 | assert.NoError(t, err)
486 | assert.NotNil(t, result)
487 |
488 | // Check the result contains expected fields
489 | resultMap, ok := result.(map[string]interface{})
490 | assert.True(t, ok)
491 | assert.Contains(t, resultMap, "message")
492 |
493 | // Verify the mock expectations
494 | mockUseCase.AssertExpectations(t)
495 | })
496 | }
497 |
```
--------------------------------------------------------------------------------
/pkg/db/timescale/metadata.go:
--------------------------------------------------------------------------------
```go
1 | package timescale
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strconv"
7 | "strings"
8 | )
9 |
10 | // HypertableMetadata represents metadata about a TimescaleDB hypertable
11 | type HypertableMetadata struct {
12 | TableName string
13 | SchemaName string
14 | Owner string
15 | NumDimensions int
16 | TimeDimension string
17 | TimeDimensionType string
18 | SpaceDimensions []string
19 | ChunkTimeInterval string
20 | Compression bool
21 | RetentionPolicy bool
22 | TotalSize string
23 | TotalRows int64
24 | Chunks int
25 | }
26 |
27 | // ColumnMetadata represents metadata about a column
28 | type ColumnMetadata struct {
29 | Name string
30 | Type string
31 | Nullable bool
32 | IsPrimaryKey bool
33 | IsIndexed bool
34 | Description string
35 | }
36 |
37 | // ContinuousAggregateMetadata represents metadata about a continuous aggregate
38 | type ContinuousAggregateMetadata struct {
39 | ViewName string
40 | ViewSchema string
41 | MaterializedOnly bool
42 | RefreshInterval string
43 | RefreshLag string
44 | RefreshStartOffset string
45 | RefreshEndOffset string
46 | HypertableName string
47 | HypertableSchema string
48 | ViewDefinition string
49 | }
50 |
51 | // GetHypertableMetadata returns detailed metadata about a hypertable
52 | func (t *DB) GetHypertableMetadata(ctx context.Context, tableName string) (*HypertableMetadata, error) {
53 | if !t.isTimescaleDB {
54 | return nil, fmt.Errorf("TimescaleDB extension not available")
55 | }
56 |
57 | // Query to get basic hypertable information
58 | query := fmt.Sprintf(`
59 | SELECT
60 | h.table_name,
61 | h.schema_name,
62 | t.tableowner as owner,
63 | h.num_dimensions,
64 | dc.column_name as time_dimension,
65 | dc.column_type as time_dimension_type,
66 | dc.time_interval as chunk_time_interval,
67 | h.compression_enabled,
68 | pg_size_pretty(pg_total_relation_size(format('%%I.%%I', h.schema_name, h.table_name))) as total_size,
69 | (SELECT count(*) FROM timescaledb_information.chunks WHERE hypertable_name = h.table_name) as chunks,
70 | (SELECT count(*) FROM %s.%s) as total_rows
71 | FROM timescaledb_information.hypertables h
72 | JOIN pg_tables t ON h.table_name = t.tablename AND h.schema_name = t.schemaname
73 | JOIN timescaledb_information.dimensions dc ON h.hypertable_name = dc.hypertable_name
74 | WHERE h.table_name = '%s' AND dc.dimension_number = 1
75 | `, tableName, tableName, tableName)
76 |
77 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
78 | if err != nil {
79 | return nil, fmt.Errorf("failed to get hypertable metadata: %w", err)
80 | }
81 |
82 | rows, ok := result.([]map[string]interface{})
83 | if !ok || len(rows) == 0 {
84 | return nil, fmt.Errorf("table '%s' is not a hypertable", tableName)
85 | }
86 |
87 | row := rows[0]
88 | metadata := &HypertableMetadata{
89 | TableName: fmt.Sprintf("%v", row["table_name"]),
90 | SchemaName: fmt.Sprintf("%v", row["schema_name"]),
91 | Owner: fmt.Sprintf("%v", row["owner"]),
92 | TimeDimension: fmt.Sprintf("%v", row["time_dimension"]),
93 | TimeDimensionType: fmt.Sprintf("%v", row["time_dimension_type"]),
94 | ChunkTimeInterval: fmt.Sprintf("%v", row["chunk_time_interval"]),
95 | TotalSize: fmt.Sprintf("%v", row["total_size"]),
96 | }
97 |
98 | // Convert numeric fields
99 | if numDimensions, ok := row["num_dimensions"].(int64); ok {
100 | metadata.NumDimensions = int(numDimensions)
101 | } else if numDimensions, ok := row["num_dimensions"].(int); ok {
102 | metadata.NumDimensions = numDimensions
103 | }
104 |
105 | if chunks, ok := row["chunks"].(int64); ok {
106 | metadata.Chunks = int(chunks)
107 | } else if chunks, ok := row["chunks"].(int); ok {
108 | metadata.Chunks = chunks
109 | }
110 |
111 | if rows, ok := row["total_rows"].(int64); ok {
112 | metadata.TotalRows = rows
113 | } else if rows, ok := row["total_rows"].(int); ok {
114 | metadata.TotalRows = int64(rows)
115 | } else if rowsStr, ok := row["total_rows"].(string); ok {
116 | if rows, err := strconv.ParseInt(rowsStr, 10, 64); err == nil {
117 | metadata.TotalRows = rows
118 | }
119 | }
120 |
121 | // Handle boolean fields
122 | if compression, ok := row["compression_enabled"].(bool); ok {
123 | metadata.Compression = compression
124 | } else if compressionStr, ok := row["compression_enabled"].(string); ok {
125 | metadata.Compression = compressionStr == "t" || compressionStr == "true" || compressionStr == "1"
126 | }
127 |
128 | // Get space dimensions if there are more than one dimension
129 | if metadata.NumDimensions > 1 {
130 | spaceDimQuery := fmt.Sprintf(`
131 | SELECT column_name
132 | FROM timescaledb_information.dimensions
133 | WHERE hypertable_name = '%s' AND dimension_number > 1
134 | ORDER BY dimension_number
135 | `, tableName)
136 |
137 | spaceResult, err := t.ExecuteSQLWithoutParams(ctx, spaceDimQuery)
138 | if err == nil {
139 | spaceDimRows, ok := spaceResult.([]map[string]interface{})
140 | if ok {
141 | for _, dimRow := range spaceDimRows {
142 | if colName, ok := dimRow["column_name"]; ok && colName != nil {
143 | metadata.SpaceDimensions = append(metadata.SpaceDimensions, fmt.Sprintf("%v", colName))
144 | }
145 | }
146 | }
147 | }
148 | }
149 |
150 | // Check if a retention policy exists
151 | retentionQuery := fmt.Sprintf(`
152 | SELECT COUNT(*) > 0 as has_retention
153 | FROM timescaledb_information.jobs
154 | WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'
155 | `, tableName)
156 |
157 | retentionResult, err := t.ExecuteSQLWithoutParams(ctx, retentionQuery)
158 | if err == nil {
159 | retentionRows, ok := retentionResult.([]map[string]interface{})
160 | if ok && len(retentionRows) > 0 {
161 | if hasRetention, ok := retentionRows[0]["has_retention"].(bool); ok {
162 | metadata.RetentionPolicy = hasRetention
163 | }
164 | }
165 | }
166 |
167 | return metadata, nil
168 | }
169 |
170 | // GetTableColumns returns metadata about columns in a table
171 | func (t *DB) GetTableColumns(ctx context.Context, tableName string) ([]ColumnMetadata, error) {
172 | query := fmt.Sprintf(`
173 | SELECT
174 | c.column_name,
175 | c.data_type,
176 | c.is_nullable = 'YES' as is_nullable,
177 | (
178 | SELECT COUNT(*) > 0
179 | FROM pg_index i
180 | JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
181 | WHERE i.indrelid = format('%%I.%%I', c.table_schema, c.table_name)::regclass
182 | AND i.indisprimary
183 | AND a.attname = c.column_name
184 | ) as is_primary_key,
185 | (
186 | SELECT COUNT(*) > 0
187 | FROM pg_index i
188 | JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
189 | WHERE i.indrelid = format('%%I.%%I', c.table_schema, c.table_name)::regclass
190 | AND NOT i.indisprimary
191 | AND a.attname = c.column_name
192 | ) as is_indexed,
193 | col_description(format('%%I.%%I', c.table_schema, c.table_name)::regclass::oid,
194 | ordinal_position) as description
195 | FROM information_schema.columns c
196 | WHERE c.table_name = '%s'
197 | ORDER BY c.ordinal_position
198 | `, tableName)
199 |
200 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
201 | if err != nil {
202 | return nil, fmt.Errorf("failed to get table columns: %w", err)
203 | }
204 |
205 | rows, ok := result.([]map[string]interface{})
206 | if !ok {
207 | return nil, fmt.Errorf("unexpected result type from database query")
208 | }
209 |
210 | var columns []ColumnMetadata
211 | for _, row := range rows {
212 | col := ColumnMetadata{
213 | Name: fmt.Sprintf("%v", row["column_name"]),
214 | Type: fmt.Sprintf("%v", row["data_type"]),
215 | }
216 |
217 | // Handle boolean fields
218 | if nullable, ok := row["is_nullable"].(bool); ok {
219 | col.Nullable = nullable
220 | }
221 | if isPK, ok := row["is_primary_key"].(bool); ok {
222 | col.IsPrimaryKey = isPK
223 | }
224 | if isIndexed, ok := row["is_indexed"].(bool); ok {
225 | col.IsIndexed = isIndexed
226 | }
227 |
228 | // Handle description which might be null
229 | if desc, ok := row["description"]; ok && desc != nil {
230 | col.Description = fmt.Sprintf("%v", desc)
231 | }
232 |
233 | columns = append(columns, col)
234 | }
235 |
236 | return columns, nil
237 | }
238 |
239 | // ListContinuousAggregates lists all continuous aggregates
240 | func (t *DB) ListContinuousAggregates(ctx context.Context) ([]ContinuousAggregateMetadata, error) {
241 | if !t.isTimescaleDB {
242 | return nil, fmt.Errorf("TimescaleDB extension not available")
243 | }
244 |
245 | query := `
246 | SELECT
247 | view_name,
248 | view_schema,
249 | materialized_only,
250 | refresh_lag,
251 | refresh_interval,
252 | hypertable_name,
253 | hypertable_schema
254 | FROM timescaledb_information.continuous_aggregates
255 | `
256 |
257 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
258 | if err != nil {
259 | return nil, fmt.Errorf("failed to list continuous aggregates: %w", err)
260 | }
261 |
262 | rows, ok := result.([]map[string]interface{})
263 | if !ok {
264 | return nil, fmt.Errorf("unexpected result type from database query")
265 | }
266 |
267 | var aggregates []ContinuousAggregateMetadata
268 | for _, row := range rows {
269 | agg := ContinuousAggregateMetadata{
270 | ViewName: fmt.Sprintf("%v", row["view_name"]),
271 | ViewSchema: fmt.Sprintf("%v", row["view_schema"]),
272 | HypertableName: fmt.Sprintf("%v", row["hypertable_name"]),
273 | HypertableSchema: fmt.Sprintf("%v", row["hypertable_schema"]),
274 | }
275 |
276 | // Handle boolean fields
277 | if materializedOnly, ok := row["materialized_only"].(bool); ok {
278 | agg.MaterializedOnly = materializedOnly
279 | }
280 |
281 | // Handle nullable fields
282 | if refreshLag, ok := row["refresh_lag"]; ok && refreshLag != nil {
283 | agg.RefreshLag = fmt.Sprintf("%v", refreshLag)
284 | }
285 | if refreshInterval, ok := row["refresh_interval"]; ok && refreshInterval != nil {
286 | agg.RefreshInterval = fmt.Sprintf("%v", refreshInterval)
287 | }
288 |
289 | // Get view definition
290 | definitionQuery := fmt.Sprintf(`
291 | SELECT pg_get_viewdef(format('%%I.%%I', '%s', '%s')::regclass, true) as view_definition
292 | `, agg.ViewSchema, agg.ViewName)
293 |
294 | defResult, err := t.ExecuteSQLWithoutParams(ctx, definitionQuery)
295 | if err == nil {
296 | defRows, ok := defResult.([]map[string]interface{})
297 | if ok && len(defRows) > 0 {
298 | if def, ok := defRows[0]["view_definition"]; ok && def != nil {
299 | agg.ViewDefinition = fmt.Sprintf("%v", def)
300 | }
301 | }
302 | }
303 |
304 | aggregates = append(aggregates, agg)
305 | }
306 |
307 | return aggregates, nil
308 | }
309 |
310 | // GetContinuousAggregate gets metadata about a specific continuous aggregate
311 | func (t *DB) GetContinuousAggregate(ctx context.Context, viewName string) (*ContinuousAggregateMetadata, error) {
312 | if !t.isTimescaleDB {
313 | return nil, fmt.Errorf("TimescaleDB extension not available")
314 | }
315 |
316 | query := fmt.Sprintf(`
317 | SELECT
318 | view_name,
319 | view_schema,
320 | materialized_only,
321 | refresh_lag,
322 | refresh_interval,
323 | hypertable_name,
324 | hypertable_schema
325 | FROM timescaledb_information.continuous_aggregates
326 | WHERE view_name = '%s'
327 | `, viewName)
328 |
329 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
330 | if err != nil {
331 | return nil, fmt.Errorf("failed to get continuous aggregate: %w", err)
332 | }
333 |
334 | rows, ok := result.([]map[string]interface{})
335 | if !ok || len(rows) == 0 {
336 | return nil, fmt.Errorf("continuous aggregate '%s' not found", viewName)
337 | }
338 |
339 | row := rows[0]
340 | agg := &ContinuousAggregateMetadata{
341 | ViewName: fmt.Sprintf("%v", row["view_name"]),
342 | ViewSchema: fmt.Sprintf("%v", row["view_schema"]),
343 | HypertableName: fmt.Sprintf("%v", row["hypertable_name"]),
344 | HypertableSchema: fmt.Sprintf("%v", row["hypertable_schema"]),
345 | }
346 |
347 | // Handle boolean fields
348 | if materializedOnly, ok := row["materialized_only"].(bool); ok {
349 | agg.MaterializedOnly = materializedOnly
350 | }
351 |
352 | // Handle nullable fields
353 | if refreshLag, ok := row["refresh_lag"]; ok && refreshLag != nil {
354 | agg.RefreshLag = fmt.Sprintf("%v", refreshLag)
355 | }
356 | if refreshInterval, ok := row["refresh_interval"]; ok && refreshInterval != nil {
357 | agg.RefreshInterval = fmt.Sprintf("%v", refreshInterval)
358 | }
359 |
360 | // Get view definition
361 | definitionQuery := fmt.Sprintf(`
362 | SELECT pg_get_viewdef(format('%%I.%%I', '%s', '%s')::regclass, true) as view_definition
363 | `, agg.ViewSchema, agg.ViewName)
364 |
365 | defResult, err := t.ExecuteSQLWithoutParams(ctx, definitionQuery)
366 | if err == nil {
367 | defRows, ok := defResult.([]map[string]interface{})
368 | if ok && len(defRows) > 0 {
369 | if def, ok := defRows[0]["view_definition"]; ok && def != nil {
370 | agg.ViewDefinition = fmt.Sprintf("%v", def)
371 | }
372 | }
373 | }
374 |
375 | return agg, nil
376 | }
377 |
378 | // GetDatabaseSize gets size information about the database
379 | func (t *DB) GetDatabaseSize(ctx context.Context) (map[string]string, error) {
380 | query := `
381 | SELECT
382 | pg_size_pretty(pg_database_size(current_database())) as database_size,
383 | current_database() as database_name,
384 | (
385 | SELECT pg_size_pretty(sum(pg_total_relation_size(format('%I.%I', h.schema_name, h.table_name))))
386 | FROM timescaledb_information.hypertables h
387 | ) as hypertables_size,
388 | (
389 | SELECT count(*)
390 | FROM timescaledb_information.hypertables
391 | ) as hypertables_count
392 | `
393 |
394 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
395 | if err != nil {
396 | return nil, fmt.Errorf("failed to get database size: %w", err)
397 | }
398 |
399 | rows, ok := result.([]map[string]interface{})
400 | if !ok || len(rows) == 0 {
401 | return nil, fmt.Errorf("failed to get database size information")
402 | }
403 |
404 | info := make(map[string]string)
405 | for k, v := range rows[0] {
406 | if v != nil {
407 | info[k] = fmt.Sprintf("%v", v)
408 | }
409 | }
410 |
411 | return info, nil
412 | }
413 |
414 | // DetectTimescaleDBVersion checks if TimescaleDB is installed and returns its version
415 | func (t *DB) DetectTimescaleDBVersion(ctx context.Context) (string, error) {
416 | query := "SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'"
417 | result, err := t.ExecuteSQLWithoutParams(ctx, query)
418 | if err != nil {
419 | return "", fmt.Errorf("failed to check TimescaleDB version: %w", err)
420 | }
421 |
422 | rows, ok := result.([]map[string]interface{})
423 | if !ok || len(rows) == 0 {
424 | return "", fmt.Errorf("TimescaleDB extension not installed")
425 | }
426 |
427 | version := rows[0]["extversion"]
428 | if version == nil {
429 | return "", fmt.Errorf("unable to determine TimescaleDB version")
430 | }
431 |
432 | return fmt.Sprintf("%v", version), nil
433 | }
434 |
435 | // GenerateHypertableSchema generates CREATE TABLE and CREATE HYPERTABLE statements for a hypertable
436 | func (t *DB) GenerateHypertableSchema(ctx context.Context, tableName string) (string, error) {
437 | if !t.isTimescaleDB {
438 | return "", fmt.Errorf("TimescaleDB extension not available")
439 | }
440 |
441 | // Get table columns and constraints
442 | columnsQuery := fmt.Sprintf(`
443 | SELECT
444 | 'CREATE TABLE ' || quote_ident('%s') || ' (' ||
445 | string_agg(
446 | quote_ident(column_name) || ' ' ||
447 | data_type ||
448 | CASE
449 | WHEN character_maximum_length IS NOT NULL THEN '(' || character_maximum_length || ')'
450 | WHEN numeric_precision IS NOT NULL AND numeric_scale IS NOT NULL THEN '(' || numeric_precision || ',' || numeric_scale || ')'
451 | ELSE ''
452 | END ||
453 | CASE WHEN is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END,
454 | ', '
455 | ) ||
456 | CASE
457 | WHEN (
458 | SELECT count(*) > 0
459 | FROM information_schema.table_constraints tc
460 | WHERE tc.table_name = '%s' AND tc.constraint_type = 'PRIMARY KEY'
461 | ) THEN
462 | ', ' || (
463 | SELECT 'PRIMARY KEY (' || string_agg(quote_ident(kcu.column_name), ', ') || ')'
464 | FROM information_schema.table_constraints tc
465 | JOIN information_schema.key_column_usage kcu ON
466 | kcu.constraint_name = tc.constraint_name AND
467 | kcu.table_schema = tc.table_schema AND
468 | kcu.table_name = tc.table_name
469 | WHERE tc.table_name = '%s' AND tc.constraint_type = 'PRIMARY KEY'
470 | )
471 | ELSE ''
472 | END ||
473 | ');' as create_table_stmt
474 | FROM information_schema.columns
475 | WHERE table_name = '%s'
476 | GROUP BY table_name
477 | `, tableName, tableName, tableName, tableName)
478 |
479 | columnsResult, err := t.ExecuteSQLWithoutParams(ctx, columnsQuery)
480 | if err != nil {
481 | return "", fmt.Errorf("failed to generate schema: %w", err)
482 | }
483 |
484 | columnsRows, ok := columnsResult.([]map[string]interface{})
485 | if !ok || len(columnsRows) == 0 {
486 | return "", fmt.Errorf("failed to generate schema for table '%s'", tableName)
487 | }
488 |
489 | createTableStmt := fmt.Sprintf("%v", columnsRows[0]["create_table_stmt"])
490 |
491 | // Get hypertable metadata
492 | metadata, err := t.GetHypertableMetadata(ctx, tableName)
493 | if err != nil {
494 | return createTableStmt, nil // Return just the CREATE TABLE statement if it's not a hypertable
495 | }
496 |
497 | // Generate CREATE HYPERTABLE statement
498 | var createHypertableStmt strings.Builder
499 | createHypertableStmt.WriteString(fmt.Sprintf("SELECT create_hypertable('%s', '%s'",
500 | tableName, metadata.TimeDimension))
501 |
502 | if metadata.ChunkTimeInterval != "" {
503 | createHypertableStmt.WriteString(fmt.Sprintf(", chunk_time_interval => INTERVAL '%s'",
504 | metadata.ChunkTimeInterval))
505 | }
506 |
507 | if len(metadata.SpaceDimensions) > 0 {
508 | createHypertableStmt.WriteString(fmt.Sprintf(", partitioning_column => '%s'",
509 | metadata.SpaceDimensions[0]))
510 | }
511 |
512 | createHypertableStmt.WriteString(");")
513 |
514 | // Combine statements
515 | result := createTableStmt + "\n\n" + createHypertableStmt.String()
516 |
517 | // Add compression statement if enabled
518 | if metadata.Compression {
519 | compressionSettings, err := t.GetCompressionSettings(ctx, tableName)
520 | if err == nil && compressionSettings.CompressionEnabled {
521 | compressionStmt := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true);", tableName)
522 | result += "\n\n" + compressionStmt
523 |
524 | // Add compression policy if exists
525 | if compressionSettings.CompressionInterval != "" {
526 | policyStmt := fmt.Sprintf("SELECT add_compression_policy('%s', INTERVAL '%s'",
527 | tableName, compressionSettings.CompressionInterval)
528 |
529 | if compressionSettings.SegmentBy != "" {
530 | policyStmt += fmt.Sprintf(", segmentby => '%s'", compressionSettings.SegmentBy)
531 | }
532 |
533 | if compressionSettings.OrderBy != "" {
534 | policyStmt += fmt.Sprintf(", orderby => '%s'", compressionSettings.OrderBy)
535 | }
536 |
537 | policyStmt += ");"
538 | result += "\n" + policyStmt
539 | }
540 | }
541 | }
542 |
543 | // Add retention policy if enabled
544 | if metadata.RetentionPolicy {
545 | retentionSettings, err := t.GetRetentionSettings(ctx, tableName)
546 | if err == nil && retentionSettings.RetentionEnabled && retentionSettings.RetentionInterval != "" {
547 | retentionStmt := fmt.Sprintf("SELECT add_retention_policy('%s', INTERVAL '%s');",
548 | tableName, retentionSettings.RetentionInterval)
549 | result += "\n\n" + retentionStmt
550 | }
551 | }
552 |
553 | return result, nil
554 | }
555 |
```
--------------------------------------------------------------------------------
/internal/delivery/mcp/tool_types.go:
--------------------------------------------------------------------------------
```go
1 | package mcp
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strings"
7 |
8 | "github.com/FreePeak/cortex/pkg/server"
9 | "github.com/FreePeak/cortex/pkg/tools"
10 | )
11 |
12 | // createTextResponse creates a simple response with a text content
13 | func createTextResponse(text string) map[string]interface{} {
14 | return map[string]interface{}{
15 | "content": []map[string]interface{}{
16 | {
17 | "type": "text",
18 | "text": text,
19 | },
20 | },
21 | }
22 | }
23 |
24 | // addMetadata adds metadata to a response
25 | func addMetadata(resp map[string]interface{}, key string, value interface{}) map[string]interface{} {
26 | if resp["metadata"] == nil {
27 | resp["metadata"] = make(map[string]interface{})
28 | }
29 |
30 | metadata, ok := resp["metadata"].(map[string]interface{})
31 | if !ok {
32 | // Create a new metadata map if conversion fails
33 | metadata = make(map[string]interface{})
34 | resp["metadata"] = metadata
35 | }
36 |
37 | metadata[key] = value
38 | return resp
39 | }
40 |
41 | // TODO: Refactor tool type implementations to reduce duplication and improve maintainability
42 | // TODO: Consider using a code generation approach for repetitive tool patterns
43 | // TODO: Add comprehensive request validation for all tool parameters
44 | // TODO: Implement proper rate limiting and resource protection
45 | // TODO: Add detailed documentation for each tool type and its parameters
46 |
47 | // ToolType interface defines the structure for different types of database tools
48 | type ToolType interface {
49 | // GetName returns the base name of the tool type (e.g., "query", "execute")
50 | GetName() string
51 |
52 | // GetDescription returns a description for this tool type
53 | GetDescription(dbID string) string
54 |
55 | // CreateTool creates a tool with the specified name
56 | // The returned tool must be compatible with server.MCPServer.AddTool's first parameter
57 | CreateTool(name string, dbID string) interface{}
58 |
59 | // HandleRequest handles tool requests for this tool type
60 | HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error)
61 | }
62 |
63 | // UseCaseProvider interface abstracts database use case operations
64 | type UseCaseProvider interface {
65 | ExecuteQuery(ctx context.Context, dbID, query string, params []interface{}) (string, error)
66 | ExecuteStatement(ctx context.Context, dbID, statement string, params []interface{}) (string, error)
67 | ExecuteTransaction(ctx context.Context, dbID, action string, txID string, statement string, params []interface{}, readOnly bool) (string, map[string]interface{}, error)
68 | GetDatabaseInfo(dbID string) (map[string]interface{}, error)
69 | ListDatabases() []string
70 | GetDatabaseType(dbID string) (string, error)
71 | }
72 |
73 | // BaseToolType provides common functionality for tool types
74 | type BaseToolType struct {
75 | name string
76 | description string
77 | }
78 |
79 | // GetName returns the name of the tool type
80 | func (b *BaseToolType) GetName() string {
81 | return b.name
82 | }
83 |
84 | // GetDescription returns a description for the tool type
85 | func (b *BaseToolType) GetDescription(dbID string) string {
86 | return fmt.Sprintf("%s on %s database", b.description, dbID)
87 | }
88 |
89 | //------------------------------------------------------------------------------
90 | // QueryTool implementation
91 | //------------------------------------------------------------------------------
92 |
93 | // QueryTool handles SQL query operations
94 | type QueryTool struct {
95 | BaseToolType
96 | }
97 |
98 | // NewQueryTool creates a new query tool type
99 | func NewQueryTool() *QueryTool {
100 | return &QueryTool{
101 | BaseToolType: BaseToolType{
102 | name: "query",
103 | description: "Execute SQL query",
104 | },
105 | }
106 | }
107 |
108 | // CreateTool creates a query tool
109 | func (t *QueryTool) CreateTool(name string, dbID string) interface{} {
110 | return tools.NewTool(
111 | name,
112 | tools.WithDescription(t.GetDescription(dbID)),
113 | tools.WithString("query",
114 | tools.Description("SQL query to execute"),
115 | tools.Required(),
116 | ),
117 | tools.WithArray("params",
118 | tools.Description("Query parameters"),
119 | tools.Items(map[string]interface{}{"type": "string"}),
120 | ),
121 | )
122 | }
123 |
124 | // HandleRequest handles query tool requests
125 | func (t *QueryTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
126 | // If dbID is not provided, extract it from the tool name
127 | if dbID == "" {
128 | dbID = extractDatabaseIDFromName(request.Name)
129 | }
130 |
131 | query, ok := request.Parameters["query"].(string)
132 | if !ok {
133 | return nil, fmt.Errorf("query parameter must be a string")
134 | }
135 |
136 | var queryParams []interface{}
137 | if request.Parameters["params"] != nil {
138 | if paramsArr, ok := request.Parameters["params"].([]interface{}); ok {
139 | queryParams = paramsArr
140 | }
141 | }
142 |
143 | result, err := useCase.ExecuteQuery(ctx, dbID, query, queryParams)
144 | if err != nil {
145 | return nil, err
146 | }
147 |
148 | return createTextResponse(result), nil
149 | }
150 |
151 | // extractDatabaseIDFromName extracts the database ID from a tool name
152 | func extractDatabaseIDFromName(name string) string {
153 | // Format is: <tooltype>_<dbID>
154 | parts := strings.Split(name, "_")
155 | if len(parts) < 2 {
156 | return ""
157 | }
158 |
159 | // The database ID is the last part
160 | return parts[len(parts)-1]
161 | }
162 |
163 | //------------------------------------------------------------------------------
164 | // ExecuteTool implementation
165 | //------------------------------------------------------------------------------
166 |
167 | // ExecuteTool handles SQL statement execution
168 | type ExecuteTool struct {
169 | BaseToolType
170 | }
171 |
172 | // NewExecuteTool creates a new execute tool type
173 | func NewExecuteTool() *ExecuteTool {
174 | return &ExecuteTool{
175 | BaseToolType: BaseToolType{
176 | name: "execute",
177 | description: "Execute SQL statement",
178 | },
179 | }
180 | }
181 |
182 | // CreateTool creates an execute tool
183 | func (t *ExecuteTool) CreateTool(name string, dbID string) interface{} {
184 | return tools.NewTool(
185 | name,
186 | tools.WithDescription(t.GetDescription(dbID)),
187 | tools.WithString("statement",
188 | tools.Description("SQL statement to execute (INSERT, UPDATE, DELETE, etc.)"),
189 | tools.Required(),
190 | ),
191 | tools.WithArray("params",
192 | tools.Description("Statement parameters"),
193 | tools.Items(map[string]interface{}{"type": "string"}),
194 | ),
195 | )
196 | }
197 |
198 | // HandleRequest handles execute tool requests
199 | func (t *ExecuteTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
200 | // If dbID is not provided, extract it from the tool name
201 | if dbID == "" {
202 | dbID = extractDatabaseIDFromName(request.Name)
203 | }
204 |
205 | statement, ok := request.Parameters["statement"].(string)
206 | if !ok {
207 | return nil, fmt.Errorf("statement parameter must be a string")
208 | }
209 |
210 | var statementParams []interface{}
211 | if request.Parameters["params"] != nil {
212 | if paramsArr, ok := request.Parameters["params"].([]interface{}); ok {
213 | statementParams = paramsArr
214 | }
215 | }
216 |
217 | result, err := useCase.ExecuteStatement(ctx, dbID, statement, statementParams)
218 | if err != nil {
219 | return nil, err
220 | }
221 |
222 | return createTextResponse(result), nil
223 | }
224 |
225 | //------------------------------------------------------------------------------
226 | // TransactionTool implementation
227 | //------------------------------------------------------------------------------
228 |
229 | // TransactionTool handles database transactions
230 | type TransactionTool struct {
231 | BaseToolType
232 | }
233 |
234 | // NewTransactionTool creates a new transaction tool type
235 | func NewTransactionTool() *TransactionTool {
236 | return &TransactionTool{
237 | BaseToolType: BaseToolType{
238 | name: "transaction",
239 | description: "Manage transactions",
240 | },
241 | }
242 | }
243 |
244 | // CreateTool creates a transaction tool
245 | func (t *TransactionTool) CreateTool(name string, dbID string) interface{} {
246 | return tools.NewTool(
247 | name,
248 | tools.WithDescription(t.GetDescription(dbID)),
249 | tools.WithString("action",
250 | tools.Description("Transaction action (begin, commit, rollback, execute)"),
251 | tools.Required(),
252 | ),
253 | tools.WithString("transactionId",
254 | tools.Description("Transaction ID (required for commit, rollback, execute)"),
255 | ),
256 | tools.WithString("statement",
257 | tools.Description("SQL statement to execute within transaction (required for execute)"),
258 | ),
259 | tools.WithArray("params",
260 | tools.Description("Statement parameters"),
261 | tools.Items(map[string]interface{}{"type": "string"}),
262 | ),
263 | tools.WithBoolean("readOnly",
264 | tools.Description("Whether the transaction is read-only (for begin)"),
265 | ),
266 | )
267 | }
268 |
269 | // HandleRequest handles transaction tool requests
270 | func (t *TransactionTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
271 | // If dbID is not provided, extract it from the tool name
272 | if dbID == "" {
273 | dbID = extractDatabaseIDFromName(request.Name)
274 | }
275 |
276 | action, ok := request.Parameters["action"].(string)
277 | if !ok {
278 | return nil, fmt.Errorf("action parameter must be a string")
279 | }
280 |
281 | txID := ""
282 | if request.Parameters["transactionId"] != nil {
283 | var ok bool
284 | txID, ok = request.Parameters["transactionId"].(string)
285 | if !ok {
286 | return nil, fmt.Errorf("transactionId parameter must be a string")
287 | }
288 | }
289 |
290 | statement := ""
291 | if request.Parameters["statement"] != nil {
292 | var ok bool
293 | statement, ok = request.Parameters["statement"].(string)
294 | if !ok {
295 | return nil, fmt.Errorf("statement parameter must be a string")
296 | }
297 | }
298 |
299 | var params []interface{}
300 | if request.Parameters["params"] != nil {
301 | if paramsArr, ok := request.Parameters["params"].([]interface{}); ok {
302 | params = paramsArr
303 | }
304 | }
305 |
306 | readOnly := false
307 | if request.Parameters["readOnly"] != nil {
308 | var ok bool
309 | readOnly, ok = request.Parameters["readOnly"].(bool)
310 | if !ok {
311 | return nil, fmt.Errorf("readOnly parameter must be a boolean")
312 | }
313 | }
314 |
315 | message, metadata, err := useCase.ExecuteTransaction(ctx, dbID, action, txID, statement, params, readOnly)
316 | if err != nil {
317 | return nil, err
318 | }
319 |
320 | // Create response with text and metadata
321 | resp := createTextResponse(message)
322 |
323 | // Add metadata if provided
324 | for k, v := range metadata {
325 | addMetadata(resp, k, v)
326 | }
327 |
328 | return resp, nil
329 | }
330 |
331 | //------------------------------------------------------------------------------
332 | // PerformanceTool implementation
333 | //------------------------------------------------------------------------------
334 |
335 | // PerformanceTool handles query performance analysis
336 | type PerformanceTool struct {
337 | BaseToolType
338 | }
339 |
340 | // NewPerformanceTool creates a new performance tool type
341 | func NewPerformanceTool() *PerformanceTool {
342 | return &PerformanceTool{
343 | BaseToolType: BaseToolType{
344 | name: "performance",
345 | description: "Analyze query performance",
346 | },
347 | }
348 | }
349 |
350 | // CreateTool creates a performance analysis tool
351 | func (t *PerformanceTool) CreateTool(name string, dbID string) interface{} {
352 | return tools.NewTool(
353 | name,
354 | tools.WithDescription(t.GetDescription(dbID)),
355 | tools.WithString("action",
356 | tools.Description("Action (getSlowQueries, getMetrics, analyzeQuery, reset, setThreshold)"),
357 | tools.Required(),
358 | ),
359 | tools.WithString("query",
360 | tools.Description("SQL query to analyze (required for analyzeQuery)"),
361 | ),
362 | tools.WithNumber("limit",
363 | tools.Description("Maximum number of results to return"),
364 | ),
365 | tools.WithNumber("threshold",
366 | tools.Description("Slow query threshold in milliseconds (required for setThreshold)"),
367 | ),
368 | )
369 | }
370 |
371 | // HandleRequest handles performance tool requests
372 | func (t *PerformanceTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
373 | // If dbID is not provided, extract it from the tool name
374 | if dbID == "" {
375 | dbID = extractDatabaseIDFromName(request.Name)
376 | }
377 |
378 | // This is a simplified implementation
379 | // In a real implementation, this would analyze query performance
380 |
381 | action, ok := request.Parameters["action"].(string)
382 | if !ok {
383 | return nil, fmt.Errorf("action parameter must be a string")
384 | }
385 |
386 | var limit int
387 | if request.Parameters["limit"] != nil {
388 | if limitParam, ok := request.Parameters["limit"].(float64); ok {
389 | limit = int(limitParam)
390 | }
391 | }
392 |
393 | query := ""
394 | if request.Parameters["query"] != nil {
395 | var ok bool
396 | query, ok = request.Parameters["query"].(string)
397 | if !ok {
398 | return nil, fmt.Errorf("query parameter must be a string")
399 | }
400 | }
401 |
402 | var threshold int
403 | if request.Parameters["threshold"] != nil {
404 | if thresholdParam, ok := request.Parameters["threshold"].(float64); ok {
405 | threshold = int(thresholdParam)
406 | }
407 | }
408 |
409 | // This is where we would call the useCase to analyze performance
410 | // For now, just return a placeholder
411 | output := fmt.Sprintf("Performance analysis for action '%s' on database '%s'\n", action, dbID)
412 |
413 | if query != "" {
414 | output += fmt.Sprintf("Query: %s\n", query)
415 | }
416 |
417 | if limit > 0 {
418 | output += fmt.Sprintf("Limit: %d\n", limit)
419 | }
420 |
421 | if threshold > 0 {
422 | output += fmt.Sprintf("Threshold: %d ms\n", threshold)
423 | }
424 |
425 | return createTextResponse(output), nil
426 | }
427 |
428 | //------------------------------------------------------------------------------
429 | // SchemaTool implementation
430 | //------------------------------------------------------------------------------
431 |
432 | // SchemaTool handles database schema exploration
433 | type SchemaTool struct {
434 | BaseToolType
435 | }
436 |
437 | // NewSchemaTool creates a new schema tool type
438 | func NewSchemaTool() *SchemaTool {
439 | return &SchemaTool{
440 | BaseToolType: BaseToolType{
441 | name: "schema",
442 | description: "Get schema of",
443 | },
444 | }
445 | }
446 |
447 | // CreateTool creates a schema tool
448 | func (t *SchemaTool) CreateTool(name string, dbID string) interface{} {
449 | return tools.NewTool(
450 | name,
451 | tools.WithDescription(t.GetDescription(dbID)),
452 | // Use any string parameter for compatibility
453 | tools.WithString("random_string",
454 | tools.Description("Dummy parameter (optional)"),
455 | ),
456 | )
457 | }
458 |
459 | // HandleRequest handles schema tool requests
460 | func (t *SchemaTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
461 | // If dbID is not provided, extract it from the tool name
462 | if dbID == "" {
463 | dbID = extractDatabaseIDFromName(request.Name)
464 | }
465 |
466 | info, err := useCase.GetDatabaseInfo(dbID)
467 | if err != nil {
468 | return nil, err
469 | }
470 |
471 | // Format response text
472 | infoStr := fmt.Sprintf("Database Schema for %s:\n\n%+v", dbID, info)
473 | return createTextResponse(infoStr), nil
474 | }
475 |
476 | //------------------------------------------------------------------------------
477 | // ListDatabasesTool implementation
478 | //------------------------------------------------------------------------------
479 |
480 | // ListDatabasesTool handles listing available databases
481 | type ListDatabasesTool struct {
482 | BaseToolType
483 | }
484 |
485 | // NewListDatabasesTool creates a new list databases tool type
486 | func NewListDatabasesTool() *ListDatabasesTool {
487 | return &ListDatabasesTool{
488 | BaseToolType: BaseToolType{
489 | name: "list_databases",
490 | description: "List all available databases",
491 | },
492 | }
493 | }
494 |
495 | // CreateTool creates a list databases tool
496 | func (t *ListDatabasesTool) CreateTool(name string, dbID string) interface{} {
497 | return tools.NewTool(
498 | name,
499 | tools.WithDescription(t.GetDescription(dbID)),
500 | // Use any string parameter for compatibility
501 | tools.WithString("random_string",
502 | tools.Description("Dummy parameter (optional)"),
503 | ),
504 | )
505 | }
506 |
507 | // HandleRequest handles list databases tool requests
508 | func (t *ListDatabasesTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
509 | databases := useCase.ListDatabases()
510 |
511 | // Format as text for display
512 | output := "Available databases:\n\n"
513 | for i, db := range databases {
514 | output += fmt.Sprintf("%d. %s\n", i+1, db)
515 | }
516 |
517 | if len(databases) == 0 {
518 | output += "No databases configured.\n"
519 | }
520 |
521 | return createTextResponse(output), nil
522 | }
523 |
524 | //------------------------------------------------------------------------------
525 | // ToolTypeFactory provides a factory for creating tool types
526 | //------------------------------------------------------------------------------
527 |
528 | // ToolTypeFactory creates and manages tool types
529 | type ToolTypeFactory struct {
530 | toolTypes map[string]ToolType
531 | }
532 |
533 | // NewToolTypeFactory creates a new tool type factory with all registered tool types
534 | func NewToolTypeFactory() *ToolTypeFactory {
535 | factory := &ToolTypeFactory{
536 | toolTypes: make(map[string]ToolType),
537 | }
538 |
539 | // Register all tool types
540 | factory.Register(NewQueryTool())
541 | factory.Register(NewExecuteTool())
542 | factory.Register(NewTransactionTool())
543 | factory.Register(NewPerformanceTool())
544 | factory.Register(NewSchemaTool())
545 | factory.Register(NewListDatabasesTool())
546 |
547 | return factory
548 | }
549 |
550 | // Register adds a tool type to the factory
551 | func (f *ToolTypeFactory) Register(toolType ToolType) {
552 | f.toolTypes[toolType.GetName()] = toolType
553 | }
554 |
555 | // GetToolType returns a tool type by name
556 | func (f *ToolTypeFactory) GetToolType(name string) (ToolType, bool) {
557 | // Handle new simpler format: <tooltype>_<dbID> or just the tool type name
558 | parts := strings.Split(name, "_")
559 | if len(parts) > 0 {
560 | // First part is the tool type name
561 | toolType, ok := f.toolTypes[parts[0]]
562 | if ok {
563 | return toolType, true
564 | }
565 | }
566 |
567 | // Direct tool type lookup
568 | toolType, ok := f.toolTypes[name]
569 | return toolType, ok
570 | }
571 |
572 | // GetToolTypeForSourceName finds the appropriate tool type for a source name
573 | func (f *ToolTypeFactory) GetToolTypeForSourceName(sourceName string) (ToolType, string, bool) {
574 | // Handle simpler format: <tooltype>_<dbID>
575 | parts := strings.Split(sourceName, "_")
576 |
577 | if len(parts) >= 2 {
578 | // First part is tool type, last part is dbID
579 | toolTypeName := parts[0]
580 | dbID := parts[len(parts)-1]
581 |
582 | toolType, ok := f.toolTypes[toolTypeName]
583 | if ok {
584 | return toolType, dbID, true
585 | }
586 | }
587 |
588 | // Handle case for global tools
589 | if sourceName == "list_databases" {
590 | toolType, ok := f.toolTypes["list_databases"]
591 | return toolType, "", ok
592 | }
593 |
594 | return nil, "", false
595 | }
596 |
597 | // GetAllToolTypes returns all registered tool types
598 | func (f *ToolTypeFactory) GetAllToolTypes() []ToolType {
599 | types := make([]ToolType, 0, len(f.toolTypes))
600 | for _, toolType := range f.toolTypes {
601 | types = append(types, toolType)
602 | }
603 | return types
604 | }
605 |
```
--------------------------------------------------------------------------------
/pkg/db/timescale/policy_test.go:
--------------------------------------------------------------------------------
```go
1 | package timescale
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "testing"
7 | )
8 |
9 | func TestEnableCompression(t *testing.T) {
10 | mockDB := NewMockDB()
11 | tsdb := &DB{
12 | Database: mockDB,
13 | isTimescaleDB: true,
14 | }
15 |
16 | ctx := context.Background()
17 |
18 | // Register mock responses for checking if the table is a hypertable
19 | mockDB.RegisterQueryResult("WHERE table_name = 'test_table'", []map[string]interface{}{
20 | {"is_hypertable": true},
21 | }, nil)
22 |
23 | // Register mock response for the compression check in timescaledb_information.hypertables
24 | mockDB.RegisterQueryResult("FROM timescaledb_information.hypertables WHERE hypertable_name", []map[string]interface{}{
25 | {"compress": true},
26 | }, nil)
27 |
28 | // Test enabling compression without interval
29 | err := tsdb.EnableCompression(ctx, "test_table", "")
30 | if err != nil {
31 | t.Fatalf("Failed to enable compression: %v", err)
32 | }
33 |
34 | // Check that the correct query was executed
35 | query, _ := mockDB.GetLastQuery()
36 | AssertQueryContains(t, query, "ALTER TABLE test_table SET (timescaledb.compress = true)")
37 |
38 | // Test enabling compression with interval
39 | // Register mock responses for specific queries used in this test
40 | mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
41 | mockDB.RegisterQueryResult("SELECT add_compression_policy", nil, nil)
42 | mockDB.RegisterQueryResult("timescaledb_information.hypertables WHERE hypertable_name = 'test_table'", []map[string]interface{}{
43 | {"compress": true},
44 | }, nil)
45 |
46 | err = tsdb.EnableCompression(ctx, "test_table", "7 days")
47 | if err != nil {
48 | t.Fatalf("Failed to enable compression with interval: %v", err)
49 | }
50 |
51 | // Check that the correct policy query was executed
52 | query, _ = mockDB.GetLastQuery()
53 | AssertQueryContains(t, query, "add_compression_policy")
54 | AssertQueryContains(t, query, "test_table")
55 | AssertQueryContains(t, query, "7 days")
56 |
57 | // Test when TimescaleDB is not available
58 | tsdb.isTimescaleDB = false
59 | err = tsdb.EnableCompression(ctx, "test_table", "")
60 | if err == nil {
61 | t.Error("Expected error when TimescaleDB is not available, got nil")
62 | }
63 |
64 | // Test execution error
65 | tsdb.isTimescaleDB = true
66 | mockDB.RegisterQueryResult("ALTER TABLE", nil, errors.New("mocked error"))
67 | err = tsdb.EnableCompression(ctx, "test_table", "")
68 | if err == nil {
69 | t.Error("Expected query error, got nil")
70 | }
71 | }
72 |
73 | func TestDisableCompression(t *testing.T) {
74 | mockDB := NewMockDB()
75 | tsdb := &DB{
76 | Database: mockDB,
77 | isTimescaleDB: true,
78 | }
79 |
80 | ctx := context.Background()
81 |
82 | // Mock successful policy removal and compression disabling
83 | mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
84 | {"job_id": 1},
85 | }, nil)
86 | mockDB.RegisterQueryResult("SELECT remove_compression_policy", nil, nil)
87 | mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
88 |
89 | // Test disabling compression
90 | err := tsdb.DisableCompression(ctx, "test_table")
91 | if err != nil {
92 | t.Fatalf("Failed to disable compression: %v", err)
93 | }
94 |
95 | // Check that the correct ALTER TABLE query was executed
96 | query, _ := mockDB.GetLastQuery()
97 | AssertQueryContains(t, query, "ALTER TABLE test_table SET (timescaledb.compress = false)")
98 |
99 | // Test when no policy exists (should still succeed)
100 | mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{}, nil)
101 | mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
102 |
103 | err = tsdb.DisableCompression(ctx, "test_table")
104 | if err != nil {
105 | t.Fatalf("Failed to disable compression when no policy exists: %v", err)
106 | }
107 |
108 | // Test when TimescaleDB is not available
109 | tsdb.isTimescaleDB = false
110 | err = tsdb.DisableCompression(ctx, "test_table")
111 | if err == nil {
112 | t.Error("Expected error when TimescaleDB is not available, got nil")
113 | }
114 |
115 | // Test execution error
116 | tsdb.isTimescaleDB = true
117 | mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
118 | {"job_id": 1},
119 | }, nil)
120 | mockDB.RegisterQueryResult("SELECT remove_compression_policy", nil, errors.New("mocked error"))
121 | err = tsdb.DisableCompression(ctx, "test_table")
122 | if err == nil {
123 | t.Error("Expected query error, got nil")
124 | }
125 | }
126 |
127 | func TestAddCompressionPolicy(t *testing.T) {
128 | mockDB := NewMockDB()
129 | tsdb := &DB{
130 | Database: mockDB,
131 | isTimescaleDB: true,
132 | }
133 |
134 | ctx := context.Background()
135 |
136 | // Mock checking compression status
137 | mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
138 | {"compress": true},
139 | }, nil)
140 |
141 | // Test adding a basic compression policy
142 | err := tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
143 | if err != nil {
144 | t.Fatalf("Failed to add compression policy: %v", err)
145 | }
146 |
147 | // Check that the correct query was executed
148 | query, _ := mockDB.GetLastQuery()
149 | AssertQueryContains(t, query, "SELECT add_compression_policy('test_table', INTERVAL '7 days'")
150 |
151 | // Test adding a policy with segmentby and orderby
152 | err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "device_id", "time DESC")
153 | if err != nil {
154 | t.Fatalf("Failed to add compression policy with additional options: %v", err)
155 | }
156 |
157 | // Check that the correct query was executed
158 | query, _ = mockDB.GetLastQuery()
159 | AssertQueryContains(t, query, "segmentby => 'device_id'")
160 | AssertQueryContains(t, query, "orderby => 'time DESC'")
161 |
162 | // Test enabling compression first if not already enabled
163 | mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
164 | {"compress": false},
165 | }, nil)
166 | mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
167 |
168 | err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
169 | if err != nil {
170 | t.Fatalf("Failed to add compression policy with compression enabling: %v", err)
171 | }
172 |
173 | // Check that the ALTER TABLE query was executed first
174 | query, _ = mockDB.GetLastQuery()
175 | AssertQueryContains(t, query, "SELECT add_compression_policy")
176 |
177 | // Test when TimescaleDB is not available
178 | tsdb.isTimescaleDB = false
179 | err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
180 | if err == nil {
181 | t.Error("Expected error when TimescaleDB is not available, got nil")
182 | }
183 |
184 | // Test execution error on compression check
185 | tsdb.isTimescaleDB = true
186 | mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", nil, errors.New("mocked error"))
187 | err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
188 | if err == nil {
189 | t.Error("Expected query error, got nil")
190 | }
191 | }
192 |
193 | func TestRemoveCompressionPolicy(t *testing.T) {
194 | mockDB := NewMockDB()
195 | tsdb := &DB{
196 | Database: mockDB,
197 | isTimescaleDB: true,
198 | }
199 |
200 | ctx := context.Background()
201 |
202 | // Mock finding a policy
203 | mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
204 | {"job_id": 1},
205 | }, nil)
206 | mockDB.RegisterQueryResult("SELECT remove_compression_policy", nil, nil)
207 |
208 | // Test removing a compression policy
209 | err := tsdb.RemoveCompressionPolicy(ctx, "test_table")
210 | if err != nil {
211 | t.Fatalf("Failed to remove compression policy: %v", err)
212 | }
213 |
214 | // Check that the correct query was executed
215 | query, _ := mockDB.GetLastQuery()
216 | AssertQueryContains(t, query, "SELECT remove_compression_policy")
217 |
218 | // Test when no policy exists (should succeed without error)
219 | mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{}, nil)
220 |
221 | err = tsdb.RemoveCompressionPolicy(ctx, "test_table")
222 | if err != nil {
223 | t.Errorf("Expected success when no policy exists, got: %v", err)
224 | }
225 |
226 | // Test when TimescaleDB is not available
227 | tsdb.isTimescaleDB = false
228 | err = tsdb.RemoveCompressionPolicy(ctx, "test_table")
229 | if err == nil {
230 | t.Error("Expected error when TimescaleDB is not available, got nil")
231 | }
232 |
233 | // Test execution error
234 | tsdb.isTimescaleDB = true
235 | mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", nil, errors.New("mocked error"))
236 | err = tsdb.RemoveCompressionPolicy(ctx, "test_table")
237 | if err == nil {
238 | t.Error("Expected query error, got nil")
239 | }
240 | }
241 |
242 | func TestGetCompressionSettings(t *testing.T) {
243 | mockDB := NewMockDB()
244 | tsdb := &DB{
245 | Database: mockDB,
246 | isTimescaleDB: true,
247 | }
248 |
249 | ctx := context.Background()
250 |
251 | // Mock compression status check
252 | mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
253 | {"compress": true},
254 | }, nil)
255 |
256 | // Mock compression settings
257 | mockDB.RegisterQueryResult("SELECT segmentby, orderby FROM timescaledb_information.compression_settings", []map[string]interface{}{
258 | {"segmentby": "device_id", "orderby": "time DESC"},
259 | }, nil)
260 |
261 | // Mock policy information
262 | mockDB.RegisterQueryResult("SELECT s.schedule_interval, h.chunk_time_interval FROM", []map[string]interface{}{
263 | {"schedule_interval": "7 days", "chunk_time_interval": "1 day"},
264 | }, nil)
265 |
266 | // Test getting compression settings
267 | settings, err := tsdb.GetCompressionSettings(ctx, "test_table")
268 | if err != nil {
269 | t.Fatalf("Failed to get compression settings: %v", err)
270 | }
271 |
272 | // Check the returned settings
273 | if settings.HypertableName != "test_table" {
274 | t.Errorf("Expected HypertableName to be 'test_table', got '%s'", settings.HypertableName)
275 | }
276 |
277 | if !settings.CompressionEnabled {
278 | t.Error("Expected CompressionEnabled to be true, got false")
279 | }
280 |
281 | if settings.SegmentBy != "device_id" {
282 | t.Errorf("Expected SegmentBy to be 'device_id', got '%s'", settings.SegmentBy)
283 | }
284 |
285 | if settings.OrderBy != "time DESC" {
286 | t.Errorf("Expected OrderBy to be 'time DESC', got '%s'", settings.OrderBy)
287 | }
288 |
289 | if settings.CompressionInterval != "7 days" {
290 | t.Errorf("Expected CompressionInterval to be '7 days', got '%s'", settings.CompressionInterval)
291 | }
292 |
293 | if settings.ChunkTimeInterval != "1 day" {
294 | t.Errorf("Expected ChunkTimeInterval to be '1 day', got '%s'", settings.ChunkTimeInterval)
295 | }
296 |
297 | // Test when compression is not enabled
298 | mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
299 | {"compress": false},
300 | }, nil)
301 |
302 | settings, err = tsdb.GetCompressionSettings(ctx, "test_table")
303 | if err != nil {
304 | t.Fatalf("Failed to get compression settings when not enabled: %v", err)
305 | }
306 |
307 | if settings.CompressionEnabled {
308 | t.Error("Expected CompressionEnabled to be false, got true")
309 | }
310 |
311 | // Test when TimescaleDB is not available
312 | tsdb.isTimescaleDB = false
313 | _, err = tsdb.GetCompressionSettings(ctx, "test_table")
314 | if err == nil {
315 | t.Error("Expected error when TimescaleDB is not available, got nil")
316 | }
317 |
318 | // Test execution error
319 | tsdb.isTimescaleDB = true
320 | mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", nil, errors.New("mocked error"))
321 | _, err = tsdb.GetCompressionSettings(ctx, "test_table")
322 | if err == nil {
323 | t.Error("Expected query error, got nil")
324 | }
325 | }
326 |
327 | func TestAddRetentionPolicy(t *testing.T) {
328 | mockDB := NewMockDB()
329 | tsdb := &DB{
330 | Database: mockDB,
331 | isTimescaleDB: true,
332 | }
333 |
334 | ctx := context.Background()
335 |
336 | // Test adding a retention policy
337 | err := tsdb.AddRetentionPolicy(ctx, "test_table", "30 days")
338 | if err != nil {
339 | t.Fatalf("Failed to add retention policy: %v", err)
340 | }
341 |
342 | // Check that the correct query was executed
343 | query, _ := mockDB.GetLastQuery()
344 | AssertQueryContains(t, query, "SELECT add_retention_policy('test_table', INTERVAL '30 days')")
345 |
346 | // Test when TimescaleDB is not available
347 | tsdb.isTimescaleDB = false
348 | err = tsdb.AddRetentionPolicy(ctx, "test_table", "30 days")
349 | if err == nil {
350 | t.Error("Expected error when TimescaleDB is not available, got nil")
351 | }
352 |
353 | // Test execution error
354 | tsdb.isTimescaleDB = true
355 | mockDB.RegisterQueryResult("SELECT add_retention_policy", nil, errors.New("mocked error"))
356 | err = tsdb.AddRetentionPolicy(ctx, "test_table", "30 days")
357 | if err == nil {
358 | t.Error("Expected query error, got nil")
359 | }
360 | }
361 |
362 | func TestRemoveRetentionPolicy(t *testing.T) {
363 | mockDB := NewMockDB()
364 | tsdb := &DB{
365 | Database: mockDB,
366 | isTimescaleDB: true,
367 | }
368 |
369 | ctx := context.Background()
370 |
371 | // Mock finding a policy
372 | mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
373 | {"job_id": 1},
374 | }, nil)
375 | mockDB.RegisterQueryResult("SELECT remove_retention_policy", nil, nil)
376 |
377 | // Test removing a retention policy
378 | err := tsdb.RemoveRetentionPolicy(ctx, "test_table")
379 | if err != nil {
380 | t.Fatalf("Failed to remove retention policy: %v", err)
381 | }
382 |
383 | // Check that the correct query was executed
384 | query, _ := mockDB.GetLastQuery()
385 | AssertQueryContains(t, query, "SELECT remove_retention_policy")
386 |
387 | // Test when no policy exists (should succeed without error)
388 | mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{}, nil)
389 |
390 | err = tsdb.RemoveRetentionPolicy(ctx, "test_table")
391 | if err != nil {
392 | t.Errorf("Expected success when no policy exists, got: %v", err)
393 | }
394 |
395 | // Test when TimescaleDB is not available
396 | tsdb.isTimescaleDB = false
397 | err = tsdb.RemoveRetentionPolicy(ctx, "test_table")
398 | if err == nil {
399 | t.Error("Expected error when TimescaleDB is not available, got nil")
400 | }
401 |
402 | // Test execution error
403 | tsdb.isTimescaleDB = true
404 | mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", nil, errors.New("mocked error"))
405 | err = tsdb.RemoveRetentionPolicy(ctx, "test_table")
406 | if err == nil {
407 | t.Error("Expected query error, got nil")
408 | }
409 | }
410 |
411 | func TestGetRetentionSettings(t *testing.T) {
412 | mockDB := NewMockDB()
413 | tsdb := &DB{
414 | Database: mockDB,
415 | isTimescaleDB: true,
416 | }
417 |
418 | ctx := context.Background()
419 |
420 | // Mock policy information
421 | mockDB.RegisterQueryResult("SELECT s.schedule_interval FROM", []map[string]interface{}{
422 | {"schedule_interval": "30 days"},
423 | }, nil)
424 |
425 | // Test getting retention settings
426 | settings, err := tsdb.GetRetentionSettings(ctx, "test_table")
427 | if err != nil {
428 | t.Fatalf("Failed to get retention settings: %v", err)
429 | }
430 |
431 | // Check the returned settings
432 | if settings.HypertableName != "test_table" {
433 | t.Errorf("Expected HypertableName to be 'test_table', got '%s'", settings.HypertableName)
434 | }
435 |
436 | if !settings.RetentionEnabled {
437 | t.Error("Expected RetentionEnabled to be true, got false")
438 | }
439 |
440 | if settings.RetentionInterval != "30 days" {
441 | t.Errorf("Expected RetentionInterval to be '30 days', got '%s'", settings.RetentionInterval)
442 | }
443 |
444 | // Test when no policy exists
445 | mockDB.RegisterQueryResult("SELECT s.schedule_interval FROM", []map[string]interface{}{}, nil)
446 |
447 | settings, err = tsdb.GetRetentionSettings(ctx, "test_table")
448 | if err != nil {
449 | t.Fatalf("Failed to get retention settings when no policy exists: %v", err)
450 | }
451 |
452 | if settings.RetentionEnabled {
453 | t.Error("Expected RetentionEnabled to be false, got true")
454 | }
455 |
456 | // Test when TimescaleDB is not available
457 | tsdb.isTimescaleDB = false
458 | _, err = tsdb.GetRetentionSettings(ctx, "test_table")
459 | if err == nil {
460 | t.Error("Expected error when TimescaleDB is not available, got nil")
461 | }
462 | }
463 |
464 | func TestCompressChunks(t *testing.T) {
465 | mockDB := NewMockDB()
466 | tsdb := &DB{
467 | Database: mockDB,
468 | isTimescaleDB: true,
469 | }
470 |
471 | ctx := context.Background()
472 |
473 | // Test compressing all chunks
474 | err := tsdb.CompressChunks(ctx, "test_table", "")
475 | if err != nil {
476 | t.Fatalf("Failed to compress all chunks: %v", err)
477 | }
478 |
479 | // Check that the correct query was executed
480 | query, _ := mockDB.GetLastQuery()
481 | AssertQueryContains(t, query, "SELECT compress_chunks(hypertable => 'test_table')")
482 |
483 | // Test compressing chunks with older_than specified
484 | err = tsdb.CompressChunks(ctx, "test_table", "7 days")
485 | if err != nil {
486 | t.Fatalf("Failed to compress chunks with older_than: %v", err)
487 | }
488 |
489 | // Check that the correct query was executed
490 | query, _ = mockDB.GetLastQuery()
491 | AssertQueryContains(t, query, "SELECT compress_chunks(hypertable => 'test_table', older_than => INTERVAL '7 days')")
492 |
493 | // Test when TimescaleDB is not available
494 | tsdb.isTimescaleDB = false
495 | err = tsdb.CompressChunks(ctx, "test_table", "")
496 | if err == nil {
497 | t.Error("Expected error when TimescaleDB is not available, got nil")
498 | }
499 |
500 | // Test execution error
501 | tsdb.isTimescaleDB = true
502 | mockDB.RegisterQueryResult("SELECT compress_chunks", nil, errors.New("mocked error"))
503 | err = tsdb.CompressChunks(ctx, "test_table", "")
504 | if err == nil {
505 | t.Error("Expected query error, got nil")
506 | }
507 | }
508 |
509 | func TestDecompressChunks(t *testing.T) {
510 | mockDB := NewMockDB()
511 | tsdb := &DB{
512 | Database: mockDB,
513 | isTimescaleDB: true,
514 | }
515 |
516 | ctx := context.Background()
517 |
518 | // Test decompressing all chunks
519 | err := tsdb.DecompressChunks(ctx, "test_table", "")
520 | if err != nil {
521 | t.Fatalf("Failed to decompress all chunks: %v", err)
522 | }
523 |
524 | // Check that the correct query was executed
525 | query, _ := mockDB.GetLastQuery()
526 | AssertQueryContains(t, query, "SELECT decompress_chunks(hypertable => 'test_table')")
527 |
528 | // Test decompressing chunks with newer_than specified
529 | err = tsdb.DecompressChunks(ctx, "test_table", "7 days")
530 | if err != nil {
531 | t.Fatalf("Failed to decompress chunks with newer_than: %v", err)
532 | }
533 |
534 | // Check that the correct query was executed
535 | query, _ = mockDB.GetLastQuery()
536 | AssertQueryContains(t, query, "SELECT decompress_chunks(hypertable => 'test_table', newer_than => INTERVAL '7 days')")
537 |
538 | // Test when TimescaleDB is not available
539 | tsdb.isTimescaleDB = false
540 | err = tsdb.DecompressChunks(ctx, "test_table", "")
541 | if err == nil {
542 | t.Error("Expected error when TimescaleDB is not available, got nil")
543 | }
544 |
545 | // Test execution error
546 | tsdb.isTimescaleDB = true
547 | mockDB.RegisterQueryResult("SELECT decompress_chunks", nil, errors.New("mocked error"))
548 | err = tsdb.DecompressChunks(ctx, "test_table", "")
549 | if err == nil {
550 | t.Error("Expected query error, got nil")
551 | }
552 | }
553 |
554 | func TestGetChunkCompressionStats(t *testing.T) {
555 | mockDB := NewMockDB()
556 | tsdb := &DB{
557 | Database: mockDB,
558 | isTimescaleDB: true,
559 | }
560 |
561 | ctx := context.Background()
562 |
563 | // Mock chunk stats
564 | mockStats := []map[string]interface{}{
565 | {
566 | "chunk_name": "_hyper_1_1_chunk",
567 | "range_start": "2023-01-01 00:00:00",
568 | "range_end": "2023-01-02 00:00:00",
569 | "is_compressed": true,
570 | "before_compression_total_bytes": 1000,
571 | "after_compression_total_bytes": 200,
572 | "compression_ratio": 80.0,
573 | },
574 | }
575 | mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", mockStats, nil)
576 |
577 | // Test getting chunk compression stats
578 | _, err := tsdb.GetChunkCompressionStats(ctx, "test_table")
579 | if err != nil {
580 | t.Fatalf("Failed to get chunk compression stats: %v", err)
581 | }
582 |
583 | // Check that the correct query was executed
584 | query, _ := mockDB.GetLastQuery()
585 | AssertQueryContains(t, query, "FROM timescaledb_information.chunks")
586 | AssertQueryContains(t, query, "hypertable_name = 'test_table'")
587 |
588 | // Test when TimescaleDB is not available
589 | tsdb.isTimescaleDB = false
590 | _, err = tsdb.GetChunkCompressionStats(ctx, "test_table")
591 | if err == nil {
592 | t.Error("Expected error when TimescaleDB is not available, got nil")
593 | }
594 |
595 | // Test execution error
596 | tsdb.isTimescaleDB = true
597 | mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", nil, errors.New("mocked error"))
598 | _, err = tsdb.GetChunkCompressionStats(ctx, "test_table")
599 | if err == nil {
600 | t.Error("Expected query error, got nil")
601 | }
602 | }
603 |
```
--------------------------------------------------------------------------------
/pkg/dbtools/schema.go:
--------------------------------------------------------------------------------
```go
1 | package dbtools
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "fmt"
7 | "time"
8 |
9 | "github.com/FreePeak/db-mcp-server/pkg/db"
10 | "github.com/FreePeak/db-mcp-server/pkg/logger"
11 | "github.com/FreePeak/db-mcp-server/pkg/tools"
12 | )
13 |
14 | // DatabaseStrategy defines the interface for database-specific query strategies
15 | type DatabaseStrategy interface {
16 | GetTablesQueries() []queryWithArgs
17 | GetColumnsQueries(table string) []queryWithArgs
18 | GetRelationshipsQueries(table string) []queryWithArgs
19 | }
20 |
21 | // NewDatabaseStrategy creates the appropriate strategy for the given database type
22 | func NewDatabaseStrategy(driverName string) DatabaseStrategy {
23 | switch driverName {
24 | case "postgres":
25 | return &PostgresStrategy{}
26 | case "mysql":
27 | return &MySQLStrategy{}
28 | default:
29 | logger.Warn("Unknown database driver: %s, will use generic strategy", driverName)
30 | return &GenericStrategy{}
31 | }
32 | }
33 |
34 | // PostgresStrategy implements DatabaseStrategy for PostgreSQL
35 | type PostgresStrategy struct{}
36 |
37 | // GetTablesQueries returns queries for retrieving tables in PostgreSQL
38 | func (s *PostgresStrategy) GetTablesQueries() []queryWithArgs {
39 | return []queryWithArgs{
40 | // Primary: pg_catalog approach
41 | {query: "SELECT tablename as table_name FROM pg_catalog.pg_tables WHERE schemaname = 'public'"},
42 | // Secondary: information_schema approach
43 | {query: "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"},
44 | // Tertiary: pg_class approach
45 | {query: "SELECT relname as table_name FROM pg_catalog.pg_class WHERE relkind = 'r' AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')"},
46 | }
47 | }
48 |
49 | // GetColumnsQueries returns queries for retrieving columns in PostgreSQL
50 | func (s *PostgresStrategy) GetColumnsQueries(table string) []queryWithArgs {
51 | return []queryWithArgs{
52 | // Primary: information_schema approach for PostgreSQL
53 | {
54 | query: `
55 | SELECT column_name, data_type,
56 | CASE WHEN is_nullable = 'YES' THEN 'YES' ELSE 'NO' END as is_nullable,
57 | column_default
58 | FROM information_schema.columns
59 | WHERE table_name = $1 AND table_schema = 'public'
60 | ORDER BY ordinal_position
61 | `,
62 | args: []interface{}{table},
63 | },
64 | // Secondary: pg_catalog approach for PostgreSQL
65 | {
66 | query: `
67 | SELECT a.attname as column_name,
68 | pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type,
69 | CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END as is_nullable,
70 | pg_catalog.pg_get_expr(d.adbin, d.adrelid) as column_default
71 | FROM pg_catalog.pg_attribute a
72 | LEFT JOIN pg_catalog.pg_attrdef d ON (a.attrelid = d.adrelid AND a.attnum = d.adnum)
73 | WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1 AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public'))
74 | AND a.attnum > 0 AND NOT a.attisdropped
75 | ORDER BY a.attnum
76 | `,
77 | args: []interface{}{table},
78 | },
79 | }
80 | }
81 |
82 | // GetRelationshipsQueries returns queries for retrieving relationships in PostgreSQL
83 | func (s *PostgresStrategy) GetRelationshipsQueries(table string) []queryWithArgs {
84 | baseQueries := []queryWithArgs{
85 | // Primary: Standard information_schema approach for PostgreSQL
86 | {
87 | query: `
88 | SELECT
89 | tc.table_schema,
90 | tc.constraint_name,
91 | tc.table_name,
92 | kcu.column_name,
93 | ccu.table_schema AS foreign_table_schema,
94 | ccu.table_name AS foreign_table_name,
95 | ccu.column_name AS foreign_column_name
96 | FROM information_schema.table_constraints AS tc
97 | JOIN information_schema.key_column_usage AS kcu
98 | ON tc.constraint_name = kcu.constraint_name
99 | AND tc.table_schema = kcu.table_schema
100 | JOIN information_schema.constraint_column_usage AS ccu
101 | ON ccu.constraint_name = tc.constraint_name
102 | AND ccu.table_schema = tc.table_schema
103 | WHERE tc.constraint_type = 'FOREIGN KEY'
104 | AND tc.table_schema = 'public'
105 | `,
106 | args: []interface{}{},
107 | },
108 | // Alternate: Using pg_catalog for older PostgreSQL versions
109 | {
110 | query: `
111 | SELECT
112 | ns.nspname AS table_schema,
113 | c.conname AS constraint_name,
114 | cl.relname AS table_name,
115 | att.attname AS column_name,
116 | ns2.nspname AS foreign_table_schema,
117 | cl2.relname AS foreign_table_name,
118 | att2.attname AS foreign_column_name
119 | FROM pg_constraint c
120 | JOIN pg_class cl ON c.conrelid = cl.oid
121 | JOIN pg_attribute att ON att.attrelid = cl.oid AND att.attnum = ANY(c.conkey)
122 | JOIN pg_namespace ns ON ns.oid = cl.relnamespace
123 | JOIN pg_class cl2 ON c.confrelid = cl2.oid
124 | JOIN pg_attribute att2 ON att2.attrelid = cl2.oid AND att2.attnum = ANY(c.confkey)
125 | JOIN pg_namespace ns2 ON ns2.oid = cl2.relnamespace
126 | WHERE c.contype = 'f'
127 | AND ns.nspname = 'public'
128 | `,
129 | args: []interface{}{},
130 | },
131 | }
132 |
133 | if table == "" {
134 | return baseQueries
135 | }
136 |
137 | queries := make([]queryWithArgs, len(baseQueries))
138 |
139 | // Add table filter
140 | queries[0] = queryWithArgs{
141 | query: baseQueries[0].query + " AND (tc.table_name = $1 OR ccu.table_name = $1)",
142 | args: []interface{}{table},
143 | }
144 |
145 | queries[1] = queryWithArgs{
146 | query: baseQueries[1].query + " AND (cl.relname = $1 OR cl2.relname = $1)",
147 | args: []interface{}{table},
148 | }
149 |
150 | return queries
151 | }
152 |
153 | // MySQLStrategy implements DatabaseStrategy for MySQL
154 | type MySQLStrategy struct{}
155 |
156 | // GetTablesQueries returns queries for retrieving tables in MySQL
157 | func (s *MySQLStrategy) GetTablesQueries() []queryWithArgs {
158 | return []queryWithArgs{
159 | // Primary: information_schema approach
160 | {query: "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()"},
161 | // Secondary: SHOW TABLES approach
162 | {query: "SHOW TABLES"},
163 | }
164 | }
165 |
166 | // GetColumnsQueries returns queries for retrieving columns in MySQL
167 | func (s *MySQLStrategy) GetColumnsQueries(table string) []queryWithArgs {
168 | return []queryWithArgs{
169 | // MySQL query for columns
170 | {
171 | query: `
172 | SELECT column_name, data_type, is_nullable, column_default
173 | FROM information_schema.columns
174 | WHERE table_name = ? AND table_schema = DATABASE()
175 | ORDER BY ordinal_position
176 | `,
177 | args: []interface{}{table},
178 | },
179 | // Fallback for older MySQL versions
180 | {
181 | query: `SHOW COLUMNS FROM ` + table,
182 | args: []interface{}{},
183 | },
184 | }
185 | }
186 |
187 | // GetRelationshipsQueries returns queries for retrieving relationships in MySQL
188 | func (s *MySQLStrategy) GetRelationshipsQueries(table string) []queryWithArgs {
189 | baseQueries := []queryWithArgs{
190 | // Primary approach for MySQL
191 | {
192 | query: `
193 | SELECT
194 | tc.table_schema,
195 | tc.constraint_name,
196 | tc.table_name,
197 | kcu.column_name,
198 | kcu.referenced_table_schema AS foreign_table_schema,
199 | kcu.referenced_table_name AS foreign_table_name,
200 | kcu.referenced_column_name AS foreign_column_name
201 | FROM information_schema.table_constraints AS tc
202 | JOIN information_schema.key_column_usage AS kcu
203 | ON tc.constraint_name = kcu.constraint_name
204 | AND tc.table_schema = kcu.table_schema
205 | WHERE tc.constraint_type = 'FOREIGN KEY'
206 | AND tc.table_schema = DATABASE()
207 | `,
208 | args: []interface{}{},
209 | },
210 | // Fallback using simpler query for older MySQL versions
211 | {
212 | query: `
213 | SELECT
214 | kcu.constraint_schema AS table_schema,
215 | kcu.constraint_name,
216 | kcu.table_name,
217 | kcu.column_name,
218 | kcu.referenced_table_schema AS foreign_table_schema,
219 | kcu.referenced_table_name AS foreign_table_name,
220 | kcu.referenced_column_name AS foreign_column_name
221 | FROM information_schema.key_column_usage kcu
222 | WHERE kcu.referenced_table_name IS NOT NULL
223 | AND kcu.constraint_schema = DATABASE()
224 | `,
225 | args: []interface{}{},
226 | },
227 | }
228 |
229 | if table == "" {
230 | return baseQueries
231 | }
232 |
233 | queries := make([]queryWithArgs, len(baseQueries))
234 |
235 | // Add table filter
236 | queries[0] = queryWithArgs{
237 | query: baseQueries[0].query + " AND (tc.table_name = ? OR kcu.referenced_table_name = ?)",
238 | args: []interface{}{table, table},
239 | }
240 |
241 | queries[1] = queryWithArgs{
242 | query: baseQueries[1].query + " AND (kcu.table_name = ? OR kcu.referenced_table_name = ?)",
243 | args: []interface{}{table, table},
244 | }
245 |
246 | return queries
247 | }
248 |
249 | // GenericStrategy implements DatabaseStrategy for unknown database types
250 | type GenericStrategy struct{}
251 |
252 | // GetTablesQueries returns generic queries for retrieving tables
253 | func (s *GenericStrategy) GetTablesQueries() []queryWithArgs {
254 | return []queryWithArgs{
255 | {query: "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"},
256 | {query: "SELECT table_name FROM information_schema.tables"},
257 | {query: "SHOW TABLES"}, // Last resort
258 | }
259 | }
260 |
261 | // GetColumnsQueries returns generic queries for retrieving columns
262 | func (s *GenericStrategy) GetColumnsQueries(table string) []queryWithArgs {
263 | return []queryWithArgs{
264 | // Try PostgreSQL-style query first
265 | {
266 | query: `
267 | SELECT column_name, data_type, is_nullable, column_default
268 | FROM information_schema.columns
269 | WHERE table_name = $1
270 | ORDER BY ordinal_position
271 | `,
272 | args: []interface{}{table},
273 | },
274 | // Try MySQL-style query
275 | {
276 | query: `
277 | SELECT column_name, data_type, is_nullable, column_default
278 | FROM information_schema.columns
279 | WHERE table_name = ?
280 | ORDER BY ordinal_position
281 | `,
282 | args: []interface{}{table},
283 | },
284 | }
285 | }
286 |
287 | // GetRelationshipsQueries returns generic queries for retrieving relationships
288 | func (s *GenericStrategy) GetRelationshipsQueries(table string) []queryWithArgs {
289 | pgQuery := queryWithArgs{
290 | query: `
291 | SELECT
292 | tc.table_schema,
293 | tc.constraint_name,
294 | tc.table_name,
295 | kcu.column_name,
296 | ccu.table_schema AS foreign_table_schema,
297 | ccu.table_name AS foreign_table_name,
298 | ccu.column_name AS foreign_column_name
299 | FROM information_schema.table_constraints AS tc
300 | JOIN information_schema.key_column_usage AS kcu
301 | ON tc.constraint_name = kcu.constraint_name
302 | AND tc.table_schema = kcu.table_schema
303 | JOIN information_schema.constraint_column_usage AS ccu
304 | ON ccu.constraint_name = tc.constraint_name
305 | AND ccu.table_schema = tc.table_schema
306 | WHERE tc.constraint_type = 'FOREIGN KEY'
307 | `,
308 | args: []interface{}{},
309 | }
310 |
311 | mysqlQuery := queryWithArgs{
312 | query: `
313 | SELECT
314 | kcu.constraint_schema AS table_schema,
315 | kcu.constraint_name,
316 | kcu.table_name,
317 | kcu.column_name,
318 | kcu.referenced_table_schema AS foreign_table_schema,
319 | kcu.referenced_table_name AS foreign_table_name,
320 | kcu.referenced_column_name AS foreign_column_name
321 | FROM information_schema.key_column_usage kcu
322 | WHERE kcu.referenced_table_name IS NOT NULL
323 | `,
324 | args: []interface{}{},
325 | }
326 |
327 | if table != "" {
328 | pgQuery.query += " AND (tc.table_name = $1 OR ccu.table_name = $1)"
329 | pgQuery.args = append(pgQuery.args, table)
330 |
331 | mysqlQuery.query += " AND (kcu.table_name = ? OR kcu.referenced_table_name = ?)"
332 | mysqlQuery.args = append(mysqlQuery.args, table, table)
333 | }
334 |
335 | return []queryWithArgs{pgQuery, mysqlQuery}
336 | }
337 |
338 | // createSchemaExplorerTool creates a tool for exploring database schema
339 | func createSchemaExplorerTool() *tools.Tool {
340 | return &tools.Tool{
341 | Name: "dbSchema",
342 | Description: "Auto-discover database structure and relationships",
343 | Category: "database",
344 | InputSchema: tools.ToolInputSchema{
345 | Type: "object",
346 | Properties: map[string]interface{}{
347 | "component": map[string]interface{}{
348 | "type": "string",
349 | "description": "Schema component to explore (tables, columns, relationships, or full)",
350 | "enum": []string{"tables", "columns", "relationships", "full"},
351 | },
352 | "table": map[string]interface{}{
353 | "type": "string",
354 | "description": "Table name to explore (optional, leave empty for all tables)",
355 | },
356 | "timeout": map[string]interface{}{
357 | "type": "integer",
358 | "description": "Query timeout in milliseconds (default: 10000)",
359 | },
360 | "database": map[string]interface{}{
361 | "type": "string",
362 | "description": "Database ID to use (optional if only one database is configured)",
363 | },
364 | },
365 | Required: []string{"component", "database"},
366 | },
367 | Handler: handleSchemaExplorer,
368 | }
369 | }
370 |
371 | // handleSchemaExplorer handles the schema explorer tool execution
372 | func handleSchemaExplorer(ctx context.Context, params map[string]interface{}) (interface{}, error) {
373 | // Check if database manager is initialized
374 | if dbManager == nil {
375 | return nil, fmt.Errorf("database manager not initialized")
376 | }
377 |
378 | // Extract parameters
379 | component, ok := getStringParam(params, "component")
380 | if !ok {
381 | return nil, fmt.Errorf("component parameter is required")
382 | }
383 |
384 | // Get database ID
385 | databaseID, ok := getStringParam(params, "database")
386 | if !ok {
387 | return nil, fmt.Errorf("database parameter is required")
388 | }
389 |
390 | // Get database instance
391 | db, err := dbManager.GetDatabase(databaseID)
392 | if err != nil {
393 | return nil, fmt.Errorf("failed to get database: %w", err)
394 | }
395 |
396 | // Extract table parameter (optional depending on component)
397 | table, _ := getStringParam(params, "table")
398 |
399 | // Extract timeout
400 | timeout := 10000 // Default timeout: 10 seconds
401 | if timeoutParam, ok := getIntParam(params, "timeout"); ok {
402 | timeout = timeoutParam
403 | }
404 |
405 | // Create context with timeout
406 | timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond)
407 | defer cancel()
408 |
409 | // Use actual database queries based on component type
410 | switch component {
411 | case "tables":
412 | return getTables(timeoutCtx, db)
413 | case "columns":
414 | if table == "" {
415 | return nil, fmt.Errorf("table parameter is required for columns component")
416 | }
417 | return getColumns(timeoutCtx, db, table)
418 | case "relationships":
419 | return getRelationships(timeoutCtx, db, table)
420 | case "full":
421 | return getFullSchema(timeoutCtx, db)
422 | default:
423 | return nil, fmt.Errorf("invalid component: %s", component)
424 | }
425 | }
426 |
427 | // executeWithFallbacks executes a series of database queries with fallbacks
428 | // Returns the first successful result or the last error encountered
429 | type queryWithArgs struct {
430 | query string
431 | args []interface{}
432 | }
433 |
434 | func executeWithFallbacks(ctx context.Context, db db.Database, queries []queryWithArgs, operationName string) (*sql.Rows, error) {
435 | var lastErr error
436 |
437 | for i, q := range queries {
438 | rows, err := db.Query(ctx, q.query, q.args...)
439 | if err == nil {
440 | return rows, nil
441 | }
442 |
443 | lastErr = err
444 | logger.Warn("%s fallback query %d failed: %v - Error: %v", operationName, i+1, q.query, err)
445 | }
446 |
447 | // All queries failed, return the last error
448 | return nil, fmt.Errorf("%s failed after trying %d fallback queries: %w", operationName, len(queries), lastErr)
449 | }
450 |
451 | // getTables retrieves the list of tables in the database
452 | func getTables(ctx context.Context, db db.Database) (interface{}, error) {
453 | // Get database type from connected database
454 | driverName := db.DriverName()
455 | dbType := driverName
456 |
457 | // Create the appropriate strategy
458 | strategy := NewDatabaseStrategy(driverName)
459 |
460 | // Get queries from strategy
461 | queries := strategy.GetTablesQueries()
462 |
463 | // Execute queries with fallbacks
464 | rows, err := executeWithFallbacks(ctx, db, queries, "getTables")
465 | if err != nil {
466 | return nil, fmt.Errorf("failed to get tables: %w", err)
467 | }
468 |
469 | defer func() {
470 | if rows != nil {
471 | if err := rows.Close(); err != nil {
472 | logger.Error("error closing rows: %v", err)
473 | }
474 | }
475 | }()
476 |
477 | // Convert rows to maps
478 | results, err := rowsToMaps(rows)
479 | if err != nil {
480 | return nil, fmt.Errorf("failed to process tables: %w", err)
481 | }
482 |
483 | return map[string]interface{}{
484 | "tables": results,
485 | "dbType": dbType,
486 | }, nil
487 | }
488 |
489 | // getColumns retrieves the columns for a specific table
490 | func getColumns(ctx context.Context, db db.Database, table string) (interface{}, error) {
491 | // Get database type from connected database
492 | driverName := db.DriverName()
493 | dbType := driverName
494 |
495 | // Create the appropriate strategy
496 | strategy := NewDatabaseStrategy(driverName)
497 |
498 | // Get queries from strategy
499 | queries := strategy.GetColumnsQueries(table)
500 |
501 | // Execute queries with fallbacks
502 | rows, err := executeWithFallbacks(ctx, db, queries, "getColumns["+table+"]")
503 | if err != nil {
504 | return nil, fmt.Errorf("failed to get columns for table %s: %w", table, err)
505 | }
506 |
507 | defer func() {
508 | if rows != nil {
509 | if err := rows.Close(); err != nil {
510 | logger.Error("error closing rows: %v", err)
511 | }
512 | }
513 | }()
514 |
515 | // Convert rows to maps
516 | results, err := rowsToMaps(rows)
517 | if err != nil {
518 | return nil, fmt.Errorf("failed to process columns: %w", err)
519 | }
520 |
521 | return map[string]interface{}{
522 | "table": table,
523 | "columns": results,
524 | "dbType": dbType,
525 | }, nil
526 | }
527 |
528 | // getRelationships retrieves the relationships for a table or all tables
529 | func getRelationships(ctx context.Context, db db.Database, table string) (interface{}, error) {
530 | // Get database type from connected database
531 | driverName := db.DriverName()
532 | dbType := driverName
533 |
534 | // Create the appropriate strategy
535 | strategy := NewDatabaseStrategy(driverName)
536 |
537 | // Get queries from strategy
538 | queries := strategy.GetRelationshipsQueries(table)
539 |
540 | // Execute queries with fallbacks
541 | rows, err := executeWithFallbacks(ctx, db, queries, "getRelationships")
542 | if err != nil {
543 | return nil, fmt.Errorf("failed to get relationships for table %s: %w", table, err)
544 | }
545 |
546 | defer func() {
547 | if rows != nil {
548 | if err := rows.Close(); err != nil {
549 | logger.Error("error closing rows: %v", err)
550 | }
551 | }
552 | }()
553 |
554 | // Convert rows to maps
555 | results, err := rowsToMaps(rows)
556 | if err != nil {
557 | return nil, fmt.Errorf("failed to process relationships: %w", err)
558 | }
559 |
560 | return map[string]interface{}{
561 | "relationships": results,
562 | "dbType": dbType,
563 | "table": table,
564 | }, nil
565 | }
566 |
567 | // safeGetMap safely gets a map from an interface value
568 | func safeGetMap(obj interface{}) (map[string]interface{}, error) {
569 | if obj == nil {
570 | return nil, fmt.Errorf("nil value cannot be converted to map")
571 | }
572 |
573 | mapVal, ok := obj.(map[string]interface{})
574 | if !ok {
575 | return nil, fmt.Errorf("value is not a map[string]interface{}: %T", obj)
576 | }
577 |
578 | return mapVal, nil
579 | }
580 |
581 | // safeGetString safely gets a string from a map key
582 | func safeGetString(m map[string]interface{}, key string) (string, error) {
583 | val, ok := m[key]
584 | if !ok {
585 | return "", fmt.Errorf("key %q not found in map", key)
586 | }
587 |
588 | strVal, ok := val.(string)
589 | if !ok {
590 | return "", fmt.Errorf("value for key %q is not a string: %T", key, val)
591 | }
592 |
593 | return strVal, nil
594 | }
595 |
596 | // getFullSchema retrieves the complete database schema
597 | func getFullSchema(ctx context.Context, db db.Database) (interface{}, error) {
598 | tablesResult, err := getTables(ctx, db)
599 | if err != nil {
600 | return nil, fmt.Errorf("failed to get tables: %w", err)
601 | }
602 |
603 | tablesMap, err := safeGetMap(tablesResult)
604 | if err != nil {
605 | return nil, fmt.Errorf("invalid tables result: %w", err)
606 | }
607 |
608 | tablesSlice, ok := tablesMap["tables"].([]map[string]interface{})
609 | if !ok {
610 | return nil, fmt.Errorf("invalid tables data format")
611 | }
612 |
613 | // For each table, get columns
614 | fullSchema := make(map[string]interface{})
615 | for _, tableInfo := range tablesSlice {
616 | tableName, err := safeGetString(tableInfo, "table_name")
617 | if err != nil {
618 | return nil, fmt.Errorf("invalid table info: %w", err)
619 | }
620 |
621 | columnsResult, columnsErr := getColumns(ctx, db, tableName)
622 | if columnsErr != nil {
623 | return nil, fmt.Errorf("failed to get columns for table %s: %w", tableName, columnsErr)
624 | }
625 | fullSchema[tableName] = columnsResult
626 | }
627 |
628 | // Get all relationships
629 | relationships, relErr := getRelationships(ctx, db, "")
630 | if relErr != nil {
631 | return nil, fmt.Errorf("failed to get relationships: %w", relErr)
632 | }
633 |
634 | relMap, err := safeGetMap(relationships)
635 | if err != nil {
636 | return nil, fmt.Errorf("invalid relationships result: %w", err)
637 | }
638 |
639 | return map[string]interface{}{
640 | "tables": tablesSlice,
641 | "schema": fullSchema,
642 | "relationships": relMap["relationships"],
643 | }, nil
644 | }
645 |
```