#
tokens: 46210/50000 8/102 files (page 4/7)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 4 of 7. Use http://codebase.md/freepeak/db-mcp-server?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .cm
│   └── gitstream.cm
├── .cursor
│   ├── mcp-example.json
│   ├── mcp.json
│   └── rules
│       └── global.mdc
├── .dockerignore
├── .DS_Store
├── .env.example
├── .github
│   ├── FUNDING.yml
│   └── workflows
│       └── go.yml
├── .gitignore
├── .golangci.yml
├── assets
│   └── logo.svg
├── CHANGELOG.md
├── cmd
│   └── server
│       └── main.go
├── commit-message.txt
├── config.json
├── config.timescaledb-test.json
├── docker-compose.mcp-test.yml
├── docker-compose.test.yml
├── docker-compose.timescaledb-test.yml
├── docker-compose.yml
├── docker-wrapper.sh
├── Dockerfile
├── docs
│   ├── REFACTORING.md
│   ├── TIMESCALEDB_FUNCTIONS.md
│   ├── TIMESCALEDB_IMPLEMENTATION.md
│   ├── TIMESCALEDB_PRD.md
│   └── TIMESCALEDB_TOOLS.md
├── examples
│   └── postgres_connection.go
├── glama.json
├── go.mod
├── go.sum
├── init-scripts
│   └── timescaledb
│       ├── 01-init.sql
│       ├── 02-sample-data.sql
│       ├── 03-continuous-aggregates.sql
│       └── README.md
├── internal
│   ├── config
│   │   ├── config_test.go
│   │   └── config.go
│   ├── delivery
│   │   └── mcp
│   │       ├── compression_policy_test.go
│   │       ├── context
│   │       │   ├── hypertable_schema_test.go
│   │       │   ├── timescale_completion_test.go
│   │       │   ├── timescale_context_test.go
│   │       │   └── timescale_query_suggestion_test.go
│   │       ├── mock_test.go
│   │       ├── response_test.go
│   │       ├── response.go
│   │       ├── retention_policy_test.go
│   │       ├── server_wrapper.go
│   │       ├── timescale_completion.go
│   │       ├── timescale_context.go
│   │       ├── timescale_schema.go
│   │       ├── timescale_tool_test.go
│   │       ├── timescale_tool.go
│   │       ├── timescale_tools_test.go
│   │       ├── tool_registry.go
│   │       └── tool_types.go
│   ├── domain
│   │   └── database.go
│   ├── logger
│   │   ├── logger_test.go
│   │   └── logger.go
│   ├── repository
│   │   └── database_repository.go
│   └── usecase
│       └── database_usecase.go
├── LICENSE
├── Makefile
├── pkg
│   ├── core
│   │   ├── core.go
│   │   └── logging.go
│   ├── db
│   │   ├── db_test.go
│   │   ├── db.go
│   │   ├── manager.go
│   │   ├── README.md
│   │   └── timescale
│   │       ├── config_test.go
│   │       ├── config.go
│   │       ├── connection_test.go
│   │       ├── connection.go
│   │       ├── continuous_aggregate_test.go
│   │       ├── continuous_aggregate.go
│   │       ├── hypertable_test.go
│   │       ├── hypertable.go
│   │       ├── metadata.go
│   │       ├── mocks_test.go
│   │       ├── policy_test.go
│   │       ├── policy.go
│   │       ├── query.go
│   │       ├── timeseries_test.go
│   │       └── timeseries.go
│   ├── dbtools
│   │   ├── db_helpers.go
│   │   ├── dbtools_test.go
│   │   ├── dbtools.go
│   │   ├── exec.go
│   │   ├── performance_test.go
│   │   ├── performance.go
│   │   ├── query.go
│   │   ├── querybuilder_test.go
│   │   ├── querybuilder.go
│   │   ├── README.md
│   │   ├── schema_test.go
│   │   ├── schema.go
│   │   ├── tx_test.go
│   │   └── tx.go
│   ├── internal
│   │   └── logger
│   │       └── logger.go
│   ├── jsonrpc
│   │   └── jsonrpc.go
│   ├── logger
│   │   └── logger.go
│   └── tools
│       └── tools.go
├── README-old.md
├── README.md
├── repomix-output.txt
├── request.json
├── start-mcp.sh
├── test.Dockerfile
├── timescaledb-test.sh
└── wait-for-it.sh
```

# Files

--------------------------------------------------------------------------------
/pkg/db/timescale/policy.go:
--------------------------------------------------------------------------------

```go
  1 | package timescale
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"fmt"
  6 | 	"strings"
  7 | )
  8 | 
  9 | // CompressionSettings represents TimescaleDB compression settings
 10 | type CompressionSettings struct {
 11 | 	HypertableName      string
 12 | 	SegmentBy           string
 13 | 	OrderBy             string
 14 | 	ChunkTimeInterval   string
 15 | 	CompressionInterval string
 16 | 	CompressionEnabled  bool
 17 | }
 18 | 
 19 | // RetentionSettings represents TimescaleDB retention settings
 20 | type RetentionSettings struct {
 21 | 	HypertableName    string
 22 | 	RetentionInterval string
 23 | 	RetentionEnabled  bool
 24 | }
 25 | 
 26 | // EnableCompression enables compression on a hypertable
 27 | func (t *DB) EnableCompression(ctx context.Context, tableName string, afterInterval string) error {
 28 | 	if !t.isTimescaleDB {
 29 | 		return fmt.Errorf("TimescaleDB extension not available")
 30 | 	}
 31 | 
 32 | 	query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", tableName)
 33 | 	_, err := t.ExecuteSQLWithoutParams(ctx, query)
 34 | 	if err != nil {
 35 | 		return fmt.Errorf("failed to enable compression: %w", err)
 36 | 	}
 37 | 
 38 | 	// Set compression policy if interval is specified
 39 | 	if afterInterval != "" {
 40 | 		err = t.AddCompressionPolicy(ctx, tableName, afterInterval, "", "")
 41 | 		if err != nil {
 42 | 			return fmt.Errorf("failed to add compression policy: %w", err)
 43 | 		}
 44 | 	}
 45 | 
 46 | 	return nil
 47 | }
 48 | 
 49 | // DisableCompression disables compression on a hypertable
 50 | func (t *DB) DisableCompression(ctx context.Context, tableName string) error {
 51 | 	if !t.isTimescaleDB {
 52 | 		return fmt.Errorf("TimescaleDB extension not available")
 53 | 	}
 54 | 
 55 | 	// First, remove any compression policies
 56 | 	err := t.RemoveCompressionPolicy(ctx, tableName)
 57 | 	if err != nil {
 58 | 		return fmt.Errorf("failed to remove compression policy: %w", err)
 59 | 	}
 60 | 
 61 | 	// Then disable compression
 62 | 	query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = false)", tableName)
 63 | 	_, err = t.ExecuteSQLWithoutParams(ctx, query)
 64 | 	if err != nil {
 65 | 		return fmt.Errorf("failed to disable compression: %w", err)
 66 | 	}
 67 | 
 68 | 	return nil
 69 | }
 70 | 
 71 | // AddCompressionPolicy adds a compression policy to a hypertable
 72 | func (t *DB) AddCompressionPolicy(ctx context.Context, tableName, interval, segmentBy, orderBy string) error {
 73 | 	if !t.isTimescaleDB {
 74 | 		return fmt.Errorf("TimescaleDB extension not available")
 75 | 	}
 76 | 
 77 | 	// First, check if the table has compression enabled
 78 | 	query := fmt.Sprintf("SELECT compress FROM timescaledb_information.hypertables WHERE hypertable_name = '%s'", tableName)
 79 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
 80 | 	if err != nil {
 81 | 		return fmt.Errorf("failed to check compression status: %w", err)
 82 | 	}
 83 | 
 84 | 	rows, ok := result.([]map[string]interface{})
 85 | 	if !ok || len(rows) == 0 {
 86 | 		return fmt.Errorf("table '%s' is not a hypertable", tableName)
 87 | 	}
 88 | 
 89 | 	isCompressed := rows[0]["compress"]
 90 | 	if isCompressed == nil || fmt.Sprintf("%v", isCompressed) == "false" {
 91 | 		// Enable compression
 92 | 		enableQuery := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", tableName)
 93 | 		_, err := t.ExecuteSQLWithoutParams(ctx, enableQuery)
 94 | 		if err != nil {
 95 | 			return fmt.Errorf("failed to enable compression: %w", err)
 96 | 		}
 97 | 	}
 98 | 
 99 | 	// Build the compression policy query
100 | 	var policyQuery strings.Builder
101 | 	policyQuery.WriteString(fmt.Sprintf("SELECT add_compression_policy('%s', INTERVAL '%s'", tableName, interval))
102 | 
103 | 	if segmentBy != "" {
104 | 		policyQuery.WriteString(fmt.Sprintf(", segmentby => '%s'", segmentBy))
105 | 	}
106 | 
107 | 	if orderBy != "" {
108 | 		policyQuery.WriteString(fmt.Sprintf(", orderby => '%s'", orderBy))
109 | 	}
110 | 
111 | 	policyQuery.WriteString(")")
112 | 
113 | 	// Add the compression policy
114 | 	_, err = t.ExecuteSQLWithoutParams(ctx, policyQuery.String())
115 | 	if err != nil {
116 | 		return fmt.Errorf("failed to add compression policy: %w", err)
117 | 	}
118 | 
119 | 	return nil
120 | }
121 | 
122 | // RemoveCompressionPolicy removes a compression policy from a hypertable
123 | func (t *DB) RemoveCompressionPolicy(ctx context.Context, tableName string) error {
124 | 	if !t.isTimescaleDB {
125 | 		return fmt.Errorf("TimescaleDB extension not available")
126 | 	}
127 | 
128 | 	// Find the policy ID
129 | 	query := fmt.Sprintf(
130 | 		"SELECT job_id FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_compression'",
131 | 		tableName,
132 | 	)
133 | 
134 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
135 | 	if err != nil {
136 | 		return fmt.Errorf("failed to find compression policy: %w", err)
137 | 	}
138 | 
139 | 	rows, ok := result.([]map[string]interface{})
140 | 	if !ok || len(rows) == 0 {
141 | 		// No policy exists, so nothing to remove
142 | 		return nil
143 | 	}
144 | 
145 | 	// Get the job ID
146 | 	jobID := rows[0]["job_id"]
147 | 	if jobID == nil {
148 | 		return fmt.Errorf("invalid job ID for compression policy")
149 | 	}
150 | 
151 | 	// Remove the policy
152 | 	removeQuery := fmt.Sprintf("SELECT remove_compression_policy(%v)", jobID)
153 | 	_, err = t.ExecuteSQLWithoutParams(ctx, removeQuery)
154 | 	if err != nil {
155 | 		return fmt.Errorf("failed to remove compression policy: %w", err)
156 | 	}
157 | 
158 | 	return nil
159 | }
160 | 
161 | // GetCompressionSettings gets the compression settings for a hypertable
162 | func (t *DB) GetCompressionSettings(ctx context.Context, tableName string) (*CompressionSettings, error) {
163 | 	if !t.isTimescaleDB {
164 | 		return nil, fmt.Errorf("TimescaleDB extension not available")
165 | 	}
166 | 
167 | 	// Check if the table has compression enabled
168 | 	query := fmt.Sprintf(
169 | 		"SELECT compress FROM timescaledb_information.hypertables WHERE hypertable_name = '%s'",
170 | 		tableName,
171 | 	)
172 | 
173 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
174 | 	if err != nil {
175 | 		return nil, fmt.Errorf("failed to check compression status: %w", err)
176 | 	}
177 | 
178 | 	rows, ok := result.([]map[string]interface{})
179 | 	if !ok || len(rows) == 0 {
180 | 		return nil, fmt.Errorf("table '%s' is not a hypertable", tableName)
181 | 	}
182 | 
183 | 	settings := &CompressionSettings{
184 | 		HypertableName: tableName,
185 | 	}
186 | 
187 | 	isCompressed := rows[0]["compress"]
188 | 	if isCompressed != nil && fmt.Sprintf("%v", isCompressed) == "true" {
189 | 		settings.CompressionEnabled = true
190 | 
191 | 		// Get compression-specific settings
192 | 		settingsQuery := fmt.Sprintf(
193 | 			"SELECT segmentby, orderby FROM timescaledb_information.compression_settings WHERE hypertable_name = '%s'",
194 | 			tableName,
195 | 		)
196 | 
197 | 		settingsResult, err := t.ExecuteSQLWithoutParams(ctx, settingsQuery)
198 | 		if err != nil {
199 | 			return nil, fmt.Errorf("failed to get compression settings: %w", err)
200 | 		}
201 | 
202 | 		settingsRows, ok := settingsResult.([]map[string]interface{})
203 | 		if ok && len(settingsRows) > 0 {
204 | 			if segmentBy, ok := settingsRows[0]["segmentby"]; ok && segmentBy != nil {
205 | 				settings.SegmentBy = fmt.Sprintf("%v", segmentBy)
206 | 			}
207 | 
208 | 			if orderBy, ok := settingsRows[0]["orderby"]; ok && orderBy != nil {
209 | 				settings.OrderBy = fmt.Sprintf("%v", orderBy)
210 | 			}
211 | 		}
212 | 
213 | 		// Check if a compression policy exists
214 | 		policyQuery := fmt.Sprintf(
215 | 			"SELECT s.schedule_interval, h.chunk_time_interval FROM timescaledb_information.jobs j "+
216 | 				"JOIN timescaledb_information.job_stats s ON j.job_id = s.job_id "+
217 | 				"JOIN timescaledb_information.hypertables h ON j.hypertable_name = h.hypertable_name "+
218 | 				"WHERE j.hypertable_name = '%s' AND j.proc_name = 'policy_compression'",
219 | 			tableName,
220 | 		)
221 | 
222 | 		policyResult, err := t.ExecuteSQLWithoutParams(ctx, policyQuery)
223 | 		if err == nil {
224 | 			policyRows, ok := policyResult.([]map[string]interface{})
225 | 			if ok && len(policyRows) > 0 {
226 | 				if interval, ok := policyRows[0]["schedule_interval"]; ok && interval != nil {
227 | 					settings.CompressionInterval = fmt.Sprintf("%v", interval)
228 | 				}
229 | 
230 | 				if chunkInterval, ok := policyRows[0]["chunk_time_interval"]; ok && chunkInterval != nil {
231 | 					settings.ChunkTimeInterval = fmt.Sprintf("%v", chunkInterval)
232 | 				}
233 | 			}
234 | 		}
235 | 	}
236 | 
237 | 	return settings, nil
238 | }
239 | 
240 | // AddRetentionPolicy adds a data retention policy to a hypertable
241 | func (t *DB) AddRetentionPolicy(ctx context.Context, tableName, interval string) error {
242 | 	if !t.isTimescaleDB {
243 | 		return fmt.Errorf("TimescaleDB extension not available")
244 | 	}
245 | 
246 | 	query := fmt.Sprintf("SELECT add_retention_policy('%s', INTERVAL '%s')", tableName, interval)
247 | 	_, err := t.ExecuteSQLWithoutParams(ctx, query)
248 | 	if err != nil {
249 | 		return fmt.Errorf("failed to add retention policy: %w", err)
250 | 	}
251 | 
252 | 	return nil
253 | }
254 | 
255 | // RemoveRetentionPolicy removes a data retention policy from a hypertable
256 | func (t *DB) RemoveRetentionPolicy(ctx context.Context, tableName string) error {
257 | 	if !t.isTimescaleDB {
258 | 		return fmt.Errorf("TimescaleDB extension not available")
259 | 	}
260 | 
261 | 	// Find the policy ID
262 | 	query := fmt.Sprintf(
263 | 		"SELECT job_id FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'",
264 | 		tableName,
265 | 	)
266 | 
267 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
268 | 	if err != nil {
269 | 		return fmt.Errorf("failed to find retention policy: %w", err)
270 | 	}
271 | 
272 | 	rows, ok := result.([]map[string]interface{})
273 | 	if !ok || len(rows) == 0 {
274 | 		// No policy exists, so nothing to remove
275 | 		return nil
276 | 	}
277 | 
278 | 	// Get the job ID
279 | 	jobID := rows[0]["job_id"]
280 | 	if jobID == nil {
281 | 		return fmt.Errorf("invalid job ID for retention policy")
282 | 	}
283 | 
284 | 	// Remove the policy
285 | 	removeQuery := fmt.Sprintf("SELECT remove_retention_policy(%v)", jobID)
286 | 	_, err = t.ExecuteSQLWithoutParams(ctx, removeQuery)
287 | 	if err != nil {
288 | 		return fmt.Errorf("failed to remove retention policy: %w", err)
289 | 	}
290 | 
291 | 	return nil
292 | }
293 | 
294 | // GetRetentionSettings gets the retention settings for a hypertable
295 | func (t *DB) GetRetentionSettings(ctx context.Context, tableName string) (*RetentionSettings, error) {
296 | 	if !t.isTimescaleDB {
297 | 		return nil, fmt.Errorf("TimescaleDB extension not available")
298 | 	}
299 | 
300 | 	settings := &RetentionSettings{
301 | 		HypertableName: tableName,
302 | 	}
303 | 
304 | 	// Check if a retention policy exists
305 | 	query := fmt.Sprintf(
306 | 		"SELECT s.schedule_interval FROM timescaledb_information.jobs j "+
307 | 			"JOIN timescaledb_information.job_stats s ON j.job_id = s.job_id "+
308 | 			"WHERE j.hypertable_name = '%s' AND j.proc_name = 'policy_retention'",
309 | 		tableName,
310 | 	)
311 | 
312 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
313 | 	if err != nil {
314 | 		return settings, nil // Return empty settings with no error
315 | 	}
316 | 
317 | 	rows, ok := result.([]map[string]interface{})
318 | 	if ok && len(rows) > 0 {
319 | 		settings.RetentionEnabled = true
320 | 		if interval, ok := rows[0]["schedule_interval"]; ok && interval != nil {
321 | 			settings.RetentionInterval = fmt.Sprintf("%v", interval)
322 | 		}
323 | 	}
324 | 
325 | 	return settings, nil
326 | }
327 | 
328 | // CompressChunks compresses chunks for a hypertable
329 | func (t *DB) CompressChunks(ctx context.Context, tableName, olderThan string) error {
330 | 	if !t.isTimescaleDB {
331 | 		return fmt.Errorf("TimescaleDB extension not available")
332 | 	}
333 | 
334 | 	var query string
335 | 	if olderThan == "" {
336 | 		// Compress all chunks
337 | 		query = fmt.Sprintf("SELECT compress_chunks(hypertable => '%s')", tableName)
338 | 	} else {
339 | 		// Compress chunks older than the specified interval
340 | 		query = fmt.Sprintf("SELECT compress_chunks(hypertable => '%s', older_than => INTERVAL '%s')",
341 | 			tableName, olderThan)
342 | 	}
343 | 
344 | 	_, err := t.ExecuteSQLWithoutParams(ctx, query)
345 | 	if err != nil {
346 | 		return fmt.Errorf("failed to compress chunks: %w", err)
347 | 	}
348 | 
349 | 	return nil
350 | }
351 | 
352 | // DecompressChunks decompresses chunks for a hypertable
353 | func (t *DB) DecompressChunks(ctx context.Context, tableName, newerThan string) error {
354 | 	if !t.isTimescaleDB {
355 | 		return fmt.Errorf("TimescaleDB extension not available")
356 | 	}
357 | 
358 | 	var query string
359 | 	if newerThan == "" {
360 | 		// Decompress all chunks
361 | 		query = fmt.Sprintf("SELECT decompress_chunks(hypertable => '%s')", tableName)
362 | 	} else {
363 | 		// Decompress chunks newer than the specified interval
364 | 		query = fmt.Sprintf("SELECT decompress_chunks(hypertable => '%s', newer_than => INTERVAL '%s')",
365 | 			tableName, newerThan)
366 | 	}
367 | 
368 | 	_, err := t.ExecuteSQLWithoutParams(ctx, query)
369 | 	if err != nil {
370 | 		return fmt.Errorf("failed to decompress chunks: %w", err)
371 | 	}
372 | 
373 | 	return nil
374 | }
375 | 
376 | // GetChunkCompressionStats gets compression statistics for a hypertable
377 | func (t *DB) GetChunkCompressionStats(ctx context.Context, tableName string) (interface{}, error) {
378 | 	if !t.isTimescaleDB {
379 | 		return nil, fmt.Errorf("TimescaleDB extension not available")
380 | 	}
381 | 
382 | 	query := fmt.Sprintf(`
383 | 		SELECT
384 | 			chunk_name,
385 | 			range_start,
386 | 			range_end,
387 | 			is_compressed,
388 | 			before_compression_total_bytes,
389 | 			after_compression_total_bytes,
390 | 			CASE
391 | 				WHEN before_compression_total_bytes = 0 THEN 0
392 | 				ELSE (1 - (after_compression_total_bytes::float / before_compression_total_bytes::float)) * 100
393 | 			END AS compression_ratio
394 | 		FROM timescaledb_information.chunks
395 | 		WHERE hypertable_name = '%s'
396 | 		ORDER BY range_end DESC
397 | 	`, tableName)
398 | 
399 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
400 | 	if err != nil {
401 | 		return nil, fmt.Errorf("failed to get chunk compression statistics: %w", err)
402 | 	}
403 | 
404 | 	return result, nil
405 | }
406 | 
```

--------------------------------------------------------------------------------
/pkg/db/timescale/hypertable_test.go:
--------------------------------------------------------------------------------

```go
  1 | package timescale
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"errors"
  6 | 	"testing"
  7 | )
  8 | 
  9 | func TestCreateHypertable(t *testing.T) {
 10 | 	mockDB := NewMockDB()
 11 | 	tsdb := &DB{
 12 | 		Database:      mockDB,
 13 | 		isTimescaleDB: true,
 14 | 	}
 15 | 
 16 | 	ctx := context.Background()
 17 | 
 18 | 	// Test basic hypertable creation
 19 | 	config := HypertableConfig{
 20 | 		TableName:         "test_table",
 21 | 		TimeColumn:        "time",
 22 | 		ChunkTimeInterval: "1 day",
 23 | 		CreateIfNotExists: true,
 24 | 	}
 25 | 
 26 | 	err := tsdb.CreateHypertable(ctx, config)
 27 | 	if err != nil {
 28 | 		t.Fatalf("Failed to create hypertable: %v", err)
 29 | 	}
 30 | 
 31 | 	// Check that the correct query was executed
 32 | 	query, _ := mockDB.GetLastQuery()
 33 | 	AssertQueryContains(t, query, "create_hypertable")
 34 | 	AssertQueryContains(t, query, "test_table")
 35 | 	AssertQueryContains(t, query, "time")
 36 | 	AssertQueryContains(t, query, "chunk_time_interval")
 37 | 	AssertQueryContains(t, query, "1 day")
 38 | 
 39 | 	// Test with partitioning
 40 | 	config = HypertableConfig{
 41 | 		TableName:          "test_table",
 42 | 		TimeColumn:         "time",
 43 | 		ChunkTimeInterval:  "1 day",
 44 | 		PartitioningColumn: "device_id",
 45 | 		SpacePartitions:    4,
 46 | 	}
 47 | 
 48 | 	err = tsdb.CreateHypertable(ctx, config)
 49 | 	if err != nil {
 50 | 		t.Fatalf("Failed to create hypertable with partitioning: %v", err)
 51 | 	}
 52 | 
 53 | 	// Check that the correct query was executed
 54 | 	query, _ = mockDB.GetLastQuery()
 55 | 	AssertQueryContains(t, query, "create_hypertable")
 56 | 	AssertQueryContains(t, query, "partition_column")
 57 | 	AssertQueryContains(t, query, "device_id")
 58 | 	AssertQueryContains(t, query, "number_partitions")
 59 | 
 60 | 	// Test with if_not_exists and migrate_data
 61 | 	config = HypertableConfig{
 62 | 		TableName:   "test_table",
 63 | 		TimeColumn:  "time",
 64 | 		IfNotExists: true,
 65 | 		MigrateData: true,
 66 | 	}
 67 | 
 68 | 	err = tsdb.CreateHypertable(ctx, config)
 69 | 	if err != nil {
 70 | 		t.Fatalf("Failed to create hypertable with extra options: %v", err)
 71 | 	}
 72 | 
 73 | 	// Check that the correct query was executed
 74 | 	query, _ = mockDB.GetLastQuery()
 75 | 	AssertQueryContains(t, query, "if_not_exists => TRUE")
 76 | 	AssertQueryContains(t, query, "migrate_data => TRUE")
 77 | 
 78 | 	// Test when TimescaleDB is not available
 79 | 	tsdb.isTimescaleDB = false
 80 | 	err = tsdb.CreateHypertable(ctx, config)
 81 | 	if err == nil {
 82 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
 83 | 	}
 84 | 
 85 | 	// Test execution error
 86 | 	tsdb.isTimescaleDB = true
 87 | 	mockDB.RegisterQueryResult("SELECT create_hypertable(", nil, errors.New("mocked error"))
 88 | 	err = tsdb.CreateHypertable(ctx, config)
 89 | 	if err == nil {
 90 | 		t.Error("Expected query error, got nil")
 91 | 	}
 92 | }
 93 | 
 94 | func TestAddDimension(t *testing.T) {
 95 | 	mockDB := NewMockDB()
 96 | 	tsdb := &DB{
 97 | 		Database:      mockDB,
 98 | 		isTimescaleDB: true,
 99 | 	}
100 | 
101 | 	ctx := context.Background()
102 | 
103 | 	// Test adding a dimension
104 | 	err := tsdb.AddDimension(ctx, "test_table", "device_id", 4)
105 | 	if err != nil {
106 | 		t.Fatalf("Failed to add dimension: %v", err)
107 | 	}
108 | 
109 | 	// Check that the correct query was executed
110 | 	query, _ := mockDB.GetLastQuery()
111 | 	AssertQueryContains(t, query, "add_dimension")
112 | 	AssertQueryContains(t, query, "test_table")
113 | 	AssertQueryContains(t, query, "device_id")
114 | 	AssertQueryContains(t, query, "number_partitions => 4")
115 | 
116 | 	// Test when TimescaleDB is not available
117 | 	tsdb.isTimescaleDB = false
118 | 	err = tsdb.AddDimension(ctx, "test_table", "device_id", 4)
119 | 	if err == nil {
120 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
121 | 	}
122 | 
123 | 	// Test execution error
124 | 	tsdb.isTimescaleDB = true
125 | 	mockDB.RegisterQueryResult("SELECT add_dimension(", nil, errors.New("mocked error"))
126 | 	err = tsdb.AddDimension(ctx, "test_table", "device_id", 4)
127 | 	if err == nil {
128 | 		t.Error("Expected query error, got nil")
129 | 	}
130 | }
131 | 
132 | func TestListHypertables(t *testing.T) {
133 | 	mockDB := NewMockDB()
134 | 	tsdb := &DB{
135 | 		Database:      mockDB,
136 | 		isTimescaleDB: true,
137 | 	}
138 | 
139 | 	ctx := context.Background()
140 | 
141 | 	// Prepare mock data
142 | 	mockResult := []map[string]interface{}{
143 | 		{
144 | 			"table_name":     "test_table",
145 | 			"schema_name":    "public",
146 | 			"time_column":    "time",
147 | 			"num_dimensions": 2,
148 | 			"space_column":   "device_id",
149 | 		},
150 | 		{
151 | 			"table_name":     "test_table2",
152 | 			"schema_name":    "public",
153 | 			"time_column":    "timestamp",
154 | 			"num_dimensions": 1,
155 | 			"space_column":   nil,
156 | 		},
157 | 	}
158 | 
159 | 	// Register different result patterns for different queries
160 | 	mockDB.RegisterQueryResult("FROM _timescaledb_catalog.hypertable h", mockResult, nil)
161 | 	mockDB.RegisterQueryResult("FROM timescaledb_information.compression_settings", []map[string]interface{}{
162 | 		{"is_compressed": true},
163 | 	}, nil)
164 | 	mockDB.RegisterQueryResult("FROM timescaledb_information.jobs", []map[string]interface{}{
165 | 		{"has_retention": true},
166 | 	}, nil)
167 | 
168 | 	// Test listing hypertables
169 | 	hypertables, err := tsdb.ListHypertables(ctx)
170 | 	if err != nil {
171 | 		t.Fatalf("Failed to list hypertables: %v", err)
172 | 	}
173 | 
174 | 	// Check the results
175 | 	if len(hypertables) != 2 {
176 | 		t.Errorf("Expected 2 hypertables, got %d", len(hypertables))
177 | 	}
178 | 
179 | 	if hypertables[0].TableName != "test_table" {
180 | 		t.Errorf("Expected TableName to be 'test_table', got '%s'", hypertables[0].TableName)
181 | 	}
182 | 
183 | 	if hypertables[0].TimeColumn != "time" {
184 | 		t.Errorf("Expected TimeColumn to be 'time', got '%s'", hypertables[0].TimeColumn)
185 | 	}
186 | 
187 | 	if hypertables[0].SpaceColumn != "device_id" {
188 | 		t.Errorf("Expected SpaceColumn to be 'device_id', got '%s'", hypertables[0].SpaceColumn)
189 | 	}
190 | 
191 | 	if hypertables[0].NumDimensions != 2 {
192 | 		t.Errorf("Expected NumDimensions to be 2, got %d", hypertables[0].NumDimensions)
193 | 	}
194 | 
195 | 	if !hypertables[0].CompressionEnabled {
196 | 		t.Error("Expected CompressionEnabled to be true, got false")
197 | 	}
198 | 
199 | 	if !hypertables[0].RetentionEnabled {
200 | 		t.Error("Expected RetentionEnabled to be true, got false")
201 | 	}
202 | 
203 | 	// Test when TimescaleDB is not available
204 | 	tsdb.isTimescaleDB = false
205 | 	_, err = tsdb.ListHypertables(ctx)
206 | 	if err == nil {
207 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
208 | 	}
209 | 
210 | 	// Test execution error
211 | 	tsdb.isTimescaleDB = true
212 | 	mockDB.RegisterQueryResult("FROM _timescaledb_catalog.hypertable h", nil, errors.New("mocked error"))
213 | 	_, err = tsdb.ListHypertables(ctx)
214 | 	if err == nil {
215 | 		t.Error("Expected query error, got nil")
216 | 	}
217 | }
218 | 
219 | func TestGetHypertable(t *testing.T) {
220 | 	mockDB := NewMockDB()
221 | 	tsdb := &DB{
222 | 		Database:      mockDB,
223 | 		isTimescaleDB: true,
224 | 	}
225 | 
226 | 	ctx := context.Background()
227 | 
228 | 	// Prepare mock data - Set up the correct result by using RegisterQueryResult
229 | 	mockResult := []map[string]interface{}{
230 | 		{
231 | 			"table_name":     "test_table",
232 | 			"schema_name":    "public",
233 | 			"time_column":    "time",
234 | 			"num_dimensions": int64(2),
235 | 			"space_column":   "device_id",
236 | 		},
237 | 	}
238 | 
239 | 	// Register the query result pattern for the main query
240 | 	mockDB.RegisterQueryResult("WHERE h.table_name = 'test_table'", mockResult, nil)
241 | 
242 | 	// Register results for the compression check
243 | 	mockDB.RegisterQueryResult("FROM timescaledb_information.compression_settings", []map[string]interface{}{
244 | 		{"is_compressed": true},
245 | 	}, nil)
246 | 
247 | 	// Register results for the retention policy check
248 | 	mockDB.RegisterQueryResult("FROM timescaledb_information.jobs", []map[string]interface{}{
249 | 		{"has_retention": true},
250 | 	}, nil)
251 | 
252 | 	// Test getting a hypertable
253 | 	hypertable, err := tsdb.GetHypertable(ctx, "test_table")
254 | 	if err != nil {
255 | 		t.Fatalf("Failed to get hypertable: %v", err)
256 | 	}
257 | 
258 | 	// Check the results
259 | 	if hypertable.TableName != "test_table" {
260 | 		t.Errorf("Expected TableName to be 'test_table', got '%s'", hypertable.TableName)
261 | 	}
262 | 
263 | 	if hypertable.TimeColumn != "time" {
264 | 		t.Errorf("Expected TimeColumn to be 'time', got '%s'", hypertable.TimeColumn)
265 | 	}
266 | 
267 | 	if hypertable.SpaceColumn != "device_id" {
268 | 		t.Errorf("Expected SpaceColumn to be 'device_id', got '%s'", hypertable.SpaceColumn)
269 | 	}
270 | 
271 | 	if hypertable.NumDimensions != 2 {
272 | 		t.Errorf("Expected NumDimensions to be 2, got %d", hypertable.NumDimensions)
273 | 	}
274 | 
275 | 	if !hypertable.CompressionEnabled {
276 | 		t.Error("Expected CompressionEnabled to be true, got false")
277 | 	}
278 | 
279 | 	if !hypertable.RetentionEnabled {
280 | 		t.Error("Expected RetentionEnabled to be true, got false")
281 | 	}
282 | 
283 | 	// Test when TimescaleDB is not available
284 | 	tsdb.isTimescaleDB = false
285 | 	_, err = tsdb.GetHypertable(ctx, "test_table")
286 | 	if err == nil {
287 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
288 | 	}
289 | 
290 | 	// Test execution error
291 | 	tsdb.isTimescaleDB = true
292 | 	mockDB.RegisterQueryResult("WHERE h.table_name = 'test_table'", nil, errors.New("mocked error"))
293 | 	_, err = tsdb.GetHypertable(ctx, "test_table")
294 | 	if err == nil {
295 | 		t.Error("Expected query error, got nil")
296 | 	}
297 | 
298 | 	// Test table not found - Create a new mock to avoid interference
299 | 	newMockDB := NewMockDB()
300 | 	newMockDB.SetTimescaleAvailable(true)
301 | 	tsdb.Database = newMockDB
302 | 
303 | 	// Register an empty result for the "not_found" table
304 | 	newMockDB.RegisterQueryResult("WHERE h.table_name = 'not_found'", []map[string]interface{}{}, nil)
305 | 	_, err = tsdb.GetHypertable(ctx, "not_found")
306 | 	if err == nil {
307 | 		t.Error("Expected error for non-existent table, got nil")
308 | 	}
309 | }
310 | 
311 | func TestDropHypertable(t *testing.T) {
312 | 	mockDB := NewMockDB()
313 | 	tsdb := &DB{
314 | 		Database:      mockDB,
315 | 		isTimescaleDB: true,
316 | 	}
317 | 
318 | 	ctx := context.Background()
319 | 
320 | 	// Test dropping a hypertable
321 | 	err := tsdb.DropHypertable(ctx, "test_table", false)
322 | 	if err != nil {
323 | 		t.Fatalf("Failed to drop hypertable: %v", err)
324 | 	}
325 | 
326 | 	// Check that the correct query was executed
327 | 	query, _ := mockDB.GetLastQuery()
328 | 	AssertQueryContains(t, query, "DROP TABLE test_table")
329 | 
330 | 	// Test dropping with CASCADE
331 | 	err = tsdb.DropHypertable(ctx, "test_table", true)
332 | 	if err != nil {
333 | 		t.Fatalf("Failed to drop hypertable with CASCADE: %v", err)
334 | 	}
335 | 
336 | 	// Check that the correct query was executed
337 | 	query, _ = mockDB.GetLastQuery()
338 | 	AssertQueryContains(t, query, "DROP TABLE test_table CASCADE")
339 | 
340 | 	// Test when TimescaleDB is not available
341 | 	tsdb.isTimescaleDB = false
342 | 	err = tsdb.DropHypertable(ctx, "test_table", false)
343 | 	if err == nil {
344 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
345 | 	}
346 | 
347 | 	// Test execution error
348 | 	tsdb.isTimescaleDB = true
349 | 	mockDB.RegisterQueryResult("DROP TABLE", nil, errors.New("mocked error"))
350 | 	err = tsdb.DropHypertable(ctx, "test_table", false)
351 | 	if err == nil {
352 | 		t.Error("Expected query error, got nil")
353 | 	}
354 | }
355 | 
356 | func TestCheckIfHypertable(t *testing.T) {
357 | 	mockDB := NewMockDB()
358 | 	tsdb := &DB{
359 | 		Database:      mockDB,
360 | 		isTimescaleDB: true,
361 | 	}
362 | 
363 | 	ctx := context.Background()
364 | 
365 | 	// Prepare mock data
366 | 	mockResultTrue := []map[string]interface{}{
367 | 		{"is_hypertable": true},
368 | 	}
369 | 
370 | 	mockResultFalse := []map[string]interface{}{
371 | 		{"is_hypertable": false},
372 | 	}
373 | 
374 | 	// Test table is a hypertable
375 | 	mockDB.RegisterQueryResult("WHERE table_name = 'test_table'", mockResultTrue, nil)
376 | 	isHypertable, err := tsdb.CheckIfHypertable(ctx, "test_table")
377 | 	if err != nil {
378 | 		t.Fatalf("Failed to check if hypertable: %v", err)
379 | 	}
380 | 
381 | 	if !isHypertable {
382 | 		t.Error("Expected table to be a hypertable, got false")
383 | 	}
384 | 
385 | 	// Test table is not a hypertable
386 | 	mockDB.RegisterQueryResult("WHERE table_name = 'regular_table'", mockResultFalse, nil)
387 | 	isHypertable, err = tsdb.CheckIfHypertable(ctx, "regular_table")
388 | 	if err != nil {
389 | 		t.Fatalf("Failed to check if hypertable: %v", err)
390 | 	}
391 | 
392 | 	if isHypertable {
393 | 		t.Error("Expected table not to be a hypertable, got true")
394 | 	}
395 | 
396 | 	// Test when TimescaleDB is not available
397 | 	tsdb.isTimescaleDB = false
398 | 	_, err = tsdb.CheckIfHypertable(ctx, "test_table")
399 | 	if err == nil {
400 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
401 | 	}
402 | 
403 | 	// Test execution error
404 | 	tsdb.isTimescaleDB = true
405 | 	mockDB.RegisterQueryResult("WHERE table_name = 'error_table'", nil, errors.New("mocked error"))
406 | 	_, err = tsdb.CheckIfHypertable(ctx, "error_table")
407 | 	if err == nil {
408 | 		t.Error("Expected query error, got nil")
409 | 	}
410 | 
411 | 	// Test unexpected result structure
412 | 	mockDB.RegisterQueryResult("WHERE table_name = 'bad_structure'", []map[string]interface{}{}, nil)
413 | 	_, err = tsdb.CheckIfHypertable(ctx, "bad_structure")
414 | 	if err == nil {
415 | 		t.Error("Expected error for bad result structure, got nil")
416 | 	}
417 | }
418 | 
419 | func TestRecentChunks(t *testing.T) {
420 | 	mockDB := NewMockDB()
421 | 	tsdb := &DB{
422 | 		Database:      mockDB,
423 | 		isTimescaleDB: true,
424 | 	}
425 | 
426 | 	ctx := context.Background()
427 | 
428 | 	// Prepare mock data
429 | 	mockResult := []map[string]interface{}{
430 | 		{
431 | 			"chunk_name":    "_hyper_1_1_chunk",
432 | 			"range_start":   "2023-01-01 00:00:00",
433 | 			"range_end":     "2023-01-02 00:00:00",
434 | 			"is_compressed": false,
435 | 		},
436 | 		{
437 | 			"chunk_name":    "_hyper_1_2_chunk",
438 | 			"range_start":   "2023-01-02 00:00:00",
439 | 			"range_end":     "2023-01-03 00:00:00",
440 | 			"is_compressed": true,
441 | 		},
442 | 	}
443 | 
444 | 	// Register mock result
445 | 	mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", mockResult, nil)
446 | 
447 | 	// Test getting recent chunks
448 | 	_, err := tsdb.RecentChunks(ctx, "test_table", 2)
449 | 	if err != nil {
450 | 		t.Fatalf("Failed to get recent chunks: %v", err)
451 | 	}
452 | 
453 | 	// Check that a query with the right table name and limit was executed
454 | 	query, _ := mockDB.GetLastQuery()
455 | 	AssertQueryContains(t, query, "hypertable_name = 'test_table'")
456 | 	AssertQueryContains(t, query, "LIMIT 2")
457 | 
458 | 	// Test with default limit
459 | 	_, err = tsdb.RecentChunks(ctx, "test_table", 0)
460 | 	if err != nil {
461 | 		t.Fatalf("Failed to get recent chunks with default limit: %v", err)
462 | 	}
463 | 
464 | 	query, _ = mockDB.GetLastQuery()
465 | 	AssertQueryContains(t, query, "LIMIT 10")
466 | 
467 | 	// Test when TimescaleDB is not available
468 | 	tsdb.isTimescaleDB = false
469 | 	_, err = tsdb.RecentChunks(ctx, "test_table", 2)
470 | 	if err == nil {
471 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
472 | 	}
473 | 
474 | 	// Test execution error
475 | 	tsdb.isTimescaleDB = true
476 | 	mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", nil, errors.New("mocked error"))
477 | 	_, err = tsdb.RecentChunks(ctx, "test_table", 2)
478 | 	if err == nil {
479 | 		t.Error("Expected query error, got nil")
480 | 	}
481 | }
482 | 
```

--------------------------------------------------------------------------------
/pkg/db/timescale/query.go:
--------------------------------------------------------------------------------

```go
  1 | package timescale
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"fmt"
  6 | 	"strings"
  7 | 	"time"
  8 | )
  9 | 
 10 | // TimeBucket represents a time bucket for time-series aggregation
 11 | type TimeBucket struct {
 12 | 	Interval string // e.g., '1 hour', '1 day', '1 month'
 13 | 	Column   string // Time column to bucket
 14 | 	Alias    string // Optional alias for the bucket column
 15 | }
 16 | 
 17 | // AggregateFunction represents a common aggregate function
 18 | type AggregateFunction string
 19 | 
 20 | const (
 21 | 	// AggrAvg calculates the average value of a column
 22 | 	AggrAvg AggregateFunction = "AVG"
 23 | 	// AggrSum calculates the sum of values in a column
 24 | 	AggrSum AggregateFunction = "SUM"
 25 | 	// AggrMin finds the minimum value in a column
 26 | 	AggrMin AggregateFunction = "MIN"
 27 | 	// AggrMax finds the maximum value in a column
 28 | 	AggrMax AggregateFunction = "MAX"
 29 | 	// AggrCount counts the number of rows
 30 | 	AggrCount AggregateFunction = "COUNT"
 31 | 	// AggrFirst takes the first value in a window
 32 | 	AggrFirst AggregateFunction = "FIRST"
 33 | 	// AggrLast takes the last value in a window
 34 | 	AggrLast AggregateFunction = "LAST"
 35 | )
 36 | 
 37 | // ColumnAggregation represents an aggregation operation on a column
 38 | type ColumnAggregation struct {
 39 | 	Function AggregateFunction
 40 | 	Column   string
 41 | 	Alias    string
 42 | }
 43 | 
 44 | // TimeseriesQueryBuilder helps build optimized time-series queries
 45 | type TimeseriesQueryBuilder struct {
 46 | 	table        string
 47 | 	timeBucket   *TimeBucket
 48 | 	selectCols   []string
 49 | 	aggregations []ColumnAggregation
 50 | 	whereClauses []string
 51 | 	whereArgs    []interface{}
 52 | 	groupByCols  []string
 53 | 	orderByCols  []string
 54 | 	limit        int
 55 | 	offset       int
 56 | }
 57 | 
 58 | // NewTimeseriesQueryBuilder creates a new builder for a specific table
 59 | func NewTimeseriesQueryBuilder(table string) *TimeseriesQueryBuilder {
 60 | 	return &TimeseriesQueryBuilder{
 61 | 		table:        table,
 62 | 		selectCols:   make([]string, 0),
 63 | 		aggregations: make([]ColumnAggregation, 0),
 64 | 		whereClauses: make([]string, 0),
 65 | 		whereArgs:    make([]interface{}, 0),
 66 | 		groupByCols:  make([]string, 0),
 67 | 		orderByCols:  make([]string, 0),
 68 | 	}
 69 | }
 70 | 
 71 | // WithTimeBucket adds a time bucket to the query
 72 | func (b *TimeseriesQueryBuilder) WithTimeBucket(interval, column, alias string) *TimeseriesQueryBuilder {
 73 | 	b.timeBucket = &TimeBucket{
 74 | 		Interval: interval,
 75 | 		Column:   column,
 76 | 		Alias:    alias,
 77 | 	}
 78 | 	return b
 79 | }
 80 | 
 81 | // Select adds columns to the SELECT clause
 82 | func (b *TimeseriesQueryBuilder) Select(cols ...string) *TimeseriesQueryBuilder {
 83 | 	b.selectCols = append(b.selectCols, cols...)
 84 | 	return b
 85 | }
 86 | 
 87 | // Aggregate adds an aggregation function to a column
 88 | func (b *TimeseriesQueryBuilder) Aggregate(function AggregateFunction, column, alias string) *TimeseriesQueryBuilder {
 89 | 	b.aggregations = append(b.aggregations, ColumnAggregation{
 90 | 		Function: function,
 91 | 		Column:   column,
 92 | 		Alias:    alias,
 93 | 	})
 94 | 	return b
 95 | }
 96 | 
 97 | // WhereTimeRange adds a time range condition
 98 | func (b *TimeseriesQueryBuilder) WhereTimeRange(column string, start, end time.Time) *TimeseriesQueryBuilder {
 99 | 	clause := fmt.Sprintf("%s BETWEEN $%d AND $%d", column, len(b.whereArgs)+1, len(b.whereArgs)+2)
100 | 	b.whereClauses = append(b.whereClauses, clause)
101 | 	b.whereArgs = append(b.whereArgs, start, end)
102 | 	return b
103 | }
104 | 
105 | // Where adds a WHERE condition
106 | func (b *TimeseriesQueryBuilder) Where(clause string, args ...interface{}) *TimeseriesQueryBuilder {
107 | 	// Adjust the parameter indices in the clause
108 | 	paramCount := len(b.whereArgs)
109 | 	for i := 1; i <= len(args); i++ {
110 | 		oldParam := fmt.Sprintf("$%d", i)
111 | 		newParam := fmt.Sprintf("$%d", i+paramCount)
112 | 		clause = strings.Replace(clause, oldParam, newParam, -1)
113 | 	}
114 | 
115 | 	b.whereClauses = append(b.whereClauses, clause)
116 | 	b.whereArgs = append(b.whereArgs, args...)
117 | 	return b
118 | }
119 | 
120 | // GroupBy adds columns to the GROUP BY clause
121 | func (b *TimeseriesQueryBuilder) GroupBy(cols ...string) *TimeseriesQueryBuilder {
122 | 	b.groupByCols = append(b.groupByCols, cols...)
123 | 	return b
124 | }
125 | 
126 | // OrderBy adds columns to the ORDER BY clause
127 | func (b *TimeseriesQueryBuilder) OrderBy(cols ...string) *TimeseriesQueryBuilder {
128 | 	b.orderByCols = append(b.orderByCols, cols...)
129 | 	return b
130 | }
131 | 
132 | // Limit sets the LIMIT clause
133 | func (b *TimeseriesQueryBuilder) Limit(limit int) *TimeseriesQueryBuilder {
134 | 	b.limit = limit
135 | 	return b
136 | }
137 | 
138 | // Offset sets the OFFSET clause
139 | func (b *TimeseriesQueryBuilder) Offset(offset int) *TimeseriesQueryBuilder {
140 | 	b.offset = offset
141 | 	return b
142 | }
143 | 
144 | // Build constructs the SQL query and args
145 | func (b *TimeseriesQueryBuilder) Build() (string, []interface{}) {
146 | 	var selectClause strings.Builder
147 | 	selectClause.WriteString("SELECT ")
148 | 
149 | 	var selects []string
150 | 
151 | 	// Add time bucket if specified
152 | 	if b.timeBucket != nil {
153 | 		alias := b.timeBucket.Alias
154 | 		if alias == "" {
155 | 			alias = "time_bucket"
156 | 		}
157 | 
158 | 		bucketStr := fmt.Sprintf(
159 | 			"time_bucket('%s', %s) AS %s",
160 | 			b.timeBucket.Interval,
161 | 			b.timeBucket.Column,
162 | 			alias,
163 | 		)
164 | 		selects = append(selects, bucketStr)
165 | 
166 | 		// Add time bucket to group by if not already included
167 | 		bucketFound := false
168 | 		for _, col := range b.groupByCols {
169 | 			if col == alias {
170 | 				bucketFound = true
171 | 				break
172 | 			}
173 | 		}
174 | 
175 | 		if !bucketFound {
176 | 			b.groupByCols = append([]string{alias}, b.groupByCols...)
177 | 		}
178 | 	}
179 | 
180 | 	// Add selected columns
181 | 	selects = append(selects, b.selectCols...)
182 | 
183 | 	// Add aggregations
184 | 	for _, agg := range b.aggregations {
185 | 		alias := agg.Alias
186 | 		if alias == "" {
187 | 			alias = strings.ToLower(string(agg.Function)) + "_" + agg.Column
188 | 		}
189 | 
190 | 		aggStr := fmt.Sprintf(
191 | 			"%s(%s) AS %s",
192 | 			agg.Function,
193 | 			agg.Column,
194 | 			alias,
195 | 		)
196 | 		selects = append(selects, aggStr)
197 | 	}
198 | 
199 | 	// If no columns or aggregations selected, use *
200 | 	if len(selects) == 0 {
201 | 		selectClause.WriteString("*")
202 | 	} else {
203 | 		selectClause.WriteString(strings.Join(selects, ", "))
204 | 	}
205 | 
206 | 	// Build query
207 | 	query := fmt.Sprintf("%s FROM %s", selectClause.String(), b.table)
208 | 
209 | 	// Add WHERE clause
210 | 	if len(b.whereClauses) > 0 {
211 | 		query += " WHERE " + strings.Join(b.whereClauses, " AND ")
212 | 	}
213 | 
214 | 	// Add GROUP BY clause
215 | 	if len(b.groupByCols) > 0 {
216 | 		query += " GROUP BY " + strings.Join(b.groupByCols, ", ")
217 | 	}
218 | 
219 | 	// Add ORDER BY clause
220 | 	if len(b.orderByCols) > 0 {
221 | 		query += " ORDER BY " + strings.Join(b.orderByCols, ", ")
222 | 	}
223 | 
224 | 	// Add LIMIT clause
225 | 	if b.limit > 0 {
226 | 		query += fmt.Sprintf(" LIMIT %d", b.limit)
227 | 	}
228 | 
229 | 	// Add OFFSET clause
230 | 	if b.offset > 0 {
231 | 		query += fmt.Sprintf(" OFFSET %d", b.offset)
232 | 	}
233 | 
234 | 	return query, b.whereArgs
235 | }
236 | 
237 | // Execute runs the query against the database
238 | func (b *TimeseriesQueryBuilder) Execute(ctx context.Context, db *DB) ([]map[string]interface{}, error) {
239 | 	query, args := b.Build()
240 | 	result, err := db.ExecuteSQL(ctx, query, args...)
241 | 	if err != nil {
242 | 		return nil, fmt.Errorf("failed to execute time-series query: %w", err)
243 | 	}
244 | 
245 | 	rows, ok := result.([]map[string]interface{})
246 | 	if !ok {
247 | 		return nil, fmt.Errorf("unexpected result type from database query")
248 | 	}
249 | 
250 | 	return rows, nil
251 | }
252 | 
253 | // DownsampleOptions describes options for downsampling time-series data
254 | type DownsampleOptions struct {
255 | 	SourceTable       string
256 | 	DestTable         string
257 | 	TimeColumn        string
258 | 	BucketInterval    string
259 | 	Aggregations      []ColumnAggregation
260 | 	WhereCondition    string
261 | 	CreateTable       bool
262 | 	ChunkTimeInterval string
263 | }
264 | 
265 | // DownsampleTimeSeries creates downsampled time-series data
266 | func (t *DB) DownsampleTimeSeries(ctx context.Context, options DownsampleOptions) error {
267 | 	if !t.isTimescaleDB {
268 | 		return fmt.Errorf("TimescaleDB extension not available")
269 | 	}
270 | 
271 | 	// Create the destination table if requested
272 | 	if options.CreateTable {
273 | 		// Get source table columns
274 | 		schemaQuery := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '%s'", options.SourceTable)
275 | 		result, err := t.ExecuteSQLWithoutParams(ctx, schemaQuery)
276 | 		if err != nil {
277 | 			return fmt.Errorf("failed to get source table schema: %w", err)
278 | 		}
279 | 
280 | 		columns, ok := result.([]map[string]interface{})
281 | 		if !ok {
282 | 			return fmt.Errorf("unexpected result from schema query")
283 | 		}
284 | 
285 | 		// Build CREATE TABLE statement
286 | 		var createStmt strings.Builder
287 | 		createStmt.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", options.DestTable))
288 | 
289 | 		// Add time bucket column
290 | 		createStmt.WriteString("time_bucket timestamptz, ")
291 | 
292 | 		// Add aggregation columns
293 | 		for i, agg := range options.Aggregations {
294 | 			colName := agg.Alias
295 | 			if colName == "" {
296 | 				colName = strings.ToLower(string(agg.Function)) + "_" + agg.Column
297 | 			}
298 | 
299 | 			// Find the data type of the source column
300 | 			var dataType string
301 | 			for _, col := range columns {
302 | 				if fmt.Sprintf("%v", col["column_name"]) == agg.Column {
303 | 					dataType = fmt.Sprintf("%v", col["data_type"])
304 | 					break
305 | 				}
306 | 			}
307 | 
308 | 			if dataType == "" {
309 | 				dataType = "double precision" // Default for numeric aggregations
310 | 			}
311 | 
312 | 			createStmt.WriteString(fmt.Sprintf("%s %s", colName, dataType))
313 | 
314 | 			if i < len(options.Aggregations)-1 {
315 | 				createStmt.WriteString(", ")
316 | 			}
317 | 		}
318 | 
319 | 		createStmt.WriteString(", PRIMARY KEY (time_bucket)")
320 | 		createStmt.WriteString(")")
321 | 
322 | 		// Create the table
323 | 		_, err = t.ExecuteSQLWithoutParams(ctx, createStmt.String())
324 | 		if err != nil {
325 | 			return fmt.Errorf("failed to create destination table: %w", err)
326 | 		}
327 | 
328 | 		// Make it a hypertable
329 | 		if options.ChunkTimeInterval == "" {
330 | 			options.ChunkTimeInterval = options.BucketInterval
331 | 		}
332 | 
333 | 		err = t.CreateHypertable(ctx, HypertableConfig{
334 | 			TableName:         options.DestTable,
335 | 			TimeColumn:        "time_bucket",
336 | 			ChunkTimeInterval: options.ChunkTimeInterval,
337 | 			IfNotExists:       true,
338 | 		})
339 | 		if err != nil {
340 | 			return fmt.Errorf("failed to create hypertable: %w", err)
341 | 		}
342 | 	}
343 | 
344 | 	// Build the INSERT statement with aggregations
345 | 	var insertStmt strings.Builder
346 | 	insertStmt.WriteString(fmt.Sprintf("INSERT INTO %s (time_bucket, ", options.DestTable))
347 | 
348 | 	// Add aggregation column names
349 | 	for i, agg := range options.Aggregations {
350 | 		colName := agg.Alias
351 | 		if colName == "" {
352 | 			colName = strings.ToLower(string(agg.Function)) + "_" + agg.Column
353 | 		}
354 | 
355 | 		insertStmt.WriteString(colName)
356 | 
357 | 		if i < len(options.Aggregations)-1 {
358 | 			insertStmt.WriteString(", ")
359 | 		}
360 | 	}
361 | 
362 | 	insertStmt.WriteString(") SELECT time_bucket('")
363 | 	insertStmt.WriteString(options.BucketInterval)
364 | 	insertStmt.WriteString("', ")
365 | 	insertStmt.WriteString(options.TimeColumn)
366 | 	insertStmt.WriteString(") AS time_bucket, ")
367 | 
368 | 	// Add aggregation functions
369 | 	for i, agg := range options.Aggregations {
370 | 		insertStmt.WriteString(fmt.Sprintf("%s(%s)", agg.Function, agg.Column))
371 | 
372 | 		if i < len(options.Aggregations)-1 {
373 | 			insertStmt.WriteString(", ")
374 | 		}
375 | 	}
376 | 
377 | 	insertStmt.WriteString(fmt.Sprintf(" FROM %s", options.SourceTable))
378 | 
379 | 	// Add WHERE clause if specified
380 | 	if options.WhereCondition != "" {
381 | 		insertStmt.WriteString(" WHERE ")
382 | 		insertStmt.WriteString(options.WhereCondition)
383 | 	}
384 | 
385 | 	// Group by time bucket
386 | 	insertStmt.WriteString(" GROUP BY time_bucket")
387 | 
388 | 	// Order by time bucket
389 | 	insertStmt.WriteString(" ORDER BY time_bucket")
390 | 
391 | 	// Execute the INSERT statement
392 | 	_, err := t.ExecuteSQLWithoutParams(ctx, insertStmt.String())
393 | 	if err != nil {
394 | 		return fmt.Errorf("failed to downsample data: %w", err)
395 | 	}
396 | 
397 | 	return nil
398 | }
399 | 
400 | // TimeRange represents a common time range for queries
401 | type TimeRange struct {
402 | 	Start time.Time
403 | 	End   time.Time
404 | }
405 | 
406 | // PredefinedTimeRange returns a TimeRange for common time ranges
407 | func PredefinedTimeRange(name string) (*TimeRange, error) {
408 | 	now := time.Now()
409 | 
410 | 	switch strings.ToLower(name) {
411 | 	case "today":
412 | 		start := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
413 | 		return &TimeRange{Start: start, End: now}, nil
414 | 
415 | 	case "yesterday":
416 | 		end := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
417 | 		start := end.Add(-24 * time.Hour)
418 | 		return &TimeRange{Start: start, End: end}, nil
419 | 
420 | 	case "last24hours", "last_24_hours":
421 | 		start := now.Add(-24 * time.Hour)
422 | 		return &TimeRange{Start: start, End: now}, nil
423 | 
424 | 	case "thisweek", "this_week":
425 | 		// Calculate the beginning of the week (Sunday/Monday depending on locale, using Sunday here)
426 | 		weekday := int(now.Weekday())
427 | 		start := now.Add(-time.Duration(weekday) * 24 * time.Hour)
428 | 		start = time.Date(start.Year(), start.Month(), start.Day(), 0, 0, 0, 0, now.Location())
429 | 		return &TimeRange{Start: start, End: now}, nil
430 | 
431 | 	case "lastweek", "last_week":
432 | 		// Calculate the beginning of this week
433 | 		weekday := int(now.Weekday())
434 | 		thisWeekStart := now.Add(-time.Duration(weekday) * 24 * time.Hour)
435 | 		thisWeekStart = time.Date(thisWeekStart.Year(), thisWeekStart.Month(), thisWeekStart.Day(), 0, 0, 0, 0, now.Location())
436 | 
437 | 		// Last week is 7 days before the beginning of this week
438 | 		lastWeekStart := thisWeekStart.Add(-7 * 24 * time.Hour)
439 | 		lastWeekEnd := thisWeekStart
440 | 
441 | 		return &TimeRange{Start: lastWeekStart, End: lastWeekEnd}, nil
442 | 
443 | 	case "last7days", "last_7_days":
444 | 		start := now.Add(-7 * 24 * time.Hour)
445 | 		return &TimeRange{Start: start, End: now}, nil
446 | 
447 | 	case "thismonth", "this_month":
448 | 		start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
449 | 		return &TimeRange{Start: start, End: now}, nil
450 | 
451 | 	case "lastmonth", "last_month":
452 | 		thisMonthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
453 | 
454 | 		var lastMonthStart time.Time
455 | 		if now.Month() == 1 {
456 | 			// January, so last month is December of previous year
457 | 			lastMonthStart = time.Date(now.Year()-1, 12, 1, 0, 0, 0, 0, now.Location())
458 | 		} else {
459 | 			// Any other month
460 | 			lastMonthStart = time.Date(now.Year(), now.Month()-1, 1, 0, 0, 0, 0, now.Location())
461 | 		}
462 | 
463 | 		return &TimeRange{Start: lastMonthStart, End: thisMonthStart}, nil
464 | 
465 | 	case "last30days", "last_30_days":
466 | 		start := now.Add(-30 * 24 * time.Hour)
467 | 		return &TimeRange{Start: start, End: now}, nil
468 | 
469 | 	case "thisyear", "this_year":
470 | 		start := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
471 | 		return &TimeRange{Start: start, End: now}, nil
472 | 
473 | 	case "lastyear", "last_year":
474 | 		thisYearStart := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
475 | 		lastYearStart := time.Date(now.Year()-1, 1, 1, 0, 0, 0, 0, now.Location())
476 | 
477 | 		return &TimeRange{Start: lastYearStart, End: thisYearStart}, nil
478 | 
479 | 	case "last365days", "last_365_days":
480 | 		start := now.Add(-365 * 24 * time.Hour)
481 | 		return &TimeRange{Start: start, End: now}, nil
482 | 
483 | 	default:
484 | 		return nil, fmt.Errorf("unknown time range: %s", name)
485 | 	}
486 | }
487 | 
```

--------------------------------------------------------------------------------
/internal/delivery/mcp/timescale_tools_test.go:
--------------------------------------------------------------------------------

```go
  1 | package mcp_test
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"strings"
  6 | 	"testing"
  7 | 
  8 | 	"github.com/FreePeak/cortex/pkg/server"
  9 | 	"github.com/stretchr/testify/assert"
 10 | 	"github.com/stretchr/testify/mock"
 11 | 
 12 | 	"github.com/FreePeak/db-mcp-server/internal/delivery/mcp"
 13 | )
 14 | 
 15 | // MockDatabaseUseCase is a mock implementation of the UseCaseProvider interface
 16 | type MockDatabaseUseCase struct {
 17 | 	mock.Mock
 18 | }
 19 | 
 20 | // ExecuteStatement mocks the ExecuteStatement method
 21 | func (m *MockDatabaseUseCase) ExecuteStatement(ctx context.Context, dbID, statement string, params []interface{}) (string, error) {
 22 | 	args := m.Called(ctx, dbID, statement, params)
 23 | 	return args.String(0), args.Error(1)
 24 | }
 25 | 
 26 | // GetDatabaseType mocks the GetDatabaseType method
 27 | func (m *MockDatabaseUseCase) GetDatabaseType(dbID string) (string, error) {
 28 | 	args := m.Called(dbID)
 29 | 	return args.String(0), args.Error(1)
 30 | }
 31 | 
 32 | // ExecuteQuery mocks the ExecuteQuery method
 33 | func (m *MockDatabaseUseCase) ExecuteQuery(ctx context.Context, dbID, query string, params []interface{}) (string, error) {
 34 | 	args := m.Called(ctx, dbID, query, params)
 35 | 	return args.String(0), args.Error(1)
 36 | }
 37 | 
 38 | // ExecuteTransaction mocks the ExecuteTransaction method
 39 | func (m *MockDatabaseUseCase) ExecuteTransaction(ctx context.Context, dbID, action string, txID string, statement string, params []interface{}, readOnly bool) (string, map[string]interface{}, error) {
 40 | 	args := m.Called(ctx, dbID, action, txID, statement, params, readOnly)
 41 | 	return args.String(0), args.Get(1).(map[string]interface{}), args.Error(2)
 42 | }
 43 | 
 44 | // GetDatabaseInfo mocks the GetDatabaseInfo method
 45 | func (m *MockDatabaseUseCase) GetDatabaseInfo(dbID string) (map[string]interface{}, error) {
 46 | 	args := m.Called(dbID)
 47 | 	return args.Get(0).(map[string]interface{}), args.Error(1)
 48 | }
 49 | 
 50 | // ListDatabases mocks the ListDatabases method
 51 | func (m *MockDatabaseUseCase) ListDatabases() []string {
 52 | 	args := m.Called()
 53 | 	return args.Get(0).([]string)
 54 | }
 55 | 
 56 | func TestTimescaleDBTool(t *testing.T) {
 57 | 	tool := mcp.NewTimescaleDBTool()
 58 | 	assert.Equal(t, "timescaledb", tool.GetName())
 59 | }
 60 | 
 61 | func TestTimeSeriesQueryTool(t *testing.T) {
 62 | 	// Create a mock use case provider
 63 | 	mockUseCase := new(MockDatabaseUseCase)
 64 | 
 65 | 	// Set up the TimescaleDB tool
 66 | 	tool := mcp.NewTimescaleDBTool()
 67 | 
 68 | 	// Create a context for testing
 69 | 	ctx := context.Background()
 70 | 
 71 | 	// Test case for time_series_query operation
 72 | 	t.Run("time_series_query with basic parameters", func(t *testing.T) {
 73 | 		// Sample result that would be returned by the database
 74 | 		sampleResult := `[
 75 | 			{"time_bucket": "2023-01-01T00:00:00Z", "avg_temp": 22.5, "count": 10},
 76 | 			{"time_bucket": "2023-01-02T00:00:00Z", "avg_temp": 23.1, "count": 12}
 77 | 		]`
 78 | 
 79 | 		// Set up expectations for the mock
 80 | 		mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.AnythingOfType("string"), mock.Anything).
 81 | 			Return(sampleResult, nil).Once()
 82 | 
 83 | 		// Create a request with time_series_query operation
 84 | 		request := server.ToolCallRequest{
 85 | 			Name: "timescaledb_timeseries_query_test_db",
 86 | 			Parameters: map[string]interface{}{
 87 | 				"operation":       "time_series_query",
 88 | 				"target_table":    "sensor_data",
 89 | 				"time_column":     "timestamp",
 90 | 				"bucket_interval": "1 day",
 91 | 				"start_time":      "2023-01-01",
 92 | 				"end_time":        "2023-01-31",
 93 | 				"aggregations":    "AVG(temperature) as avg_temp, COUNT(*) as count",
 94 | 			},
 95 | 		}
 96 | 
 97 | 		// Call the handler
 98 | 		result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
 99 | 
100 | 		// Verify the result
101 | 		assert.NoError(t, err)
102 | 		assert.NotNil(t, result)
103 | 
104 | 		// Check the result contains expected fields
105 | 		resultMap, ok := result.(map[string]interface{})
106 | 		assert.True(t, ok)
107 | 		assert.Contains(t, resultMap, "message")
108 | 		assert.Contains(t, resultMap, "details")
109 | 		assert.Equal(t, sampleResult, resultMap["details"])
110 | 
111 | 		// Verify the mock expectations
112 | 		mockUseCase.AssertExpectations(t)
113 | 	})
114 | 
115 | 	t.Run("time_series_query with window functions", func(t *testing.T) {
116 | 		// Sample result that would be returned by the database
117 | 		sampleResult := `[
118 | 			{"time_bucket": "2023-01-01T00:00:00Z", "avg_temp": 22.5, "prev_avg": null},
119 | 			{"time_bucket": "2023-01-02T00:00:00Z", "avg_temp": 23.1, "prev_avg": 22.5}
120 | 		]`
121 | 
122 | 		// Set up expectations for the mock
123 | 		mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.AnythingOfType("string"), mock.Anything).
124 | 			Return(sampleResult, nil).Once()
125 | 
126 | 		// Create a request with time_series_query operation
127 | 		request := server.ToolCallRequest{
128 | 			Name: "timescaledb_timeseries_query_test_db",
129 | 			Parameters: map[string]interface{}{
130 | 				"operation":        "time_series_query",
131 | 				"target_table":     "sensor_data",
132 | 				"time_column":      "timestamp",
133 | 				"bucket_interval":  "1 day",
134 | 				"aggregations":     "AVG(temperature) as avg_temp",
135 | 				"window_functions": "LAG(avg_temp) OVER (ORDER BY time_bucket) AS prev_avg",
136 | 				"format_pretty":    true,
137 | 			},
138 | 		}
139 | 
140 | 		// Call the handler
141 | 		result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
142 | 
143 | 		// Verify the result
144 | 		assert.NoError(t, err)
145 | 		assert.NotNil(t, result)
146 | 
147 | 		// Check the result contains expected fields
148 | 		resultMap, ok := result.(map[string]interface{})
149 | 		assert.True(t, ok)
150 | 		assert.Contains(t, resultMap, "message")
151 | 		assert.Contains(t, resultMap, "details")
152 | 		assert.Contains(t, resultMap, "metadata")
153 | 
154 | 		// Check metadata contains expected fields for pretty formatting
155 | 		metadata, ok := resultMap["metadata"].(map[string]interface{})
156 | 		assert.True(t, ok)
157 | 		assert.Contains(t, metadata, "num_rows")
158 | 		assert.Contains(t, metadata, "time_bucket_interval")
159 | 
160 | 		// Verify the mock expectations
161 | 		mockUseCase.AssertExpectations(t)
162 | 	})
163 | }
164 | 
165 | // TestContinuousAggregateTool tests the continuous aggregate operations
166 | func TestContinuousAggregateTool(t *testing.T) {
167 | 	// Create a context for testing
168 | 	ctx := context.Background()
169 | 
170 | 	// Test case for create_continuous_aggregate operation
171 | 	t.Run("create_continuous_aggregate", func(t *testing.T) {
172 | 		// Create a new mock for this test case
173 | 		mockUseCase := new(MockDatabaseUseCase)
174 | 
175 | 		// Set up the TimescaleDB tool
176 | 		tool := mcp.NewTimescaleDBTool()
177 | 
178 | 		// Set up expectations
179 | 		// Removed GetDatabaseType expectation as it's not called in this handler
180 | 
181 | 		// Add mock expectation for the SQL containing CREATE MATERIALIZED VIEW
182 | 		mockUseCase.On("ExecuteStatement",
183 | 			mock.Anything,
184 | 			"test_db",
185 | 			mock.MatchedBy(func(sql string) bool {
186 | 				return strings.Contains(sql, "CREATE MATERIALIZED VIEW")
187 | 			}),
188 | 			mock.Anything).Return(`{"result": "success"}`, nil)
189 | 
190 | 		// Add separate mock expectation for policy SQL if needed
191 | 		mockUseCase.On("ExecuteStatement",
192 | 			mock.Anything,
193 | 			"test_db",
194 | 			mock.MatchedBy(func(sql string) bool {
195 | 				return strings.Contains(sql, "add_continuous_aggregate_policy")
196 | 			}),
197 | 			mock.Anything).Return(`{"result": "success"}`, nil)
198 | 
199 | 		// Create a request
200 | 		request := server.ToolCallRequest{
201 | 			Name: "timescaledb_create_continuous_aggregate_test_db",
202 | 			Parameters: map[string]interface{}{
203 | 				"operation":        "create_continuous_aggregate",
204 | 				"view_name":        "daily_metrics",
205 | 				"source_table":     "sensor_data",
206 | 				"time_column":      "timestamp",
207 | 				"bucket_interval":  "1 day",
208 | 				"aggregations":     "AVG(temperature) as avg_temp, MIN(temperature) as min_temp, MAX(temperature) as max_temp",
209 | 				"with_data":        true,
210 | 				"refresh_policy":   true,
211 | 				"refresh_interval": "1 hour",
212 | 			},
213 | 		}
214 | 
215 | 		// Call the handler
216 | 		result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
217 | 
218 | 		// Verify the result
219 | 		assert.NoError(t, err)
220 | 		assert.NotNil(t, result)
221 | 
222 | 		// Check the result contains expected fields
223 | 		resultMap, ok := result.(map[string]interface{})
224 | 		assert.True(t, ok)
225 | 		assert.Contains(t, resultMap, "message")
226 | 		assert.Contains(t, resultMap, "sql")
227 | 
228 | 		// Verify the mock expectations
229 | 		mockUseCase.AssertExpectations(t)
230 | 	})
231 | 
232 | 	// Test case for refresh_continuous_aggregate operation
233 | 	t.Run("refresh_continuous_aggregate", func(t *testing.T) {
234 | 		// Create a new mock for this test case
235 | 		mockUseCase := new(MockDatabaseUseCase)
236 | 
237 | 		// Set up the TimescaleDB tool
238 | 		tool := mcp.NewTimescaleDBTool()
239 | 
240 | 		// Set up expectations
241 | 		// Removed GetDatabaseType expectation as it's not called in this handler
242 | 		mockUseCase.On("ExecuteStatement",
243 | 			mock.Anything,
244 | 			"test_db",
245 | 			mock.MatchedBy(func(sql string) bool {
246 | 				return strings.Contains(sql, "CALL refresh_continuous_aggregate")
247 | 			}),
248 | 			mock.Anything).Return(`{"result": "success"}`, nil)
249 | 
250 | 		// Create a request
251 | 		request := server.ToolCallRequest{
252 | 			Name: "timescaledb_refresh_continuous_aggregate_test_db",
253 | 			Parameters: map[string]interface{}{
254 | 				"operation":  "refresh_continuous_aggregate",
255 | 				"view_name":  "daily_metrics",
256 | 				"start_time": "2023-01-01",
257 | 				"end_time":   "2023-01-31",
258 | 			},
259 | 		}
260 | 
261 | 		// Call the handler
262 | 		result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
263 | 
264 | 		// Verify the result
265 | 		assert.NoError(t, err)
266 | 		assert.NotNil(t, result)
267 | 
268 | 		// Check the result contains expected fields
269 | 		resultMap, ok := result.(map[string]interface{})
270 | 		assert.True(t, ok)
271 | 		assert.Contains(t, resultMap, "message")
272 | 
273 | 		// Verify the mock expectations
274 | 		mockUseCase.AssertExpectations(t)
275 | 	})
276 | 
277 | 	// Test case for drop_continuous_aggregate operation
278 | 	t.Run("drop_continuous_aggregate", func(t *testing.T) {
279 | 		// Create a new mock for this test case
280 | 		mockUseCase := new(MockDatabaseUseCase)
281 | 
282 | 		// Set up the TimescaleDB tool
283 | 		tool := mcp.NewTimescaleDBTool()
284 | 
285 | 		// Set up expectations
286 | 		// Removed GetDatabaseType expectation as it's not called in this handler
287 | 		mockUseCase.On("ExecuteStatement",
288 | 			mock.Anything,
289 | 			"test_db",
290 | 			mock.MatchedBy(func(sql string) bool {
291 | 				return strings.Contains(sql, "DROP MATERIALIZED VIEW")
292 | 			}),
293 | 			mock.Anything).Return(`{"result": "success"}`, nil)
294 | 
295 | 		// Create a request
296 | 		request := server.ToolCallRequest{
297 | 			Name: "timescaledb_drop_continuous_aggregate_test_db",
298 | 			Parameters: map[string]interface{}{
299 | 				"operation": "drop_continuous_aggregate",
300 | 				"view_name": "daily_metrics",
301 | 				"cascade":   true,
302 | 			},
303 | 		}
304 | 
305 | 		// Call the handler
306 | 		result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
307 | 
308 | 		// Verify the result
309 | 		assert.NoError(t, err)
310 | 		assert.NotNil(t, result)
311 | 
312 | 		// Check the result contains expected fields
313 | 		resultMap, ok := result.(map[string]interface{})
314 | 		assert.True(t, ok)
315 | 		assert.Contains(t, resultMap, "message")
316 | 
317 | 		// Verify the mock expectations
318 | 		mockUseCase.AssertExpectations(t)
319 | 	})
320 | 
321 | 	// Test case for list_continuous_aggregates operation
322 | 	t.Run("list_continuous_aggregates", func(t *testing.T) {
323 | 		// Create a new mock for this test case
324 | 		mockUseCase := new(MockDatabaseUseCase)
325 | 
326 | 		// Set up the TimescaleDB tool
327 | 		tool := mcp.NewTimescaleDBTool()
328 | 
329 | 		// Set up expectations
330 | 		mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
331 | 		mockUseCase.On("ExecuteStatement",
332 | 			mock.Anything,
333 | 			"test_db",
334 | 			mock.MatchedBy(func(sql string) bool {
335 | 				return strings.Contains(sql, "SELECT") && strings.Contains(sql, "continuous_aggregates")
336 | 			}),
337 | 			mock.Anything).Return(`[{"view_name": "daily_metrics", "source_table": "sensor_data"}]`, nil)
338 | 
339 | 		// Create a request
340 | 		request := server.ToolCallRequest{
341 | 			Name: "timescaledb_list_continuous_aggregates_test_db",
342 | 			Parameters: map[string]interface{}{
343 | 				"operation": "list_continuous_aggregates",
344 | 			},
345 | 		}
346 | 
347 | 		// Call the handler
348 | 		result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
349 | 
350 | 		// Verify the result
351 | 		assert.NoError(t, err)
352 | 		assert.NotNil(t, result)
353 | 
354 | 		// Check the result contains expected fields
355 | 		resultMap, ok := result.(map[string]interface{})
356 | 		assert.True(t, ok)
357 | 		assert.Contains(t, resultMap, "message")
358 | 		assert.Contains(t, resultMap, "details")
359 | 
360 | 		// Verify the mock expectations
361 | 		mockUseCase.AssertExpectations(t)
362 | 	})
363 | 
364 | 	// Test case for get_continuous_aggregate_info operation
365 | 	t.Run("get_continuous_aggregate_info", func(t *testing.T) {
366 | 		// Create a new mock for this test case
367 | 		mockUseCase := new(MockDatabaseUseCase)
368 | 
369 | 		// Set up the TimescaleDB tool
370 | 		tool := mcp.NewTimescaleDBTool()
371 | 
372 | 		// Set up expectations
373 | 		mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
374 | 		mockUseCase.On("ExecuteStatement",
375 | 			mock.Anything,
376 | 			"test_db",
377 | 			mock.MatchedBy(func(sql string) bool {
378 | 				return strings.Contains(sql, "SELECT") && strings.Contains(sql, "continuous_aggregates") && strings.Contains(sql, "WHERE")
379 | 			}),
380 | 			mock.Anything).Return(`[{"view_name": "daily_metrics", "source_table": "sensor_data", "bucket_interval": "1 day"}]`, nil)
381 | 
382 | 		// Create a request
383 | 		request := server.ToolCallRequest{
384 | 			Name: "timescaledb_get_continuous_aggregate_info_test_db",
385 | 			Parameters: map[string]interface{}{
386 | 				"operation": "get_continuous_aggregate_info",
387 | 				"view_name": "daily_metrics",
388 | 			},
389 | 		}
390 | 
391 | 		// Call the handler
392 | 		result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
393 | 
394 | 		// Verify the result
395 | 		assert.NoError(t, err)
396 | 		assert.NotNil(t, result)
397 | 
398 | 		// Check the result contains expected fields
399 | 		resultMap, ok := result.(map[string]interface{})
400 | 		assert.True(t, ok)
401 | 		assert.Contains(t, resultMap, "message")
402 | 		assert.Contains(t, resultMap, "details")
403 | 
404 | 		// Verify the mock expectations
405 | 		mockUseCase.AssertExpectations(t)
406 | 	})
407 | 
408 | 	// Test case for add_continuous_aggregate_policy operation
409 | 	t.Run("add_continuous_aggregate_policy", func(t *testing.T) {
410 | 		// Create a new mock for this test case
411 | 		mockUseCase := new(MockDatabaseUseCase)
412 | 
413 | 		// Set up the TimescaleDB tool
414 | 		tool := mcp.NewTimescaleDBTool()
415 | 
416 | 		// Set up expectations
417 | 		mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
418 | 		mockUseCase.On("ExecuteStatement",
419 | 			mock.Anything,
420 | 			"test_db",
421 | 			mock.MatchedBy(func(sql string) bool {
422 | 				return strings.Contains(sql, "add_continuous_aggregate_policy")
423 | 			}),
424 | 			mock.Anything).Return(`{"result": "success"}`, nil)
425 | 
426 | 		// Create a request
427 | 		request := server.ToolCallRequest{
428 | 			Name: "timescaledb_add_continuous_aggregate_policy_test_db",
429 | 			Parameters: map[string]interface{}{
430 | 				"operation":         "add_continuous_aggregate_policy",
431 | 				"view_name":         "daily_metrics",
432 | 				"start_offset":      "1 month",
433 | 				"end_offset":        "2 hours",
434 | 				"schedule_interval": "6 hours",
435 | 			},
436 | 		}
437 | 
438 | 		// Call the handler
439 | 		result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
440 | 
441 | 		// Verify the result
442 | 		assert.NoError(t, err)
443 | 		assert.NotNil(t, result)
444 | 
445 | 		// Check the result contains expected fields
446 | 		resultMap, ok := result.(map[string]interface{})
447 | 		assert.True(t, ok)
448 | 		assert.Contains(t, resultMap, "message")
449 | 
450 | 		// Verify the mock expectations
451 | 		mockUseCase.AssertExpectations(t)
452 | 	})
453 | 
454 | 	// Test case for remove_continuous_aggregate_policy operation
455 | 	t.Run("remove_continuous_aggregate_policy", func(t *testing.T) {
456 | 		// Create a new mock for this test case
457 | 		mockUseCase := new(MockDatabaseUseCase)
458 | 
459 | 		// Set up the TimescaleDB tool
460 | 		tool := mcp.NewTimescaleDBTool()
461 | 
462 | 		// Set up expectations
463 | 		mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
464 | 		mockUseCase.On("ExecuteStatement",
465 | 			mock.Anything,
466 | 			"test_db",
467 | 			mock.MatchedBy(func(sql string) bool {
468 | 				return strings.Contains(sql, "remove_continuous_aggregate_policy")
469 | 			}),
470 | 			mock.Anything).Return(`{"result": "success"}`, nil)
471 | 
472 | 		// Create a request
473 | 		request := server.ToolCallRequest{
474 | 			Name: "timescaledb_remove_continuous_aggregate_policy_test_db",
475 | 			Parameters: map[string]interface{}{
476 | 				"operation": "remove_continuous_aggregate_policy",
477 | 				"view_name": "daily_metrics",
478 | 			},
479 | 		}
480 | 
481 | 		// Call the handler
482 | 		result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
483 | 
484 | 		// Verify the result
485 | 		assert.NoError(t, err)
486 | 		assert.NotNil(t, result)
487 | 
488 | 		// Check the result contains expected fields
489 | 		resultMap, ok := result.(map[string]interface{})
490 | 		assert.True(t, ok)
491 | 		assert.Contains(t, resultMap, "message")
492 | 
493 | 		// Verify the mock expectations
494 | 		mockUseCase.AssertExpectations(t)
495 | 	})
496 | }
497 | 
```

--------------------------------------------------------------------------------
/pkg/db/timescale/metadata.go:
--------------------------------------------------------------------------------

```go
  1 | package timescale
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"fmt"
  6 | 	"strconv"
  7 | 	"strings"
  8 | )
  9 | 
 10 | // HypertableMetadata represents metadata about a TimescaleDB hypertable
 11 | type HypertableMetadata struct {
 12 | 	TableName         string
 13 | 	SchemaName        string
 14 | 	Owner             string
 15 | 	NumDimensions     int
 16 | 	TimeDimension     string
 17 | 	TimeDimensionType string
 18 | 	SpaceDimensions   []string
 19 | 	ChunkTimeInterval string
 20 | 	Compression       bool
 21 | 	RetentionPolicy   bool
 22 | 	TotalSize         string
 23 | 	TotalRows         int64
 24 | 	Chunks            int
 25 | }
 26 | 
 27 | // ColumnMetadata represents metadata about a column
 28 | type ColumnMetadata struct {
 29 | 	Name         string
 30 | 	Type         string
 31 | 	Nullable     bool
 32 | 	IsPrimaryKey bool
 33 | 	IsIndexed    bool
 34 | 	Description  string
 35 | }
 36 | 
 37 | // ContinuousAggregateMetadata represents metadata about a continuous aggregate
 38 | type ContinuousAggregateMetadata struct {
 39 | 	ViewName           string
 40 | 	ViewSchema         string
 41 | 	MaterializedOnly   bool
 42 | 	RefreshInterval    string
 43 | 	RefreshLag         string
 44 | 	RefreshStartOffset string
 45 | 	RefreshEndOffset   string
 46 | 	HypertableName     string
 47 | 	HypertableSchema   string
 48 | 	ViewDefinition     string
 49 | }
 50 | 
 51 | // GetHypertableMetadata returns detailed metadata about a hypertable
 52 | func (t *DB) GetHypertableMetadata(ctx context.Context, tableName string) (*HypertableMetadata, error) {
 53 | 	if !t.isTimescaleDB {
 54 | 		return nil, fmt.Errorf("TimescaleDB extension not available")
 55 | 	}
 56 | 
 57 | 	// Query to get basic hypertable information
 58 | 	query := fmt.Sprintf(`
 59 | 		SELECT 
 60 | 			h.table_name,
 61 | 			h.schema_name,
 62 | 			t.tableowner as owner,
 63 | 			h.num_dimensions,
 64 | 			dc.column_name as time_dimension,
 65 | 			dc.column_type as time_dimension_type,
 66 | 			dc.time_interval as chunk_time_interval,
 67 | 			h.compression_enabled,
 68 | 			pg_size_pretty(pg_total_relation_size(format('%%I.%%I', h.schema_name, h.table_name))) as total_size,
 69 | 			(SELECT count(*) FROM timescaledb_information.chunks WHERE hypertable_name = h.table_name) as chunks,
 70 | 			(SELECT count(*) FROM %s.%s) as total_rows
 71 | 		FROM timescaledb_information.hypertables h
 72 | 		JOIN pg_tables t ON h.table_name = t.tablename AND h.schema_name = t.schemaname
 73 | 		JOIN timescaledb_information.dimensions dc ON h.hypertable_name = dc.hypertable_name
 74 | 		WHERE h.table_name = '%s' AND dc.dimension_number = 1
 75 | 	`, tableName, tableName, tableName)
 76 | 
 77 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
 78 | 	if err != nil {
 79 | 		return nil, fmt.Errorf("failed to get hypertable metadata: %w", err)
 80 | 	}
 81 | 
 82 | 	rows, ok := result.([]map[string]interface{})
 83 | 	if !ok || len(rows) == 0 {
 84 | 		return nil, fmt.Errorf("table '%s' is not a hypertable", tableName)
 85 | 	}
 86 | 
 87 | 	row := rows[0]
 88 | 	metadata := &HypertableMetadata{
 89 | 		TableName:         fmt.Sprintf("%v", row["table_name"]),
 90 | 		SchemaName:        fmt.Sprintf("%v", row["schema_name"]),
 91 | 		Owner:             fmt.Sprintf("%v", row["owner"]),
 92 | 		TimeDimension:     fmt.Sprintf("%v", row["time_dimension"]),
 93 | 		TimeDimensionType: fmt.Sprintf("%v", row["time_dimension_type"]),
 94 | 		ChunkTimeInterval: fmt.Sprintf("%v", row["chunk_time_interval"]),
 95 | 		TotalSize:         fmt.Sprintf("%v", row["total_size"]),
 96 | 	}
 97 | 
 98 | 	// Convert numeric fields
 99 | 	if numDimensions, ok := row["num_dimensions"].(int64); ok {
100 | 		metadata.NumDimensions = int(numDimensions)
101 | 	} else if numDimensions, ok := row["num_dimensions"].(int); ok {
102 | 		metadata.NumDimensions = numDimensions
103 | 	}
104 | 
105 | 	if chunks, ok := row["chunks"].(int64); ok {
106 | 		metadata.Chunks = int(chunks)
107 | 	} else if chunks, ok := row["chunks"].(int); ok {
108 | 		metadata.Chunks = chunks
109 | 	}
110 | 
111 | 	if rows, ok := row["total_rows"].(int64); ok {
112 | 		metadata.TotalRows = rows
113 | 	} else if rows, ok := row["total_rows"].(int); ok {
114 | 		metadata.TotalRows = int64(rows)
115 | 	} else if rowsStr, ok := row["total_rows"].(string); ok {
116 | 		if rows, err := strconv.ParseInt(rowsStr, 10, 64); err == nil {
117 | 			metadata.TotalRows = rows
118 | 		}
119 | 	}
120 | 
121 | 	// Handle boolean fields
122 | 	if compression, ok := row["compression_enabled"].(bool); ok {
123 | 		metadata.Compression = compression
124 | 	} else if compressionStr, ok := row["compression_enabled"].(string); ok {
125 | 		metadata.Compression = compressionStr == "t" || compressionStr == "true" || compressionStr == "1"
126 | 	}
127 | 
128 | 	// Get space dimensions if there are more than one dimension
129 | 	if metadata.NumDimensions > 1 {
130 | 		spaceDimQuery := fmt.Sprintf(`
131 | 			SELECT column_name
132 | 			FROM timescaledb_information.dimensions
133 | 			WHERE hypertable_name = '%s' AND dimension_number > 1
134 | 			ORDER BY dimension_number
135 | 		`, tableName)
136 | 
137 | 		spaceResult, err := t.ExecuteSQLWithoutParams(ctx, spaceDimQuery)
138 | 		if err == nil {
139 | 			spaceDimRows, ok := spaceResult.([]map[string]interface{})
140 | 			if ok {
141 | 				for _, dimRow := range spaceDimRows {
142 | 					if colName, ok := dimRow["column_name"]; ok && colName != nil {
143 | 						metadata.SpaceDimensions = append(metadata.SpaceDimensions, fmt.Sprintf("%v", colName))
144 | 					}
145 | 				}
146 | 			}
147 | 		}
148 | 	}
149 | 
150 | 	// Check if a retention policy exists
151 | 	retentionQuery := fmt.Sprintf(`
152 | 		SELECT COUNT(*) > 0 as has_retention
153 | 		FROM timescaledb_information.jobs
154 | 		WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'
155 | 	`, tableName)
156 | 
157 | 	retentionResult, err := t.ExecuteSQLWithoutParams(ctx, retentionQuery)
158 | 	if err == nil {
159 | 		retentionRows, ok := retentionResult.([]map[string]interface{})
160 | 		if ok && len(retentionRows) > 0 {
161 | 			if hasRetention, ok := retentionRows[0]["has_retention"].(bool); ok {
162 | 				metadata.RetentionPolicy = hasRetention
163 | 			}
164 | 		}
165 | 	}
166 | 
167 | 	return metadata, nil
168 | }
169 | 
170 | // GetTableColumns returns metadata about columns in a table
171 | func (t *DB) GetTableColumns(ctx context.Context, tableName string) ([]ColumnMetadata, error) {
172 | 	query := fmt.Sprintf(`
173 | 		SELECT 
174 | 			c.column_name, 
175 | 			c.data_type,
176 | 			c.is_nullable = 'YES' as is_nullable,
177 | 			(
178 | 				SELECT COUNT(*) > 0
179 | 				FROM pg_index i
180 | 				JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
181 | 				WHERE i.indrelid = format('%%I.%%I', c.table_schema, c.table_name)::regclass
182 | 				AND i.indisprimary
183 | 				AND a.attname = c.column_name
184 | 			) as is_primary_key,
185 | 			(
186 | 				SELECT COUNT(*) > 0
187 | 				FROM pg_index i
188 | 				JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
189 | 				WHERE i.indrelid = format('%%I.%%I', c.table_schema, c.table_name)::regclass
190 | 				AND NOT i.indisprimary
191 | 				AND a.attname = c.column_name
192 | 			) as is_indexed,
193 | 			col_description(format('%%I.%%I', c.table_schema, c.table_name)::regclass::oid, 
194 | 							ordinal_position) as description
195 | 		FROM information_schema.columns c
196 | 		WHERE c.table_name = '%s'
197 | 		ORDER BY c.ordinal_position
198 | 	`, tableName)
199 | 
200 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
201 | 	if err != nil {
202 | 		return nil, fmt.Errorf("failed to get table columns: %w", err)
203 | 	}
204 | 
205 | 	rows, ok := result.([]map[string]interface{})
206 | 	if !ok {
207 | 		return nil, fmt.Errorf("unexpected result type from database query")
208 | 	}
209 | 
210 | 	var columns []ColumnMetadata
211 | 	for _, row := range rows {
212 | 		col := ColumnMetadata{
213 | 			Name: fmt.Sprintf("%v", row["column_name"]),
214 | 			Type: fmt.Sprintf("%v", row["data_type"]),
215 | 		}
216 | 
217 | 		// Handle boolean fields
218 | 		if nullable, ok := row["is_nullable"].(bool); ok {
219 | 			col.Nullable = nullable
220 | 		}
221 | 		if isPK, ok := row["is_primary_key"].(bool); ok {
222 | 			col.IsPrimaryKey = isPK
223 | 		}
224 | 		if isIndexed, ok := row["is_indexed"].(bool); ok {
225 | 			col.IsIndexed = isIndexed
226 | 		}
227 | 
228 | 		// Handle description which might be null
229 | 		if desc, ok := row["description"]; ok && desc != nil {
230 | 			col.Description = fmt.Sprintf("%v", desc)
231 | 		}
232 | 
233 | 		columns = append(columns, col)
234 | 	}
235 | 
236 | 	return columns, nil
237 | }
238 | 
239 | // ListContinuousAggregates lists all continuous aggregates
240 | func (t *DB) ListContinuousAggregates(ctx context.Context) ([]ContinuousAggregateMetadata, error) {
241 | 	if !t.isTimescaleDB {
242 | 		return nil, fmt.Errorf("TimescaleDB extension not available")
243 | 	}
244 | 
245 | 	query := `
246 | 		SELECT 
247 | 			view_name,
248 | 			view_schema,
249 | 			materialized_only,
250 | 			refresh_lag,
251 | 			refresh_interval,
252 | 			hypertable_name,
253 | 			hypertable_schema
254 | 		FROM timescaledb_information.continuous_aggregates
255 | 	`
256 | 
257 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
258 | 	if err != nil {
259 | 		return nil, fmt.Errorf("failed to list continuous aggregates: %w", err)
260 | 	}
261 | 
262 | 	rows, ok := result.([]map[string]interface{})
263 | 	if !ok {
264 | 		return nil, fmt.Errorf("unexpected result type from database query")
265 | 	}
266 | 
267 | 	var aggregates []ContinuousAggregateMetadata
268 | 	for _, row := range rows {
269 | 		agg := ContinuousAggregateMetadata{
270 | 			ViewName:         fmt.Sprintf("%v", row["view_name"]),
271 | 			ViewSchema:       fmt.Sprintf("%v", row["view_schema"]),
272 | 			HypertableName:   fmt.Sprintf("%v", row["hypertable_name"]),
273 | 			HypertableSchema: fmt.Sprintf("%v", row["hypertable_schema"]),
274 | 		}
275 | 
276 | 		// Handle boolean fields
277 | 		if materializedOnly, ok := row["materialized_only"].(bool); ok {
278 | 			agg.MaterializedOnly = materializedOnly
279 | 		}
280 | 
281 | 		// Handle nullable fields
282 | 		if refreshLag, ok := row["refresh_lag"]; ok && refreshLag != nil {
283 | 			agg.RefreshLag = fmt.Sprintf("%v", refreshLag)
284 | 		}
285 | 		if refreshInterval, ok := row["refresh_interval"]; ok && refreshInterval != nil {
286 | 			agg.RefreshInterval = fmt.Sprintf("%v", refreshInterval)
287 | 		}
288 | 
289 | 		// Get view definition
290 | 		definitionQuery := fmt.Sprintf(`
291 | 			SELECT pg_get_viewdef(format('%%I.%%I', '%s', '%s')::regclass, true) as view_definition
292 | 		`, agg.ViewSchema, agg.ViewName)
293 | 
294 | 		defResult, err := t.ExecuteSQLWithoutParams(ctx, definitionQuery)
295 | 		if err == nil {
296 | 			defRows, ok := defResult.([]map[string]interface{})
297 | 			if ok && len(defRows) > 0 {
298 | 				if def, ok := defRows[0]["view_definition"]; ok && def != nil {
299 | 					agg.ViewDefinition = fmt.Sprintf("%v", def)
300 | 				}
301 | 			}
302 | 		}
303 | 
304 | 		aggregates = append(aggregates, agg)
305 | 	}
306 | 
307 | 	return aggregates, nil
308 | }
309 | 
310 | // GetContinuousAggregate gets metadata about a specific continuous aggregate
311 | func (t *DB) GetContinuousAggregate(ctx context.Context, viewName string) (*ContinuousAggregateMetadata, error) {
312 | 	if !t.isTimescaleDB {
313 | 		return nil, fmt.Errorf("TimescaleDB extension not available")
314 | 	}
315 | 
316 | 	query := fmt.Sprintf(`
317 | 		SELECT 
318 | 			view_name,
319 | 			view_schema,
320 | 			materialized_only,
321 | 			refresh_lag,
322 | 			refresh_interval,
323 | 			hypertable_name,
324 | 			hypertable_schema
325 | 		FROM timescaledb_information.continuous_aggregates
326 | 		WHERE view_name = '%s'
327 | 	`, viewName)
328 | 
329 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
330 | 	if err != nil {
331 | 		return nil, fmt.Errorf("failed to get continuous aggregate: %w", err)
332 | 	}
333 | 
334 | 	rows, ok := result.([]map[string]interface{})
335 | 	if !ok || len(rows) == 0 {
336 | 		return nil, fmt.Errorf("continuous aggregate '%s' not found", viewName)
337 | 	}
338 | 
339 | 	row := rows[0]
340 | 	agg := &ContinuousAggregateMetadata{
341 | 		ViewName:         fmt.Sprintf("%v", row["view_name"]),
342 | 		ViewSchema:       fmt.Sprintf("%v", row["view_schema"]),
343 | 		HypertableName:   fmt.Sprintf("%v", row["hypertable_name"]),
344 | 		HypertableSchema: fmt.Sprintf("%v", row["hypertable_schema"]),
345 | 	}
346 | 
347 | 	// Handle boolean fields
348 | 	if materializedOnly, ok := row["materialized_only"].(bool); ok {
349 | 		agg.MaterializedOnly = materializedOnly
350 | 	}
351 | 
352 | 	// Handle nullable fields
353 | 	if refreshLag, ok := row["refresh_lag"]; ok && refreshLag != nil {
354 | 		agg.RefreshLag = fmt.Sprintf("%v", refreshLag)
355 | 	}
356 | 	if refreshInterval, ok := row["refresh_interval"]; ok && refreshInterval != nil {
357 | 		agg.RefreshInterval = fmt.Sprintf("%v", refreshInterval)
358 | 	}
359 | 
360 | 	// Get view definition
361 | 	definitionQuery := fmt.Sprintf(`
362 | 		SELECT pg_get_viewdef(format('%%I.%%I', '%s', '%s')::regclass, true) as view_definition
363 | 	`, agg.ViewSchema, agg.ViewName)
364 | 
365 | 	defResult, err := t.ExecuteSQLWithoutParams(ctx, definitionQuery)
366 | 	if err == nil {
367 | 		defRows, ok := defResult.([]map[string]interface{})
368 | 		if ok && len(defRows) > 0 {
369 | 			if def, ok := defRows[0]["view_definition"]; ok && def != nil {
370 | 				agg.ViewDefinition = fmt.Sprintf("%v", def)
371 | 			}
372 | 		}
373 | 	}
374 | 
375 | 	return agg, nil
376 | }
377 | 
378 | // GetDatabaseSize gets size information about the database
379 | func (t *DB) GetDatabaseSize(ctx context.Context) (map[string]string, error) {
380 | 	query := `
381 | 		SELECT 
382 | 			pg_size_pretty(pg_database_size(current_database())) as database_size,
383 | 			current_database() as database_name,
384 | 			(
385 | 				SELECT pg_size_pretty(sum(pg_total_relation_size(format('%I.%I', h.schema_name, h.table_name))))
386 | 				FROM timescaledb_information.hypertables h
387 | 			) as hypertables_size,
388 | 			(
389 | 				SELECT count(*)
390 | 				FROM timescaledb_information.hypertables
391 | 			) as hypertables_count
392 | 	`
393 | 
394 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
395 | 	if err != nil {
396 | 		return nil, fmt.Errorf("failed to get database size: %w", err)
397 | 	}
398 | 
399 | 	rows, ok := result.([]map[string]interface{})
400 | 	if !ok || len(rows) == 0 {
401 | 		return nil, fmt.Errorf("failed to get database size information")
402 | 	}
403 | 
404 | 	info := make(map[string]string)
405 | 	for k, v := range rows[0] {
406 | 		if v != nil {
407 | 			info[k] = fmt.Sprintf("%v", v)
408 | 		}
409 | 	}
410 | 
411 | 	return info, nil
412 | }
413 | 
414 | // DetectTimescaleDBVersion checks if TimescaleDB is installed and returns its version
415 | func (t *DB) DetectTimescaleDBVersion(ctx context.Context) (string, error) {
416 | 	query := "SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'"
417 | 	result, err := t.ExecuteSQLWithoutParams(ctx, query)
418 | 	if err != nil {
419 | 		return "", fmt.Errorf("failed to check TimescaleDB version: %w", err)
420 | 	}
421 | 
422 | 	rows, ok := result.([]map[string]interface{})
423 | 	if !ok || len(rows) == 0 {
424 | 		return "", fmt.Errorf("TimescaleDB extension not installed")
425 | 	}
426 | 
427 | 	version := rows[0]["extversion"]
428 | 	if version == nil {
429 | 		return "", fmt.Errorf("unable to determine TimescaleDB version")
430 | 	}
431 | 
432 | 	return fmt.Sprintf("%v", version), nil
433 | }
434 | 
435 | // GenerateHypertableSchema generates CREATE TABLE and CREATE HYPERTABLE statements for a hypertable
436 | func (t *DB) GenerateHypertableSchema(ctx context.Context, tableName string) (string, error) {
437 | 	if !t.isTimescaleDB {
438 | 		return "", fmt.Errorf("TimescaleDB extension not available")
439 | 	}
440 | 
441 | 	// Get table columns and constraints
442 | 	columnsQuery := fmt.Sprintf(`
443 | 		SELECT 
444 | 			'CREATE TABLE ' || quote_ident('%s') || ' (' ||
445 | 			string_agg(
446 | 				quote_ident(column_name) || ' ' || 
447 | 				data_type || 
448 | 				CASE 
449 | 					WHEN character_maximum_length IS NOT NULL THEN '(' || character_maximum_length || ')'
450 | 					WHEN numeric_precision IS NOT NULL AND numeric_scale IS NOT NULL THEN '(' || numeric_precision || ',' || numeric_scale || ')'
451 | 					ELSE ''
452 | 				END ||
453 | 				CASE WHEN is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END,
454 | 				', '
455 | 			) ||
456 | 			CASE 
457 | 				WHEN (
458 | 					SELECT count(*) > 0 
459 | 					FROM information_schema.table_constraints tc
460 | 					WHERE tc.table_name = '%s' AND tc.constraint_type = 'PRIMARY KEY'
461 | 				) THEN 
462 | 					', ' || (
463 | 						SELECT 'PRIMARY KEY (' || string_agg(quote_ident(kcu.column_name), ', ') || ')'
464 | 						FROM information_schema.table_constraints tc
465 | 						JOIN information_schema.key_column_usage kcu ON 
466 | 							kcu.constraint_name = tc.constraint_name AND
467 | 							kcu.table_schema = tc.table_schema AND
468 | 							kcu.table_name = tc.table_name
469 | 						WHERE tc.table_name = '%s' AND tc.constraint_type = 'PRIMARY KEY'
470 | 					)
471 | 				ELSE ''
472 | 			END ||
473 | 			');' as create_table_stmt
474 | 		FROM information_schema.columns
475 | 		WHERE table_name = '%s'
476 | 		GROUP BY table_name
477 | 	`, tableName, tableName, tableName, tableName)
478 | 
479 | 	columnsResult, err := t.ExecuteSQLWithoutParams(ctx, columnsQuery)
480 | 	if err != nil {
481 | 		return "", fmt.Errorf("failed to generate schema: %w", err)
482 | 	}
483 | 
484 | 	columnsRows, ok := columnsResult.([]map[string]interface{})
485 | 	if !ok || len(columnsRows) == 0 {
486 | 		return "", fmt.Errorf("failed to generate schema for table '%s'", tableName)
487 | 	}
488 | 
489 | 	createTableStmt := fmt.Sprintf("%v", columnsRows[0]["create_table_stmt"])
490 | 
491 | 	// Get hypertable metadata
492 | 	metadata, err := t.GetHypertableMetadata(ctx, tableName)
493 | 	if err != nil {
494 | 		return createTableStmt, nil // Return just the CREATE TABLE statement if it's not a hypertable
495 | 	}
496 | 
497 | 	// Generate CREATE HYPERTABLE statement
498 | 	var createHypertableStmt strings.Builder
499 | 	createHypertableStmt.WriteString(fmt.Sprintf("SELECT create_hypertable('%s', '%s'",
500 | 		tableName, metadata.TimeDimension))
501 | 
502 | 	if metadata.ChunkTimeInterval != "" {
503 | 		createHypertableStmt.WriteString(fmt.Sprintf(", chunk_time_interval => INTERVAL '%s'",
504 | 			metadata.ChunkTimeInterval))
505 | 	}
506 | 
507 | 	if len(metadata.SpaceDimensions) > 0 {
508 | 		createHypertableStmt.WriteString(fmt.Sprintf(", partitioning_column => '%s'",
509 | 			metadata.SpaceDimensions[0]))
510 | 	}
511 | 
512 | 	createHypertableStmt.WriteString(");")
513 | 
514 | 	// Combine statements
515 | 	result := createTableStmt + "\n\n" + createHypertableStmt.String()
516 | 
517 | 	// Add compression statement if enabled
518 | 	if metadata.Compression {
519 | 		compressionSettings, err := t.GetCompressionSettings(ctx, tableName)
520 | 		if err == nil && compressionSettings.CompressionEnabled {
521 | 			compressionStmt := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true);", tableName)
522 | 			result += "\n\n" + compressionStmt
523 | 
524 | 			// Add compression policy if exists
525 | 			if compressionSettings.CompressionInterval != "" {
526 | 				policyStmt := fmt.Sprintf("SELECT add_compression_policy('%s', INTERVAL '%s'",
527 | 					tableName, compressionSettings.CompressionInterval)
528 | 
529 | 				if compressionSettings.SegmentBy != "" {
530 | 					policyStmt += fmt.Sprintf(", segmentby => '%s'", compressionSettings.SegmentBy)
531 | 				}
532 | 
533 | 				if compressionSettings.OrderBy != "" {
534 | 					policyStmt += fmt.Sprintf(", orderby => '%s'", compressionSettings.OrderBy)
535 | 				}
536 | 
537 | 				policyStmt += ");"
538 | 				result += "\n" + policyStmt
539 | 			}
540 | 		}
541 | 	}
542 | 
543 | 	// Add retention policy if enabled
544 | 	if metadata.RetentionPolicy {
545 | 		retentionSettings, err := t.GetRetentionSettings(ctx, tableName)
546 | 		if err == nil && retentionSettings.RetentionEnabled && retentionSettings.RetentionInterval != "" {
547 | 			retentionStmt := fmt.Sprintf("SELECT add_retention_policy('%s', INTERVAL '%s');",
548 | 				tableName, retentionSettings.RetentionInterval)
549 | 			result += "\n\n" + retentionStmt
550 | 		}
551 | 	}
552 | 
553 | 	return result, nil
554 | }
555 | 
```

--------------------------------------------------------------------------------
/internal/delivery/mcp/tool_types.go:
--------------------------------------------------------------------------------

```go
  1 | package mcp
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"fmt"
  6 | 	"strings"
  7 | 
  8 | 	"github.com/FreePeak/cortex/pkg/server"
  9 | 	"github.com/FreePeak/cortex/pkg/tools"
 10 | )
 11 | 
 12 | // createTextResponse creates a simple response with a text content
 13 | func createTextResponse(text string) map[string]interface{} {
 14 | 	return map[string]interface{}{
 15 | 		"content": []map[string]interface{}{
 16 | 			{
 17 | 				"type": "text",
 18 | 				"text": text,
 19 | 			},
 20 | 		},
 21 | 	}
 22 | }
 23 | 
 24 | // addMetadata adds metadata to a response
 25 | func addMetadata(resp map[string]interface{}, key string, value interface{}) map[string]interface{} {
 26 | 	if resp["metadata"] == nil {
 27 | 		resp["metadata"] = make(map[string]interface{})
 28 | 	}
 29 | 
 30 | 	metadata, ok := resp["metadata"].(map[string]interface{})
 31 | 	if !ok {
 32 | 		// Create a new metadata map if conversion fails
 33 | 		metadata = make(map[string]interface{})
 34 | 		resp["metadata"] = metadata
 35 | 	}
 36 | 
 37 | 	metadata[key] = value
 38 | 	return resp
 39 | }
 40 | 
 41 | // TODO: Refactor tool type implementations to reduce duplication and improve maintainability
 42 | // TODO: Consider using a code generation approach for repetitive tool patterns
 43 | // TODO: Add comprehensive request validation for all tool parameters
 44 | // TODO: Implement proper rate limiting and resource protection
 45 | // TODO: Add detailed documentation for each tool type and its parameters
 46 | 
 47 | // ToolType interface defines the structure for different types of database tools
 48 | type ToolType interface {
 49 | 	// GetName returns the base name of the tool type (e.g., "query", "execute")
 50 | 	GetName() string
 51 | 
 52 | 	// GetDescription returns a description for this tool type
 53 | 	GetDescription(dbID string) string
 54 | 
 55 | 	// CreateTool creates a tool with the specified name
 56 | 	// The returned tool must be compatible with server.MCPServer.AddTool's first parameter
 57 | 	CreateTool(name string, dbID string) interface{}
 58 | 
 59 | 	// HandleRequest handles tool requests for this tool type
 60 | 	HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error)
 61 | }
 62 | 
 63 | // UseCaseProvider interface abstracts database use case operations
 64 | type UseCaseProvider interface {
 65 | 	ExecuteQuery(ctx context.Context, dbID, query string, params []interface{}) (string, error)
 66 | 	ExecuteStatement(ctx context.Context, dbID, statement string, params []interface{}) (string, error)
 67 | 	ExecuteTransaction(ctx context.Context, dbID, action string, txID string, statement string, params []interface{}, readOnly bool) (string, map[string]interface{}, error)
 68 | 	GetDatabaseInfo(dbID string) (map[string]interface{}, error)
 69 | 	ListDatabases() []string
 70 | 	GetDatabaseType(dbID string) (string, error)
 71 | }
 72 | 
 73 | // BaseToolType provides common functionality for tool types
 74 | type BaseToolType struct {
 75 | 	name        string
 76 | 	description string
 77 | }
 78 | 
 79 | // GetName returns the name of the tool type
 80 | func (b *BaseToolType) GetName() string {
 81 | 	return b.name
 82 | }
 83 | 
 84 | // GetDescription returns a description for the tool type
 85 | func (b *BaseToolType) GetDescription(dbID string) string {
 86 | 	return fmt.Sprintf("%s on %s database", b.description, dbID)
 87 | }
 88 | 
 89 | //------------------------------------------------------------------------------
 90 | // QueryTool implementation
 91 | //------------------------------------------------------------------------------
 92 | 
 93 | // QueryTool handles SQL query operations
 94 | type QueryTool struct {
 95 | 	BaseToolType
 96 | }
 97 | 
 98 | // NewQueryTool creates a new query tool type
 99 | func NewQueryTool() *QueryTool {
100 | 	return &QueryTool{
101 | 		BaseToolType: BaseToolType{
102 | 			name:        "query",
103 | 			description: "Execute SQL query",
104 | 		},
105 | 	}
106 | }
107 | 
108 | // CreateTool creates a query tool
109 | func (t *QueryTool) CreateTool(name string, dbID string) interface{} {
110 | 	return tools.NewTool(
111 | 		name,
112 | 		tools.WithDescription(t.GetDescription(dbID)),
113 | 		tools.WithString("query",
114 | 			tools.Description("SQL query to execute"),
115 | 			tools.Required(),
116 | 		),
117 | 		tools.WithArray("params",
118 | 			tools.Description("Query parameters"),
119 | 			tools.Items(map[string]interface{}{"type": "string"}),
120 | 		),
121 | 	)
122 | }
123 | 
124 | // HandleRequest handles query tool requests
125 | func (t *QueryTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
126 | 	// If dbID is not provided, extract it from the tool name
127 | 	if dbID == "" {
128 | 		dbID = extractDatabaseIDFromName(request.Name)
129 | 	}
130 | 
131 | 	query, ok := request.Parameters["query"].(string)
132 | 	if !ok {
133 | 		return nil, fmt.Errorf("query parameter must be a string")
134 | 	}
135 | 
136 | 	var queryParams []interface{}
137 | 	if request.Parameters["params"] != nil {
138 | 		if paramsArr, ok := request.Parameters["params"].([]interface{}); ok {
139 | 			queryParams = paramsArr
140 | 		}
141 | 	}
142 | 
143 | 	result, err := useCase.ExecuteQuery(ctx, dbID, query, queryParams)
144 | 	if err != nil {
145 | 		return nil, err
146 | 	}
147 | 
148 | 	return createTextResponse(result), nil
149 | }
150 | 
151 | // extractDatabaseIDFromName extracts the database ID from a tool name
152 | func extractDatabaseIDFromName(name string) string {
153 | 	// Format is: <tooltype>_<dbID>
154 | 	parts := strings.Split(name, "_")
155 | 	if len(parts) < 2 {
156 | 		return ""
157 | 	}
158 | 
159 | 	// The database ID is the last part
160 | 	return parts[len(parts)-1]
161 | }
162 | 
163 | //------------------------------------------------------------------------------
164 | // ExecuteTool implementation
165 | //------------------------------------------------------------------------------
166 | 
167 | // ExecuteTool handles SQL statement execution
168 | type ExecuteTool struct {
169 | 	BaseToolType
170 | }
171 | 
172 | // NewExecuteTool creates a new execute tool type
173 | func NewExecuteTool() *ExecuteTool {
174 | 	return &ExecuteTool{
175 | 		BaseToolType: BaseToolType{
176 | 			name:        "execute",
177 | 			description: "Execute SQL statement",
178 | 		},
179 | 	}
180 | }
181 | 
182 | // CreateTool creates an execute tool
183 | func (t *ExecuteTool) CreateTool(name string, dbID string) interface{} {
184 | 	return tools.NewTool(
185 | 		name,
186 | 		tools.WithDescription(t.GetDescription(dbID)),
187 | 		tools.WithString("statement",
188 | 			tools.Description("SQL statement to execute (INSERT, UPDATE, DELETE, etc.)"),
189 | 			tools.Required(),
190 | 		),
191 | 		tools.WithArray("params",
192 | 			tools.Description("Statement parameters"),
193 | 			tools.Items(map[string]interface{}{"type": "string"}),
194 | 		),
195 | 	)
196 | }
197 | 
198 | // HandleRequest handles execute tool requests
199 | func (t *ExecuteTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
200 | 	// If dbID is not provided, extract it from the tool name
201 | 	if dbID == "" {
202 | 		dbID = extractDatabaseIDFromName(request.Name)
203 | 	}
204 | 
205 | 	statement, ok := request.Parameters["statement"].(string)
206 | 	if !ok {
207 | 		return nil, fmt.Errorf("statement parameter must be a string")
208 | 	}
209 | 
210 | 	var statementParams []interface{}
211 | 	if request.Parameters["params"] != nil {
212 | 		if paramsArr, ok := request.Parameters["params"].([]interface{}); ok {
213 | 			statementParams = paramsArr
214 | 		}
215 | 	}
216 | 
217 | 	result, err := useCase.ExecuteStatement(ctx, dbID, statement, statementParams)
218 | 	if err != nil {
219 | 		return nil, err
220 | 	}
221 | 
222 | 	return createTextResponse(result), nil
223 | }
224 | 
225 | //------------------------------------------------------------------------------
226 | // TransactionTool implementation
227 | //------------------------------------------------------------------------------
228 | 
229 | // TransactionTool handles database transactions
230 | type TransactionTool struct {
231 | 	BaseToolType
232 | }
233 | 
234 | // NewTransactionTool creates a new transaction tool type
235 | func NewTransactionTool() *TransactionTool {
236 | 	return &TransactionTool{
237 | 		BaseToolType: BaseToolType{
238 | 			name:        "transaction",
239 | 			description: "Manage transactions",
240 | 		},
241 | 	}
242 | }
243 | 
244 | // CreateTool creates a transaction tool
245 | func (t *TransactionTool) CreateTool(name string, dbID string) interface{} {
246 | 	return tools.NewTool(
247 | 		name,
248 | 		tools.WithDescription(t.GetDescription(dbID)),
249 | 		tools.WithString("action",
250 | 			tools.Description("Transaction action (begin, commit, rollback, execute)"),
251 | 			tools.Required(),
252 | 		),
253 | 		tools.WithString("transactionId",
254 | 			tools.Description("Transaction ID (required for commit, rollback, execute)"),
255 | 		),
256 | 		tools.WithString("statement",
257 | 			tools.Description("SQL statement to execute within transaction (required for execute)"),
258 | 		),
259 | 		tools.WithArray("params",
260 | 			tools.Description("Statement parameters"),
261 | 			tools.Items(map[string]interface{}{"type": "string"}),
262 | 		),
263 | 		tools.WithBoolean("readOnly",
264 | 			tools.Description("Whether the transaction is read-only (for begin)"),
265 | 		),
266 | 	)
267 | }
268 | 
269 | // HandleRequest handles transaction tool requests
270 | func (t *TransactionTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
271 | 	// If dbID is not provided, extract it from the tool name
272 | 	if dbID == "" {
273 | 		dbID = extractDatabaseIDFromName(request.Name)
274 | 	}
275 | 
276 | 	action, ok := request.Parameters["action"].(string)
277 | 	if !ok {
278 | 		return nil, fmt.Errorf("action parameter must be a string")
279 | 	}
280 | 
281 | 	txID := ""
282 | 	if request.Parameters["transactionId"] != nil {
283 | 		var ok bool
284 | 		txID, ok = request.Parameters["transactionId"].(string)
285 | 		if !ok {
286 | 			return nil, fmt.Errorf("transactionId parameter must be a string")
287 | 		}
288 | 	}
289 | 
290 | 	statement := ""
291 | 	if request.Parameters["statement"] != nil {
292 | 		var ok bool
293 | 		statement, ok = request.Parameters["statement"].(string)
294 | 		if !ok {
295 | 			return nil, fmt.Errorf("statement parameter must be a string")
296 | 		}
297 | 	}
298 | 
299 | 	var params []interface{}
300 | 	if request.Parameters["params"] != nil {
301 | 		if paramsArr, ok := request.Parameters["params"].([]interface{}); ok {
302 | 			params = paramsArr
303 | 		}
304 | 	}
305 | 
306 | 	readOnly := false
307 | 	if request.Parameters["readOnly"] != nil {
308 | 		var ok bool
309 | 		readOnly, ok = request.Parameters["readOnly"].(bool)
310 | 		if !ok {
311 | 			return nil, fmt.Errorf("readOnly parameter must be a boolean")
312 | 		}
313 | 	}
314 | 
315 | 	message, metadata, err := useCase.ExecuteTransaction(ctx, dbID, action, txID, statement, params, readOnly)
316 | 	if err != nil {
317 | 		return nil, err
318 | 	}
319 | 
320 | 	// Create response with text and metadata
321 | 	resp := createTextResponse(message)
322 | 
323 | 	// Add metadata if provided
324 | 	for k, v := range metadata {
325 | 		addMetadata(resp, k, v)
326 | 	}
327 | 
328 | 	return resp, nil
329 | }
330 | 
331 | //------------------------------------------------------------------------------
332 | // PerformanceTool implementation
333 | //------------------------------------------------------------------------------
334 | 
335 | // PerformanceTool handles query performance analysis
336 | type PerformanceTool struct {
337 | 	BaseToolType
338 | }
339 | 
340 | // NewPerformanceTool creates a new performance tool type
341 | func NewPerformanceTool() *PerformanceTool {
342 | 	return &PerformanceTool{
343 | 		BaseToolType: BaseToolType{
344 | 			name:        "performance",
345 | 			description: "Analyze query performance",
346 | 		},
347 | 	}
348 | }
349 | 
350 | // CreateTool creates a performance analysis tool
351 | func (t *PerformanceTool) CreateTool(name string, dbID string) interface{} {
352 | 	return tools.NewTool(
353 | 		name,
354 | 		tools.WithDescription(t.GetDescription(dbID)),
355 | 		tools.WithString("action",
356 | 			tools.Description("Action (getSlowQueries, getMetrics, analyzeQuery, reset, setThreshold)"),
357 | 			tools.Required(),
358 | 		),
359 | 		tools.WithString("query",
360 | 			tools.Description("SQL query to analyze (required for analyzeQuery)"),
361 | 		),
362 | 		tools.WithNumber("limit",
363 | 			tools.Description("Maximum number of results to return"),
364 | 		),
365 | 		tools.WithNumber("threshold",
366 | 			tools.Description("Slow query threshold in milliseconds (required for setThreshold)"),
367 | 		),
368 | 	)
369 | }
370 | 
371 | // HandleRequest handles performance tool requests
372 | func (t *PerformanceTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
373 | 	// If dbID is not provided, extract it from the tool name
374 | 	if dbID == "" {
375 | 		dbID = extractDatabaseIDFromName(request.Name)
376 | 	}
377 | 
378 | 	// This is a simplified implementation
379 | 	// In a real implementation, this would analyze query performance
380 | 
381 | 	action, ok := request.Parameters["action"].(string)
382 | 	if !ok {
383 | 		return nil, fmt.Errorf("action parameter must be a string")
384 | 	}
385 | 
386 | 	var limit int
387 | 	if request.Parameters["limit"] != nil {
388 | 		if limitParam, ok := request.Parameters["limit"].(float64); ok {
389 | 			limit = int(limitParam)
390 | 		}
391 | 	}
392 | 
393 | 	query := ""
394 | 	if request.Parameters["query"] != nil {
395 | 		var ok bool
396 | 		query, ok = request.Parameters["query"].(string)
397 | 		if !ok {
398 | 			return nil, fmt.Errorf("query parameter must be a string")
399 | 		}
400 | 	}
401 | 
402 | 	var threshold int
403 | 	if request.Parameters["threshold"] != nil {
404 | 		if thresholdParam, ok := request.Parameters["threshold"].(float64); ok {
405 | 			threshold = int(thresholdParam)
406 | 		}
407 | 	}
408 | 
409 | 	// This is where we would call the useCase to analyze performance
410 | 	// For now, just return a placeholder
411 | 	output := fmt.Sprintf("Performance analysis for action '%s' on database '%s'\n", action, dbID)
412 | 
413 | 	if query != "" {
414 | 		output += fmt.Sprintf("Query: %s\n", query)
415 | 	}
416 | 
417 | 	if limit > 0 {
418 | 		output += fmt.Sprintf("Limit: %d\n", limit)
419 | 	}
420 | 
421 | 	if threshold > 0 {
422 | 		output += fmt.Sprintf("Threshold: %d ms\n", threshold)
423 | 	}
424 | 
425 | 	return createTextResponse(output), nil
426 | }
427 | 
428 | //------------------------------------------------------------------------------
429 | // SchemaTool implementation
430 | //------------------------------------------------------------------------------
431 | 
432 | // SchemaTool handles database schema exploration
433 | type SchemaTool struct {
434 | 	BaseToolType
435 | }
436 | 
437 | // NewSchemaTool creates a new schema tool type
438 | func NewSchemaTool() *SchemaTool {
439 | 	return &SchemaTool{
440 | 		BaseToolType: BaseToolType{
441 | 			name:        "schema",
442 | 			description: "Get schema of",
443 | 		},
444 | 	}
445 | }
446 | 
447 | // CreateTool creates a schema tool
448 | func (t *SchemaTool) CreateTool(name string, dbID string) interface{} {
449 | 	return tools.NewTool(
450 | 		name,
451 | 		tools.WithDescription(t.GetDescription(dbID)),
452 | 		// Use any string parameter for compatibility
453 | 		tools.WithString("random_string",
454 | 			tools.Description("Dummy parameter (optional)"),
455 | 		),
456 | 	)
457 | }
458 | 
459 | // HandleRequest handles schema tool requests
460 | func (t *SchemaTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
461 | 	// If dbID is not provided, extract it from the tool name
462 | 	if dbID == "" {
463 | 		dbID = extractDatabaseIDFromName(request.Name)
464 | 	}
465 | 
466 | 	info, err := useCase.GetDatabaseInfo(dbID)
467 | 	if err != nil {
468 | 		return nil, err
469 | 	}
470 | 
471 | 	// Format response text
472 | 	infoStr := fmt.Sprintf("Database Schema for %s:\n\n%+v", dbID, info)
473 | 	return createTextResponse(infoStr), nil
474 | }
475 | 
476 | //------------------------------------------------------------------------------
477 | // ListDatabasesTool implementation
478 | //------------------------------------------------------------------------------
479 | 
480 | // ListDatabasesTool handles listing available databases
481 | type ListDatabasesTool struct {
482 | 	BaseToolType
483 | }
484 | 
485 | // NewListDatabasesTool creates a new list databases tool type
486 | func NewListDatabasesTool() *ListDatabasesTool {
487 | 	return &ListDatabasesTool{
488 | 		BaseToolType: BaseToolType{
489 | 			name:        "list_databases",
490 | 			description: "List all available databases",
491 | 		},
492 | 	}
493 | }
494 | 
495 | // CreateTool creates a list databases tool
496 | func (t *ListDatabasesTool) CreateTool(name string, dbID string) interface{} {
497 | 	return tools.NewTool(
498 | 		name,
499 | 		tools.WithDescription(t.GetDescription(dbID)),
500 | 		// Use any string parameter for compatibility
501 | 		tools.WithString("random_string",
502 | 			tools.Description("Dummy parameter (optional)"),
503 | 		),
504 | 	)
505 | }
506 | 
507 | // HandleRequest handles list databases tool requests
508 | func (t *ListDatabasesTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
509 | 	databases := useCase.ListDatabases()
510 | 
511 | 	// Format as text for display
512 | 	output := "Available databases:\n\n"
513 | 	for i, db := range databases {
514 | 		output += fmt.Sprintf("%d. %s\n", i+1, db)
515 | 	}
516 | 
517 | 	if len(databases) == 0 {
518 | 		output += "No databases configured.\n"
519 | 	}
520 | 
521 | 	return createTextResponse(output), nil
522 | }
523 | 
524 | //------------------------------------------------------------------------------
525 | // ToolTypeFactory provides a factory for creating tool types
526 | //------------------------------------------------------------------------------
527 | 
528 | // ToolTypeFactory creates and manages tool types
529 | type ToolTypeFactory struct {
530 | 	toolTypes map[string]ToolType
531 | }
532 | 
533 | // NewToolTypeFactory creates a new tool type factory with all registered tool types
534 | func NewToolTypeFactory() *ToolTypeFactory {
535 | 	factory := &ToolTypeFactory{
536 | 		toolTypes: make(map[string]ToolType),
537 | 	}
538 | 
539 | 	// Register all tool types
540 | 	factory.Register(NewQueryTool())
541 | 	factory.Register(NewExecuteTool())
542 | 	factory.Register(NewTransactionTool())
543 | 	factory.Register(NewPerformanceTool())
544 | 	factory.Register(NewSchemaTool())
545 | 	factory.Register(NewListDatabasesTool())
546 | 
547 | 	return factory
548 | }
549 | 
550 | // Register adds a tool type to the factory
551 | func (f *ToolTypeFactory) Register(toolType ToolType) {
552 | 	f.toolTypes[toolType.GetName()] = toolType
553 | }
554 | 
555 | // GetToolType returns a tool type by name
556 | func (f *ToolTypeFactory) GetToolType(name string) (ToolType, bool) {
557 | 	// Handle new simpler format: <tooltype>_<dbID> or just the tool type name
558 | 	parts := strings.Split(name, "_")
559 | 	if len(parts) > 0 {
560 | 		// First part is the tool type name
561 | 		toolType, ok := f.toolTypes[parts[0]]
562 | 		if ok {
563 | 			return toolType, true
564 | 		}
565 | 	}
566 | 
567 | 	// Direct tool type lookup
568 | 	toolType, ok := f.toolTypes[name]
569 | 	return toolType, ok
570 | }
571 | 
572 | // GetToolTypeForSourceName finds the appropriate tool type for a source name
573 | func (f *ToolTypeFactory) GetToolTypeForSourceName(sourceName string) (ToolType, string, bool) {
574 | 	// Handle simpler format: <tooltype>_<dbID>
575 | 	parts := strings.Split(sourceName, "_")
576 | 
577 | 	if len(parts) >= 2 {
578 | 		// First part is tool type, last part is dbID
579 | 		toolTypeName := parts[0]
580 | 		dbID := parts[len(parts)-1]
581 | 
582 | 		toolType, ok := f.toolTypes[toolTypeName]
583 | 		if ok {
584 | 			return toolType, dbID, true
585 | 		}
586 | 	}
587 | 
588 | 	// Handle case for global tools
589 | 	if sourceName == "list_databases" {
590 | 		toolType, ok := f.toolTypes["list_databases"]
591 | 		return toolType, "", ok
592 | 	}
593 | 
594 | 	return nil, "", false
595 | }
596 | 
597 | // GetAllToolTypes returns all registered tool types
598 | func (f *ToolTypeFactory) GetAllToolTypes() []ToolType {
599 | 	types := make([]ToolType, 0, len(f.toolTypes))
600 | 	for _, toolType := range f.toolTypes {
601 | 		types = append(types, toolType)
602 | 	}
603 | 	return types
604 | }
605 | 
```

--------------------------------------------------------------------------------
/pkg/db/timescale/policy_test.go:
--------------------------------------------------------------------------------

```go
  1 | package timescale
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"errors"
  6 | 	"testing"
  7 | )
  8 | 
  9 | func TestEnableCompression(t *testing.T) {
 10 | 	mockDB := NewMockDB()
 11 | 	tsdb := &DB{
 12 | 		Database:      mockDB,
 13 | 		isTimescaleDB: true,
 14 | 	}
 15 | 
 16 | 	ctx := context.Background()
 17 | 
 18 | 	// Register mock responses for checking if the table is a hypertable
 19 | 	mockDB.RegisterQueryResult("WHERE table_name = 'test_table'", []map[string]interface{}{
 20 | 		{"is_hypertable": true},
 21 | 	}, nil)
 22 | 
 23 | 	// Register mock response for the compression check in timescaledb_information.hypertables
 24 | 	mockDB.RegisterQueryResult("FROM timescaledb_information.hypertables WHERE hypertable_name", []map[string]interface{}{
 25 | 		{"compress": true},
 26 | 	}, nil)
 27 | 
 28 | 	// Test enabling compression without interval
 29 | 	err := tsdb.EnableCompression(ctx, "test_table", "")
 30 | 	if err != nil {
 31 | 		t.Fatalf("Failed to enable compression: %v", err)
 32 | 	}
 33 | 
 34 | 	// Check that the correct query was executed
 35 | 	query, _ := mockDB.GetLastQuery()
 36 | 	AssertQueryContains(t, query, "ALTER TABLE test_table SET (timescaledb.compress = true)")
 37 | 
 38 | 	// Test enabling compression with interval
 39 | 	// Register mock responses for specific queries used in this test
 40 | 	mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
 41 | 	mockDB.RegisterQueryResult("SELECT add_compression_policy", nil, nil)
 42 | 	mockDB.RegisterQueryResult("timescaledb_information.hypertables WHERE hypertable_name = 'test_table'", []map[string]interface{}{
 43 | 		{"compress": true},
 44 | 	}, nil)
 45 | 
 46 | 	err = tsdb.EnableCompression(ctx, "test_table", "7 days")
 47 | 	if err != nil {
 48 | 		t.Fatalf("Failed to enable compression with interval: %v", err)
 49 | 	}
 50 | 
 51 | 	// Check that the correct policy query was executed
 52 | 	query, _ = mockDB.GetLastQuery()
 53 | 	AssertQueryContains(t, query, "add_compression_policy")
 54 | 	AssertQueryContains(t, query, "test_table")
 55 | 	AssertQueryContains(t, query, "7 days")
 56 | 
 57 | 	// Test when TimescaleDB is not available
 58 | 	tsdb.isTimescaleDB = false
 59 | 	err = tsdb.EnableCompression(ctx, "test_table", "")
 60 | 	if err == nil {
 61 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
 62 | 	}
 63 | 
 64 | 	// Test execution error
 65 | 	tsdb.isTimescaleDB = true
 66 | 	mockDB.RegisterQueryResult("ALTER TABLE", nil, errors.New("mocked error"))
 67 | 	err = tsdb.EnableCompression(ctx, "test_table", "")
 68 | 	if err == nil {
 69 | 		t.Error("Expected query error, got nil")
 70 | 	}
 71 | }
 72 | 
 73 | func TestDisableCompression(t *testing.T) {
 74 | 	mockDB := NewMockDB()
 75 | 	tsdb := &DB{
 76 | 		Database:      mockDB,
 77 | 		isTimescaleDB: true,
 78 | 	}
 79 | 
 80 | 	ctx := context.Background()
 81 | 
 82 | 	// Mock successful policy removal and compression disabling
 83 | 	mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
 84 | 		{"job_id": 1},
 85 | 	}, nil)
 86 | 	mockDB.RegisterQueryResult("SELECT remove_compression_policy", nil, nil)
 87 | 	mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
 88 | 
 89 | 	// Test disabling compression
 90 | 	err := tsdb.DisableCompression(ctx, "test_table")
 91 | 	if err != nil {
 92 | 		t.Fatalf("Failed to disable compression: %v", err)
 93 | 	}
 94 | 
 95 | 	// Check that the correct ALTER TABLE query was executed
 96 | 	query, _ := mockDB.GetLastQuery()
 97 | 	AssertQueryContains(t, query, "ALTER TABLE test_table SET (timescaledb.compress = false)")
 98 | 
 99 | 	// Test when no policy exists (should still succeed)
100 | 	mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{}, nil)
101 | 	mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
102 | 
103 | 	err = tsdb.DisableCompression(ctx, "test_table")
104 | 	if err != nil {
105 | 		t.Fatalf("Failed to disable compression when no policy exists: %v", err)
106 | 	}
107 | 
108 | 	// Test when TimescaleDB is not available
109 | 	tsdb.isTimescaleDB = false
110 | 	err = tsdb.DisableCompression(ctx, "test_table")
111 | 	if err == nil {
112 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
113 | 	}
114 | 
115 | 	// Test execution error
116 | 	tsdb.isTimescaleDB = true
117 | 	mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
118 | 		{"job_id": 1},
119 | 	}, nil)
120 | 	mockDB.RegisterQueryResult("SELECT remove_compression_policy", nil, errors.New("mocked error"))
121 | 	err = tsdb.DisableCompression(ctx, "test_table")
122 | 	if err == nil {
123 | 		t.Error("Expected query error, got nil")
124 | 	}
125 | }
126 | 
127 | func TestAddCompressionPolicy(t *testing.T) {
128 | 	mockDB := NewMockDB()
129 | 	tsdb := &DB{
130 | 		Database:      mockDB,
131 | 		isTimescaleDB: true,
132 | 	}
133 | 
134 | 	ctx := context.Background()
135 | 
136 | 	// Mock checking compression status
137 | 	mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
138 | 		{"compress": true},
139 | 	}, nil)
140 | 
141 | 	// Test adding a basic compression policy
142 | 	err := tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
143 | 	if err != nil {
144 | 		t.Fatalf("Failed to add compression policy: %v", err)
145 | 	}
146 | 
147 | 	// Check that the correct query was executed
148 | 	query, _ := mockDB.GetLastQuery()
149 | 	AssertQueryContains(t, query, "SELECT add_compression_policy('test_table', INTERVAL '7 days'")
150 | 
151 | 	// Test adding a policy with segmentby and orderby
152 | 	err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "device_id", "time DESC")
153 | 	if err != nil {
154 | 		t.Fatalf("Failed to add compression policy with additional options: %v", err)
155 | 	}
156 | 
157 | 	// Check that the correct query was executed
158 | 	query, _ = mockDB.GetLastQuery()
159 | 	AssertQueryContains(t, query, "segmentby => 'device_id'")
160 | 	AssertQueryContains(t, query, "orderby => 'time DESC'")
161 | 
162 | 	// Test enabling compression first if not already enabled
163 | 	mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
164 | 		{"compress": false},
165 | 	}, nil)
166 | 	mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
167 | 
168 | 	err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
169 | 	if err != nil {
170 | 		t.Fatalf("Failed to add compression policy with compression enabling: %v", err)
171 | 	}
172 | 
173 | 	// Check that the ALTER TABLE query was executed first
174 | 	query, _ = mockDB.GetLastQuery()
175 | 	AssertQueryContains(t, query, "SELECT add_compression_policy")
176 | 
177 | 	// Test when TimescaleDB is not available
178 | 	tsdb.isTimescaleDB = false
179 | 	err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
180 | 	if err == nil {
181 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
182 | 	}
183 | 
184 | 	// Test execution error on compression check
185 | 	tsdb.isTimescaleDB = true
186 | 	mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", nil, errors.New("mocked error"))
187 | 	err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
188 | 	if err == nil {
189 | 		t.Error("Expected query error, got nil")
190 | 	}
191 | }
192 | 
193 | func TestRemoveCompressionPolicy(t *testing.T) {
194 | 	mockDB := NewMockDB()
195 | 	tsdb := &DB{
196 | 		Database:      mockDB,
197 | 		isTimescaleDB: true,
198 | 	}
199 | 
200 | 	ctx := context.Background()
201 | 
202 | 	// Mock finding a policy
203 | 	mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
204 | 		{"job_id": 1},
205 | 	}, nil)
206 | 	mockDB.RegisterQueryResult("SELECT remove_compression_policy", nil, nil)
207 | 
208 | 	// Test removing a compression policy
209 | 	err := tsdb.RemoveCompressionPolicy(ctx, "test_table")
210 | 	if err != nil {
211 | 		t.Fatalf("Failed to remove compression policy: %v", err)
212 | 	}
213 | 
214 | 	// Check that the correct query was executed
215 | 	query, _ := mockDB.GetLastQuery()
216 | 	AssertQueryContains(t, query, "SELECT remove_compression_policy")
217 | 
218 | 	// Test when no policy exists (should succeed without error)
219 | 	mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{}, nil)
220 | 
221 | 	err = tsdb.RemoveCompressionPolicy(ctx, "test_table")
222 | 	if err != nil {
223 | 		t.Errorf("Expected success when no policy exists, got: %v", err)
224 | 	}
225 | 
226 | 	// Test when TimescaleDB is not available
227 | 	tsdb.isTimescaleDB = false
228 | 	err = tsdb.RemoveCompressionPolicy(ctx, "test_table")
229 | 	if err == nil {
230 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
231 | 	}
232 | 
233 | 	// Test execution error
234 | 	tsdb.isTimescaleDB = true
235 | 	mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", nil, errors.New("mocked error"))
236 | 	err = tsdb.RemoveCompressionPolicy(ctx, "test_table")
237 | 	if err == nil {
238 | 		t.Error("Expected query error, got nil")
239 | 	}
240 | }
241 | 
242 | func TestGetCompressionSettings(t *testing.T) {
243 | 	mockDB := NewMockDB()
244 | 	tsdb := &DB{
245 | 		Database:      mockDB,
246 | 		isTimescaleDB: true,
247 | 	}
248 | 
249 | 	ctx := context.Background()
250 | 
251 | 	// Mock compression status check
252 | 	mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
253 | 		{"compress": true},
254 | 	}, nil)
255 | 
256 | 	// Mock compression settings
257 | 	mockDB.RegisterQueryResult("SELECT segmentby, orderby FROM timescaledb_information.compression_settings", []map[string]interface{}{
258 | 		{"segmentby": "device_id", "orderby": "time DESC"},
259 | 	}, nil)
260 | 
261 | 	// Mock policy information
262 | 	mockDB.RegisterQueryResult("SELECT s.schedule_interval, h.chunk_time_interval FROM", []map[string]interface{}{
263 | 		{"schedule_interval": "7 days", "chunk_time_interval": "1 day"},
264 | 	}, nil)
265 | 
266 | 	// Test getting compression settings
267 | 	settings, err := tsdb.GetCompressionSettings(ctx, "test_table")
268 | 	if err != nil {
269 | 		t.Fatalf("Failed to get compression settings: %v", err)
270 | 	}
271 | 
272 | 	// Check the returned settings
273 | 	if settings.HypertableName != "test_table" {
274 | 		t.Errorf("Expected HypertableName to be 'test_table', got '%s'", settings.HypertableName)
275 | 	}
276 | 
277 | 	if !settings.CompressionEnabled {
278 | 		t.Error("Expected CompressionEnabled to be true, got false")
279 | 	}
280 | 
281 | 	if settings.SegmentBy != "device_id" {
282 | 		t.Errorf("Expected SegmentBy to be 'device_id', got '%s'", settings.SegmentBy)
283 | 	}
284 | 
285 | 	if settings.OrderBy != "time DESC" {
286 | 		t.Errorf("Expected OrderBy to be 'time DESC', got '%s'", settings.OrderBy)
287 | 	}
288 | 
289 | 	if settings.CompressionInterval != "7 days" {
290 | 		t.Errorf("Expected CompressionInterval to be '7 days', got '%s'", settings.CompressionInterval)
291 | 	}
292 | 
293 | 	if settings.ChunkTimeInterval != "1 day" {
294 | 		t.Errorf("Expected ChunkTimeInterval to be '1 day', got '%s'", settings.ChunkTimeInterval)
295 | 	}
296 | 
297 | 	// Test when compression is not enabled
298 | 	mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
299 | 		{"compress": false},
300 | 	}, nil)
301 | 
302 | 	settings, err = tsdb.GetCompressionSettings(ctx, "test_table")
303 | 	if err != nil {
304 | 		t.Fatalf("Failed to get compression settings when not enabled: %v", err)
305 | 	}
306 | 
307 | 	if settings.CompressionEnabled {
308 | 		t.Error("Expected CompressionEnabled to be false, got true")
309 | 	}
310 | 
311 | 	// Test when TimescaleDB is not available
312 | 	tsdb.isTimescaleDB = false
313 | 	_, err = tsdb.GetCompressionSettings(ctx, "test_table")
314 | 	if err == nil {
315 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
316 | 	}
317 | 
318 | 	// Test execution error
319 | 	tsdb.isTimescaleDB = true
320 | 	mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", nil, errors.New("mocked error"))
321 | 	_, err = tsdb.GetCompressionSettings(ctx, "test_table")
322 | 	if err == nil {
323 | 		t.Error("Expected query error, got nil")
324 | 	}
325 | }
326 | 
327 | func TestAddRetentionPolicy(t *testing.T) {
328 | 	mockDB := NewMockDB()
329 | 	tsdb := &DB{
330 | 		Database:      mockDB,
331 | 		isTimescaleDB: true,
332 | 	}
333 | 
334 | 	ctx := context.Background()
335 | 
336 | 	// Test adding a retention policy
337 | 	err := tsdb.AddRetentionPolicy(ctx, "test_table", "30 days")
338 | 	if err != nil {
339 | 		t.Fatalf("Failed to add retention policy: %v", err)
340 | 	}
341 | 
342 | 	// Check that the correct query was executed
343 | 	query, _ := mockDB.GetLastQuery()
344 | 	AssertQueryContains(t, query, "SELECT add_retention_policy('test_table', INTERVAL '30 days')")
345 | 
346 | 	// Test when TimescaleDB is not available
347 | 	tsdb.isTimescaleDB = false
348 | 	err = tsdb.AddRetentionPolicy(ctx, "test_table", "30 days")
349 | 	if err == nil {
350 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
351 | 	}
352 | 
353 | 	// Test execution error
354 | 	tsdb.isTimescaleDB = true
355 | 	mockDB.RegisterQueryResult("SELECT add_retention_policy", nil, errors.New("mocked error"))
356 | 	err = tsdb.AddRetentionPolicy(ctx, "test_table", "30 days")
357 | 	if err == nil {
358 | 		t.Error("Expected query error, got nil")
359 | 	}
360 | }
361 | 
362 | func TestRemoveRetentionPolicy(t *testing.T) {
363 | 	mockDB := NewMockDB()
364 | 	tsdb := &DB{
365 | 		Database:      mockDB,
366 | 		isTimescaleDB: true,
367 | 	}
368 | 
369 | 	ctx := context.Background()
370 | 
371 | 	// Mock finding a policy
372 | 	mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
373 | 		{"job_id": 1},
374 | 	}, nil)
375 | 	mockDB.RegisterQueryResult("SELECT remove_retention_policy", nil, nil)
376 | 
377 | 	// Test removing a retention policy
378 | 	err := tsdb.RemoveRetentionPolicy(ctx, "test_table")
379 | 	if err != nil {
380 | 		t.Fatalf("Failed to remove retention policy: %v", err)
381 | 	}
382 | 
383 | 	// Check that the correct query was executed
384 | 	query, _ := mockDB.GetLastQuery()
385 | 	AssertQueryContains(t, query, "SELECT remove_retention_policy")
386 | 
387 | 	// Test when no policy exists (should succeed without error)
388 | 	mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{}, nil)
389 | 
390 | 	err = tsdb.RemoveRetentionPolicy(ctx, "test_table")
391 | 	if err != nil {
392 | 		t.Errorf("Expected success when no policy exists, got: %v", err)
393 | 	}
394 | 
395 | 	// Test when TimescaleDB is not available
396 | 	tsdb.isTimescaleDB = false
397 | 	err = tsdb.RemoveRetentionPolicy(ctx, "test_table")
398 | 	if err == nil {
399 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
400 | 	}
401 | 
402 | 	// Test execution error
403 | 	tsdb.isTimescaleDB = true
404 | 	mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", nil, errors.New("mocked error"))
405 | 	err = tsdb.RemoveRetentionPolicy(ctx, "test_table")
406 | 	if err == nil {
407 | 		t.Error("Expected query error, got nil")
408 | 	}
409 | }
410 | 
411 | func TestGetRetentionSettings(t *testing.T) {
412 | 	mockDB := NewMockDB()
413 | 	tsdb := &DB{
414 | 		Database:      mockDB,
415 | 		isTimescaleDB: true,
416 | 	}
417 | 
418 | 	ctx := context.Background()
419 | 
420 | 	// Mock policy information
421 | 	mockDB.RegisterQueryResult("SELECT s.schedule_interval FROM", []map[string]interface{}{
422 | 		{"schedule_interval": "30 days"},
423 | 	}, nil)
424 | 
425 | 	// Test getting retention settings
426 | 	settings, err := tsdb.GetRetentionSettings(ctx, "test_table")
427 | 	if err != nil {
428 | 		t.Fatalf("Failed to get retention settings: %v", err)
429 | 	}
430 | 
431 | 	// Check the returned settings
432 | 	if settings.HypertableName != "test_table" {
433 | 		t.Errorf("Expected HypertableName to be 'test_table', got '%s'", settings.HypertableName)
434 | 	}
435 | 
436 | 	if !settings.RetentionEnabled {
437 | 		t.Error("Expected RetentionEnabled to be true, got false")
438 | 	}
439 | 
440 | 	if settings.RetentionInterval != "30 days" {
441 | 		t.Errorf("Expected RetentionInterval to be '30 days', got '%s'", settings.RetentionInterval)
442 | 	}
443 | 
444 | 	// Test when no policy exists
445 | 	mockDB.RegisterQueryResult("SELECT s.schedule_interval FROM", []map[string]interface{}{}, nil)
446 | 
447 | 	settings, err = tsdb.GetRetentionSettings(ctx, "test_table")
448 | 	if err != nil {
449 | 		t.Fatalf("Failed to get retention settings when no policy exists: %v", err)
450 | 	}
451 | 
452 | 	if settings.RetentionEnabled {
453 | 		t.Error("Expected RetentionEnabled to be false, got true")
454 | 	}
455 | 
456 | 	// Test when TimescaleDB is not available
457 | 	tsdb.isTimescaleDB = false
458 | 	_, err = tsdb.GetRetentionSettings(ctx, "test_table")
459 | 	if err == nil {
460 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
461 | 	}
462 | }
463 | 
464 | func TestCompressChunks(t *testing.T) {
465 | 	mockDB := NewMockDB()
466 | 	tsdb := &DB{
467 | 		Database:      mockDB,
468 | 		isTimescaleDB: true,
469 | 	}
470 | 
471 | 	ctx := context.Background()
472 | 
473 | 	// Test compressing all chunks
474 | 	err := tsdb.CompressChunks(ctx, "test_table", "")
475 | 	if err != nil {
476 | 		t.Fatalf("Failed to compress all chunks: %v", err)
477 | 	}
478 | 
479 | 	// Check that the correct query was executed
480 | 	query, _ := mockDB.GetLastQuery()
481 | 	AssertQueryContains(t, query, "SELECT compress_chunks(hypertable => 'test_table')")
482 | 
483 | 	// Test compressing chunks with older_than specified
484 | 	err = tsdb.CompressChunks(ctx, "test_table", "7 days")
485 | 	if err != nil {
486 | 		t.Fatalf("Failed to compress chunks with older_than: %v", err)
487 | 	}
488 | 
489 | 	// Check that the correct query was executed
490 | 	query, _ = mockDB.GetLastQuery()
491 | 	AssertQueryContains(t, query, "SELECT compress_chunks(hypertable => 'test_table', older_than => INTERVAL '7 days')")
492 | 
493 | 	// Test when TimescaleDB is not available
494 | 	tsdb.isTimescaleDB = false
495 | 	err = tsdb.CompressChunks(ctx, "test_table", "")
496 | 	if err == nil {
497 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
498 | 	}
499 | 
500 | 	// Test execution error
501 | 	tsdb.isTimescaleDB = true
502 | 	mockDB.RegisterQueryResult("SELECT compress_chunks", nil, errors.New("mocked error"))
503 | 	err = tsdb.CompressChunks(ctx, "test_table", "")
504 | 	if err == nil {
505 | 		t.Error("Expected query error, got nil")
506 | 	}
507 | }
508 | 
509 | func TestDecompressChunks(t *testing.T) {
510 | 	mockDB := NewMockDB()
511 | 	tsdb := &DB{
512 | 		Database:      mockDB,
513 | 		isTimescaleDB: true,
514 | 	}
515 | 
516 | 	ctx := context.Background()
517 | 
518 | 	// Test decompressing all chunks
519 | 	err := tsdb.DecompressChunks(ctx, "test_table", "")
520 | 	if err != nil {
521 | 		t.Fatalf("Failed to decompress all chunks: %v", err)
522 | 	}
523 | 
524 | 	// Check that the correct query was executed
525 | 	query, _ := mockDB.GetLastQuery()
526 | 	AssertQueryContains(t, query, "SELECT decompress_chunks(hypertable => 'test_table')")
527 | 
528 | 	// Test decompressing chunks with newer_than specified
529 | 	err = tsdb.DecompressChunks(ctx, "test_table", "7 days")
530 | 	if err != nil {
531 | 		t.Fatalf("Failed to decompress chunks with newer_than: %v", err)
532 | 	}
533 | 
534 | 	// Check that the correct query was executed
535 | 	query, _ = mockDB.GetLastQuery()
536 | 	AssertQueryContains(t, query, "SELECT decompress_chunks(hypertable => 'test_table', newer_than => INTERVAL '7 days')")
537 | 
538 | 	// Test when TimescaleDB is not available
539 | 	tsdb.isTimescaleDB = false
540 | 	err = tsdb.DecompressChunks(ctx, "test_table", "")
541 | 	if err == nil {
542 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
543 | 	}
544 | 
545 | 	// Test execution error
546 | 	tsdb.isTimescaleDB = true
547 | 	mockDB.RegisterQueryResult("SELECT decompress_chunks", nil, errors.New("mocked error"))
548 | 	err = tsdb.DecompressChunks(ctx, "test_table", "")
549 | 	if err == nil {
550 | 		t.Error("Expected query error, got nil")
551 | 	}
552 | }
553 | 
554 | func TestGetChunkCompressionStats(t *testing.T) {
555 | 	mockDB := NewMockDB()
556 | 	tsdb := &DB{
557 | 		Database:      mockDB,
558 | 		isTimescaleDB: true,
559 | 	}
560 | 
561 | 	ctx := context.Background()
562 | 
563 | 	// Mock chunk stats
564 | 	mockStats := []map[string]interface{}{
565 | 		{
566 | 			"chunk_name":                     "_hyper_1_1_chunk",
567 | 			"range_start":                    "2023-01-01 00:00:00",
568 | 			"range_end":                      "2023-01-02 00:00:00",
569 | 			"is_compressed":                  true,
570 | 			"before_compression_total_bytes": 1000,
571 | 			"after_compression_total_bytes":  200,
572 | 			"compression_ratio":              80.0,
573 | 		},
574 | 	}
575 | 	mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", mockStats, nil)
576 | 
577 | 	// Test getting chunk compression stats
578 | 	_, err := tsdb.GetChunkCompressionStats(ctx, "test_table")
579 | 	if err != nil {
580 | 		t.Fatalf("Failed to get chunk compression stats: %v", err)
581 | 	}
582 | 
583 | 	// Check that the correct query was executed
584 | 	query, _ := mockDB.GetLastQuery()
585 | 	AssertQueryContains(t, query, "FROM timescaledb_information.chunks")
586 | 	AssertQueryContains(t, query, "hypertable_name = 'test_table'")
587 | 
588 | 	// Test when TimescaleDB is not available
589 | 	tsdb.isTimescaleDB = false
590 | 	_, err = tsdb.GetChunkCompressionStats(ctx, "test_table")
591 | 	if err == nil {
592 | 		t.Error("Expected error when TimescaleDB is not available, got nil")
593 | 	}
594 | 
595 | 	// Test execution error
596 | 	tsdb.isTimescaleDB = true
597 | 	mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", nil, errors.New("mocked error"))
598 | 	_, err = tsdb.GetChunkCompressionStats(ctx, "test_table")
599 | 	if err == nil {
600 | 		t.Error("Expected query error, got nil")
601 | 	}
602 | }
603 | 
```

--------------------------------------------------------------------------------
/pkg/dbtools/schema.go:
--------------------------------------------------------------------------------

```go
  1 | package dbtools
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"database/sql"
  6 | 	"fmt"
  7 | 	"time"
  8 | 
  9 | 	"github.com/FreePeak/db-mcp-server/pkg/db"
 10 | 	"github.com/FreePeak/db-mcp-server/pkg/logger"
 11 | 	"github.com/FreePeak/db-mcp-server/pkg/tools"
 12 | )
 13 | 
 14 | // DatabaseStrategy defines the interface for database-specific query strategies
 15 | type DatabaseStrategy interface {
 16 | 	GetTablesQueries() []queryWithArgs
 17 | 	GetColumnsQueries(table string) []queryWithArgs
 18 | 	GetRelationshipsQueries(table string) []queryWithArgs
 19 | }
 20 | 
 21 | // NewDatabaseStrategy creates the appropriate strategy for the given database type
 22 | func NewDatabaseStrategy(driverName string) DatabaseStrategy {
 23 | 	switch driverName {
 24 | 	case "postgres":
 25 | 		return &PostgresStrategy{}
 26 | 	case "mysql":
 27 | 		return &MySQLStrategy{}
 28 | 	default:
 29 | 		logger.Warn("Unknown database driver: %s, will use generic strategy", driverName)
 30 | 		return &GenericStrategy{}
 31 | 	}
 32 | }
 33 | 
 34 | // PostgresStrategy implements DatabaseStrategy for PostgreSQL
 35 | type PostgresStrategy struct{}
 36 | 
 37 | // GetTablesQueries returns queries for retrieving tables in PostgreSQL
 38 | func (s *PostgresStrategy) GetTablesQueries() []queryWithArgs {
 39 | 	return []queryWithArgs{
 40 | 		// Primary: pg_catalog approach
 41 | 		{query: "SELECT tablename as table_name FROM pg_catalog.pg_tables WHERE schemaname = 'public'"},
 42 | 		// Secondary: information_schema approach
 43 | 		{query: "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"},
 44 | 		// Tertiary: pg_class approach
 45 | 		{query: "SELECT relname as table_name FROM pg_catalog.pg_class WHERE relkind = 'r' AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')"},
 46 | 	}
 47 | }
 48 | 
 49 | // GetColumnsQueries returns queries for retrieving columns in PostgreSQL
 50 | func (s *PostgresStrategy) GetColumnsQueries(table string) []queryWithArgs {
 51 | 	return []queryWithArgs{
 52 | 		// Primary: information_schema approach for PostgreSQL
 53 | 		{
 54 | 			query: `
 55 | 				SELECT column_name, data_type, 
 56 | 				CASE WHEN is_nullable = 'YES' THEN 'YES' ELSE 'NO' END as is_nullable,
 57 | 				column_default
 58 | 				FROM information_schema.columns 
 59 | 				WHERE table_name = $1 AND table_schema = 'public'
 60 | 				ORDER BY ordinal_position
 61 | 			`,
 62 | 			args: []interface{}{table},
 63 | 		},
 64 | 		// Secondary: pg_catalog approach for PostgreSQL
 65 | 		{
 66 | 			query: `
 67 | 				SELECT a.attname as column_name, 
 68 | 				pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type,
 69 | 				CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END as is_nullable,
 70 | 				pg_catalog.pg_get_expr(d.adbin, d.adrelid) as column_default
 71 | 				FROM pg_catalog.pg_attribute a
 72 | 				LEFT JOIN pg_catalog.pg_attrdef d ON (a.attrelid = d.adrelid AND a.attnum = d.adnum)
 73 | 				WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1 AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public'))
 74 | 				AND a.attnum > 0 AND NOT a.attisdropped
 75 | 				ORDER BY a.attnum
 76 | 			`,
 77 | 			args: []interface{}{table},
 78 | 		},
 79 | 	}
 80 | }
 81 | 
 82 | // GetRelationshipsQueries returns queries for retrieving relationships in PostgreSQL
 83 | func (s *PostgresStrategy) GetRelationshipsQueries(table string) []queryWithArgs {
 84 | 	baseQueries := []queryWithArgs{
 85 | 		// Primary: Standard information_schema approach for PostgreSQL
 86 | 		{
 87 | 			query: `
 88 | 				SELECT
 89 | 					tc.table_schema,
 90 | 					tc.constraint_name,
 91 | 					tc.table_name,
 92 | 					kcu.column_name,
 93 | 					ccu.table_schema AS foreign_table_schema,
 94 | 					ccu.table_name AS foreign_table_name,
 95 | 					ccu.column_name AS foreign_column_name
 96 | 				FROM information_schema.table_constraints AS tc
 97 | 				JOIN information_schema.key_column_usage AS kcu
 98 | 					ON tc.constraint_name = kcu.constraint_name
 99 | 					AND tc.table_schema = kcu.table_schema
100 | 				JOIN information_schema.constraint_column_usage AS ccu
101 | 					ON ccu.constraint_name = tc.constraint_name
102 | 					AND ccu.table_schema = tc.table_schema
103 | 				WHERE tc.constraint_type = 'FOREIGN KEY'
104 | 					AND tc.table_schema = 'public'
105 | 			`,
106 | 			args: []interface{}{},
107 | 		},
108 | 		// Alternate: Using pg_catalog for older PostgreSQL versions
109 | 		{
110 | 			query: `
111 | 				SELECT
112 | 					ns.nspname AS table_schema,
113 | 					c.conname AS constraint_name,
114 | 					cl.relname AS table_name,
115 | 					att.attname AS column_name,
116 | 					ns2.nspname AS foreign_table_schema,
117 | 					cl2.relname AS foreign_table_name,
118 | 					att2.attname AS foreign_column_name
119 | 				FROM pg_constraint c
120 | 				JOIN pg_class cl ON c.conrelid = cl.oid
121 | 				JOIN pg_attribute att ON att.attrelid = cl.oid AND att.attnum = ANY(c.conkey)
122 | 				JOIN pg_namespace ns ON ns.oid = cl.relnamespace
123 | 				JOIN pg_class cl2 ON c.confrelid = cl2.oid
124 | 				JOIN pg_attribute att2 ON att2.attrelid = cl2.oid AND att2.attnum = ANY(c.confkey)
125 | 				JOIN pg_namespace ns2 ON ns2.oid = cl2.relnamespace
126 | 				WHERE c.contype = 'f'
127 | 				AND ns.nspname = 'public'
128 | 			`,
129 | 			args: []interface{}{},
130 | 		},
131 | 	}
132 | 
133 | 	if table == "" {
134 | 		return baseQueries
135 | 	}
136 | 
137 | 	queries := make([]queryWithArgs, len(baseQueries))
138 | 
139 | 	// Add table filter
140 | 	queries[0] = queryWithArgs{
141 | 		query: baseQueries[0].query + " AND (tc.table_name = $1 OR ccu.table_name = $1)",
142 | 		args:  []interface{}{table},
143 | 	}
144 | 
145 | 	queries[1] = queryWithArgs{
146 | 		query: baseQueries[1].query + " AND (cl.relname = $1 OR cl2.relname = $1)",
147 | 		args:  []interface{}{table},
148 | 	}
149 | 
150 | 	return queries
151 | }
152 | 
153 | // MySQLStrategy implements DatabaseStrategy for MySQL
154 | type MySQLStrategy struct{}
155 | 
156 | // GetTablesQueries returns queries for retrieving tables in MySQL
157 | func (s *MySQLStrategy) GetTablesQueries() []queryWithArgs {
158 | 	return []queryWithArgs{
159 | 		// Primary: information_schema approach
160 | 		{query: "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()"},
161 | 		// Secondary: SHOW TABLES approach
162 | 		{query: "SHOW TABLES"},
163 | 	}
164 | }
165 | 
166 | // GetColumnsQueries returns queries for retrieving columns in MySQL
167 | func (s *MySQLStrategy) GetColumnsQueries(table string) []queryWithArgs {
168 | 	return []queryWithArgs{
169 | 		// MySQL query for columns
170 | 		{
171 | 			query: `
172 | 				SELECT column_name, data_type, is_nullable, column_default
173 | 				FROM information_schema.columns
174 | 				WHERE table_name = ? AND table_schema = DATABASE()
175 | 				ORDER BY ordinal_position
176 | 			`,
177 | 			args: []interface{}{table},
178 | 		},
179 | 		// Fallback for older MySQL versions
180 | 		{
181 | 			query: `SHOW COLUMNS FROM ` + table,
182 | 			args:  []interface{}{},
183 | 		},
184 | 	}
185 | }
186 | 
187 | // GetRelationshipsQueries returns queries for retrieving relationships in MySQL
188 | func (s *MySQLStrategy) GetRelationshipsQueries(table string) []queryWithArgs {
189 | 	baseQueries := []queryWithArgs{
190 | 		// Primary approach for MySQL
191 | 		{
192 | 			query: `
193 | 				SELECT
194 | 					tc.table_schema,
195 | 					tc.constraint_name,
196 | 					tc.table_name,
197 | 					kcu.column_name,
198 | 					kcu.referenced_table_schema AS foreign_table_schema,
199 | 					kcu.referenced_table_name AS foreign_table_name,
200 | 					kcu.referenced_column_name AS foreign_column_name
201 | 				FROM information_schema.table_constraints AS tc
202 | 				JOIN information_schema.key_column_usage AS kcu
203 | 					ON tc.constraint_name = kcu.constraint_name
204 | 					AND tc.table_schema = kcu.table_schema
205 | 				WHERE tc.constraint_type = 'FOREIGN KEY'
206 | 					AND tc.table_schema = DATABASE()
207 | 			`,
208 | 			args: []interface{}{},
209 | 		},
210 | 		// Fallback using simpler query for older MySQL versions
211 | 		{
212 | 			query: `
213 | 				SELECT
214 | 					kcu.constraint_schema AS table_schema,
215 | 					kcu.constraint_name,
216 | 					kcu.table_name,
217 | 					kcu.column_name,
218 | 					kcu.referenced_table_schema AS foreign_table_schema,
219 | 					kcu.referenced_table_name AS foreign_table_name,
220 | 					kcu.referenced_column_name AS foreign_column_name
221 | 				FROM information_schema.key_column_usage kcu
222 | 				WHERE kcu.referenced_table_name IS NOT NULL
223 | 					AND kcu.constraint_schema = DATABASE()
224 | 			`,
225 | 			args: []interface{}{},
226 | 		},
227 | 	}
228 | 
229 | 	if table == "" {
230 | 		return baseQueries
231 | 	}
232 | 
233 | 	queries := make([]queryWithArgs, len(baseQueries))
234 | 
235 | 	// Add table filter
236 | 	queries[0] = queryWithArgs{
237 | 		query: baseQueries[0].query + " AND (tc.table_name = ? OR kcu.referenced_table_name = ?)",
238 | 		args:  []interface{}{table, table},
239 | 	}
240 | 
241 | 	queries[1] = queryWithArgs{
242 | 		query: baseQueries[1].query + " AND (kcu.table_name = ? OR kcu.referenced_table_name = ?)",
243 | 		args:  []interface{}{table, table},
244 | 	}
245 | 
246 | 	return queries
247 | }
248 | 
249 | // GenericStrategy implements DatabaseStrategy for unknown database types
250 | type GenericStrategy struct{}
251 | 
252 | // GetTablesQueries returns generic queries for retrieving tables
253 | func (s *GenericStrategy) GetTablesQueries() []queryWithArgs {
254 | 	return []queryWithArgs{
255 | 		{query: "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"},
256 | 		{query: "SELECT table_name FROM information_schema.tables"},
257 | 		{query: "SHOW TABLES"}, // Last resort
258 | 	}
259 | }
260 | 
261 | // GetColumnsQueries returns generic queries for retrieving columns
262 | func (s *GenericStrategy) GetColumnsQueries(table string) []queryWithArgs {
263 | 	return []queryWithArgs{
264 | 		// Try PostgreSQL-style query first
265 | 		{
266 | 			query: `
267 | 				SELECT column_name, data_type, is_nullable, column_default
268 | 				FROM information_schema.columns
269 | 				WHERE table_name = $1
270 | 				ORDER BY ordinal_position
271 | 			`,
272 | 			args: []interface{}{table},
273 | 		},
274 | 		// Try MySQL-style query
275 | 		{
276 | 			query: `
277 | 				SELECT column_name, data_type, is_nullable, column_default
278 | 				FROM information_schema.columns
279 | 				WHERE table_name = ?
280 | 				ORDER BY ordinal_position
281 | 			`,
282 | 			args: []interface{}{table},
283 | 		},
284 | 	}
285 | }
286 | 
287 | // GetRelationshipsQueries returns generic queries for retrieving relationships
288 | func (s *GenericStrategy) GetRelationshipsQueries(table string) []queryWithArgs {
289 | 	pgQuery := queryWithArgs{
290 | 		query: `
291 | 			SELECT
292 | 				tc.table_schema,
293 | 				tc.constraint_name,
294 | 				tc.table_name,
295 | 				kcu.column_name,
296 | 				ccu.table_schema AS foreign_table_schema,
297 | 				ccu.table_name AS foreign_table_name,
298 | 				ccu.column_name AS foreign_column_name
299 | 			FROM information_schema.table_constraints AS tc
300 | 			JOIN information_schema.key_column_usage AS kcu
301 | 				ON tc.constraint_name = kcu.constraint_name
302 | 				AND tc.table_schema = kcu.table_schema
303 | 			JOIN information_schema.constraint_column_usage AS ccu
304 | 				ON ccu.constraint_name = tc.constraint_name
305 | 				AND ccu.table_schema = tc.table_schema
306 | 			WHERE tc.constraint_type = 'FOREIGN KEY'
307 | 		`,
308 | 		args: []interface{}{},
309 | 	}
310 | 
311 | 	mysqlQuery := queryWithArgs{
312 | 		query: `
313 | 			SELECT
314 | 				kcu.constraint_schema AS table_schema,
315 | 				kcu.constraint_name,
316 | 				kcu.table_name,
317 | 				kcu.column_name,
318 | 				kcu.referenced_table_schema AS foreign_table_schema,
319 | 				kcu.referenced_table_name AS foreign_table_name,
320 | 				kcu.referenced_column_name AS foreign_column_name
321 | 			FROM information_schema.key_column_usage kcu
322 | 			WHERE kcu.referenced_table_name IS NOT NULL
323 | 		`,
324 | 		args: []interface{}{},
325 | 	}
326 | 
327 | 	if table != "" {
328 | 		pgQuery.query += " AND (tc.table_name = $1 OR ccu.table_name = $1)"
329 | 		pgQuery.args = append(pgQuery.args, table)
330 | 
331 | 		mysqlQuery.query += " AND (kcu.table_name = ? OR kcu.referenced_table_name = ?)"
332 | 		mysqlQuery.args = append(mysqlQuery.args, table, table)
333 | 	}
334 | 
335 | 	return []queryWithArgs{pgQuery, mysqlQuery}
336 | }
337 | 
338 | // createSchemaExplorerTool creates a tool for exploring database schema
339 | func createSchemaExplorerTool() *tools.Tool {
340 | 	return &tools.Tool{
341 | 		Name:        "dbSchema",
342 | 		Description: "Auto-discover database structure and relationships",
343 | 		Category:    "database",
344 | 		InputSchema: tools.ToolInputSchema{
345 | 			Type: "object",
346 | 			Properties: map[string]interface{}{
347 | 				"component": map[string]interface{}{
348 | 					"type":        "string",
349 | 					"description": "Schema component to explore (tables, columns, relationships, or full)",
350 | 					"enum":        []string{"tables", "columns", "relationships", "full"},
351 | 				},
352 | 				"table": map[string]interface{}{
353 | 					"type":        "string",
354 | 					"description": "Table name to explore (optional, leave empty for all tables)",
355 | 				},
356 | 				"timeout": map[string]interface{}{
357 | 					"type":        "integer",
358 | 					"description": "Query timeout in milliseconds (default: 10000)",
359 | 				},
360 | 				"database": map[string]interface{}{
361 | 					"type":        "string",
362 | 					"description": "Database ID to use (optional if only one database is configured)",
363 | 				},
364 | 			},
365 | 			Required: []string{"component", "database"},
366 | 		},
367 | 		Handler: handleSchemaExplorer,
368 | 	}
369 | }
370 | 
371 | // handleSchemaExplorer handles the schema explorer tool execution
372 | func handleSchemaExplorer(ctx context.Context, params map[string]interface{}) (interface{}, error) {
373 | 	// Check if database manager is initialized
374 | 	if dbManager == nil {
375 | 		return nil, fmt.Errorf("database manager not initialized")
376 | 	}
377 | 
378 | 	// Extract parameters
379 | 	component, ok := getStringParam(params, "component")
380 | 	if !ok {
381 | 		return nil, fmt.Errorf("component parameter is required")
382 | 	}
383 | 
384 | 	// Get database ID
385 | 	databaseID, ok := getStringParam(params, "database")
386 | 	if !ok {
387 | 		return nil, fmt.Errorf("database parameter is required")
388 | 	}
389 | 
390 | 	// Get database instance
391 | 	db, err := dbManager.GetDatabase(databaseID)
392 | 	if err != nil {
393 | 		return nil, fmt.Errorf("failed to get database: %w", err)
394 | 	}
395 | 
396 | 	// Extract table parameter (optional depending on component)
397 | 	table, _ := getStringParam(params, "table")
398 | 
399 | 	// Extract timeout
400 | 	timeout := 10000 // Default timeout: 10 seconds
401 | 	if timeoutParam, ok := getIntParam(params, "timeout"); ok {
402 | 		timeout = timeoutParam
403 | 	}
404 | 
405 | 	// Create context with timeout
406 | 	timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond)
407 | 	defer cancel()
408 | 
409 | 	// Use actual database queries based on component type
410 | 	switch component {
411 | 	case "tables":
412 | 		return getTables(timeoutCtx, db)
413 | 	case "columns":
414 | 		if table == "" {
415 | 			return nil, fmt.Errorf("table parameter is required for columns component")
416 | 		}
417 | 		return getColumns(timeoutCtx, db, table)
418 | 	case "relationships":
419 | 		return getRelationships(timeoutCtx, db, table)
420 | 	case "full":
421 | 		return getFullSchema(timeoutCtx, db)
422 | 	default:
423 | 		return nil, fmt.Errorf("invalid component: %s", component)
424 | 	}
425 | }
426 | 
427 | // executeWithFallbacks executes a series of database queries with fallbacks
428 | // Returns the first successful result or the last error encountered
429 | type queryWithArgs struct {
430 | 	query string
431 | 	args  []interface{}
432 | }
433 | 
434 | func executeWithFallbacks(ctx context.Context, db db.Database, queries []queryWithArgs, operationName string) (*sql.Rows, error) {
435 | 	var lastErr error
436 | 
437 | 	for i, q := range queries {
438 | 		rows, err := db.Query(ctx, q.query, q.args...)
439 | 		if err == nil {
440 | 			return rows, nil
441 | 		}
442 | 
443 | 		lastErr = err
444 | 		logger.Warn("%s fallback query %d failed: %v - Error: %v", operationName, i+1, q.query, err)
445 | 	}
446 | 
447 | 	// All queries failed, return the last error
448 | 	return nil, fmt.Errorf("%s failed after trying %d fallback queries: %w", operationName, len(queries), lastErr)
449 | }
450 | 
451 | // getTables retrieves the list of tables in the database
452 | func getTables(ctx context.Context, db db.Database) (interface{}, error) {
453 | 	// Get database type from connected database
454 | 	driverName := db.DriverName()
455 | 	dbType := driverName
456 | 
457 | 	// Create the appropriate strategy
458 | 	strategy := NewDatabaseStrategy(driverName)
459 | 
460 | 	// Get queries from strategy
461 | 	queries := strategy.GetTablesQueries()
462 | 
463 | 	// Execute queries with fallbacks
464 | 	rows, err := executeWithFallbacks(ctx, db, queries, "getTables")
465 | 	if err != nil {
466 | 		return nil, fmt.Errorf("failed to get tables: %w", err)
467 | 	}
468 | 
469 | 	defer func() {
470 | 		if rows != nil {
471 | 			if err := rows.Close(); err != nil {
472 | 				logger.Error("error closing rows: %v", err)
473 | 			}
474 | 		}
475 | 	}()
476 | 
477 | 	// Convert rows to maps
478 | 	results, err := rowsToMaps(rows)
479 | 	if err != nil {
480 | 		return nil, fmt.Errorf("failed to process tables: %w", err)
481 | 	}
482 | 
483 | 	return map[string]interface{}{
484 | 		"tables": results,
485 | 		"dbType": dbType,
486 | 	}, nil
487 | }
488 | 
489 | // getColumns retrieves the columns for a specific table
490 | func getColumns(ctx context.Context, db db.Database, table string) (interface{}, error) {
491 | 	// Get database type from connected database
492 | 	driverName := db.DriverName()
493 | 	dbType := driverName
494 | 
495 | 	// Create the appropriate strategy
496 | 	strategy := NewDatabaseStrategy(driverName)
497 | 
498 | 	// Get queries from strategy
499 | 	queries := strategy.GetColumnsQueries(table)
500 | 
501 | 	// Execute queries with fallbacks
502 | 	rows, err := executeWithFallbacks(ctx, db, queries, "getColumns["+table+"]")
503 | 	if err != nil {
504 | 		return nil, fmt.Errorf("failed to get columns for table %s: %w", table, err)
505 | 	}
506 | 
507 | 	defer func() {
508 | 		if rows != nil {
509 | 			if err := rows.Close(); err != nil {
510 | 				logger.Error("error closing rows: %v", err)
511 | 			}
512 | 		}
513 | 	}()
514 | 
515 | 	// Convert rows to maps
516 | 	results, err := rowsToMaps(rows)
517 | 	if err != nil {
518 | 		return nil, fmt.Errorf("failed to process columns: %w", err)
519 | 	}
520 | 
521 | 	return map[string]interface{}{
522 | 		"table":   table,
523 | 		"columns": results,
524 | 		"dbType":  dbType,
525 | 	}, nil
526 | }
527 | 
528 | // getRelationships retrieves the relationships for a table or all tables
529 | func getRelationships(ctx context.Context, db db.Database, table string) (interface{}, error) {
530 | 	// Get database type from connected database
531 | 	driverName := db.DriverName()
532 | 	dbType := driverName
533 | 
534 | 	// Create the appropriate strategy
535 | 	strategy := NewDatabaseStrategy(driverName)
536 | 
537 | 	// Get queries from strategy
538 | 	queries := strategy.GetRelationshipsQueries(table)
539 | 
540 | 	// Execute queries with fallbacks
541 | 	rows, err := executeWithFallbacks(ctx, db, queries, "getRelationships")
542 | 	if err != nil {
543 | 		return nil, fmt.Errorf("failed to get relationships for table %s: %w", table, err)
544 | 	}
545 | 
546 | 	defer func() {
547 | 		if rows != nil {
548 | 			if err := rows.Close(); err != nil {
549 | 				logger.Error("error closing rows: %v", err)
550 | 			}
551 | 		}
552 | 	}()
553 | 
554 | 	// Convert rows to maps
555 | 	results, err := rowsToMaps(rows)
556 | 	if err != nil {
557 | 		return nil, fmt.Errorf("failed to process relationships: %w", err)
558 | 	}
559 | 
560 | 	return map[string]interface{}{
561 | 		"relationships": results,
562 | 		"dbType":        dbType,
563 | 		"table":         table,
564 | 	}, nil
565 | }
566 | 
567 | // safeGetMap safely gets a map from an interface value
568 | func safeGetMap(obj interface{}) (map[string]interface{}, error) {
569 | 	if obj == nil {
570 | 		return nil, fmt.Errorf("nil value cannot be converted to map")
571 | 	}
572 | 
573 | 	mapVal, ok := obj.(map[string]interface{})
574 | 	if !ok {
575 | 		return nil, fmt.Errorf("value is not a map[string]interface{}: %T", obj)
576 | 	}
577 | 
578 | 	return mapVal, nil
579 | }
580 | 
581 | // safeGetString safely gets a string from a map key
582 | func safeGetString(m map[string]interface{}, key string) (string, error) {
583 | 	val, ok := m[key]
584 | 	if !ok {
585 | 		return "", fmt.Errorf("key %q not found in map", key)
586 | 	}
587 | 
588 | 	strVal, ok := val.(string)
589 | 	if !ok {
590 | 		return "", fmt.Errorf("value for key %q is not a string: %T", key, val)
591 | 	}
592 | 
593 | 	return strVal, nil
594 | }
595 | 
596 | // getFullSchema retrieves the complete database schema
597 | func getFullSchema(ctx context.Context, db db.Database) (interface{}, error) {
598 | 	tablesResult, err := getTables(ctx, db)
599 | 	if err != nil {
600 | 		return nil, fmt.Errorf("failed to get tables: %w", err)
601 | 	}
602 | 
603 | 	tablesMap, err := safeGetMap(tablesResult)
604 | 	if err != nil {
605 | 		return nil, fmt.Errorf("invalid tables result: %w", err)
606 | 	}
607 | 
608 | 	tablesSlice, ok := tablesMap["tables"].([]map[string]interface{})
609 | 	if !ok {
610 | 		return nil, fmt.Errorf("invalid tables data format")
611 | 	}
612 | 
613 | 	// For each table, get columns
614 | 	fullSchema := make(map[string]interface{})
615 | 	for _, tableInfo := range tablesSlice {
616 | 		tableName, err := safeGetString(tableInfo, "table_name")
617 | 		if err != nil {
618 | 			return nil, fmt.Errorf("invalid table info: %w", err)
619 | 		}
620 | 
621 | 		columnsResult, columnsErr := getColumns(ctx, db, tableName)
622 | 		if columnsErr != nil {
623 | 			return nil, fmt.Errorf("failed to get columns for table %s: %w", tableName, columnsErr)
624 | 		}
625 | 		fullSchema[tableName] = columnsResult
626 | 	}
627 | 
628 | 	// Get all relationships
629 | 	relationships, relErr := getRelationships(ctx, db, "")
630 | 	if relErr != nil {
631 | 		return nil, fmt.Errorf("failed to get relationships: %w", relErr)
632 | 	}
633 | 
634 | 	relMap, err := safeGetMap(relationships)
635 | 	if err != nil {
636 | 		return nil, fmt.Errorf("invalid relationships result: %w", err)
637 | 	}
638 | 
639 | 	return map[string]interface{}{
640 | 		"tables":        tablesSlice,
641 | 		"schema":        fullSchema,
642 | 		"relationships": relMap["relationships"],
643 | 	}, nil
644 | }
645 | 
```
Page 4/7FirstPrevNextLast