This is page 3 of 5. Use http://codebase.md/freepeak/db-mcp-server?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .cm
│ └── gitstream.cm
├── .cursor
│ ├── mcp-example.json
│ ├── mcp.json
│ └── rules
│ └── global.mdc
├── .dockerignore
├── .DS_Store
├── .env.example
├── .github
│ ├── FUNDING.yml
│ └── workflows
│ └── go.yml
├── .gitignore
├── .golangci.yml
├── assets
│ └── logo.svg
├── CHANGELOG.md
├── cmd
│ └── server
│ └── main.go
├── commit-message.txt
├── config.json
├── config.timescaledb-test.json
├── docker-compose.mcp-test.yml
├── docker-compose.test.yml
├── docker-compose.timescaledb-test.yml
├── docker-compose.yml
├── docker-wrapper.sh
├── Dockerfile
├── docs
│ ├── REFACTORING.md
│ ├── TIMESCALEDB_FUNCTIONS.md
│ ├── TIMESCALEDB_IMPLEMENTATION.md
│ ├── TIMESCALEDB_PRD.md
│ └── TIMESCALEDB_TOOLS.md
├── examples
│ └── postgres_connection.go
├── glama.json
├── go.mod
├── go.sum
├── init-scripts
│ └── timescaledb
│ ├── 01-init.sql
│ ├── 02-sample-data.sql
│ ├── 03-continuous-aggregates.sql
│ └── README.md
├── internal
│ ├── config
│ │ ├── config_test.go
│ │ └── config.go
│ ├── delivery
│ │ └── mcp
│ │ ├── compression_policy_test.go
│ │ ├── context
│ │ │ ├── hypertable_schema_test.go
│ │ │ ├── timescale_completion_test.go
│ │ │ ├── timescale_context_test.go
│ │ │ └── timescale_query_suggestion_test.go
│ │ ├── mock_test.go
│ │ ├── response_test.go
│ │ ├── response.go
│ │ ├── retention_policy_test.go
│ │ ├── server_wrapper.go
│ │ ├── timescale_completion.go
│ │ ├── timescale_context.go
│ │ ├── timescale_schema.go
│ │ ├── timescale_tool_test.go
│ │ ├── timescale_tool.go
│ │ ├── timescale_tools_test.go
│ │ ├── tool_registry.go
│ │ └── tool_types.go
│ ├── domain
│ │ └── database.go
│ ├── logger
│ │ ├── logger_test.go
│ │ └── logger.go
│ ├── repository
│ │ └── database_repository.go
│ └── usecase
│ └── database_usecase.go
├── LICENSE
├── Makefile
├── pkg
│ ├── core
│ │ ├── core.go
│ │ └── logging.go
│ ├── db
│ │ ├── db_test.go
│ │ ├── db.go
│ │ ├── manager.go
│ │ ├── README.md
│ │ └── timescale
│ │ ├── config_test.go
│ │ ├── config.go
│ │ ├── connection_test.go
│ │ ├── connection.go
│ │ ├── continuous_aggregate_test.go
│ │ ├── continuous_aggregate.go
│ │ ├── hypertable_test.go
│ │ ├── hypertable.go
│ │ ├── metadata.go
│ │ ├── mocks_test.go
│ │ ├── policy_test.go
│ │ ├── policy.go
│ │ ├── query.go
│ │ ├── timeseries_test.go
│ │ └── timeseries.go
│ ├── dbtools
│ │ ├── db_helpers.go
│ │ ├── dbtools_test.go
│ │ ├── dbtools.go
│ │ ├── exec.go
│ │ ├── performance_test.go
│ │ ├── performance.go
│ │ ├── query.go
│ │ ├── querybuilder_test.go
│ │ ├── querybuilder.go
│ │ ├── README.md
│ │ ├── schema_test.go
│ │ ├── schema.go
│ │ ├── tx_test.go
│ │ └── tx.go
│ ├── internal
│ │ └── logger
│ │ └── logger.go
│ ├── jsonrpc
│ │ └── jsonrpc.go
│ ├── logger
│ │ └── logger.go
│ └── tools
│ └── tools.go
├── README-old.md
├── README.md
├── repomix-output.txt
├── request.json
├── start-mcp.sh
├── test.Dockerfile
├── timescaledb-test.sh
└── wait-for-it.sh
```
# Files
--------------------------------------------------------------------------------
/pkg/db/timescale/hypertable.go:
--------------------------------------------------------------------------------
```go
package timescale
import (
"context"
"fmt"
"strings"
)
// HypertableConfig defines configuration for creating a hypertable
type HypertableConfig struct {
TableName string
TimeColumn string
ChunkTimeInterval string
PartitioningColumn string
CreateIfNotExists bool
SpacePartitions int // Number of space partitions (for multi-dimensional partitioning)
IfNotExists bool // If true, don't error if table is already a hypertable
MigrateData bool // If true, migrate existing data to chunks
}
// Hypertable represents a TimescaleDB hypertable
type Hypertable struct {
TableName string
SchemaName string
TimeColumn string
SpaceColumn string
NumDimensions int
CompressionEnabled bool
RetentionEnabled bool
}
// CreateHypertable converts a regular PostgreSQL table to a TimescaleDB hypertable
func (t *DB) CreateHypertable(ctx context.Context, config HypertableConfig) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
// Construct the create_hypertable call
var queryBuilder strings.Builder
queryBuilder.WriteString("SELECT create_hypertable(")
// Table name and time column are required
queryBuilder.WriteString(fmt.Sprintf("'%s', '%s'", config.TableName, config.TimeColumn))
// Optional parameters
if config.PartitioningColumn != "" {
queryBuilder.WriteString(fmt.Sprintf(", partition_column => '%s'", config.PartitioningColumn))
}
if config.ChunkTimeInterval != "" {
queryBuilder.WriteString(fmt.Sprintf(", chunk_time_interval => INTERVAL '%s'", config.ChunkTimeInterval))
}
if config.SpacePartitions > 0 {
queryBuilder.WriteString(fmt.Sprintf(", number_partitions => %d", config.SpacePartitions))
}
if config.IfNotExists {
queryBuilder.WriteString(", if_not_exists => TRUE")
}
if config.MigrateData {
queryBuilder.WriteString(", migrate_data => TRUE")
}
queryBuilder.WriteString(")")
// Execute the query
_, err := t.ExecuteSQLWithoutParams(ctx, queryBuilder.String())
if err != nil {
return fmt.Errorf("failed to create hypertable: %w", err)
}
return nil
}
// AddDimension adds a new dimension (partitioning key) to a hypertable
func (t *DB) AddDimension(ctx context.Context, tableName, columnName string, numPartitions int) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
query := fmt.Sprintf("SELECT add_dimension('%s', '%s', number_partitions => %d)",
tableName, columnName, numPartitions)
_, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to add dimension: %w", err)
}
return nil
}
// ListHypertables returns a list of all hypertables in the database
func (t *DB) ListHypertables(ctx context.Context) ([]Hypertable, error) {
if !t.isTimescaleDB {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
query := `
SELECT h.table_name, h.schema_name, d.column_name as time_column,
count(d.id) as num_dimensions,
(
SELECT column_name FROM _timescaledb_catalog.dimension
WHERE hypertable_id = h.id AND column_type != 'TIMESTAMP'
AND column_type != 'TIMESTAMPTZ'
LIMIT 1
) as space_column
FROM _timescaledb_catalog.hypertable h
JOIN _timescaledb_catalog.dimension d ON h.id = d.hypertable_id
GROUP BY h.id, h.table_name, h.schema_name
`
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list hypertables: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected result type from database query")
}
var hypertables []Hypertable
for _, row := range rows {
ht := Hypertable{
TableName: fmt.Sprintf("%v", row["table_name"]),
SchemaName: fmt.Sprintf("%v", row["schema_name"]),
TimeColumn: fmt.Sprintf("%v", row["time_column"]),
}
// Handle nullable columns
if row["space_column"] != nil {
ht.SpaceColumn = fmt.Sprintf("%v", row["space_column"])
}
// Convert numeric dimensions
if dimensions, ok := row["num_dimensions"].(int64); ok {
ht.NumDimensions = int(dimensions)
} else if dimensions, ok := row["num_dimensions"].(int); ok {
ht.NumDimensions = dimensions
}
// Check if compression is enabled
compQuery := fmt.Sprintf(
"SELECT count(*) > 0 as is_compressed FROM timescaledb_information.compression_settings WHERE hypertable_name = '%s'",
ht.TableName,
)
compResult, err := t.ExecuteSQLWithoutParams(ctx, compQuery)
if err == nil {
if compRows, ok := compResult.([]map[string]interface{}); ok && len(compRows) > 0 {
if isCompressed, ok := compRows[0]["is_compressed"].(bool); ok {
ht.CompressionEnabled = isCompressed
}
}
}
// Check if retention policy is enabled
retQuery := fmt.Sprintf(
"SELECT count(*) > 0 as has_retention FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'",
ht.TableName,
)
retResult, err := t.ExecuteSQLWithoutParams(ctx, retQuery)
if err == nil {
if retRows, ok := retResult.([]map[string]interface{}); ok && len(retRows) > 0 {
if hasRetention, ok := retRows[0]["has_retention"].(bool); ok {
ht.RetentionEnabled = hasRetention
}
}
}
hypertables = append(hypertables, ht)
}
return hypertables, nil
}
// GetHypertable gets information about a specific hypertable
func (t *DB) GetHypertable(ctx context.Context, tableName string) (*Hypertable, error) {
if !t.isTimescaleDB {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
query := fmt.Sprintf(`
SELECT h.table_name, h.schema_name, d.column_name as time_column,
count(d.id) as num_dimensions,
(
SELECT column_name FROM _timescaledb_catalog.dimension
WHERE hypertable_id = h.id AND column_type != 'TIMESTAMP'
AND column_type != 'TIMESTAMPTZ'
LIMIT 1
) as space_column
FROM _timescaledb_catalog.hypertable h
JOIN _timescaledb_catalog.dimension d ON h.id = d.hypertable_id
WHERE h.table_name = '%s'
GROUP BY h.id, h.table_name, h.schema_name
`, tableName)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to get hypertable information: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
return nil, fmt.Errorf("table '%s' is not a hypertable", tableName)
}
row := rows[0]
ht := &Hypertable{
TableName: fmt.Sprintf("%v", row["table_name"]),
SchemaName: fmt.Sprintf("%v", row["schema_name"]),
TimeColumn: fmt.Sprintf("%v", row["time_column"]),
}
// Handle nullable columns
if row["space_column"] != nil {
ht.SpaceColumn = fmt.Sprintf("%v", row["space_column"])
}
// Convert numeric dimensions
if dimensions, ok := row["num_dimensions"].(int64); ok {
ht.NumDimensions = int(dimensions)
} else if dimensions, ok := row["num_dimensions"].(int); ok {
ht.NumDimensions = dimensions
}
// Check if compression is enabled
compQuery := fmt.Sprintf(
"SELECT count(*) > 0 as is_compressed FROM timescaledb_information.compression_settings WHERE hypertable_name = '%s'",
ht.TableName,
)
compResult, err := t.ExecuteSQLWithoutParams(ctx, compQuery)
if err == nil {
if compRows, ok := compResult.([]map[string]interface{}); ok && len(compRows) > 0 {
if isCompressed, ok := compRows[0]["is_compressed"].(bool); ok {
ht.CompressionEnabled = isCompressed
}
}
}
// Check if retention policy is enabled
retQuery := fmt.Sprintf(
"SELECT count(*) > 0 as has_retention FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'",
ht.TableName,
)
retResult, err := t.ExecuteSQLWithoutParams(ctx, retQuery)
if err == nil {
if retRows, ok := retResult.([]map[string]interface{}); ok && len(retRows) > 0 {
if hasRetention, ok := retRows[0]["has_retention"].(bool); ok {
ht.RetentionEnabled = hasRetention
}
}
}
return ht, nil
}
// DropHypertable drops a hypertable
func (t *DB) DropHypertable(ctx context.Context, tableName string, cascade bool) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
// Build the DROP TABLE query
query := fmt.Sprintf("DROP TABLE %s", tableName)
if cascade {
query += " CASCADE"
}
_, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to drop hypertable: %w", err)
}
return nil
}
// CheckIfHypertable checks if a table is a hypertable
func (t *DB) CheckIfHypertable(ctx context.Context, tableName string) (bool, error) {
if !t.isTimescaleDB {
return false, fmt.Errorf("TimescaleDB extension not available")
}
query := fmt.Sprintf(`
SELECT EXISTS (
SELECT 1
FROM _timescaledb_catalog.hypertable
WHERE table_name = '%s'
) as is_hypertable
`, tableName)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return false, fmt.Errorf("failed to check if table is a hypertable: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
return false, fmt.Errorf("unexpected result from database query")
}
isHypertable, ok := rows[0]["is_hypertable"].(bool)
if !ok {
return false, fmt.Errorf("unexpected result type for is_hypertable")
}
return isHypertable, nil
}
// RecentChunks returns information about the most recent chunks
func (t *DB) RecentChunks(ctx context.Context, tableName string, limit int) (interface{}, error) {
if !t.isTimescaleDB {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
// Use default limit of 10 if not specified
if limit <= 0 {
limit = 10
}
query := fmt.Sprintf(`
SELECT
chunk_name,
range_start,
range_end
FROM timescaledb_information.chunks
WHERE hypertable_name = '%s'
ORDER BY range_end DESC
LIMIT %d
`, tableName, limit)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to get recent chunks: %w", err)
}
return result, nil
}
// CreateHypertable is a helper function to create a hypertable with the given configuration options.
// This is exported for use by other packages.
func CreateHypertable(ctx context.Context, db *DB, table, timeColumn string, opts ...HypertableOption) error {
config := HypertableConfig{
TableName: table,
TimeColumn: timeColumn,
}
// Apply all options
for _, opt := range opts {
opt(&config)
}
return db.CreateHypertable(ctx, config)
}
// HypertableOption is a functional option for CreateHypertable
type HypertableOption func(*HypertableConfig)
// WithChunkInterval sets the chunk time interval
func WithChunkInterval(interval string) HypertableOption {
return func(config *HypertableConfig) {
config.ChunkTimeInterval = interval
}
}
// WithPartitioningColumn sets the space partitioning column
func WithPartitioningColumn(column string) HypertableOption {
return func(config *HypertableConfig) {
config.PartitioningColumn = column
}
}
// WithIfNotExists sets the if_not_exists flag
func WithIfNotExists(ifNotExists bool) HypertableOption {
return func(config *HypertableConfig) {
config.IfNotExists = ifNotExists
}
}
// WithMigrateData sets the migrate_data flag
func WithMigrateData(migrateData bool) HypertableOption {
return func(config *HypertableConfig) {
config.MigrateData = migrateData
}
}
```
--------------------------------------------------------------------------------
/pkg/dbtools/performance.go:
--------------------------------------------------------------------------------
```go
package dbtools
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/FreePeak/db-mcp-server/pkg/logger"
)
// QueryMetrics stores performance metrics for a database query
type QueryMetrics struct {
Query string // SQL query text
Count int // Number of times the query was executed
TotalDuration time.Duration // Total execution time
MinDuration time.Duration // Minimum execution time
MaxDuration time.Duration // Maximum execution time
AvgDuration time.Duration // Average execution time
LastExecuted time.Time // When the query was last executed
}
// PerformanceAnalyzer tracks query performance and provides optimization suggestions
type PerformanceAnalyzer struct {
slowThreshold time.Duration
queryHistory []QueryRecord
maxHistory int
}
// QueryRecord stores information about a query execution
type QueryRecord struct {
Query string `json:"query"`
Params []interface{} `json:"params"`
Duration time.Duration `json:"duration"`
StartTime time.Time `json:"startTime"`
Error string `json:"error,omitempty"`
Optimized bool `json:"optimized"`
Suggestion string `json:"suggestion,omitempty"`
}
// SQLIssueDetector detects potential issues in SQL queries
type SQLIssueDetector struct {
patterns map[string]*regexp.Regexp
}
// singleton instance
var performanceAnalyzer *PerformanceAnalyzer
// GetPerformanceAnalyzer returns the singleton performance analyzer
func GetPerformanceAnalyzer() *PerformanceAnalyzer {
if performanceAnalyzer == nil {
performanceAnalyzer = NewPerformanceAnalyzer()
}
return performanceAnalyzer
}
// NewPerformanceAnalyzer creates a new performance analyzer
func NewPerformanceAnalyzer() *PerformanceAnalyzer {
return &PerformanceAnalyzer{
slowThreshold: 500 * time.Millisecond, // Default: 500ms
queryHistory: make([]QueryRecord, 0),
maxHistory: 100, // Default: store last 100 queries
}
}
// LogSlowQuery logs a warning if a query takes longer than the slow query threshold
func (pa *PerformanceAnalyzer) LogSlowQuery(query string, params []interface{}, duration time.Duration) {
if duration >= pa.slowThreshold {
paramStr := formatParams(params)
logger.Warn("Slow query detected (%.2fms): %s [params: %s]",
float64(duration.Microseconds())/1000.0,
query,
paramStr)
}
}
// TrackQuery tracks the execution of a query and logs slow queries
func (pa *PerformanceAnalyzer) TrackQuery(ctx context.Context, query string, params []interface{}, exec func() (interface{}, error)) (interface{}, error) {
startTime := time.Now()
result, err := exec()
duration := time.Since(startTime)
// Create query record
record := QueryRecord{
Query: query,
Params: params,
Duration: duration,
StartTime: startTime,
}
// Check if query is slow
if duration >= pa.slowThreshold {
pa.LogSlowQuery(query, params, duration)
record.Suggestion = "Query execution time exceeds threshold"
}
// Record error if any
if err != nil {
record.Error = err.Error()
}
// Add to history (keeping max size)
pa.queryHistory = append(pa.queryHistory, record)
if len(pa.queryHistory) > pa.maxHistory {
pa.queryHistory = pa.queryHistory[1:]
}
return result, err
}
// SQLIssueDetector methods
// NewSQLIssueDetector creates a new SQL issue detector
func NewSQLIssueDetector() *SQLIssueDetector {
detector := &SQLIssueDetector{
patterns: make(map[string]*regexp.Regexp),
}
// Add known issue patterns
detector.AddPattern("cartesian-join", `SELECT.*FROM\s+(\w+)\s*,\s*(\w+)`)
detector.AddPattern("select-star", `SELECT\s+\*\s+FROM`)
detector.AddPattern("missing-where", `(DELETE\s+FROM|UPDATE)\s+\w+\s+(?:SET\s+(?:\w+\s*=\s*[^,]+)(?:\s*,\s*\w+\s*=\s*[^,]+)*\s*)*(;|\z)`)
detector.AddPattern("or-in-where", `WHERE.*\s+OR\s+`)
detector.AddPattern("in-with-many-items", `IN\s*\(\s*(?:'[^']*'\s*,\s*){10,}`)
detector.AddPattern("not-in", `NOT\s+IN\s*\(`)
detector.AddPattern("is-null", `IS\s+NULL`)
detector.AddPattern("function-on-column", `WHERE\s+\w+\s*\(\s*\w+\s*\)`)
detector.AddPattern("order-by-rand", `ORDER\s+BY\s+RAND\(\)`)
detector.AddPattern("group-by-number", `GROUP\s+BY\s+\d+`)
detector.AddPattern("having-without-group", `HAVING.*(?:(?:GROUP\s+BY.*$)|(?:$))`)
return detector
}
// AddPattern adds a pattern for detecting SQL issues
func (d *SQLIssueDetector) AddPattern(name, pattern string) {
re, err := regexp.Compile("(?i)" + pattern)
if err != nil {
logger.Error("Error compiling regex pattern '%s': %v", pattern, err)
return
}
d.patterns[name] = re
}
// DetectIssues detects issues in a SQL query
func (d *SQLIssueDetector) DetectIssues(query string) map[string]string {
issues := make(map[string]string)
for name, pattern := range d.patterns {
if pattern.MatchString(query) {
issues[name] = d.getSuggestionForIssue(name)
}
}
return issues
}
// getSuggestionForIssue returns a suggestion for a detected issue
func (d *SQLIssueDetector) getSuggestionForIssue(issue string) string {
suggestions := map[string]string{
"cartesian-join": "Use explicit JOIN statements instead of comma-syntax joins to avoid unintended Cartesian products.",
"select-star": "Specify exact columns needed instead of SELECT * to reduce network traffic and improve query execution.",
"missing-where": "Add a WHERE clause to avoid affecting all rows in the table.",
"or-in-where": "Consider using IN instead of multiple OR conditions for better performance.",
"in-with-many-items": "Too many items in IN clause; consider a temporary table or a JOIN instead.",
"not-in": "NOT IN with subqueries can be slow. Consider using NOT EXISTS or LEFT JOIN/IS NULL pattern.",
"is-null": "IS NULL conditions prevent index usage. Consider redesigning to avoid NULL values if possible.",
"function-on-column": "Applying functions to columns in WHERE clauses prevents index usage. Restructure if possible.",
"order-by-rand": "ORDER BY RAND() causes full table scan and sort. Consider alternative randomization methods.",
"group-by-number": "Using column position numbers in GROUP BY can be error-prone. Use explicit column names.",
"having-without-group": "HAVING without GROUP BY may indicate a logical error in query structure.",
}
if suggestion, ok := suggestions[issue]; ok {
return suggestion
}
return "Potential issue detected; review query for optimization opportunities."
}
// Helper functions
// formatParams converts query parameters to a readable string format
func formatParams(params []interface{}) string {
if len(params) == 0 {
return "none"
}
paramStrings := make([]string, len(params))
for i, param := range params {
if param == nil {
paramStrings[i] = "NULL"
} else {
paramStrings[i] = fmt.Sprintf("%v", param)
}
}
return "[" + strings.Join(paramStrings, ", ") + "]"
}
// Placeholder for future implementation
// Currently not used but kept for reference
/*
func handlePerformanceAnalyzer(ctx context.Context, params map[string]interface{}) (interface{}, error) {
// This function will be implemented in the future
return nil, fmt.Errorf("not implemented")
}
*/
// StripComments removes SQL comments from a query string
func StripComments(input string) string {
// Strip /* ... */ comments
multiLineRe, err := regexp.Compile(`/\*[\s\S]*?\*/`)
if err != nil {
// If there's an error with the regex, just return the input
logger.Error("Error compiling regex pattern '%s': %v", `\/\*[\s\S]*?\*\/`, err)
return input
}
withoutMultiLine := multiLineRe.ReplaceAllString(input, "")
// Strip -- comments
singleLineRe, err := regexp.Compile(`--.*$`)
if err != nil {
return withoutMultiLine
}
return singleLineRe.ReplaceAllString(withoutMultiLine, "")
}
// GetAllMetrics returns all collected metrics
func (pa *PerformanceAnalyzer) GetAllMetrics() []*QueryMetrics {
// Group query history by normalized query text
queryMap := make(map[string]*QueryMetrics)
for _, record := range pa.queryHistory {
normalizedQuery := normalizeQuery(record.Query)
metrics, exists := queryMap[normalizedQuery]
if !exists {
metrics = &QueryMetrics{
Query: record.Query,
MinDuration: record.Duration,
MaxDuration: record.Duration,
LastExecuted: record.StartTime,
}
queryMap[normalizedQuery] = metrics
}
// Update metrics
metrics.Count++
metrics.TotalDuration += record.Duration
if record.Duration < metrics.MinDuration {
metrics.MinDuration = record.Duration
}
if record.Duration > metrics.MaxDuration {
metrics.MaxDuration = record.Duration
}
if record.StartTime.After(metrics.LastExecuted) {
metrics.LastExecuted = record.StartTime
}
}
// Calculate averages and convert to slice
metrics := make([]*QueryMetrics, 0, len(queryMap))
for _, m := range queryMap {
m.AvgDuration = time.Duration(int64(m.TotalDuration) / int64(m.Count))
metrics = append(metrics, m)
}
return metrics
}
// Reset clears all collected metrics
func (pa *PerformanceAnalyzer) Reset() {
pa.queryHistory = make([]QueryRecord, 0)
}
// GetSlowThreshold returns the current slow query threshold
func (pa *PerformanceAnalyzer) GetSlowThreshold() time.Duration {
return pa.slowThreshold
}
// SetSlowThreshold sets the slow query threshold
func (pa *PerformanceAnalyzer) SetSlowThreshold(threshold time.Duration) {
pa.slowThreshold = threshold
}
// AnalyzeQuery analyzes a SQL query and returns optimization suggestions
func AnalyzeQuery(query string) []string {
// Create detector and get suggestions
detector := NewSQLIssueDetector()
issues := detector.DetectIssues(query)
suggestions := make([]string, 0, len(issues))
for _, suggestion := range issues {
suggestions = append(suggestions, suggestion)
}
// Add default suggestions for query patterns
if strings.Contains(strings.ToUpper(query), "SELECT *") {
suggestions = append(suggestions, "Avoid using SELECT * - specify only the columns you need")
}
if !strings.Contains(strings.ToUpper(query), "WHERE") &&
!strings.Contains(strings.ToUpper(query), "JOIN") {
suggestions = append(suggestions, "Consider adding a WHERE clause to limit the result set")
}
if strings.Contains(strings.ToUpper(query), "JOIN") &&
!strings.Contains(strings.ToUpper(query), "ON") {
suggestions = append(suggestions, "Ensure all JOINs have proper conditions")
}
if strings.Contains(strings.ToUpper(query), "ORDER BY") {
suggestions = append(suggestions, "Verify that ORDER BY columns are properly indexed")
}
if strings.Contains(query, "(SELECT") {
suggestions = append(suggestions, "Consider replacing subqueries with JOINs where possible")
}
return suggestions
}
// normalizeQuery standardizes SQL queries for comparison by replacing literals
func normalizeQuery(query string) string {
// Trim and normalize whitespace
query = strings.TrimSpace(query)
wsRegex := regexp.MustCompile(`\s+`)
query = wsRegex.ReplaceAllString(query, " ")
// Replace numeric literals
numRegex := regexp.MustCompile(`\b\d+\b`)
query = numRegex.ReplaceAllString(query, "?")
// Replace string literals in single quotes
strRegex := regexp.MustCompile(`'[^']*'`)
query = strRegex.ReplaceAllString(query, "'?'")
// Replace string literals in double quotes
dblQuoteRegex := regexp.MustCompile(`"[^"]*"`)
query = dblQuoteRegex.ReplaceAllString(query, "\"?\"")
return query
}
// TODO: Implement more sophisticated performance metrics and query analysis
// TODO: Add support for query plan visualization
// TODO: Consider using time-series storage for long-term performance tracking
// TODO: Implement anomaly detection for query performance
// TODO: Add integration with external monitoring systems
// TODO: Implement periodic background performance analysis
```
--------------------------------------------------------------------------------
/pkg/db/timescale/mocks_test.go:
--------------------------------------------------------------------------------
```go
package timescale
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"strings"
"testing"
)
// MockDB simulates a database for testing purposes
type MockDB struct {
mockResults map[string]MockExecuteResult
lastQuery string
lastQueryArgs []interface{}
queryHistory []string
connectCalled bool
connectError error
closeCalled bool
closeError error
isTimescaleDB bool
// Store the expected scan value for QueryRow
queryScanValues map[string]interface{}
// Added for additional mock methods
queryResult []map[string]interface{}
err error
}
// MockExecuteResult is used for mocking query results
type MockExecuteResult struct {
Result interface{}
Error error
}
// NewMockDB creates a new mock database for testing
func NewMockDB() *MockDB {
return &MockDB{
mockResults: make(map[string]MockExecuteResult),
queryScanValues: make(map[string]interface{}),
queryHistory: make([]string, 0),
isTimescaleDB: true, // Default is true
}
}
// RegisterQueryResult registers a result to be returned for a specific query
func (m *MockDB) RegisterQueryResult(query string, result interface{}, err error) {
// Store the result for exact matching
m.mockResults[query] = MockExecuteResult{
Result: result,
Error: err,
}
// Also store the result for partial matching
if result != nil || err != nil {
m.mockResults["partial:"+query] = MockExecuteResult{
Result: result,
Error: err,
}
}
// Also store the result as a scan value for QueryRow
m.queryScanValues[query] = result
}
// getQueryResult tries to find a matching result for a query
// First tries exact match, then partial match
func (m *MockDB) getQueryResult(query string) (MockExecuteResult, bool) {
// Try exact match first
if result, ok := m.mockResults[query]; ok {
return result, true
}
// Try partial match
for k, v := range m.mockResults {
if strings.HasPrefix(k, "partial:") && strings.Contains(query, k[8:]) {
return v, true
}
}
return MockExecuteResult{}, false
}
// GetLastQuery returns the last executed query and args
func (m *MockDB) GetLastQuery() (string, []interface{}) {
return m.lastQuery, m.lastQueryArgs
}
// Connect implements db.Database.Connect
func (m *MockDB) Connect() error {
m.connectCalled = true
return m.connectError
}
// SetConnectError sets an error to be returned from Connect()
func (m *MockDB) SetConnectError(err error) {
m.connectError = err
}
// Close implements db.Database.Close
func (m *MockDB) Close() error {
m.closeCalled = true
return m.closeError
}
// SetCloseError sets an error to be returned from Close()
func (m *MockDB) SetCloseError(err error) {
m.closeError = err
}
// SetTimescaleAvailable sets whether TimescaleDB is available for this mock
func (m *MockDB) SetTimescaleAvailable(available bool) {
m.isTimescaleDB = available
}
// Exec implements db.Database.Exec
func (m *MockDB) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
m.lastQuery = query
m.lastQueryArgs = args
if result, found := m.getQueryResult(query); found {
if result.Error != nil {
return nil, result.Error
}
if sqlResult, ok := result.Result.(sql.Result); ok {
return sqlResult, nil
}
}
return &MockResult{}, nil
}
// Query implements db.Database.Query
func (m *MockDB) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
m.lastQuery = query
m.lastQueryArgs = args
if result, found := m.getQueryResult(query); found {
if result.Error != nil {
return nil, result.Error
}
if rows, ok := result.Result.(*sql.Rows); ok {
return rows, nil
}
}
// Create a MockRows for the test
return sql.OpenDB(&MockConnector{mockDB: m, query: query}).Query(query)
}
// QueryRow implements db.Database.QueryRow
func (m *MockDB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row {
m.lastQuery = query
m.lastQueryArgs = args
// Use a custom connector to create a sql.DB that's backed by our mock
db := sql.OpenDB(&MockConnector{mockDB: m, query: query})
return db.QueryRow(query)
}
// BeginTx implements db.Database.BeginTx
func (m *MockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return nil, nil
}
// Prepare implements db.Database.Prepare
func (m *MockDB) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
return nil, nil
}
// Ping implements db.Database.Ping
func (m *MockDB) Ping(ctx context.Context) error {
return nil
}
// DB implements db.Database.DB
func (m *MockDB) DB() *sql.DB {
return nil
}
// ConnectionString implements db.Database.ConnectionString
func (m *MockDB) ConnectionString() string {
return "mock://localhost/testdb"
}
// DriverName implements db.Database.DriverName
func (m *MockDB) DriverName() string {
return "postgres"
}
// QueryTimeout implements db.Database.QueryTimeout
func (m *MockDB) QueryTimeout() int {
return 30
}
// MockResult implements sql.Result
type MockResult struct{}
// LastInsertId implements sql.Result.LastInsertId
func (r *MockResult) LastInsertId() (int64, error) {
return 1, nil
}
// RowsAffected implements sql.Result.RowsAffected
func (r *MockResult) RowsAffected() (int64, error) {
return 1, nil
}
// MockTimescaleDB creates a TimescaleDB instance with mocked database for testing
func MockTimescaleDB(t testing.TB) (*DB, *MockDB) {
mockDB := NewMockDB()
mockDB.SetTimescaleAvailable(true)
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
extVersion: "2.8.0",
config: DBConfig{
UseTimescaleDB: true,
},
}
return tsdb, mockDB
}
// AssertQueryContains checks if the query contains the expected substrings
func AssertQueryContains(t testing.TB, query string, substrings ...string) {
for _, substring := range substrings {
if !contains(query, substring) {
t.Errorf("Expected query to contain '%s', but got: %s", substring, query)
}
}
}
// Helper function to check if a string contains another string
func contains(s, substr string) bool {
return strings.Contains(s, substr)
}
// Mock driver implementation to support sql.Row
// MockConnector implements driver.Connector
type MockConnector struct {
mockDB *MockDB
query string
}
// Connect implements driver.Connector
func (c *MockConnector) Connect(_ context.Context) (driver.Conn, error) {
return &MockConn{mockDB: c.mockDB, query: c.query}, nil
}
// Driver implements driver.Connector
func (c *MockConnector) Driver() driver.Driver {
return &MockDriver{}
}
// MockDriver implements driver.Driver
type MockDriver struct{}
// Open implements driver.Driver
func (d *MockDriver) Open(name string) (driver.Conn, error) {
return &MockConn{}, nil
}
// MockConn implements driver.Conn
type MockConn struct {
mockDB *MockDB
query string
}
// Prepare implements driver.Conn
func (c *MockConn) Prepare(query string) (driver.Stmt, error) {
return &MockStmt{mockDB: c.mockDB, query: query}, nil
}
// Close implements driver.Conn
func (c *MockConn) Close() error {
return nil
}
// Begin implements driver.Conn
func (c *MockConn) Begin() (driver.Tx, error) {
return nil, nil
}
// MockStmt implements driver.Stmt
type MockStmt struct {
mockDB *MockDB
query string
}
// Close implements driver.Stmt
func (s *MockStmt) Close() error {
return nil
}
// NumInput implements driver.Stmt
func (s *MockStmt) NumInput() int {
return 0
}
// Exec implements driver.Stmt
func (s *MockStmt) Exec(args []driver.Value) (driver.Result, error) {
return nil, nil
}
// Query implements driver.Stmt
func (s *MockStmt) Query(args []driver.Value) (driver.Rows, error) {
// Return the registered result for this query
if s.mockDB != nil {
if result, found := s.mockDB.getQueryResult(s.query); found {
if result.Error != nil {
return nil, result.Error
}
return &MockRows{value: result.Result}, nil
}
}
return &MockRows{}, nil
}
// MockRows implements driver.Rows
type MockRows struct {
value interface{}
currentRow int
columnNames []string
}
// Columns implements driver.Rows
func (r *MockRows) Columns() []string {
// If we have a slice of maps, extract column names from the first map
if rows, ok := r.value.([]map[string]interface{}); ok && len(rows) > 0 {
if r.columnNames == nil {
r.columnNames = make([]string, 0, len(rows[0]))
for k := range rows[0] {
r.columnNames = append(r.columnNames, k)
}
}
return r.columnNames
}
return []string{"value"}
}
// Close implements driver.Rows
func (r *MockRows) Close() error {
return nil
}
// Next implements driver.Rows
func (r *MockRows) Next(dest []driver.Value) error {
// Handle slice of maps (multiple rows of data)
if rows, ok := r.value.([]map[string]interface{}); ok {
if r.currentRow < len(rows) {
row := rows[r.currentRow]
r.currentRow++
// Find column values in the expected order
columns := r.Columns()
for i, col := range columns {
if i < len(dest) {
dest[i] = row[col]
}
}
return nil
}
return io.EOF
}
// Handle simple string value
if r.currentRow == 0 && r.value != nil {
r.currentRow++
dest[0] = r.value
return nil
}
return io.EOF
}
// RunQueryTest executes a mock query test against the DB
func RunQueryTest(t *testing.T, testFunc func(*DB) error) {
tsdb, _ := MockTimescaleDB(t)
err := testFunc(tsdb)
if err != nil {
t.Errorf("Test failed: %v", err)
}
}
// SetQueryResult sets the mock result for queries
func (m *MockDB) SetQueryResult(result []map[string]interface{}) {
m.queryResult = result
}
// SetError sets the mock error
func (m *MockDB) SetError(errMsg string) {
m.err = fmt.Errorf("%s", errMsg)
}
// LastQuery returns the last executed query
func (m *MockDB) LastQuery() string {
return m.lastQuery
}
// QueryContains checks if the last query contains a substring
func (m *MockDB) QueryContains(substring string) bool {
return strings.Contains(m.lastQuery, substring)
}
// ExecuteSQL implements db.Database.ExecuteSQL
func (m *MockDB) ExecuteSQL(ctx context.Context, query string, args ...interface{}) (interface{}, error) {
m.lastQuery = query
m.lastQueryArgs = args
m.queryHistory = append(m.queryHistory, query)
if m.err != nil {
return nil, m.err
}
// If TimescaleDB is not available and the query is for TimescaleDB specific features
if !m.isTimescaleDB && (strings.Contains(query, "time_bucket") ||
strings.Contains(query, "hypertable") ||
strings.Contains(query, "continuous_aggregate") ||
strings.Contains(query, "timescaledb_information")) {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
if m.queryResult != nil {
return m.queryResult, nil
}
// Return an empty result set by default
return []map[string]interface{}{}, nil
}
// ExecuteSQLWithoutParams implements db.Database.ExecuteSQLWithoutParams
func (m *MockDB) ExecuteSQLWithoutParams(ctx context.Context, query string) (interface{}, error) {
m.lastQuery = query
m.lastQueryArgs = nil
m.queryHistory = append(m.queryHistory, query)
if m.err != nil {
return nil, m.err
}
// If TimescaleDB is not available and the query is for TimescaleDB specific features
if !m.isTimescaleDB && (strings.Contains(query, "time_bucket") ||
strings.Contains(query, "hypertable") ||
strings.Contains(query, "continuous_aggregate") ||
strings.Contains(query, "timescaledb_information")) {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
if m.queryResult != nil {
return m.queryResult, nil
}
// Return an empty result set by default
return []map[string]interface{}{}, nil
}
// QueryHistory returns the complete history of queries executed
func (m *MockDB) QueryHistory() []string {
return m.queryHistory
}
// AnyQueryContains checks if any query in the history contains the given substring
func (m *MockDB) AnyQueryContains(substring string) bool {
for _, query := range m.queryHistory {
if strings.Contains(query, substring) {
return true
}
}
return false
}
```
--------------------------------------------------------------------------------
/pkg/db/timescale/policy.go:
--------------------------------------------------------------------------------
```go
package timescale
import (
"context"
"fmt"
"strings"
)
// CompressionSettings represents TimescaleDB compression settings
type CompressionSettings struct {
HypertableName string
SegmentBy string
OrderBy string
ChunkTimeInterval string
CompressionInterval string
CompressionEnabled bool
}
// RetentionSettings represents TimescaleDB retention settings
type RetentionSettings struct {
HypertableName string
RetentionInterval string
RetentionEnabled bool
}
// EnableCompression enables compression on a hypertable
func (t *DB) EnableCompression(ctx context.Context, tableName string, afterInterval string) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", tableName)
_, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to enable compression: %w", err)
}
// Set compression policy if interval is specified
if afterInterval != "" {
err = t.AddCompressionPolicy(ctx, tableName, afterInterval, "", "")
if err != nil {
return fmt.Errorf("failed to add compression policy: %w", err)
}
}
return nil
}
// DisableCompression disables compression on a hypertable
func (t *DB) DisableCompression(ctx context.Context, tableName string) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
// First, remove any compression policies
err := t.RemoveCompressionPolicy(ctx, tableName)
if err != nil {
return fmt.Errorf("failed to remove compression policy: %w", err)
}
// Then disable compression
query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = false)", tableName)
_, err = t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to disable compression: %w", err)
}
return nil
}
// AddCompressionPolicy adds a compression policy to a hypertable
func (t *DB) AddCompressionPolicy(ctx context.Context, tableName, interval, segmentBy, orderBy string) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
// First, check if the table has compression enabled
query := fmt.Sprintf("SELECT compress FROM timescaledb_information.hypertables WHERE hypertable_name = '%s'", tableName)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to check compression status: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
return fmt.Errorf("table '%s' is not a hypertable", tableName)
}
isCompressed := rows[0]["compress"]
if isCompressed == nil || fmt.Sprintf("%v", isCompressed) == "false" {
// Enable compression
enableQuery := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", tableName)
_, err := t.ExecuteSQLWithoutParams(ctx, enableQuery)
if err != nil {
return fmt.Errorf("failed to enable compression: %w", err)
}
}
// Build the compression policy query
var policyQuery strings.Builder
policyQuery.WriteString(fmt.Sprintf("SELECT add_compression_policy('%s', INTERVAL '%s'", tableName, interval))
if segmentBy != "" {
policyQuery.WriteString(fmt.Sprintf(", segmentby => '%s'", segmentBy))
}
if orderBy != "" {
policyQuery.WriteString(fmt.Sprintf(", orderby => '%s'", orderBy))
}
policyQuery.WriteString(")")
// Add the compression policy
_, err = t.ExecuteSQLWithoutParams(ctx, policyQuery.String())
if err != nil {
return fmt.Errorf("failed to add compression policy: %w", err)
}
return nil
}
// RemoveCompressionPolicy removes a compression policy from a hypertable
func (t *DB) RemoveCompressionPolicy(ctx context.Context, tableName string) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
// Find the policy ID
query := fmt.Sprintf(
"SELECT job_id FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_compression'",
tableName,
)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to find compression policy: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
// No policy exists, so nothing to remove
return nil
}
// Get the job ID
jobID := rows[0]["job_id"]
if jobID == nil {
return fmt.Errorf("invalid job ID for compression policy")
}
// Remove the policy
removeQuery := fmt.Sprintf("SELECT remove_compression_policy(%v)", jobID)
_, err = t.ExecuteSQLWithoutParams(ctx, removeQuery)
if err != nil {
return fmt.Errorf("failed to remove compression policy: %w", err)
}
return nil
}
// GetCompressionSettings gets the compression settings for a hypertable
func (t *DB) GetCompressionSettings(ctx context.Context, tableName string) (*CompressionSettings, error) {
if !t.isTimescaleDB {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
// Check if the table has compression enabled
query := fmt.Sprintf(
"SELECT compress FROM timescaledb_information.hypertables WHERE hypertable_name = '%s'",
tableName,
)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to check compression status: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
return nil, fmt.Errorf("table '%s' is not a hypertable", tableName)
}
settings := &CompressionSettings{
HypertableName: tableName,
}
isCompressed := rows[0]["compress"]
if isCompressed != nil && fmt.Sprintf("%v", isCompressed) == "true" {
settings.CompressionEnabled = true
// Get compression-specific settings
settingsQuery := fmt.Sprintf(
"SELECT segmentby, orderby FROM timescaledb_information.compression_settings WHERE hypertable_name = '%s'",
tableName,
)
settingsResult, err := t.ExecuteSQLWithoutParams(ctx, settingsQuery)
if err != nil {
return nil, fmt.Errorf("failed to get compression settings: %w", err)
}
settingsRows, ok := settingsResult.([]map[string]interface{})
if ok && len(settingsRows) > 0 {
if segmentBy, ok := settingsRows[0]["segmentby"]; ok && segmentBy != nil {
settings.SegmentBy = fmt.Sprintf("%v", segmentBy)
}
if orderBy, ok := settingsRows[0]["orderby"]; ok && orderBy != nil {
settings.OrderBy = fmt.Sprintf("%v", orderBy)
}
}
// Check if a compression policy exists
policyQuery := fmt.Sprintf(
"SELECT s.schedule_interval, h.chunk_time_interval FROM timescaledb_information.jobs j "+
"JOIN timescaledb_information.job_stats s ON j.job_id = s.job_id "+
"JOIN timescaledb_information.hypertables h ON j.hypertable_name = h.hypertable_name "+
"WHERE j.hypertable_name = '%s' AND j.proc_name = 'policy_compression'",
tableName,
)
policyResult, err := t.ExecuteSQLWithoutParams(ctx, policyQuery)
if err == nil {
policyRows, ok := policyResult.([]map[string]interface{})
if ok && len(policyRows) > 0 {
if interval, ok := policyRows[0]["schedule_interval"]; ok && interval != nil {
settings.CompressionInterval = fmt.Sprintf("%v", interval)
}
if chunkInterval, ok := policyRows[0]["chunk_time_interval"]; ok && chunkInterval != nil {
settings.ChunkTimeInterval = fmt.Sprintf("%v", chunkInterval)
}
}
}
}
return settings, nil
}
// AddRetentionPolicy adds a data retention policy to a hypertable
func (t *DB) AddRetentionPolicy(ctx context.Context, tableName, interval string) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
query := fmt.Sprintf("SELECT add_retention_policy('%s', INTERVAL '%s')", tableName, interval)
_, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to add retention policy: %w", err)
}
return nil
}
// RemoveRetentionPolicy removes a data retention policy from a hypertable
func (t *DB) RemoveRetentionPolicy(ctx context.Context, tableName string) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
// Find the policy ID
query := fmt.Sprintf(
"SELECT job_id FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'",
tableName,
)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to find retention policy: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
// No policy exists, so nothing to remove
return nil
}
// Get the job ID
jobID := rows[0]["job_id"]
if jobID == nil {
return fmt.Errorf("invalid job ID for retention policy")
}
// Remove the policy
removeQuery := fmt.Sprintf("SELECT remove_retention_policy(%v)", jobID)
_, err = t.ExecuteSQLWithoutParams(ctx, removeQuery)
if err != nil {
return fmt.Errorf("failed to remove retention policy: %w", err)
}
return nil
}
// GetRetentionSettings gets the retention settings for a hypertable
func (t *DB) GetRetentionSettings(ctx context.Context, tableName string) (*RetentionSettings, error) {
if !t.isTimescaleDB {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
settings := &RetentionSettings{
HypertableName: tableName,
}
// Check if a retention policy exists
query := fmt.Sprintf(
"SELECT s.schedule_interval FROM timescaledb_information.jobs j "+
"JOIN timescaledb_information.job_stats s ON j.job_id = s.job_id "+
"WHERE j.hypertable_name = '%s' AND j.proc_name = 'policy_retention'",
tableName,
)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return settings, nil // Return empty settings with no error
}
rows, ok := result.([]map[string]interface{})
if ok && len(rows) > 0 {
settings.RetentionEnabled = true
if interval, ok := rows[0]["schedule_interval"]; ok && interval != nil {
settings.RetentionInterval = fmt.Sprintf("%v", interval)
}
}
return settings, nil
}
// CompressChunks compresses chunks for a hypertable
func (t *DB) CompressChunks(ctx context.Context, tableName, olderThan string) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
var query string
if olderThan == "" {
// Compress all chunks
query = fmt.Sprintf("SELECT compress_chunks(hypertable => '%s')", tableName)
} else {
// Compress chunks older than the specified interval
query = fmt.Sprintf("SELECT compress_chunks(hypertable => '%s', older_than => INTERVAL '%s')",
tableName, olderThan)
}
_, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to compress chunks: %w", err)
}
return nil
}
// DecompressChunks decompresses chunks for a hypertable
func (t *DB) DecompressChunks(ctx context.Context, tableName, newerThan string) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
var query string
if newerThan == "" {
// Decompress all chunks
query = fmt.Sprintf("SELECT decompress_chunks(hypertable => '%s')", tableName)
} else {
// Decompress chunks newer than the specified interval
query = fmt.Sprintf("SELECT decompress_chunks(hypertable => '%s', newer_than => INTERVAL '%s')",
tableName, newerThan)
}
_, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to decompress chunks: %w", err)
}
return nil
}
// GetChunkCompressionStats gets compression statistics for a hypertable
func (t *DB) GetChunkCompressionStats(ctx context.Context, tableName string) (interface{}, error) {
if !t.isTimescaleDB {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
query := fmt.Sprintf(`
SELECT
chunk_name,
range_start,
range_end,
is_compressed,
before_compression_total_bytes,
after_compression_total_bytes,
CASE
WHEN before_compression_total_bytes = 0 THEN 0
ELSE (1 - (after_compression_total_bytes::float / before_compression_total_bytes::float)) * 100
END AS compression_ratio
FROM timescaledb_information.chunks
WHERE hypertable_name = '%s'
ORDER BY range_end DESC
`, tableName)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to get chunk compression statistics: %w", err)
}
return result, nil
}
```
--------------------------------------------------------------------------------
/pkg/db/timescale/hypertable_test.go:
--------------------------------------------------------------------------------
```go
package timescale
import (
"context"
"errors"
"testing"
)
func TestCreateHypertable(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Test basic hypertable creation
config := HypertableConfig{
TableName: "test_table",
TimeColumn: "time",
ChunkTimeInterval: "1 day",
CreateIfNotExists: true,
}
err := tsdb.CreateHypertable(ctx, config)
if err != nil {
t.Fatalf("Failed to create hypertable: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "create_hypertable")
AssertQueryContains(t, query, "test_table")
AssertQueryContains(t, query, "time")
AssertQueryContains(t, query, "chunk_time_interval")
AssertQueryContains(t, query, "1 day")
// Test with partitioning
config = HypertableConfig{
TableName: "test_table",
TimeColumn: "time",
ChunkTimeInterval: "1 day",
PartitioningColumn: "device_id",
SpacePartitions: 4,
}
err = tsdb.CreateHypertable(ctx, config)
if err != nil {
t.Fatalf("Failed to create hypertable with partitioning: %v", err)
}
// Check that the correct query was executed
query, _ = mockDB.GetLastQuery()
AssertQueryContains(t, query, "create_hypertable")
AssertQueryContains(t, query, "partition_column")
AssertQueryContains(t, query, "device_id")
AssertQueryContains(t, query, "number_partitions")
// Test with if_not_exists and migrate_data
config = HypertableConfig{
TableName: "test_table",
TimeColumn: "time",
IfNotExists: true,
MigrateData: true,
}
err = tsdb.CreateHypertable(ctx, config)
if err != nil {
t.Fatalf("Failed to create hypertable with extra options: %v", err)
}
// Check that the correct query was executed
query, _ = mockDB.GetLastQuery()
AssertQueryContains(t, query, "if_not_exists => TRUE")
AssertQueryContains(t, query, "migrate_data => TRUE")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.CreateHypertable(ctx, config)
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT create_hypertable(", nil, errors.New("mocked error"))
err = tsdb.CreateHypertable(ctx, config)
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestAddDimension(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Test adding a dimension
err := tsdb.AddDimension(ctx, "test_table", "device_id", 4)
if err != nil {
t.Fatalf("Failed to add dimension: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "add_dimension")
AssertQueryContains(t, query, "test_table")
AssertQueryContains(t, query, "device_id")
AssertQueryContains(t, query, "number_partitions => 4")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.AddDimension(ctx, "test_table", "device_id", 4)
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT add_dimension(", nil, errors.New("mocked error"))
err = tsdb.AddDimension(ctx, "test_table", "device_id", 4)
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestListHypertables(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Prepare mock data
mockResult := []map[string]interface{}{
{
"table_name": "test_table",
"schema_name": "public",
"time_column": "time",
"num_dimensions": 2,
"space_column": "device_id",
},
{
"table_name": "test_table2",
"schema_name": "public",
"time_column": "timestamp",
"num_dimensions": 1,
"space_column": nil,
},
}
// Register different result patterns for different queries
mockDB.RegisterQueryResult("FROM _timescaledb_catalog.hypertable h", mockResult, nil)
mockDB.RegisterQueryResult("FROM timescaledb_information.compression_settings", []map[string]interface{}{
{"is_compressed": true},
}, nil)
mockDB.RegisterQueryResult("FROM timescaledb_information.jobs", []map[string]interface{}{
{"has_retention": true},
}, nil)
// Test listing hypertables
hypertables, err := tsdb.ListHypertables(ctx)
if err != nil {
t.Fatalf("Failed to list hypertables: %v", err)
}
// Check the results
if len(hypertables) != 2 {
t.Errorf("Expected 2 hypertables, got %d", len(hypertables))
}
if hypertables[0].TableName != "test_table" {
t.Errorf("Expected TableName to be 'test_table', got '%s'", hypertables[0].TableName)
}
if hypertables[0].TimeColumn != "time" {
t.Errorf("Expected TimeColumn to be 'time', got '%s'", hypertables[0].TimeColumn)
}
if hypertables[0].SpaceColumn != "device_id" {
t.Errorf("Expected SpaceColumn to be 'device_id', got '%s'", hypertables[0].SpaceColumn)
}
if hypertables[0].NumDimensions != 2 {
t.Errorf("Expected NumDimensions to be 2, got %d", hypertables[0].NumDimensions)
}
if !hypertables[0].CompressionEnabled {
t.Error("Expected CompressionEnabled to be true, got false")
}
if !hypertables[0].RetentionEnabled {
t.Error("Expected RetentionEnabled to be true, got false")
}
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
_, err = tsdb.ListHypertables(ctx)
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("FROM _timescaledb_catalog.hypertable h", nil, errors.New("mocked error"))
_, err = tsdb.ListHypertables(ctx)
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestGetHypertable(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Prepare mock data - Set up the correct result by using RegisterQueryResult
mockResult := []map[string]interface{}{
{
"table_name": "test_table",
"schema_name": "public",
"time_column": "time",
"num_dimensions": int64(2),
"space_column": "device_id",
},
}
// Register the query result pattern for the main query
mockDB.RegisterQueryResult("WHERE h.table_name = 'test_table'", mockResult, nil)
// Register results for the compression check
mockDB.RegisterQueryResult("FROM timescaledb_information.compression_settings", []map[string]interface{}{
{"is_compressed": true},
}, nil)
// Register results for the retention policy check
mockDB.RegisterQueryResult("FROM timescaledb_information.jobs", []map[string]interface{}{
{"has_retention": true},
}, nil)
// Test getting a hypertable
hypertable, err := tsdb.GetHypertable(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to get hypertable: %v", err)
}
// Check the results
if hypertable.TableName != "test_table" {
t.Errorf("Expected TableName to be 'test_table', got '%s'", hypertable.TableName)
}
if hypertable.TimeColumn != "time" {
t.Errorf("Expected TimeColumn to be 'time', got '%s'", hypertable.TimeColumn)
}
if hypertable.SpaceColumn != "device_id" {
t.Errorf("Expected SpaceColumn to be 'device_id', got '%s'", hypertable.SpaceColumn)
}
if hypertable.NumDimensions != 2 {
t.Errorf("Expected NumDimensions to be 2, got %d", hypertable.NumDimensions)
}
if !hypertable.CompressionEnabled {
t.Error("Expected CompressionEnabled to be true, got false")
}
if !hypertable.RetentionEnabled {
t.Error("Expected RetentionEnabled to be true, got false")
}
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
_, err = tsdb.GetHypertable(ctx, "test_table")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("WHERE h.table_name = 'test_table'", nil, errors.New("mocked error"))
_, err = tsdb.GetHypertable(ctx, "test_table")
if err == nil {
t.Error("Expected query error, got nil")
}
// Test table not found - Create a new mock to avoid interference
newMockDB := NewMockDB()
newMockDB.SetTimescaleAvailable(true)
tsdb.Database = newMockDB
// Register an empty result for the "not_found" table
newMockDB.RegisterQueryResult("WHERE h.table_name = 'not_found'", []map[string]interface{}{}, nil)
_, err = tsdb.GetHypertable(ctx, "not_found")
if err == nil {
t.Error("Expected error for non-existent table, got nil")
}
}
func TestDropHypertable(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Test dropping a hypertable
err := tsdb.DropHypertable(ctx, "test_table", false)
if err != nil {
t.Fatalf("Failed to drop hypertable: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "DROP TABLE test_table")
// Test dropping with CASCADE
err = tsdb.DropHypertable(ctx, "test_table", true)
if err != nil {
t.Fatalf("Failed to drop hypertable with CASCADE: %v", err)
}
// Check that the correct query was executed
query, _ = mockDB.GetLastQuery()
AssertQueryContains(t, query, "DROP TABLE test_table CASCADE")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.DropHypertable(ctx, "test_table", false)
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("DROP TABLE", nil, errors.New("mocked error"))
err = tsdb.DropHypertable(ctx, "test_table", false)
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestCheckIfHypertable(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Prepare mock data
mockResultTrue := []map[string]interface{}{
{"is_hypertable": true},
}
mockResultFalse := []map[string]interface{}{
{"is_hypertable": false},
}
// Test table is a hypertable
mockDB.RegisterQueryResult("WHERE table_name = 'test_table'", mockResultTrue, nil)
isHypertable, err := tsdb.CheckIfHypertable(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to check if hypertable: %v", err)
}
if !isHypertable {
t.Error("Expected table to be a hypertable, got false")
}
// Test table is not a hypertable
mockDB.RegisterQueryResult("WHERE table_name = 'regular_table'", mockResultFalse, nil)
isHypertable, err = tsdb.CheckIfHypertable(ctx, "regular_table")
if err != nil {
t.Fatalf("Failed to check if hypertable: %v", err)
}
if isHypertable {
t.Error("Expected table not to be a hypertable, got true")
}
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
_, err = tsdb.CheckIfHypertable(ctx, "test_table")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("WHERE table_name = 'error_table'", nil, errors.New("mocked error"))
_, err = tsdb.CheckIfHypertable(ctx, "error_table")
if err == nil {
t.Error("Expected query error, got nil")
}
// Test unexpected result structure
mockDB.RegisterQueryResult("WHERE table_name = 'bad_structure'", []map[string]interface{}{}, nil)
_, err = tsdb.CheckIfHypertable(ctx, "bad_structure")
if err == nil {
t.Error("Expected error for bad result structure, got nil")
}
}
func TestRecentChunks(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Prepare mock data
mockResult := []map[string]interface{}{
{
"chunk_name": "_hyper_1_1_chunk",
"range_start": "2023-01-01 00:00:00",
"range_end": "2023-01-02 00:00:00",
"is_compressed": false,
},
{
"chunk_name": "_hyper_1_2_chunk",
"range_start": "2023-01-02 00:00:00",
"range_end": "2023-01-03 00:00:00",
"is_compressed": true,
},
}
// Register mock result
mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", mockResult, nil)
// Test getting recent chunks
_, err := tsdb.RecentChunks(ctx, "test_table", 2)
if err != nil {
t.Fatalf("Failed to get recent chunks: %v", err)
}
// Check that a query with the right table name and limit was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "hypertable_name = 'test_table'")
AssertQueryContains(t, query, "LIMIT 2")
// Test with default limit
_, err = tsdb.RecentChunks(ctx, "test_table", 0)
if err != nil {
t.Fatalf("Failed to get recent chunks with default limit: %v", err)
}
query, _ = mockDB.GetLastQuery()
AssertQueryContains(t, query, "LIMIT 10")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
_, err = tsdb.RecentChunks(ctx, "test_table", 2)
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", nil, errors.New("mocked error"))
_, err = tsdb.RecentChunks(ctx, "test_table", 2)
if err == nil {
t.Error("Expected query error, got nil")
}
}
```
--------------------------------------------------------------------------------
/pkg/db/timescale/query.go:
--------------------------------------------------------------------------------
```go
package timescale
import (
"context"
"fmt"
"strings"
"time"
)
// TimeBucket represents a time bucket for time-series aggregation
type TimeBucket struct {
Interval string // e.g., '1 hour', '1 day', '1 month'
Column string // Time column to bucket
Alias string // Optional alias for the bucket column
}
// AggregateFunction represents a common aggregate function
type AggregateFunction string
const (
// AggrAvg calculates the average value of a column
AggrAvg AggregateFunction = "AVG"
// AggrSum calculates the sum of values in a column
AggrSum AggregateFunction = "SUM"
// AggrMin finds the minimum value in a column
AggrMin AggregateFunction = "MIN"
// AggrMax finds the maximum value in a column
AggrMax AggregateFunction = "MAX"
// AggrCount counts the number of rows
AggrCount AggregateFunction = "COUNT"
// AggrFirst takes the first value in a window
AggrFirst AggregateFunction = "FIRST"
// AggrLast takes the last value in a window
AggrLast AggregateFunction = "LAST"
)
// ColumnAggregation represents an aggregation operation on a column
type ColumnAggregation struct {
Function AggregateFunction
Column string
Alias string
}
// TimeseriesQueryBuilder helps build optimized time-series queries
type TimeseriesQueryBuilder struct {
table string
timeBucket *TimeBucket
selectCols []string
aggregations []ColumnAggregation
whereClauses []string
whereArgs []interface{}
groupByCols []string
orderByCols []string
limit int
offset int
}
// NewTimeseriesQueryBuilder creates a new builder for a specific table
func NewTimeseriesQueryBuilder(table string) *TimeseriesQueryBuilder {
return &TimeseriesQueryBuilder{
table: table,
selectCols: make([]string, 0),
aggregations: make([]ColumnAggregation, 0),
whereClauses: make([]string, 0),
whereArgs: make([]interface{}, 0),
groupByCols: make([]string, 0),
orderByCols: make([]string, 0),
}
}
// WithTimeBucket adds a time bucket to the query
func (b *TimeseriesQueryBuilder) WithTimeBucket(interval, column, alias string) *TimeseriesQueryBuilder {
b.timeBucket = &TimeBucket{
Interval: interval,
Column: column,
Alias: alias,
}
return b
}
// Select adds columns to the SELECT clause
func (b *TimeseriesQueryBuilder) Select(cols ...string) *TimeseriesQueryBuilder {
b.selectCols = append(b.selectCols, cols...)
return b
}
// Aggregate adds an aggregation function to a column
func (b *TimeseriesQueryBuilder) Aggregate(function AggregateFunction, column, alias string) *TimeseriesQueryBuilder {
b.aggregations = append(b.aggregations, ColumnAggregation{
Function: function,
Column: column,
Alias: alias,
})
return b
}
// WhereTimeRange adds a time range condition
func (b *TimeseriesQueryBuilder) WhereTimeRange(column string, start, end time.Time) *TimeseriesQueryBuilder {
clause := fmt.Sprintf("%s BETWEEN $%d AND $%d", column, len(b.whereArgs)+1, len(b.whereArgs)+2)
b.whereClauses = append(b.whereClauses, clause)
b.whereArgs = append(b.whereArgs, start, end)
return b
}
// Where adds a WHERE condition
func (b *TimeseriesQueryBuilder) Where(clause string, args ...interface{}) *TimeseriesQueryBuilder {
// Adjust the parameter indices in the clause
paramCount := len(b.whereArgs)
for i := 1; i <= len(args); i++ {
oldParam := fmt.Sprintf("$%d", i)
newParam := fmt.Sprintf("$%d", i+paramCount)
clause = strings.Replace(clause, oldParam, newParam, -1)
}
b.whereClauses = append(b.whereClauses, clause)
b.whereArgs = append(b.whereArgs, args...)
return b
}
// GroupBy adds columns to the GROUP BY clause
func (b *TimeseriesQueryBuilder) GroupBy(cols ...string) *TimeseriesQueryBuilder {
b.groupByCols = append(b.groupByCols, cols...)
return b
}
// OrderBy adds columns to the ORDER BY clause
func (b *TimeseriesQueryBuilder) OrderBy(cols ...string) *TimeseriesQueryBuilder {
b.orderByCols = append(b.orderByCols, cols...)
return b
}
// Limit sets the LIMIT clause
func (b *TimeseriesQueryBuilder) Limit(limit int) *TimeseriesQueryBuilder {
b.limit = limit
return b
}
// Offset sets the OFFSET clause
func (b *TimeseriesQueryBuilder) Offset(offset int) *TimeseriesQueryBuilder {
b.offset = offset
return b
}
// Build constructs the SQL query and args
func (b *TimeseriesQueryBuilder) Build() (string, []interface{}) {
var selectClause strings.Builder
selectClause.WriteString("SELECT ")
var selects []string
// Add time bucket if specified
if b.timeBucket != nil {
alias := b.timeBucket.Alias
if alias == "" {
alias = "time_bucket"
}
bucketStr := fmt.Sprintf(
"time_bucket('%s', %s) AS %s",
b.timeBucket.Interval,
b.timeBucket.Column,
alias,
)
selects = append(selects, bucketStr)
// Add time bucket to group by if not already included
bucketFound := false
for _, col := range b.groupByCols {
if col == alias {
bucketFound = true
break
}
}
if !bucketFound {
b.groupByCols = append([]string{alias}, b.groupByCols...)
}
}
// Add selected columns
selects = append(selects, b.selectCols...)
// Add aggregations
for _, agg := range b.aggregations {
alias := agg.Alias
if alias == "" {
alias = strings.ToLower(string(agg.Function)) + "_" + agg.Column
}
aggStr := fmt.Sprintf(
"%s(%s) AS %s",
agg.Function,
agg.Column,
alias,
)
selects = append(selects, aggStr)
}
// If no columns or aggregations selected, use *
if len(selects) == 0 {
selectClause.WriteString("*")
} else {
selectClause.WriteString(strings.Join(selects, ", "))
}
// Build query
query := fmt.Sprintf("%s FROM %s", selectClause.String(), b.table)
// Add WHERE clause
if len(b.whereClauses) > 0 {
query += " WHERE " + strings.Join(b.whereClauses, " AND ")
}
// Add GROUP BY clause
if len(b.groupByCols) > 0 {
query += " GROUP BY " + strings.Join(b.groupByCols, ", ")
}
// Add ORDER BY clause
if len(b.orderByCols) > 0 {
query += " ORDER BY " + strings.Join(b.orderByCols, ", ")
}
// Add LIMIT clause
if b.limit > 0 {
query += fmt.Sprintf(" LIMIT %d", b.limit)
}
// Add OFFSET clause
if b.offset > 0 {
query += fmt.Sprintf(" OFFSET %d", b.offset)
}
return query, b.whereArgs
}
// Execute runs the query against the database
func (b *TimeseriesQueryBuilder) Execute(ctx context.Context, db *DB) ([]map[string]interface{}, error) {
query, args := b.Build()
result, err := db.ExecuteSQL(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to execute time-series query: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected result type from database query")
}
return rows, nil
}
// DownsampleOptions describes options for downsampling time-series data
type DownsampleOptions struct {
SourceTable string
DestTable string
TimeColumn string
BucketInterval string
Aggregations []ColumnAggregation
WhereCondition string
CreateTable bool
ChunkTimeInterval string
}
// DownsampleTimeSeries creates downsampled time-series data
func (t *DB) DownsampleTimeSeries(ctx context.Context, options DownsampleOptions) error {
if !t.isTimescaleDB {
return fmt.Errorf("TimescaleDB extension not available")
}
// Create the destination table if requested
if options.CreateTable {
// Get source table columns
schemaQuery := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '%s'", options.SourceTable)
result, err := t.ExecuteSQLWithoutParams(ctx, schemaQuery)
if err != nil {
return fmt.Errorf("failed to get source table schema: %w", err)
}
columns, ok := result.([]map[string]interface{})
if !ok {
return fmt.Errorf("unexpected result from schema query")
}
// Build CREATE TABLE statement
var createStmt strings.Builder
createStmt.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", options.DestTable))
// Add time bucket column
createStmt.WriteString("time_bucket timestamptz, ")
// Add aggregation columns
for i, agg := range options.Aggregations {
colName := agg.Alias
if colName == "" {
colName = strings.ToLower(string(agg.Function)) + "_" + agg.Column
}
// Find the data type of the source column
var dataType string
for _, col := range columns {
if fmt.Sprintf("%v", col["column_name"]) == agg.Column {
dataType = fmt.Sprintf("%v", col["data_type"])
break
}
}
if dataType == "" {
dataType = "double precision" // Default for numeric aggregations
}
createStmt.WriteString(fmt.Sprintf("%s %s", colName, dataType))
if i < len(options.Aggregations)-1 {
createStmt.WriteString(", ")
}
}
createStmt.WriteString(", PRIMARY KEY (time_bucket)")
createStmt.WriteString(")")
// Create the table
_, err = t.ExecuteSQLWithoutParams(ctx, createStmt.String())
if err != nil {
return fmt.Errorf("failed to create destination table: %w", err)
}
// Make it a hypertable
if options.ChunkTimeInterval == "" {
options.ChunkTimeInterval = options.BucketInterval
}
err = t.CreateHypertable(ctx, HypertableConfig{
TableName: options.DestTable,
TimeColumn: "time_bucket",
ChunkTimeInterval: options.ChunkTimeInterval,
IfNotExists: true,
})
if err != nil {
return fmt.Errorf("failed to create hypertable: %w", err)
}
}
// Build the INSERT statement with aggregations
var insertStmt strings.Builder
insertStmt.WriteString(fmt.Sprintf("INSERT INTO %s (time_bucket, ", options.DestTable))
// Add aggregation column names
for i, agg := range options.Aggregations {
colName := agg.Alias
if colName == "" {
colName = strings.ToLower(string(agg.Function)) + "_" + agg.Column
}
insertStmt.WriteString(colName)
if i < len(options.Aggregations)-1 {
insertStmt.WriteString(", ")
}
}
insertStmt.WriteString(") SELECT time_bucket('")
insertStmt.WriteString(options.BucketInterval)
insertStmt.WriteString("', ")
insertStmt.WriteString(options.TimeColumn)
insertStmt.WriteString(") AS time_bucket, ")
// Add aggregation functions
for i, agg := range options.Aggregations {
insertStmt.WriteString(fmt.Sprintf("%s(%s)", agg.Function, agg.Column))
if i < len(options.Aggregations)-1 {
insertStmt.WriteString(", ")
}
}
insertStmt.WriteString(fmt.Sprintf(" FROM %s", options.SourceTable))
// Add WHERE clause if specified
if options.WhereCondition != "" {
insertStmt.WriteString(" WHERE ")
insertStmt.WriteString(options.WhereCondition)
}
// Group by time bucket
insertStmt.WriteString(" GROUP BY time_bucket")
// Order by time bucket
insertStmt.WriteString(" ORDER BY time_bucket")
// Execute the INSERT statement
_, err := t.ExecuteSQLWithoutParams(ctx, insertStmt.String())
if err != nil {
return fmt.Errorf("failed to downsample data: %w", err)
}
return nil
}
// TimeRange represents a common time range for queries
type TimeRange struct {
Start time.Time
End time.Time
}
// PredefinedTimeRange returns a TimeRange for common time ranges
func PredefinedTimeRange(name string) (*TimeRange, error) {
now := time.Now()
switch strings.ToLower(name) {
case "today":
start := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
return &TimeRange{Start: start, End: now}, nil
case "yesterday":
end := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
start := end.Add(-24 * time.Hour)
return &TimeRange{Start: start, End: end}, nil
case "last24hours", "last_24_hours":
start := now.Add(-24 * time.Hour)
return &TimeRange{Start: start, End: now}, nil
case "thisweek", "this_week":
// Calculate the beginning of the week (Sunday/Monday depending on locale, using Sunday here)
weekday := int(now.Weekday())
start := now.Add(-time.Duration(weekday) * 24 * time.Hour)
start = time.Date(start.Year(), start.Month(), start.Day(), 0, 0, 0, 0, now.Location())
return &TimeRange{Start: start, End: now}, nil
case "lastweek", "last_week":
// Calculate the beginning of this week
weekday := int(now.Weekday())
thisWeekStart := now.Add(-time.Duration(weekday) * 24 * time.Hour)
thisWeekStart = time.Date(thisWeekStart.Year(), thisWeekStart.Month(), thisWeekStart.Day(), 0, 0, 0, 0, now.Location())
// Last week is 7 days before the beginning of this week
lastWeekStart := thisWeekStart.Add(-7 * 24 * time.Hour)
lastWeekEnd := thisWeekStart
return &TimeRange{Start: lastWeekStart, End: lastWeekEnd}, nil
case "last7days", "last_7_days":
start := now.Add(-7 * 24 * time.Hour)
return &TimeRange{Start: start, End: now}, nil
case "thismonth", "this_month":
start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
return &TimeRange{Start: start, End: now}, nil
case "lastmonth", "last_month":
thisMonthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
var lastMonthStart time.Time
if now.Month() == 1 {
// January, so last month is December of previous year
lastMonthStart = time.Date(now.Year()-1, 12, 1, 0, 0, 0, 0, now.Location())
} else {
// Any other month
lastMonthStart = time.Date(now.Year(), now.Month()-1, 1, 0, 0, 0, 0, now.Location())
}
return &TimeRange{Start: lastMonthStart, End: thisMonthStart}, nil
case "last30days", "last_30_days":
start := now.Add(-30 * 24 * time.Hour)
return &TimeRange{Start: start, End: now}, nil
case "thisyear", "this_year":
start := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
return &TimeRange{Start: start, End: now}, nil
case "lastyear", "last_year":
thisYearStart := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
lastYearStart := time.Date(now.Year()-1, 1, 1, 0, 0, 0, 0, now.Location())
return &TimeRange{Start: lastYearStart, End: thisYearStart}, nil
case "last365days", "last_365_days":
start := now.Add(-365 * 24 * time.Hour)
return &TimeRange{Start: start, End: now}, nil
default:
return nil, fmt.Errorf("unknown time range: %s", name)
}
}
```
--------------------------------------------------------------------------------
/internal/delivery/mcp/timescale_tools_test.go:
--------------------------------------------------------------------------------
```go
package mcp_test
import (
"context"
"strings"
"testing"
"github.com/FreePeak/cortex/pkg/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/FreePeak/db-mcp-server/internal/delivery/mcp"
)
// MockDatabaseUseCase is a mock implementation of the UseCaseProvider interface
type MockDatabaseUseCase struct {
mock.Mock
}
// ExecuteStatement mocks the ExecuteStatement method
func (m *MockDatabaseUseCase) ExecuteStatement(ctx context.Context, dbID, statement string, params []interface{}) (string, error) {
args := m.Called(ctx, dbID, statement, params)
return args.String(0), args.Error(1)
}
// GetDatabaseType mocks the GetDatabaseType method
func (m *MockDatabaseUseCase) GetDatabaseType(dbID string) (string, error) {
args := m.Called(dbID)
return args.String(0), args.Error(1)
}
// ExecuteQuery mocks the ExecuteQuery method
func (m *MockDatabaseUseCase) ExecuteQuery(ctx context.Context, dbID, query string, params []interface{}) (string, error) {
args := m.Called(ctx, dbID, query, params)
return args.String(0), args.Error(1)
}
// ExecuteTransaction mocks the ExecuteTransaction method
func (m *MockDatabaseUseCase) ExecuteTransaction(ctx context.Context, dbID, action string, txID string, statement string, params []interface{}, readOnly bool) (string, map[string]interface{}, error) {
args := m.Called(ctx, dbID, action, txID, statement, params, readOnly)
return args.String(0), args.Get(1).(map[string]interface{}), args.Error(2)
}
// GetDatabaseInfo mocks the GetDatabaseInfo method
func (m *MockDatabaseUseCase) GetDatabaseInfo(dbID string) (map[string]interface{}, error) {
args := m.Called(dbID)
return args.Get(0).(map[string]interface{}), args.Error(1)
}
// ListDatabases mocks the ListDatabases method
func (m *MockDatabaseUseCase) ListDatabases() []string {
args := m.Called()
return args.Get(0).([]string)
}
func TestTimescaleDBTool(t *testing.T) {
tool := mcp.NewTimescaleDBTool()
assert.Equal(t, "timescaledb", tool.GetName())
}
func TestTimeSeriesQueryTool(t *testing.T) {
// Create a mock use case provider
mockUseCase := new(MockDatabaseUseCase)
// Set up the TimescaleDB tool
tool := mcp.NewTimescaleDBTool()
// Create a context for testing
ctx := context.Background()
// Test case for time_series_query operation
t.Run("time_series_query with basic parameters", func(t *testing.T) {
// Sample result that would be returned by the database
sampleResult := `[
{"time_bucket": "2023-01-01T00:00:00Z", "avg_temp": 22.5, "count": 10},
{"time_bucket": "2023-01-02T00:00:00Z", "avg_temp": 23.1, "count": 12}
]`
// Set up expectations for the mock
mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.AnythingOfType("string"), mock.Anything).
Return(sampleResult, nil).Once()
// Create a request with time_series_query operation
request := server.ToolCallRequest{
Name: "timescaledb_timeseries_query_test_db",
Parameters: map[string]interface{}{
"operation": "time_series_query",
"target_table": "sensor_data",
"time_column": "timestamp",
"bucket_interval": "1 day",
"start_time": "2023-01-01",
"end_time": "2023-01-31",
"aggregations": "AVG(temperature) as avg_temp, COUNT(*) as count",
},
}
// Call the handler
result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
// Verify the result
assert.NoError(t, err)
assert.NotNil(t, result)
// Check the result contains expected fields
resultMap, ok := result.(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, resultMap, "message")
assert.Contains(t, resultMap, "details")
assert.Equal(t, sampleResult, resultMap["details"])
// Verify the mock expectations
mockUseCase.AssertExpectations(t)
})
t.Run("time_series_query with window functions", func(t *testing.T) {
// Sample result that would be returned by the database
sampleResult := `[
{"time_bucket": "2023-01-01T00:00:00Z", "avg_temp": 22.5, "prev_avg": null},
{"time_bucket": "2023-01-02T00:00:00Z", "avg_temp": 23.1, "prev_avg": 22.5}
]`
// Set up expectations for the mock
mockUseCase.On("ExecuteStatement", mock.Anything, "test_db", mock.AnythingOfType("string"), mock.Anything).
Return(sampleResult, nil).Once()
// Create a request with time_series_query operation
request := server.ToolCallRequest{
Name: "timescaledb_timeseries_query_test_db",
Parameters: map[string]interface{}{
"operation": "time_series_query",
"target_table": "sensor_data",
"time_column": "timestamp",
"bucket_interval": "1 day",
"aggregations": "AVG(temperature) as avg_temp",
"window_functions": "LAG(avg_temp) OVER (ORDER BY time_bucket) AS prev_avg",
"format_pretty": true,
},
}
// Call the handler
result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
// Verify the result
assert.NoError(t, err)
assert.NotNil(t, result)
// Check the result contains expected fields
resultMap, ok := result.(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, resultMap, "message")
assert.Contains(t, resultMap, "details")
assert.Contains(t, resultMap, "metadata")
// Check metadata contains expected fields for pretty formatting
metadata, ok := resultMap["metadata"].(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, metadata, "num_rows")
assert.Contains(t, metadata, "time_bucket_interval")
// Verify the mock expectations
mockUseCase.AssertExpectations(t)
})
}
// TestContinuousAggregateTool tests the continuous aggregate operations
func TestContinuousAggregateTool(t *testing.T) {
// Create a context for testing
ctx := context.Background()
// Test case for create_continuous_aggregate operation
t.Run("create_continuous_aggregate", func(t *testing.T) {
// Create a new mock for this test case
mockUseCase := new(MockDatabaseUseCase)
// Set up the TimescaleDB tool
tool := mcp.NewTimescaleDBTool()
// Set up expectations
// Removed GetDatabaseType expectation as it's not called in this handler
// Add mock expectation for the SQL containing CREATE MATERIALIZED VIEW
mockUseCase.On("ExecuteStatement",
mock.Anything,
"test_db",
mock.MatchedBy(func(sql string) bool {
return strings.Contains(sql, "CREATE MATERIALIZED VIEW")
}),
mock.Anything).Return(`{"result": "success"}`, nil)
// Add separate mock expectation for policy SQL if needed
mockUseCase.On("ExecuteStatement",
mock.Anything,
"test_db",
mock.MatchedBy(func(sql string) bool {
return strings.Contains(sql, "add_continuous_aggregate_policy")
}),
mock.Anything).Return(`{"result": "success"}`, nil)
// Create a request
request := server.ToolCallRequest{
Name: "timescaledb_create_continuous_aggregate_test_db",
Parameters: map[string]interface{}{
"operation": "create_continuous_aggregate",
"view_name": "daily_metrics",
"source_table": "sensor_data",
"time_column": "timestamp",
"bucket_interval": "1 day",
"aggregations": "AVG(temperature) as avg_temp, MIN(temperature) as min_temp, MAX(temperature) as max_temp",
"with_data": true,
"refresh_policy": true,
"refresh_interval": "1 hour",
},
}
// Call the handler
result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
// Verify the result
assert.NoError(t, err)
assert.NotNil(t, result)
// Check the result contains expected fields
resultMap, ok := result.(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, resultMap, "message")
assert.Contains(t, resultMap, "sql")
// Verify the mock expectations
mockUseCase.AssertExpectations(t)
})
// Test case for refresh_continuous_aggregate operation
t.Run("refresh_continuous_aggregate", func(t *testing.T) {
// Create a new mock for this test case
mockUseCase := new(MockDatabaseUseCase)
// Set up the TimescaleDB tool
tool := mcp.NewTimescaleDBTool()
// Set up expectations
// Removed GetDatabaseType expectation as it's not called in this handler
mockUseCase.On("ExecuteStatement",
mock.Anything,
"test_db",
mock.MatchedBy(func(sql string) bool {
return strings.Contains(sql, "CALL refresh_continuous_aggregate")
}),
mock.Anything).Return(`{"result": "success"}`, nil)
// Create a request
request := server.ToolCallRequest{
Name: "timescaledb_refresh_continuous_aggregate_test_db",
Parameters: map[string]interface{}{
"operation": "refresh_continuous_aggregate",
"view_name": "daily_metrics",
"start_time": "2023-01-01",
"end_time": "2023-01-31",
},
}
// Call the handler
result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
// Verify the result
assert.NoError(t, err)
assert.NotNil(t, result)
// Check the result contains expected fields
resultMap, ok := result.(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, resultMap, "message")
// Verify the mock expectations
mockUseCase.AssertExpectations(t)
})
// Test case for drop_continuous_aggregate operation
t.Run("drop_continuous_aggregate", func(t *testing.T) {
// Create a new mock for this test case
mockUseCase := new(MockDatabaseUseCase)
// Set up the TimescaleDB tool
tool := mcp.NewTimescaleDBTool()
// Set up expectations
// Removed GetDatabaseType expectation as it's not called in this handler
mockUseCase.On("ExecuteStatement",
mock.Anything,
"test_db",
mock.MatchedBy(func(sql string) bool {
return strings.Contains(sql, "DROP MATERIALIZED VIEW")
}),
mock.Anything).Return(`{"result": "success"}`, nil)
// Create a request
request := server.ToolCallRequest{
Name: "timescaledb_drop_continuous_aggregate_test_db",
Parameters: map[string]interface{}{
"operation": "drop_continuous_aggregate",
"view_name": "daily_metrics",
"cascade": true,
},
}
// Call the handler
result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
// Verify the result
assert.NoError(t, err)
assert.NotNil(t, result)
// Check the result contains expected fields
resultMap, ok := result.(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, resultMap, "message")
// Verify the mock expectations
mockUseCase.AssertExpectations(t)
})
// Test case for list_continuous_aggregates operation
t.Run("list_continuous_aggregates", func(t *testing.T) {
// Create a new mock for this test case
mockUseCase := new(MockDatabaseUseCase)
// Set up the TimescaleDB tool
tool := mcp.NewTimescaleDBTool()
// Set up expectations
mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
mockUseCase.On("ExecuteStatement",
mock.Anything,
"test_db",
mock.MatchedBy(func(sql string) bool {
return strings.Contains(sql, "SELECT") && strings.Contains(sql, "continuous_aggregates")
}),
mock.Anything).Return(`[{"view_name": "daily_metrics", "source_table": "sensor_data"}]`, nil)
// Create a request
request := server.ToolCallRequest{
Name: "timescaledb_list_continuous_aggregates_test_db",
Parameters: map[string]interface{}{
"operation": "list_continuous_aggregates",
},
}
// Call the handler
result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
// Verify the result
assert.NoError(t, err)
assert.NotNil(t, result)
// Check the result contains expected fields
resultMap, ok := result.(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, resultMap, "message")
assert.Contains(t, resultMap, "details")
// Verify the mock expectations
mockUseCase.AssertExpectations(t)
})
// Test case for get_continuous_aggregate_info operation
t.Run("get_continuous_aggregate_info", func(t *testing.T) {
// Create a new mock for this test case
mockUseCase := new(MockDatabaseUseCase)
// Set up the TimescaleDB tool
tool := mcp.NewTimescaleDBTool()
// Set up expectations
mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
mockUseCase.On("ExecuteStatement",
mock.Anything,
"test_db",
mock.MatchedBy(func(sql string) bool {
return strings.Contains(sql, "SELECT") && strings.Contains(sql, "continuous_aggregates") && strings.Contains(sql, "WHERE")
}),
mock.Anything).Return(`[{"view_name": "daily_metrics", "source_table": "sensor_data", "bucket_interval": "1 day"}]`, nil)
// Create a request
request := server.ToolCallRequest{
Name: "timescaledb_get_continuous_aggregate_info_test_db",
Parameters: map[string]interface{}{
"operation": "get_continuous_aggregate_info",
"view_name": "daily_metrics",
},
}
// Call the handler
result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
// Verify the result
assert.NoError(t, err)
assert.NotNil(t, result)
// Check the result contains expected fields
resultMap, ok := result.(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, resultMap, "message")
assert.Contains(t, resultMap, "details")
// Verify the mock expectations
mockUseCase.AssertExpectations(t)
})
// Test case for add_continuous_aggregate_policy operation
t.Run("add_continuous_aggregate_policy", func(t *testing.T) {
// Create a new mock for this test case
mockUseCase := new(MockDatabaseUseCase)
// Set up the TimescaleDB tool
tool := mcp.NewTimescaleDBTool()
// Set up expectations
mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
mockUseCase.On("ExecuteStatement",
mock.Anything,
"test_db",
mock.MatchedBy(func(sql string) bool {
return strings.Contains(sql, "add_continuous_aggregate_policy")
}),
mock.Anything).Return(`{"result": "success"}`, nil)
// Create a request
request := server.ToolCallRequest{
Name: "timescaledb_add_continuous_aggregate_policy_test_db",
Parameters: map[string]interface{}{
"operation": "add_continuous_aggregate_policy",
"view_name": "daily_metrics",
"start_offset": "1 month",
"end_offset": "2 hours",
"schedule_interval": "6 hours",
},
}
// Call the handler
result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
// Verify the result
assert.NoError(t, err)
assert.NotNil(t, result)
// Check the result contains expected fields
resultMap, ok := result.(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, resultMap, "message")
// Verify the mock expectations
mockUseCase.AssertExpectations(t)
})
// Test case for remove_continuous_aggregate_policy operation
t.Run("remove_continuous_aggregate_policy", func(t *testing.T) {
// Create a new mock for this test case
mockUseCase := new(MockDatabaseUseCase)
// Set up the TimescaleDB tool
tool := mcp.NewTimescaleDBTool()
// Set up expectations
mockUseCase.On("GetDatabaseType", "test_db").Return("postgres", nil)
mockUseCase.On("ExecuteStatement",
mock.Anything,
"test_db",
mock.MatchedBy(func(sql string) bool {
return strings.Contains(sql, "remove_continuous_aggregate_policy")
}),
mock.Anything).Return(`{"result": "success"}`, nil)
// Create a request
request := server.ToolCallRequest{
Name: "timescaledb_remove_continuous_aggregate_policy_test_db",
Parameters: map[string]interface{}{
"operation": "remove_continuous_aggregate_policy",
"view_name": "daily_metrics",
},
}
// Call the handler
result, err := tool.HandleRequest(ctx, request, "test_db", mockUseCase)
// Verify the result
assert.NoError(t, err)
assert.NotNil(t, result)
// Check the result contains expected fields
resultMap, ok := result.(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, resultMap, "message")
// Verify the mock expectations
mockUseCase.AssertExpectations(t)
})
}
```
--------------------------------------------------------------------------------
/pkg/db/timescale/metadata.go:
--------------------------------------------------------------------------------
```go
package timescale
import (
"context"
"fmt"
"strconv"
"strings"
)
// HypertableMetadata represents metadata about a TimescaleDB hypertable
type HypertableMetadata struct {
TableName string
SchemaName string
Owner string
NumDimensions int
TimeDimension string
TimeDimensionType string
SpaceDimensions []string
ChunkTimeInterval string
Compression bool
RetentionPolicy bool
TotalSize string
TotalRows int64
Chunks int
}
// ColumnMetadata represents metadata about a column
type ColumnMetadata struct {
Name string
Type string
Nullable bool
IsPrimaryKey bool
IsIndexed bool
Description string
}
// ContinuousAggregateMetadata represents metadata about a continuous aggregate
type ContinuousAggregateMetadata struct {
ViewName string
ViewSchema string
MaterializedOnly bool
RefreshInterval string
RefreshLag string
RefreshStartOffset string
RefreshEndOffset string
HypertableName string
HypertableSchema string
ViewDefinition string
}
// GetHypertableMetadata returns detailed metadata about a hypertable
func (t *DB) GetHypertableMetadata(ctx context.Context, tableName string) (*HypertableMetadata, error) {
if !t.isTimescaleDB {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
// Query to get basic hypertable information
query := fmt.Sprintf(`
SELECT
h.table_name,
h.schema_name,
t.tableowner as owner,
h.num_dimensions,
dc.column_name as time_dimension,
dc.column_type as time_dimension_type,
dc.time_interval as chunk_time_interval,
h.compression_enabled,
pg_size_pretty(pg_total_relation_size(format('%%I.%%I', h.schema_name, h.table_name))) as total_size,
(SELECT count(*) FROM timescaledb_information.chunks WHERE hypertable_name = h.table_name) as chunks,
(SELECT count(*) FROM %s.%s) as total_rows
FROM timescaledb_information.hypertables h
JOIN pg_tables t ON h.table_name = t.tablename AND h.schema_name = t.schemaname
JOIN timescaledb_information.dimensions dc ON h.hypertable_name = dc.hypertable_name
WHERE h.table_name = '%s' AND dc.dimension_number = 1
`, tableName, tableName, tableName)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to get hypertable metadata: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
return nil, fmt.Errorf("table '%s' is not a hypertable", tableName)
}
row := rows[0]
metadata := &HypertableMetadata{
TableName: fmt.Sprintf("%v", row["table_name"]),
SchemaName: fmt.Sprintf("%v", row["schema_name"]),
Owner: fmt.Sprintf("%v", row["owner"]),
TimeDimension: fmt.Sprintf("%v", row["time_dimension"]),
TimeDimensionType: fmt.Sprintf("%v", row["time_dimension_type"]),
ChunkTimeInterval: fmt.Sprintf("%v", row["chunk_time_interval"]),
TotalSize: fmt.Sprintf("%v", row["total_size"]),
}
// Convert numeric fields
if numDimensions, ok := row["num_dimensions"].(int64); ok {
metadata.NumDimensions = int(numDimensions)
} else if numDimensions, ok := row["num_dimensions"].(int); ok {
metadata.NumDimensions = numDimensions
}
if chunks, ok := row["chunks"].(int64); ok {
metadata.Chunks = int(chunks)
} else if chunks, ok := row["chunks"].(int); ok {
metadata.Chunks = chunks
}
if rows, ok := row["total_rows"].(int64); ok {
metadata.TotalRows = rows
} else if rows, ok := row["total_rows"].(int); ok {
metadata.TotalRows = int64(rows)
} else if rowsStr, ok := row["total_rows"].(string); ok {
if rows, err := strconv.ParseInt(rowsStr, 10, 64); err == nil {
metadata.TotalRows = rows
}
}
// Handle boolean fields
if compression, ok := row["compression_enabled"].(bool); ok {
metadata.Compression = compression
} else if compressionStr, ok := row["compression_enabled"].(string); ok {
metadata.Compression = compressionStr == "t" || compressionStr == "true" || compressionStr == "1"
}
// Get space dimensions if there are more than one dimension
if metadata.NumDimensions > 1 {
spaceDimQuery := fmt.Sprintf(`
SELECT column_name
FROM timescaledb_information.dimensions
WHERE hypertable_name = '%s' AND dimension_number > 1
ORDER BY dimension_number
`, tableName)
spaceResult, err := t.ExecuteSQLWithoutParams(ctx, spaceDimQuery)
if err == nil {
spaceDimRows, ok := spaceResult.([]map[string]interface{})
if ok {
for _, dimRow := range spaceDimRows {
if colName, ok := dimRow["column_name"]; ok && colName != nil {
metadata.SpaceDimensions = append(metadata.SpaceDimensions, fmt.Sprintf("%v", colName))
}
}
}
}
}
// Check if a retention policy exists
retentionQuery := fmt.Sprintf(`
SELECT COUNT(*) > 0 as has_retention
FROM timescaledb_information.jobs
WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'
`, tableName)
retentionResult, err := t.ExecuteSQLWithoutParams(ctx, retentionQuery)
if err == nil {
retentionRows, ok := retentionResult.([]map[string]interface{})
if ok && len(retentionRows) > 0 {
if hasRetention, ok := retentionRows[0]["has_retention"].(bool); ok {
metadata.RetentionPolicy = hasRetention
}
}
}
return metadata, nil
}
// GetTableColumns returns metadata about columns in a table
func (t *DB) GetTableColumns(ctx context.Context, tableName string) ([]ColumnMetadata, error) {
query := fmt.Sprintf(`
SELECT
c.column_name,
c.data_type,
c.is_nullable = 'YES' as is_nullable,
(
SELECT COUNT(*) > 0
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = format('%%I.%%I', c.table_schema, c.table_name)::regclass
AND i.indisprimary
AND a.attname = c.column_name
) as is_primary_key,
(
SELECT COUNT(*) > 0
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = format('%%I.%%I', c.table_schema, c.table_name)::regclass
AND NOT i.indisprimary
AND a.attname = c.column_name
) as is_indexed,
col_description(format('%%I.%%I', c.table_schema, c.table_name)::regclass::oid,
ordinal_position) as description
FROM information_schema.columns c
WHERE c.table_name = '%s'
ORDER BY c.ordinal_position
`, tableName)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to get table columns: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected result type from database query")
}
var columns []ColumnMetadata
for _, row := range rows {
col := ColumnMetadata{
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
}
// Handle boolean fields
if nullable, ok := row["is_nullable"].(bool); ok {
col.Nullable = nullable
}
if isPK, ok := row["is_primary_key"].(bool); ok {
col.IsPrimaryKey = isPK
}
if isIndexed, ok := row["is_indexed"].(bool); ok {
col.IsIndexed = isIndexed
}
// Handle description which might be null
if desc, ok := row["description"]; ok && desc != nil {
col.Description = fmt.Sprintf("%v", desc)
}
columns = append(columns, col)
}
return columns, nil
}
// ListContinuousAggregates lists all continuous aggregates
func (t *DB) ListContinuousAggregates(ctx context.Context) ([]ContinuousAggregateMetadata, error) {
if !t.isTimescaleDB {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
query := `
SELECT
view_name,
view_schema,
materialized_only,
refresh_lag,
refresh_interval,
hypertable_name,
hypertable_schema
FROM timescaledb_information.continuous_aggregates
`
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list continuous aggregates: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected result type from database query")
}
var aggregates []ContinuousAggregateMetadata
for _, row := range rows {
agg := ContinuousAggregateMetadata{
ViewName: fmt.Sprintf("%v", row["view_name"]),
ViewSchema: fmt.Sprintf("%v", row["view_schema"]),
HypertableName: fmt.Sprintf("%v", row["hypertable_name"]),
HypertableSchema: fmt.Sprintf("%v", row["hypertable_schema"]),
}
// Handle boolean fields
if materializedOnly, ok := row["materialized_only"].(bool); ok {
agg.MaterializedOnly = materializedOnly
}
// Handle nullable fields
if refreshLag, ok := row["refresh_lag"]; ok && refreshLag != nil {
agg.RefreshLag = fmt.Sprintf("%v", refreshLag)
}
if refreshInterval, ok := row["refresh_interval"]; ok && refreshInterval != nil {
agg.RefreshInterval = fmt.Sprintf("%v", refreshInterval)
}
// Get view definition
definitionQuery := fmt.Sprintf(`
SELECT pg_get_viewdef(format('%%I.%%I', '%s', '%s')::regclass, true) as view_definition
`, agg.ViewSchema, agg.ViewName)
defResult, err := t.ExecuteSQLWithoutParams(ctx, definitionQuery)
if err == nil {
defRows, ok := defResult.([]map[string]interface{})
if ok && len(defRows) > 0 {
if def, ok := defRows[0]["view_definition"]; ok && def != nil {
agg.ViewDefinition = fmt.Sprintf("%v", def)
}
}
}
aggregates = append(aggregates, agg)
}
return aggregates, nil
}
// GetContinuousAggregate gets metadata about a specific continuous aggregate
func (t *DB) GetContinuousAggregate(ctx context.Context, viewName string) (*ContinuousAggregateMetadata, error) {
if !t.isTimescaleDB {
return nil, fmt.Errorf("TimescaleDB extension not available")
}
query := fmt.Sprintf(`
SELECT
view_name,
view_schema,
materialized_only,
refresh_lag,
refresh_interval,
hypertable_name,
hypertable_schema
FROM timescaledb_information.continuous_aggregates
WHERE view_name = '%s'
`, viewName)
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to get continuous aggregate: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
return nil, fmt.Errorf("continuous aggregate '%s' not found", viewName)
}
row := rows[0]
agg := &ContinuousAggregateMetadata{
ViewName: fmt.Sprintf("%v", row["view_name"]),
ViewSchema: fmt.Sprintf("%v", row["view_schema"]),
HypertableName: fmt.Sprintf("%v", row["hypertable_name"]),
HypertableSchema: fmt.Sprintf("%v", row["hypertable_schema"]),
}
// Handle boolean fields
if materializedOnly, ok := row["materialized_only"].(bool); ok {
agg.MaterializedOnly = materializedOnly
}
// Handle nullable fields
if refreshLag, ok := row["refresh_lag"]; ok && refreshLag != nil {
agg.RefreshLag = fmt.Sprintf("%v", refreshLag)
}
if refreshInterval, ok := row["refresh_interval"]; ok && refreshInterval != nil {
agg.RefreshInterval = fmt.Sprintf("%v", refreshInterval)
}
// Get view definition
definitionQuery := fmt.Sprintf(`
SELECT pg_get_viewdef(format('%%I.%%I', '%s', '%s')::regclass, true) as view_definition
`, agg.ViewSchema, agg.ViewName)
defResult, err := t.ExecuteSQLWithoutParams(ctx, definitionQuery)
if err == nil {
defRows, ok := defResult.([]map[string]interface{})
if ok && len(defRows) > 0 {
if def, ok := defRows[0]["view_definition"]; ok && def != nil {
agg.ViewDefinition = fmt.Sprintf("%v", def)
}
}
}
return agg, nil
}
// GetDatabaseSize gets size information about the database
func (t *DB) GetDatabaseSize(ctx context.Context) (map[string]string, error) {
query := `
SELECT
pg_size_pretty(pg_database_size(current_database())) as database_size,
current_database() as database_name,
(
SELECT pg_size_pretty(sum(pg_total_relation_size(format('%I.%I', h.schema_name, h.table_name))))
FROM timescaledb_information.hypertables h
) as hypertables_size,
(
SELECT count(*)
FROM timescaledb_information.hypertables
) as hypertables_count
`
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to get database size: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
return nil, fmt.Errorf("failed to get database size information")
}
info := make(map[string]string)
for k, v := range rows[0] {
if v != nil {
info[k] = fmt.Sprintf("%v", v)
}
}
return info, nil
}
// DetectTimescaleDBVersion checks if TimescaleDB is installed and returns its version
func (t *DB) DetectTimescaleDBVersion(ctx context.Context) (string, error) {
query := "SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'"
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return "", fmt.Errorf("failed to check TimescaleDB version: %w", err)
}
rows, ok := result.([]map[string]interface{})
if !ok || len(rows) == 0 {
return "", fmt.Errorf("TimescaleDB extension not installed")
}
version := rows[0]["extversion"]
if version == nil {
return "", fmt.Errorf("unable to determine TimescaleDB version")
}
return fmt.Sprintf("%v", version), nil
}
// GenerateHypertableSchema generates CREATE TABLE and CREATE HYPERTABLE statements for a hypertable
func (t *DB) GenerateHypertableSchema(ctx context.Context, tableName string) (string, error) {
if !t.isTimescaleDB {
return "", fmt.Errorf("TimescaleDB extension not available")
}
// Get table columns and constraints
columnsQuery := fmt.Sprintf(`
SELECT
'CREATE TABLE ' || quote_ident('%s') || ' (' ||
string_agg(
quote_ident(column_name) || ' ' ||
data_type ||
CASE
WHEN character_maximum_length IS NOT NULL THEN '(' || character_maximum_length || ')'
WHEN numeric_precision IS NOT NULL AND numeric_scale IS NOT NULL THEN '(' || numeric_precision || ',' || numeric_scale || ')'
ELSE ''
END ||
CASE WHEN is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END,
', '
) ||
CASE
WHEN (
SELECT count(*) > 0
FROM information_schema.table_constraints tc
WHERE tc.table_name = '%s' AND tc.constraint_type = 'PRIMARY KEY'
) THEN
', ' || (
SELECT 'PRIMARY KEY (' || string_agg(quote_ident(kcu.column_name), ', ') || ')'
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu ON
kcu.constraint_name = tc.constraint_name AND
kcu.table_schema = tc.table_schema AND
kcu.table_name = tc.table_name
WHERE tc.table_name = '%s' AND tc.constraint_type = 'PRIMARY KEY'
)
ELSE ''
END ||
');' as create_table_stmt
FROM information_schema.columns
WHERE table_name = '%s'
GROUP BY table_name
`, tableName, tableName, tableName, tableName)
columnsResult, err := t.ExecuteSQLWithoutParams(ctx, columnsQuery)
if err != nil {
return "", fmt.Errorf("failed to generate schema: %w", err)
}
columnsRows, ok := columnsResult.([]map[string]interface{})
if !ok || len(columnsRows) == 0 {
return "", fmt.Errorf("failed to generate schema for table '%s'", tableName)
}
createTableStmt := fmt.Sprintf("%v", columnsRows[0]["create_table_stmt"])
// Get hypertable metadata
metadata, err := t.GetHypertableMetadata(ctx, tableName)
if err != nil {
return createTableStmt, nil // Return just the CREATE TABLE statement if it's not a hypertable
}
// Generate CREATE HYPERTABLE statement
var createHypertableStmt strings.Builder
createHypertableStmt.WriteString(fmt.Sprintf("SELECT create_hypertable('%s', '%s'",
tableName, metadata.TimeDimension))
if metadata.ChunkTimeInterval != "" {
createHypertableStmt.WriteString(fmt.Sprintf(", chunk_time_interval => INTERVAL '%s'",
metadata.ChunkTimeInterval))
}
if len(metadata.SpaceDimensions) > 0 {
createHypertableStmt.WriteString(fmt.Sprintf(", partitioning_column => '%s'",
metadata.SpaceDimensions[0]))
}
createHypertableStmt.WriteString(");")
// Combine statements
result := createTableStmt + "\n\n" + createHypertableStmt.String()
// Add compression statement if enabled
if metadata.Compression {
compressionSettings, err := t.GetCompressionSettings(ctx, tableName)
if err == nil && compressionSettings.CompressionEnabled {
compressionStmt := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true);", tableName)
result += "\n\n" + compressionStmt
// Add compression policy if exists
if compressionSettings.CompressionInterval != "" {
policyStmt := fmt.Sprintf("SELECT add_compression_policy('%s', INTERVAL '%s'",
tableName, compressionSettings.CompressionInterval)
if compressionSettings.SegmentBy != "" {
policyStmt += fmt.Sprintf(", segmentby => '%s'", compressionSettings.SegmentBy)
}
if compressionSettings.OrderBy != "" {
policyStmt += fmt.Sprintf(", orderby => '%s'", compressionSettings.OrderBy)
}
policyStmt += ");"
result += "\n" + policyStmt
}
}
}
// Add retention policy if enabled
if metadata.RetentionPolicy {
retentionSettings, err := t.GetRetentionSettings(ctx, tableName)
if err == nil && retentionSettings.RetentionEnabled && retentionSettings.RetentionInterval != "" {
retentionStmt := fmt.Sprintf("SELECT add_retention_policy('%s', INTERVAL '%s');",
tableName, retentionSettings.RetentionInterval)
result += "\n\n" + retentionStmt
}
}
return result, nil
}
```
--------------------------------------------------------------------------------
/internal/delivery/mcp/tool_types.go:
--------------------------------------------------------------------------------
```go
package mcp
import (
"context"
"fmt"
"strings"
"github.com/FreePeak/cortex/pkg/server"
"github.com/FreePeak/cortex/pkg/tools"
)
// createTextResponse creates a simple response with a text content
func createTextResponse(text string) map[string]interface{} {
return map[string]interface{}{
"content": []map[string]interface{}{
{
"type": "text",
"text": text,
},
},
}
}
// addMetadata adds metadata to a response
func addMetadata(resp map[string]interface{}, key string, value interface{}) map[string]interface{} {
if resp["metadata"] == nil {
resp["metadata"] = make(map[string]interface{})
}
metadata, ok := resp["metadata"].(map[string]interface{})
if !ok {
// Create a new metadata map if conversion fails
metadata = make(map[string]interface{})
resp["metadata"] = metadata
}
metadata[key] = value
return resp
}
// TODO: Refactor tool type implementations to reduce duplication and improve maintainability
// TODO: Consider using a code generation approach for repetitive tool patterns
// TODO: Add comprehensive request validation for all tool parameters
// TODO: Implement proper rate limiting and resource protection
// TODO: Add detailed documentation for each tool type and its parameters
// ToolType interface defines the structure for different types of database tools
type ToolType interface {
// GetName returns the base name of the tool type (e.g., "query", "execute")
GetName() string
// GetDescription returns a description for this tool type
GetDescription(dbID string) string
// CreateTool creates a tool with the specified name
// The returned tool must be compatible with server.MCPServer.AddTool's first parameter
CreateTool(name string, dbID string) interface{}
// HandleRequest handles tool requests for this tool type
HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error)
}
// UseCaseProvider interface abstracts database use case operations
type UseCaseProvider interface {
ExecuteQuery(ctx context.Context, dbID, query string, params []interface{}) (string, error)
ExecuteStatement(ctx context.Context, dbID, statement string, params []interface{}) (string, error)
ExecuteTransaction(ctx context.Context, dbID, action string, txID string, statement string, params []interface{}, readOnly bool) (string, map[string]interface{}, error)
GetDatabaseInfo(dbID string) (map[string]interface{}, error)
ListDatabases() []string
GetDatabaseType(dbID string) (string, error)
}
// BaseToolType provides common functionality for tool types
type BaseToolType struct {
name string
description string
}
// GetName returns the name of the tool type
func (b *BaseToolType) GetName() string {
return b.name
}
// GetDescription returns a description for the tool type
func (b *BaseToolType) GetDescription(dbID string) string {
return fmt.Sprintf("%s on %s database", b.description, dbID)
}
//------------------------------------------------------------------------------
// QueryTool implementation
//------------------------------------------------------------------------------
// QueryTool handles SQL query operations
type QueryTool struct {
BaseToolType
}
// NewQueryTool creates a new query tool type
func NewQueryTool() *QueryTool {
return &QueryTool{
BaseToolType: BaseToolType{
name: "query",
description: "Execute SQL query",
},
}
}
// CreateTool creates a query tool
func (t *QueryTool) CreateTool(name string, dbID string) interface{} {
return tools.NewTool(
name,
tools.WithDescription(t.GetDescription(dbID)),
tools.WithString("query",
tools.Description("SQL query to execute"),
tools.Required(),
),
tools.WithArray("params",
tools.Description("Query parameters"),
tools.Items(map[string]interface{}{"type": "string"}),
),
)
}
// HandleRequest handles query tool requests
func (t *QueryTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
// If dbID is not provided, extract it from the tool name
if dbID == "" {
dbID = extractDatabaseIDFromName(request.Name)
}
query, ok := request.Parameters["query"].(string)
if !ok {
return nil, fmt.Errorf("query parameter must be a string")
}
var queryParams []interface{}
if request.Parameters["params"] != nil {
if paramsArr, ok := request.Parameters["params"].([]interface{}); ok {
queryParams = paramsArr
}
}
result, err := useCase.ExecuteQuery(ctx, dbID, query, queryParams)
if err != nil {
return nil, err
}
return createTextResponse(result), nil
}
// extractDatabaseIDFromName extracts the database ID from a tool name
func extractDatabaseIDFromName(name string) string {
// Format is: <tooltype>_<dbID>
parts := strings.Split(name, "_")
if len(parts) < 2 {
return ""
}
// The database ID is the last part
return parts[len(parts)-1]
}
//------------------------------------------------------------------------------
// ExecuteTool implementation
//------------------------------------------------------------------------------
// ExecuteTool handles SQL statement execution
type ExecuteTool struct {
BaseToolType
}
// NewExecuteTool creates a new execute tool type
func NewExecuteTool() *ExecuteTool {
return &ExecuteTool{
BaseToolType: BaseToolType{
name: "execute",
description: "Execute SQL statement",
},
}
}
// CreateTool creates an execute tool
func (t *ExecuteTool) CreateTool(name string, dbID string) interface{} {
return tools.NewTool(
name,
tools.WithDescription(t.GetDescription(dbID)),
tools.WithString("statement",
tools.Description("SQL statement to execute (INSERT, UPDATE, DELETE, etc.)"),
tools.Required(),
),
tools.WithArray("params",
tools.Description("Statement parameters"),
tools.Items(map[string]interface{}{"type": "string"}),
),
)
}
// HandleRequest handles execute tool requests
func (t *ExecuteTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
// If dbID is not provided, extract it from the tool name
if dbID == "" {
dbID = extractDatabaseIDFromName(request.Name)
}
statement, ok := request.Parameters["statement"].(string)
if !ok {
return nil, fmt.Errorf("statement parameter must be a string")
}
var statementParams []interface{}
if request.Parameters["params"] != nil {
if paramsArr, ok := request.Parameters["params"].([]interface{}); ok {
statementParams = paramsArr
}
}
result, err := useCase.ExecuteStatement(ctx, dbID, statement, statementParams)
if err != nil {
return nil, err
}
return createTextResponse(result), nil
}
//------------------------------------------------------------------------------
// TransactionTool implementation
//------------------------------------------------------------------------------
// TransactionTool handles database transactions
type TransactionTool struct {
BaseToolType
}
// NewTransactionTool creates a new transaction tool type
func NewTransactionTool() *TransactionTool {
return &TransactionTool{
BaseToolType: BaseToolType{
name: "transaction",
description: "Manage transactions",
},
}
}
// CreateTool creates a transaction tool
func (t *TransactionTool) CreateTool(name string, dbID string) interface{} {
return tools.NewTool(
name,
tools.WithDescription(t.GetDescription(dbID)),
tools.WithString("action",
tools.Description("Transaction action (begin, commit, rollback, execute)"),
tools.Required(),
),
tools.WithString("transactionId",
tools.Description("Transaction ID (required for commit, rollback, execute)"),
),
tools.WithString("statement",
tools.Description("SQL statement to execute within transaction (required for execute)"),
),
tools.WithArray("params",
tools.Description("Statement parameters"),
tools.Items(map[string]interface{}{"type": "string"}),
),
tools.WithBoolean("readOnly",
tools.Description("Whether the transaction is read-only (for begin)"),
),
)
}
// HandleRequest handles transaction tool requests
func (t *TransactionTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
// If dbID is not provided, extract it from the tool name
if dbID == "" {
dbID = extractDatabaseIDFromName(request.Name)
}
action, ok := request.Parameters["action"].(string)
if !ok {
return nil, fmt.Errorf("action parameter must be a string")
}
txID := ""
if request.Parameters["transactionId"] != nil {
var ok bool
txID, ok = request.Parameters["transactionId"].(string)
if !ok {
return nil, fmt.Errorf("transactionId parameter must be a string")
}
}
statement := ""
if request.Parameters["statement"] != nil {
var ok bool
statement, ok = request.Parameters["statement"].(string)
if !ok {
return nil, fmt.Errorf("statement parameter must be a string")
}
}
var params []interface{}
if request.Parameters["params"] != nil {
if paramsArr, ok := request.Parameters["params"].([]interface{}); ok {
params = paramsArr
}
}
readOnly := false
if request.Parameters["readOnly"] != nil {
var ok bool
readOnly, ok = request.Parameters["readOnly"].(bool)
if !ok {
return nil, fmt.Errorf("readOnly parameter must be a boolean")
}
}
message, metadata, err := useCase.ExecuteTransaction(ctx, dbID, action, txID, statement, params, readOnly)
if err != nil {
return nil, err
}
// Create response with text and metadata
resp := createTextResponse(message)
// Add metadata if provided
for k, v := range metadata {
addMetadata(resp, k, v)
}
return resp, nil
}
//------------------------------------------------------------------------------
// PerformanceTool implementation
//------------------------------------------------------------------------------
// PerformanceTool handles query performance analysis
type PerformanceTool struct {
BaseToolType
}
// NewPerformanceTool creates a new performance tool type
func NewPerformanceTool() *PerformanceTool {
return &PerformanceTool{
BaseToolType: BaseToolType{
name: "performance",
description: "Analyze query performance",
},
}
}
// CreateTool creates a performance analysis tool
func (t *PerformanceTool) CreateTool(name string, dbID string) interface{} {
return tools.NewTool(
name,
tools.WithDescription(t.GetDescription(dbID)),
tools.WithString("action",
tools.Description("Action (getSlowQueries, getMetrics, analyzeQuery, reset, setThreshold)"),
tools.Required(),
),
tools.WithString("query",
tools.Description("SQL query to analyze (required for analyzeQuery)"),
),
tools.WithNumber("limit",
tools.Description("Maximum number of results to return"),
),
tools.WithNumber("threshold",
tools.Description("Slow query threshold in milliseconds (required for setThreshold)"),
),
)
}
// HandleRequest handles performance tool requests
func (t *PerformanceTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
// If dbID is not provided, extract it from the tool name
if dbID == "" {
dbID = extractDatabaseIDFromName(request.Name)
}
// This is a simplified implementation
// In a real implementation, this would analyze query performance
action, ok := request.Parameters["action"].(string)
if !ok {
return nil, fmt.Errorf("action parameter must be a string")
}
var limit int
if request.Parameters["limit"] != nil {
if limitParam, ok := request.Parameters["limit"].(float64); ok {
limit = int(limitParam)
}
}
query := ""
if request.Parameters["query"] != nil {
var ok bool
query, ok = request.Parameters["query"].(string)
if !ok {
return nil, fmt.Errorf("query parameter must be a string")
}
}
var threshold int
if request.Parameters["threshold"] != nil {
if thresholdParam, ok := request.Parameters["threshold"].(float64); ok {
threshold = int(thresholdParam)
}
}
// This is where we would call the useCase to analyze performance
// For now, just return a placeholder
output := fmt.Sprintf("Performance analysis for action '%s' on database '%s'\n", action, dbID)
if query != "" {
output += fmt.Sprintf("Query: %s\n", query)
}
if limit > 0 {
output += fmt.Sprintf("Limit: %d\n", limit)
}
if threshold > 0 {
output += fmt.Sprintf("Threshold: %d ms\n", threshold)
}
return createTextResponse(output), nil
}
//------------------------------------------------------------------------------
// SchemaTool implementation
//------------------------------------------------------------------------------
// SchemaTool handles database schema exploration
type SchemaTool struct {
BaseToolType
}
// NewSchemaTool creates a new schema tool type
func NewSchemaTool() *SchemaTool {
return &SchemaTool{
BaseToolType: BaseToolType{
name: "schema",
description: "Get schema of",
},
}
}
// CreateTool creates a schema tool
func (t *SchemaTool) CreateTool(name string, dbID string) interface{} {
return tools.NewTool(
name,
tools.WithDescription(t.GetDescription(dbID)),
// Use any string parameter for compatibility
tools.WithString("random_string",
tools.Description("Dummy parameter (optional)"),
),
)
}
// HandleRequest handles schema tool requests
func (t *SchemaTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
// If dbID is not provided, extract it from the tool name
if dbID == "" {
dbID = extractDatabaseIDFromName(request.Name)
}
info, err := useCase.GetDatabaseInfo(dbID)
if err != nil {
return nil, err
}
// Format response text
infoStr := fmt.Sprintf("Database Schema for %s:\n\n%+v", dbID, info)
return createTextResponse(infoStr), nil
}
//------------------------------------------------------------------------------
// ListDatabasesTool implementation
//------------------------------------------------------------------------------
// ListDatabasesTool handles listing available databases
type ListDatabasesTool struct {
BaseToolType
}
// NewListDatabasesTool creates a new list databases tool type
func NewListDatabasesTool() *ListDatabasesTool {
return &ListDatabasesTool{
BaseToolType: BaseToolType{
name: "list_databases",
description: "List all available databases",
},
}
}
// CreateTool creates a list databases tool
func (t *ListDatabasesTool) CreateTool(name string, dbID string) interface{} {
return tools.NewTool(
name,
tools.WithDescription(t.GetDescription(dbID)),
// Use any string parameter for compatibility
tools.WithString("random_string",
tools.Description("Dummy parameter (optional)"),
),
)
}
// HandleRequest handles list databases tool requests
func (t *ListDatabasesTool) HandleRequest(ctx context.Context, request server.ToolCallRequest, dbID string, useCase UseCaseProvider) (interface{}, error) {
databases := useCase.ListDatabases()
// Format as text for display
output := "Available databases:\n\n"
for i, db := range databases {
output += fmt.Sprintf("%d. %s\n", i+1, db)
}
if len(databases) == 0 {
output += "No databases configured.\n"
}
return createTextResponse(output), nil
}
//------------------------------------------------------------------------------
// ToolTypeFactory provides a factory for creating tool types
//------------------------------------------------------------------------------
// ToolTypeFactory creates and manages tool types
type ToolTypeFactory struct {
toolTypes map[string]ToolType
}
// NewToolTypeFactory creates a new tool type factory with all registered tool types
func NewToolTypeFactory() *ToolTypeFactory {
factory := &ToolTypeFactory{
toolTypes: make(map[string]ToolType),
}
// Register all tool types
factory.Register(NewQueryTool())
factory.Register(NewExecuteTool())
factory.Register(NewTransactionTool())
factory.Register(NewPerformanceTool())
factory.Register(NewSchemaTool())
factory.Register(NewListDatabasesTool())
return factory
}
// Register adds a tool type to the factory
func (f *ToolTypeFactory) Register(toolType ToolType) {
f.toolTypes[toolType.GetName()] = toolType
}
// GetToolType returns a tool type by name
func (f *ToolTypeFactory) GetToolType(name string) (ToolType, bool) {
// Handle new simpler format: <tooltype>_<dbID> or just the tool type name
parts := strings.Split(name, "_")
if len(parts) > 0 {
// First part is the tool type name
toolType, ok := f.toolTypes[parts[0]]
if ok {
return toolType, true
}
}
// Direct tool type lookup
toolType, ok := f.toolTypes[name]
return toolType, ok
}
// GetToolTypeForSourceName finds the appropriate tool type for a source name
func (f *ToolTypeFactory) GetToolTypeForSourceName(sourceName string) (ToolType, string, bool) {
// Handle simpler format: <tooltype>_<dbID>
parts := strings.Split(sourceName, "_")
if len(parts) >= 2 {
// First part is tool type, last part is dbID
toolTypeName := parts[0]
dbID := parts[len(parts)-1]
toolType, ok := f.toolTypes[toolTypeName]
if ok {
return toolType, dbID, true
}
}
// Handle case for global tools
if sourceName == "list_databases" {
toolType, ok := f.toolTypes["list_databases"]
return toolType, "", ok
}
return nil, "", false
}
// GetAllToolTypes returns all registered tool types
func (f *ToolTypeFactory) GetAllToolTypes() []ToolType {
types := make([]ToolType, 0, len(f.toolTypes))
for _, toolType := range f.toolTypes {
types = append(types, toolType)
}
return types
}
```
--------------------------------------------------------------------------------
/pkg/db/timescale/policy_test.go:
--------------------------------------------------------------------------------
```go
package timescale
import (
"context"
"errors"
"testing"
)
func TestEnableCompression(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Register mock responses for checking if the table is a hypertable
mockDB.RegisterQueryResult("WHERE table_name = 'test_table'", []map[string]interface{}{
{"is_hypertable": true},
}, nil)
// Register mock response for the compression check in timescaledb_information.hypertables
mockDB.RegisterQueryResult("FROM timescaledb_information.hypertables WHERE hypertable_name", []map[string]interface{}{
{"compress": true},
}, nil)
// Test enabling compression without interval
err := tsdb.EnableCompression(ctx, "test_table", "")
if err != nil {
t.Fatalf("Failed to enable compression: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "ALTER TABLE test_table SET (timescaledb.compress = true)")
// Test enabling compression with interval
// Register mock responses for specific queries used in this test
mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
mockDB.RegisterQueryResult("SELECT add_compression_policy", nil, nil)
mockDB.RegisterQueryResult("timescaledb_information.hypertables WHERE hypertable_name = 'test_table'", []map[string]interface{}{
{"compress": true},
}, nil)
err = tsdb.EnableCompression(ctx, "test_table", "7 days")
if err != nil {
t.Fatalf("Failed to enable compression with interval: %v", err)
}
// Check that the correct policy query was executed
query, _ = mockDB.GetLastQuery()
AssertQueryContains(t, query, "add_compression_policy")
AssertQueryContains(t, query, "test_table")
AssertQueryContains(t, query, "7 days")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.EnableCompression(ctx, "test_table", "")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("ALTER TABLE", nil, errors.New("mocked error"))
err = tsdb.EnableCompression(ctx, "test_table", "")
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestDisableCompression(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Mock successful policy removal and compression disabling
mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
{"job_id": 1},
}, nil)
mockDB.RegisterQueryResult("SELECT remove_compression_policy", nil, nil)
mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
// Test disabling compression
err := tsdb.DisableCompression(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to disable compression: %v", err)
}
// Check that the correct ALTER TABLE query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "ALTER TABLE test_table SET (timescaledb.compress = false)")
// Test when no policy exists (should still succeed)
mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{}, nil)
mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
err = tsdb.DisableCompression(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to disable compression when no policy exists: %v", err)
}
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.DisableCompression(ctx, "test_table")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
{"job_id": 1},
}, nil)
mockDB.RegisterQueryResult("SELECT remove_compression_policy", nil, errors.New("mocked error"))
err = tsdb.DisableCompression(ctx, "test_table")
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestAddCompressionPolicy(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Mock checking compression status
mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
{"compress": true},
}, nil)
// Test adding a basic compression policy
err := tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
if err != nil {
t.Fatalf("Failed to add compression policy: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "SELECT add_compression_policy('test_table', INTERVAL '7 days'")
// Test adding a policy with segmentby and orderby
err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "device_id", "time DESC")
if err != nil {
t.Fatalf("Failed to add compression policy with additional options: %v", err)
}
// Check that the correct query was executed
query, _ = mockDB.GetLastQuery()
AssertQueryContains(t, query, "segmentby => 'device_id'")
AssertQueryContains(t, query, "orderby => 'time DESC'")
// Test enabling compression first if not already enabled
mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
{"compress": false},
}, nil)
mockDB.RegisterQueryResult("ALTER TABLE", nil, nil)
err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
if err != nil {
t.Fatalf("Failed to add compression policy with compression enabling: %v", err)
}
// Check that the ALTER TABLE query was executed first
query, _ = mockDB.GetLastQuery()
AssertQueryContains(t, query, "SELECT add_compression_policy")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error on compression check
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", nil, errors.New("mocked error"))
err = tsdb.AddCompressionPolicy(ctx, "test_table", "7 days", "", "")
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestRemoveCompressionPolicy(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Mock finding a policy
mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
{"job_id": 1},
}, nil)
mockDB.RegisterQueryResult("SELECT remove_compression_policy", nil, nil)
// Test removing a compression policy
err := tsdb.RemoveCompressionPolicy(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to remove compression policy: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "SELECT remove_compression_policy")
// Test when no policy exists (should succeed without error)
mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{}, nil)
err = tsdb.RemoveCompressionPolicy(ctx, "test_table")
if err != nil {
t.Errorf("Expected success when no policy exists, got: %v", err)
}
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.RemoveCompressionPolicy(ctx, "test_table")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", nil, errors.New("mocked error"))
err = tsdb.RemoveCompressionPolicy(ctx, "test_table")
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestGetCompressionSettings(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Mock compression status check
mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
{"compress": true},
}, nil)
// Mock compression settings
mockDB.RegisterQueryResult("SELECT segmentby, orderby FROM timescaledb_information.compression_settings", []map[string]interface{}{
{"segmentby": "device_id", "orderby": "time DESC"},
}, nil)
// Mock policy information
mockDB.RegisterQueryResult("SELECT s.schedule_interval, h.chunk_time_interval FROM", []map[string]interface{}{
{"schedule_interval": "7 days", "chunk_time_interval": "1 day"},
}, nil)
// Test getting compression settings
settings, err := tsdb.GetCompressionSettings(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to get compression settings: %v", err)
}
// Check the returned settings
if settings.HypertableName != "test_table" {
t.Errorf("Expected HypertableName to be 'test_table', got '%s'", settings.HypertableName)
}
if !settings.CompressionEnabled {
t.Error("Expected CompressionEnabled to be true, got false")
}
if settings.SegmentBy != "device_id" {
t.Errorf("Expected SegmentBy to be 'device_id', got '%s'", settings.SegmentBy)
}
if settings.OrderBy != "time DESC" {
t.Errorf("Expected OrderBy to be 'time DESC', got '%s'", settings.OrderBy)
}
if settings.CompressionInterval != "7 days" {
t.Errorf("Expected CompressionInterval to be '7 days', got '%s'", settings.CompressionInterval)
}
if settings.ChunkTimeInterval != "1 day" {
t.Errorf("Expected ChunkTimeInterval to be '1 day', got '%s'", settings.ChunkTimeInterval)
}
// Test when compression is not enabled
mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", []map[string]interface{}{
{"compress": false},
}, nil)
settings, err = tsdb.GetCompressionSettings(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to get compression settings when not enabled: %v", err)
}
if settings.CompressionEnabled {
t.Error("Expected CompressionEnabled to be false, got true")
}
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
_, err = tsdb.GetCompressionSettings(ctx, "test_table")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT compress FROM timescaledb_information.hypertables", nil, errors.New("mocked error"))
_, err = tsdb.GetCompressionSettings(ctx, "test_table")
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestAddRetentionPolicy(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Test adding a retention policy
err := tsdb.AddRetentionPolicy(ctx, "test_table", "30 days")
if err != nil {
t.Fatalf("Failed to add retention policy: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "SELECT add_retention_policy('test_table', INTERVAL '30 days')")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.AddRetentionPolicy(ctx, "test_table", "30 days")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT add_retention_policy", nil, errors.New("mocked error"))
err = tsdb.AddRetentionPolicy(ctx, "test_table", "30 days")
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestRemoveRetentionPolicy(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Mock finding a policy
mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{
{"job_id": 1},
}, nil)
mockDB.RegisterQueryResult("SELECT remove_retention_policy", nil, nil)
// Test removing a retention policy
err := tsdb.RemoveRetentionPolicy(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to remove retention policy: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "SELECT remove_retention_policy")
// Test when no policy exists (should succeed without error)
mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", []map[string]interface{}{}, nil)
err = tsdb.RemoveRetentionPolicy(ctx, "test_table")
if err != nil {
t.Errorf("Expected success when no policy exists, got: %v", err)
}
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.RemoveRetentionPolicy(ctx, "test_table")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT job_id FROM timescaledb_information.jobs", nil, errors.New("mocked error"))
err = tsdb.RemoveRetentionPolicy(ctx, "test_table")
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestGetRetentionSettings(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Mock policy information
mockDB.RegisterQueryResult("SELECT s.schedule_interval FROM", []map[string]interface{}{
{"schedule_interval": "30 days"},
}, nil)
// Test getting retention settings
settings, err := tsdb.GetRetentionSettings(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to get retention settings: %v", err)
}
// Check the returned settings
if settings.HypertableName != "test_table" {
t.Errorf("Expected HypertableName to be 'test_table', got '%s'", settings.HypertableName)
}
if !settings.RetentionEnabled {
t.Error("Expected RetentionEnabled to be true, got false")
}
if settings.RetentionInterval != "30 days" {
t.Errorf("Expected RetentionInterval to be '30 days', got '%s'", settings.RetentionInterval)
}
// Test when no policy exists
mockDB.RegisterQueryResult("SELECT s.schedule_interval FROM", []map[string]interface{}{}, nil)
settings, err = tsdb.GetRetentionSettings(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to get retention settings when no policy exists: %v", err)
}
if settings.RetentionEnabled {
t.Error("Expected RetentionEnabled to be false, got true")
}
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
_, err = tsdb.GetRetentionSettings(ctx, "test_table")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
}
func TestCompressChunks(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Test compressing all chunks
err := tsdb.CompressChunks(ctx, "test_table", "")
if err != nil {
t.Fatalf("Failed to compress all chunks: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "SELECT compress_chunks(hypertable => 'test_table')")
// Test compressing chunks with older_than specified
err = tsdb.CompressChunks(ctx, "test_table", "7 days")
if err != nil {
t.Fatalf("Failed to compress chunks with older_than: %v", err)
}
// Check that the correct query was executed
query, _ = mockDB.GetLastQuery()
AssertQueryContains(t, query, "SELECT compress_chunks(hypertable => 'test_table', older_than => INTERVAL '7 days')")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.CompressChunks(ctx, "test_table", "")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT compress_chunks", nil, errors.New("mocked error"))
err = tsdb.CompressChunks(ctx, "test_table", "")
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestDecompressChunks(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Test decompressing all chunks
err := tsdb.DecompressChunks(ctx, "test_table", "")
if err != nil {
t.Fatalf("Failed to decompress all chunks: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "SELECT decompress_chunks(hypertable => 'test_table')")
// Test decompressing chunks with newer_than specified
err = tsdb.DecompressChunks(ctx, "test_table", "7 days")
if err != nil {
t.Fatalf("Failed to decompress chunks with newer_than: %v", err)
}
// Check that the correct query was executed
query, _ = mockDB.GetLastQuery()
AssertQueryContains(t, query, "SELECT decompress_chunks(hypertable => 'test_table', newer_than => INTERVAL '7 days')")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
err = tsdb.DecompressChunks(ctx, "test_table", "")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("SELECT decompress_chunks", nil, errors.New("mocked error"))
err = tsdb.DecompressChunks(ctx, "test_table", "")
if err == nil {
t.Error("Expected query error, got nil")
}
}
func TestGetChunkCompressionStats(t *testing.T) {
mockDB := NewMockDB()
tsdb := &DB{
Database: mockDB,
isTimescaleDB: true,
}
ctx := context.Background()
// Mock chunk stats
mockStats := []map[string]interface{}{
{
"chunk_name": "_hyper_1_1_chunk",
"range_start": "2023-01-01 00:00:00",
"range_end": "2023-01-02 00:00:00",
"is_compressed": true,
"before_compression_total_bytes": 1000,
"after_compression_total_bytes": 200,
"compression_ratio": 80.0,
},
}
mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", mockStats, nil)
// Test getting chunk compression stats
_, err := tsdb.GetChunkCompressionStats(ctx, "test_table")
if err != nil {
t.Fatalf("Failed to get chunk compression stats: %v", err)
}
// Check that the correct query was executed
query, _ := mockDB.GetLastQuery()
AssertQueryContains(t, query, "FROM timescaledb_information.chunks")
AssertQueryContains(t, query, "hypertable_name = 'test_table'")
// Test when TimescaleDB is not available
tsdb.isTimescaleDB = false
_, err = tsdb.GetChunkCompressionStats(ctx, "test_table")
if err == nil {
t.Error("Expected error when TimescaleDB is not available, got nil")
}
// Test execution error
tsdb.isTimescaleDB = true
mockDB.RegisterQueryResult("FROM timescaledb_information.chunks", nil, errors.New("mocked error"))
_, err = tsdb.GetChunkCompressionStats(ctx, "test_table")
if err == nil {
t.Error("Expected query error, got nil")
}
}
```
--------------------------------------------------------------------------------
/pkg/dbtools/schema.go:
--------------------------------------------------------------------------------
```go
package dbtools
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/FreePeak/db-mcp-server/pkg/db"
"github.com/FreePeak/db-mcp-server/pkg/logger"
"github.com/FreePeak/db-mcp-server/pkg/tools"
)
// DatabaseStrategy defines the interface for database-specific query strategies
type DatabaseStrategy interface {
GetTablesQueries() []queryWithArgs
GetColumnsQueries(table string) []queryWithArgs
GetRelationshipsQueries(table string) []queryWithArgs
}
// NewDatabaseStrategy creates the appropriate strategy for the given database type
func NewDatabaseStrategy(driverName string) DatabaseStrategy {
switch driverName {
case "postgres":
return &PostgresStrategy{}
case "mysql":
return &MySQLStrategy{}
default:
logger.Warn("Unknown database driver: %s, will use generic strategy", driverName)
return &GenericStrategy{}
}
}
// PostgresStrategy implements DatabaseStrategy for PostgreSQL
type PostgresStrategy struct{}
// GetTablesQueries returns queries for retrieving tables in PostgreSQL
func (s *PostgresStrategy) GetTablesQueries() []queryWithArgs {
return []queryWithArgs{
// Primary: pg_catalog approach
{query: "SELECT tablename as table_name FROM pg_catalog.pg_tables WHERE schemaname = 'public'"},
// Secondary: information_schema approach
{query: "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"},
// Tertiary: pg_class approach
{query: "SELECT relname as table_name FROM pg_catalog.pg_class WHERE relkind = 'r' AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')"},
}
}
// GetColumnsQueries returns queries for retrieving columns in PostgreSQL
func (s *PostgresStrategy) GetColumnsQueries(table string) []queryWithArgs {
return []queryWithArgs{
// Primary: information_schema approach for PostgreSQL
{
query: `
SELECT column_name, data_type,
CASE WHEN is_nullable = 'YES' THEN 'YES' ELSE 'NO' END as is_nullable,
column_default
FROM information_schema.columns
WHERE table_name = $1 AND table_schema = 'public'
ORDER BY ordinal_position
`,
args: []interface{}{table},
},
// Secondary: pg_catalog approach for PostgreSQL
{
query: `
SELECT a.attname as column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type,
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END as is_nullable,
pg_catalog.pg_get_expr(d.adbin, d.adrelid) as column_default
FROM pg_catalog.pg_attribute a
LEFT JOIN pg_catalog.pg_attrdef d ON (a.attrelid = d.adrelid AND a.attnum = d.adnum)
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1 AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public'))
AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum
`,
args: []interface{}{table},
},
}
}
// GetRelationshipsQueries returns queries for retrieving relationships in PostgreSQL
func (s *PostgresStrategy) GetRelationshipsQueries(table string) []queryWithArgs {
baseQueries := []queryWithArgs{
// Primary: Standard information_schema approach for PostgreSQL
{
query: `
SELECT
tc.table_schema,
tc.constraint_name,
tc.table_name,
kcu.column_name,
ccu.table_schema AS foreign_table_schema,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
`,
args: []interface{}{},
},
// Alternate: Using pg_catalog for older PostgreSQL versions
{
query: `
SELECT
ns.nspname AS table_schema,
c.conname AS constraint_name,
cl.relname AS table_name,
att.attname AS column_name,
ns2.nspname AS foreign_table_schema,
cl2.relname AS foreign_table_name,
att2.attname AS foreign_column_name
FROM pg_constraint c
JOIN pg_class cl ON c.conrelid = cl.oid
JOIN pg_attribute att ON att.attrelid = cl.oid AND att.attnum = ANY(c.conkey)
JOIN pg_namespace ns ON ns.oid = cl.relnamespace
JOIN pg_class cl2 ON c.confrelid = cl2.oid
JOIN pg_attribute att2 ON att2.attrelid = cl2.oid AND att2.attnum = ANY(c.confkey)
JOIN pg_namespace ns2 ON ns2.oid = cl2.relnamespace
WHERE c.contype = 'f'
AND ns.nspname = 'public'
`,
args: []interface{}{},
},
}
if table == "" {
return baseQueries
}
queries := make([]queryWithArgs, len(baseQueries))
// Add table filter
queries[0] = queryWithArgs{
query: baseQueries[0].query + " AND (tc.table_name = $1 OR ccu.table_name = $1)",
args: []interface{}{table},
}
queries[1] = queryWithArgs{
query: baseQueries[1].query + " AND (cl.relname = $1 OR cl2.relname = $1)",
args: []interface{}{table},
}
return queries
}
// MySQLStrategy implements DatabaseStrategy for MySQL
type MySQLStrategy struct{}
// GetTablesQueries returns queries for retrieving tables in MySQL
func (s *MySQLStrategy) GetTablesQueries() []queryWithArgs {
return []queryWithArgs{
// Primary: information_schema approach
{query: "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()"},
// Secondary: SHOW TABLES approach
{query: "SHOW TABLES"},
}
}
// GetColumnsQueries returns queries for retrieving columns in MySQL
func (s *MySQLStrategy) GetColumnsQueries(table string) []queryWithArgs {
return []queryWithArgs{
// MySQL query for columns
{
query: `
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = ? AND table_schema = DATABASE()
ORDER BY ordinal_position
`,
args: []interface{}{table},
},
// Fallback for older MySQL versions
{
query: `SHOW COLUMNS FROM ` + table,
args: []interface{}{},
},
}
}
// GetRelationshipsQueries returns queries for retrieving relationships in MySQL
func (s *MySQLStrategy) GetRelationshipsQueries(table string) []queryWithArgs {
baseQueries := []queryWithArgs{
// Primary approach for MySQL
{
query: `
SELECT
tc.table_schema,
tc.constraint_name,
tc.table_name,
kcu.column_name,
kcu.referenced_table_schema AS foreign_table_schema,
kcu.referenced_table_name AS foreign_table_name,
kcu.referenced_column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = DATABASE()
`,
args: []interface{}{},
},
// Fallback using simpler query for older MySQL versions
{
query: `
SELECT
kcu.constraint_schema AS table_schema,
kcu.constraint_name,
kcu.table_name,
kcu.column_name,
kcu.referenced_table_schema AS foreign_table_schema,
kcu.referenced_table_name AS foreign_table_name,
kcu.referenced_column_name AS foreign_column_name
FROM information_schema.key_column_usage kcu
WHERE kcu.referenced_table_name IS NOT NULL
AND kcu.constraint_schema = DATABASE()
`,
args: []interface{}{},
},
}
if table == "" {
return baseQueries
}
queries := make([]queryWithArgs, len(baseQueries))
// Add table filter
queries[0] = queryWithArgs{
query: baseQueries[0].query + " AND (tc.table_name = ? OR kcu.referenced_table_name = ?)",
args: []interface{}{table, table},
}
queries[1] = queryWithArgs{
query: baseQueries[1].query + " AND (kcu.table_name = ? OR kcu.referenced_table_name = ?)",
args: []interface{}{table, table},
}
return queries
}
// GenericStrategy implements DatabaseStrategy for unknown database types
type GenericStrategy struct{}
// GetTablesQueries returns generic queries for retrieving tables
func (s *GenericStrategy) GetTablesQueries() []queryWithArgs {
return []queryWithArgs{
{query: "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"},
{query: "SELECT table_name FROM information_schema.tables"},
{query: "SHOW TABLES"}, // Last resort
}
}
// GetColumnsQueries returns generic queries for retrieving columns
func (s *GenericStrategy) GetColumnsQueries(table string) []queryWithArgs {
return []queryWithArgs{
// Try PostgreSQL-style query first
{
query: `
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = $1
ORDER BY ordinal_position
`,
args: []interface{}{table},
},
// Try MySQL-style query
{
query: `
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = ?
ORDER BY ordinal_position
`,
args: []interface{}{table},
},
}
}
// GetRelationshipsQueries returns generic queries for retrieving relationships
func (s *GenericStrategy) GetRelationshipsQueries(table string) []queryWithArgs {
pgQuery := queryWithArgs{
query: `
SELECT
tc.table_schema,
tc.constraint_name,
tc.table_name,
kcu.column_name,
ccu.table_schema AS foreign_table_schema,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
`,
args: []interface{}{},
}
mysqlQuery := queryWithArgs{
query: `
SELECT
kcu.constraint_schema AS table_schema,
kcu.constraint_name,
kcu.table_name,
kcu.column_name,
kcu.referenced_table_schema AS foreign_table_schema,
kcu.referenced_table_name AS foreign_table_name,
kcu.referenced_column_name AS foreign_column_name
FROM information_schema.key_column_usage kcu
WHERE kcu.referenced_table_name IS NOT NULL
`,
args: []interface{}{},
}
if table != "" {
pgQuery.query += " AND (tc.table_name = $1 OR ccu.table_name = $1)"
pgQuery.args = append(pgQuery.args, table)
mysqlQuery.query += " AND (kcu.table_name = ? OR kcu.referenced_table_name = ?)"
mysqlQuery.args = append(mysqlQuery.args, table, table)
}
return []queryWithArgs{pgQuery, mysqlQuery}
}
// createSchemaExplorerTool creates a tool for exploring database schema
func createSchemaExplorerTool() *tools.Tool {
return &tools.Tool{
Name: "dbSchema",
Description: "Auto-discover database structure and relationships",
Category: "database",
InputSchema: tools.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{
"component": map[string]interface{}{
"type": "string",
"description": "Schema component to explore (tables, columns, relationships, or full)",
"enum": []string{"tables", "columns", "relationships", "full"},
},
"table": map[string]interface{}{
"type": "string",
"description": "Table name to explore (optional, leave empty for all tables)",
},
"timeout": map[string]interface{}{
"type": "integer",
"description": "Query timeout in milliseconds (default: 10000)",
},
"database": map[string]interface{}{
"type": "string",
"description": "Database ID to use (optional if only one database is configured)",
},
},
Required: []string{"component", "database"},
},
Handler: handleSchemaExplorer,
}
}
// handleSchemaExplorer handles the schema explorer tool execution
func handleSchemaExplorer(ctx context.Context, params map[string]interface{}) (interface{}, error) {
// Check if database manager is initialized
if dbManager == nil {
return nil, fmt.Errorf("database manager not initialized")
}
// Extract parameters
component, ok := getStringParam(params, "component")
if !ok {
return nil, fmt.Errorf("component parameter is required")
}
// Get database ID
databaseID, ok := getStringParam(params, "database")
if !ok {
return nil, fmt.Errorf("database parameter is required")
}
// Get database instance
db, err := dbManager.GetDatabase(databaseID)
if err != nil {
return nil, fmt.Errorf("failed to get database: %w", err)
}
// Extract table parameter (optional depending on component)
table, _ := getStringParam(params, "table")
// Extract timeout
timeout := 10000 // Default timeout: 10 seconds
if timeoutParam, ok := getIntParam(params, "timeout"); ok {
timeout = timeoutParam
}
// Create context with timeout
timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond)
defer cancel()
// Use actual database queries based on component type
switch component {
case "tables":
return getTables(timeoutCtx, db)
case "columns":
if table == "" {
return nil, fmt.Errorf("table parameter is required for columns component")
}
return getColumns(timeoutCtx, db, table)
case "relationships":
return getRelationships(timeoutCtx, db, table)
case "full":
return getFullSchema(timeoutCtx, db)
default:
return nil, fmt.Errorf("invalid component: %s", component)
}
}
// executeWithFallbacks executes a series of database queries with fallbacks
// Returns the first successful result or the last error encountered
type queryWithArgs struct {
query string
args []interface{}
}
func executeWithFallbacks(ctx context.Context, db db.Database, queries []queryWithArgs, operationName string) (*sql.Rows, error) {
var lastErr error
for i, q := range queries {
rows, err := db.Query(ctx, q.query, q.args...)
if err == nil {
return rows, nil
}
lastErr = err
logger.Warn("%s fallback query %d failed: %v - Error: %v", operationName, i+1, q.query, err)
}
// All queries failed, return the last error
return nil, fmt.Errorf("%s failed after trying %d fallback queries: %w", operationName, len(queries), lastErr)
}
// getTables retrieves the list of tables in the database
func getTables(ctx context.Context, db db.Database) (interface{}, error) {
// Get database type from connected database
driverName := db.DriverName()
dbType := driverName
// Create the appropriate strategy
strategy := NewDatabaseStrategy(driverName)
// Get queries from strategy
queries := strategy.GetTablesQueries()
// Execute queries with fallbacks
rows, err := executeWithFallbacks(ctx, db, queries, "getTables")
if err != nil {
return nil, fmt.Errorf("failed to get tables: %w", err)
}
defer func() {
if rows != nil {
if err := rows.Close(); err != nil {
logger.Error("error closing rows: %v", err)
}
}
}()
// Convert rows to maps
results, err := rowsToMaps(rows)
if err != nil {
return nil, fmt.Errorf("failed to process tables: %w", err)
}
return map[string]interface{}{
"tables": results,
"dbType": dbType,
}, nil
}
// getColumns retrieves the columns for a specific table
func getColumns(ctx context.Context, db db.Database, table string) (interface{}, error) {
// Get database type from connected database
driverName := db.DriverName()
dbType := driverName
// Create the appropriate strategy
strategy := NewDatabaseStrategy(driverName)
// Get queries from strategy
queries := strategy.GetColumnsQueries(table)
// Execute queries with fallbacks
rows, err := executeWithFallbacks(ctx, db, queries, "getColumns["+table+"]")
if err != nil {
return nil, fmt.Errorf("failed to get columns for table %s: %w", table, err)
}
defer func() {
if rows != nil {
if err := rows.Close(); err != nil {
logger.Error("error closing rows: %v", err)
}
}
}()
// Convert rows to maps
results, err := rowsToMaps(rows)
if err != nil {
return nil, fmt.Errorf("failed to process columns: %w", err)
}
return map[string]interface{}{
"table": table,
"columns": results,
"dbType": dbType,
}, nil
}
// getRelationships retrieves the relationships for a table or all tables
func getRelationships(ctx context.Context, db db.Database, table string) (interface{}, error) {
// Get database type from connected database
driverName := db.DriverName()
dbType := driverName
// Create the appropriate strategy
strategy := NewDatabaseStrategy(driverName)
// Get queries from strategy
queries := strategy.GetRelationshipsQueries(table)
// Execute queries with fallbacks
rows, err := executeWithFallbacks(ctx, db, queries, "getRelationships")
if err != nil {
return nil, fmt.Errorf("failed to get relationships for table %s: %w", table, err)
}
defer func() {
if rows != nil {
if err := rows.Close(); err != nil {
logger.Error("error closing rows: %v", err)
}
}
}()
// Convert rows to maps
results, err := rowsToMaps(rows)
if err != nil {
return nil, fmt.Errorf("failed to process relationships: %w", err)
}
return map[string]interface{}{
"relationships": results,
"dbType": dbType,
"table": table,
}, nil
}
// safeGetMap safely gets a map from an interface value
func safeGetMap(obj interface{}) (map[string]interface{}, error) {
if obj == nil {
return nil, fmt.Errorf("nil value cannot be converted to map")
}
mapVal, ok := obj.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value is not a map[string]interface{}: %T", obj)
}
return mapVal, nil
}
// safeGetString safely gets a string from a map key
func safeGetString(m map[string]interface{}, key string) (string, error) {
val, ok := m[key]
if !ok {
return "", fmt.Errorf("key %q not found in map", key)
}
strVal, ok := val.(string)
if !ok {
return "", fmt.Errorf("value for key %q is not a string: %T", key, val)
}
return strVal, nil
}
// getFullSchema retrieves the complete database schema
func getFullSchema(ctx context.Context, db db.Database) (interface{}, error) {
tablesResult, err := getTables(ctx, db)
if err != nil {
return nil, fmt.Errorf("failed to get tables: %w", err)
}
tablesMap, err := safeGetMap(tablesResult)
if err != nil {
return nil, fmt.Errorf("invalid tables result: %w", err)
}
tablesSlice, ok := tablesMap["tables"].([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid tables data format")
}
// For each table, get columns
fullSchema := make(map[string]interface{})
for _, tableInfo := range tablesSlice {
tableName, err := safeGetString(tableInfo, "table_name")
if err != nil {
return nil, fmt.Errorf("invalid table info: %w", err)
}
columnsResult, columnsErr := getColumns(ctx, db, tableName)
if columnsErr != nil {
return nil, fmt.Errorf("failed to get columns for table %s: %w", tableName, columnsErr)
}
fullSchema[tableName] = columnsResult
}
// Get all relationships
relationships, relErr := getRelationships(ctx, db, "")
if relErr != nil {
return nil, fmt.Errorf("failed to get relationships: %w", relErr)
}
relMap, err := safeGetMap(relationships)
if err != nil {
return nil, fmt.Errorf("invalid relationships result: %w", err)
}
return map[string]interface{}{
"tables": tablesSlice,
"schema": fullSchema,
"relationships": relMap["relationships"],
}, nil
}
```
--------------------------------------------------------------------------------
/pkg/dbtools/querybuilder.go:
--------------------------------------------------------------------------------
```go
package dbtools
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"time"
"github.com/FreePeak/db-mcp-server/pkg/db"
"github.com/FreePeak/db-mcp-server/pkg/logger"
"github.com/FreePeak/db-mcp-server/pkg/tools"
)
// QueryComponents represents the components of a SQL query
type QueryComponents struct {
Select []string `json:"select"`
From string `json:"from"`
Joins []JoinClause `json:"joins"`
Where []Condition `json:"where"`
GroupBy []string `json:"groupBy"`
Having []string `json:"having"`
OrderBy []OrderBy `json:"orderBy"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
// JoinClause represents a SQL JOIN clause
type JoinClause struct {
Type string `json:"type"`
Table string `json:"table"`
On string `json:"on"`
}
// Condition represents a WHERE condition
type Condition struct {
Column string `json:"column"`
Operator string `json:"operator"`
Value string `json:"value"`
Connector string `json:"connector"`
}
// OrderBy represents an ORDER BY clause
type OrderBy struct {
Column string `json:"column"`
Direction string `json:"direction"`
}
// createQueryBuilderTool creates a tool for building and validating SQL queries
func createQueryBuilderTool() *tools.Tool {
return &tools.Tool{
Name: "dbQueryBuilder",
Description: "Visual SQL query construction with syntax validation",
Category: "database",
InputSchema: tools.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"description": "Action to perform (validate, build, analyze)",
"enum": []string{"validate", "build", "analyze"},
},
"query": map[string]interface{}{
"type": "string",
"description": "SQL query to validate or analyze",
},
"components": map[string]interface{}{
"type": "object",
"description": "Query components for building a query",
"properties": map[string]interface{}{
"select": map[string]interface{}{
"type": "array",
"description": "Columns to select",
"items": map[string]interface{}{
"type": "string",
},
},
"from": map[string]interface{}{
"type": "string",
"description": "Table to select from",
},
"joins": map[string]interface{}{
"type": "array",
"description": "Joins to include",
"items": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"type": map[string]interface{}{
"type": "string",
"enum": []string{"inner", "left", "right", "full"},
},
"table": map[string]interface{}{
"type": "string",
},
"on": map[string]interface{}{
"type": "string",
},
},
},
},
"where": map[string]interface{}{
"type": "array",
"description": "Where conditions",
"items": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"column": map[string]interface{}{
"type": "string",
},
"operator": map[string]interface{}{
"type": "string",
"enum": []string{"=", "!=", "<", ">", "<=", ">=", "LIKE", "IN", "NOT IN", "IS NULL", "IS NOT NULL"},
},
"value": map[string]interface{}{
"type": "string",
},
"connector": map[string]interface{}{
"type": "string",
"enum": []string{"AND", "OR"},
},
},
},
},
"groupBy": map[string]interface{}{
"type": "array",
"description": "Columns to group by",
"items": map[string]interface{}{
"type": "string",
},
},
"having": map[string]interface{}{
"type": "array",
"description": "Having conditions",
"items": map[string]interface{}{
"type": "string",
},
},
"orderBy": map[string]interface{}{
"type": "array",
"description": "Columns to order by",
"items": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"column": map[string]interface{}{
"type": "string",
},
"direction": map[string]interface{}{
"type": "string",
"enum": []string{"ASC", "DESC"},
},
},
},
},
"limit": map[string]interface{}{
"type": "integer",
"description": "Limit results",
},
"offset": map[string]interface{}{
"type": "integer",
"description": "Offset results",
},
},
},
"timeout": map[string]interface{}{
"type": "integer",
"description": "Execution timeout in milliseconds (default: 5000)",
},
"database": map[string]interface{}{
"type": "string",
"description": "Database ID to use (optional if only one database is configured)",
},
},
Required: []string{"action", "database"},
},
Handler: handleQueryBuilder,
}
}
// handleQueryBuilder handles the query builder tool execution
func handleQueryBuilder(ctx context.Context, params map[string]interface{}) (interface{}, error) {
// Check if database manager is initialized
if dbManager == nil {
return nil, fmt.Errorf("database manager not initialized")
}
// Extract parameters
action, ok := getStringParam(params, "action")
if !ok {
return nil, fmt.Errorf("action parameter is required")
}
// Get database ID
databaseID, ok := getStringParam(params, "database")
if !ok {
return nil, fmt.Errorf("database parameter is required")
}
// Get database instance
db, err := dbManager.GetDatabase(databaseID)
if err != nil {
return nil, fmt.Errorf("failed to get database: %w", err)
}
// Extract query parameter
query, _ := getStringParam(params, "query")
// Extract components parameter
var components QueryComponents
if componentsMap, ok := params["components"].(map[string]interface{}); ok {
// Parse components from map
if err := parseQueryComponents(&components, componentsMap); err != nil {
return nil, fmt.Errorf("failed to parse query components: %w", err)
}
}
// Create context with timeout
dbTimeout := db.QueryTimeout() * 1000 // Convert from seconds to milliseconds
timeout := dbTimeout // Default to the database's query timeout
if timeoutParam, ok := getIntParam(params, "timeout"); ok {
timeout = timeoutParam
}
timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond)
defer cancel()
// Execute requested action
switch action {
case "validate":
if query == "" {
return nil, fmt.Errorf("query parameter is required for validate action")
}
return validateQuery(timeoutCtx, db, query)
case "build":
if err := validateQueryComponents(&components); err != nil {
return nil, fmt.Errorf("invalid query components: %w", err)
}
builtQuery, err := buildQueryFromComponents(&components)
if err != nil {
return nil, fmt.Errorf("failed to build query: %w", err)
}
return validateQuery(timeoutCtx, db, builtQuery)
case "analyze":
if query == "" {
return nil, fmt.Errorf("query parameter is required for analyze action")
}
return analyzeQueryPlan(timeoutCtx, db, query)
default:
return nil, fmt.Errorf("invalid action: %s", action)
}
}
// parseQueryComponents parses query components from a map
func parseQueryComponents(components *QueryComponents, data map[string]interface{}) error {
// Parse SELECT columns
if selectArr, ok := data["select"].([]interface{}); ok {
components.Select = make([]string, len(selectArr))
for i, col := range selectArr {
if str, ok := col.(string); ok {
components.Select[i] = str
}
}
}
// Parse FROM table
if from, ok := data["from"].(string); ok {
components.From = from
}
// Parse JOINs
if joinsArr, ok := data["joins"].([]interface{}); ok {
components.Joins = make([]JoinClause, len(joinsArr))
for i, join := range joinsArr {
if joinMap, ok := join.(map[string]interface{}); ok {
if joinType, ok := joinMap["type"].(string); ok {
components.Joins[i].Type = joinType
}
if table, ok := joinMap["table"].(string); ok {
components.Joins[i].Table = table
}
if on, ok := joinMap["on"].(string); ok {
components.Joins[i].On = on
}
}
}
}
// Parse WHERE conditions
if whereArr, ok := data["where"].([]interface{}); ok {
components.Where = make([]Condition, len(whereArr))
for i, cond := range whereArr {
if condMap, ok := cond.(map[string]interface{}); ok {
if col, ok := condMap["column"].(string); ok {
components.Where[i].Column = col
}
if op, ok := condMap["operator"].(string); ok {
components.Where[i].Operator = op
}
if val, ok := condMap["value"].(string); ok {
components.Where[i].Value = val
}
if conn, ok := condMap["connector"].(string); ok {
components.Where[i].Connector = conn
}
}
}
}
// Parse GROUP BY columns
if groupByArr, ok := data["groupBy"].([]interface{}); ok {
components.GroupBy = make([]string, len(groupByArr))
for i, col := range groupByArr {
if str, ok := col.(string); ok {
components.GroupBy[i] = str
}
}
}
// Parse HAVING conditions
if havingArr, ok := data["having"].([]interface{}); ok {
components.Having = make([]string, len(havingArr))
for i, cond := range havingArr {
if str, ok := cond.(string); ok {
components.Having[i] = str
}
}
}
// Parse ORDER BY clauses
if orderByArr, ok := data["orderBy"].([]interface{}); ok {
components.OrderBy = make([]OrderBy, len(orderByArr))
for i, order := range orderByArr {
if orderMap, ok := order.(map[string]interface{}); ok {
if col, ok := orderMap["column"].(string); ok {
components.OrderBy[i].Column = col
}
if dir, ok := orderMap["direction"].(string); ok {
components.OrderBy[i].Direction = dir
}
}
}
}
// Parse LIMIT
if limit, ok := data["limit"].(float64); ok {
components.Limit = int(limit)
}
// Parse OFFSET
if offset, ok := data["offset"].(float64); ok {
components.Offset = int(offset)
}
return nil
}
// validateQueryComponents validates query components
func validateQueryComponents(components *QueryComponents) error {
if components.From == "" {
return fmt.Errorf("FROM clause is required")
}
if len(components.Select) == 0 {
return fmt.Errorf("SELECT clause must have at least one column")
}
for _, join := range components.Joins {
if join.Table == "" {
return fmt.Errorf("JOIN clause must have a table")
}
if join.On == "" {
return fmt.Errorf("JOIN clause must have an ON condition")
}
}
for _, where := range components.Where {
if where.Column == "" {
return fmt.Errorf("WHERE condition must have a column")
}
if where.Operator == "" {
return fmt.Errorf("WHERE condition must have an operator")
}
}
for _, order := range components.OrderBy {
if order.Column == "" {
return fmt.Errorf("ORDER BY clause must have a column")
}
if order.Direction != "ASC" && order.Direction != "DESC" {
return fmt.Errorf("ORDER BY direction must be ASC or DESC")
}
}
return nil
}
// buildQueryFromComponents builds a SQL query from components
func buildQueryFromComponents(components *QueryComponents) (string, error) {
var query strings.Builder
// Build SELECT clause
query.WriteString("SELECT ")
query.WriteString(strings.Join(components.Select, ", "))
// Build FROM clause
query.WriteString(" FROM ")
query.WriteString(components.From)
// Build JOIN clauses
for _, join := range components.Joins {
query.WriteString(" ")
query.WriteString(strings.ToUpper(join.Type))
query.WriteString(" JOIN ")
query.WriteString(join.Table)
query.WriteString(" ON ")
query.WriteString(join.On)
}
// Build WHERE clause
if len(components.Where) > 0 {
query.WriteString(" WHERE ")
for i, cond := range components.Where {
if i > 0 {
query.WriteString(" ")
query.WriteString(cond.Connector)
query.WriteString(" ")
}
query.WriteString(cond.Column)
query.WriteString(" ")
query.WriteString(cond.Operator)
if cond.Value != "" {
query.WriteString(" ")
query.WriteString(cond.Value)
}
}
}
// Build GROUP BY clause
if len(components.GroupBy) > 0 {
query.WriteString(" GROUP BY ")
query.WriteString(strings.Join(components.GroupBy, ", "))
}
// Build HAVING clause
if len(components.Having) > 0 {
query.WriteString(" HAVING ")
query.WriteString(strings.Join(components.Having, " AND "))
}
// Build ORDER BY clause
if len(components.OrderBy) > 0 {
query.WriteString(" ORDER BY ")
var orders []string
for _, order := range components.OrderBy {
orders = append(orders, order.Column+" "+order.Direction)
}
query.WriteString(strings.Join(orders, ", "))
}
// Build LIMIT clause
if components.Limit > 0 {
query.WriteString(fmt.Sprintf(" LIMIT %d", components.Limit))
}
// Build OFFSET clause
if components.Offset > 0 {
query.WriteString(fmt.Sprintf(" OFFSET %d", components.Offset))
}
return query.String(), nil
}
// validateQuery validates a SQL query for syntax errors
func validateQuery(ctx context.Context, db db.Database, query string) (interface{}, error) {
// Validate query by attempting to execute it with EXPLAIN
explainQuery := "EXPLAIN " + query
_, err := db.Query(ctx, explainQuery)
if err != nil {
return map[string]interface{}{
"valid": false,
"error": err.Error(),
"query": query,
}, nil
}
return map[string]interface{}{
"valid": true,
"query": query,
}, nil
}
// analyzeQueryPlan analyzes a specific query for performance
func analyzeQueryPlan(ctx context.Context, db db.Database, query string) (interface{}, error) {
explainQuery := "EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) " + query
rows, err := db.Query(ctx, explainQuery)
if err != nil {
return nil, fmt.Errorf("failed to analyze query: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
logger.Error("error closing rows: %v", err)
}
}()
var plan []byte
if !rows.Next() {
return nil, fmt.Errorf("no explain plan returned")
}
if err := rows.Scan(&plan); err != nil {
return nil, fmt.Errorf("failed to scan explain plan: %w", err)
}
return map[string]interface{}{
"query": query,
"plan": string(plan),
}, nil
}
// Helper function to calculate query complexity
func calculateQueryComplexity(query string) string {
query = strings.ToUpper(query)
// Count common complexity factors
joins := strings.Count(query, " JOIN ")
subqueries := strings.Count(query, "SELECT") - 1 // Subtract the main query
if subqueries < 0 {
subqueries = 0
}
aggregations := strings.Count(query, " SUM(") +
strings.Count(query, " COUNT(") +
strings.Count(query, " AVG(") +
strings.Count(query, " MIN(") +
strings.Count(query, " MAX(")
groupBy := strings.Count(query, " GROUP BY ")
orderBy := strings.Count(query, " ORDER BY ")
having := strings.Count(query, " HAVING ")
distinct := strings.Count(query, " DISTINCT ")
unions := strings.Count(query, " UNION ")
// Calculate complexity score - adjusted to match test expectations
score := joins*2 + (subqueries * 3) + aggregations + groupBy + orderBy + having*2 + distinct + unions*3
// Check special cases that should be complex
if joins >= 3 || (joins >= 2 && subqueries >= 1) || (subqueries >= 1 && aggregations >= 1) {
return "Complex"
}
// Determine complexity level
if score <= 2 {
return "Simple"
} else if score <= 6 {
return "Moderate"
} else {
return "Complex"
}
}
// Helper functions to extract error information from error messages
func getSuggestionForError(errorMsg string) string {
errorMsg = strings.ToLower(errorMsg)
if strings.Contains(errorMsg, "syntax error") {
return "Check SQL syntax for errors such as missing keywords, incorrect operators, or unmatched parentheses"
} else if strings.Contains(errorMsg, "unknown column") {
return "Column name is incorrect or doesn't exist in the specified table"
} else if strings.Contains(errorMsg, "unknown table") {
return "Table name is incorrect or doesn't exist in the database"
} else if strings.Contains(errorMsg, "ambiguous") {
return "Column name is ambiguous. Qualify it with the table name"
} else if strings.Contains(errorMsg, "missing") && strings.Contains(errorMsg, "from") {
return "FROM clause is missing or incorrectly formatted"
} else if strings.Contains(errorMsg, "no such table") {
return "Table specified does not exist in the database"
}
return "Review the query syntax and structure"
}
// extractLineNumberFromError extracts line number from a database error message
//
//nolint:unused // Used in future implementation
func extractLineNumberFromError(errMsg string) int {
// Check for line number patterns like "at line 42" or "line 42"
linePatterns := []string{
"at line ([0-9]+)",
"line ([0-9]+)",
"LINE ([0-9]+)",
}
for _, pattern := range linePatterns {
lineMatch := regexp.MustCompile(pattern).FindStringSubmatch(errMsg)
if len(lineMatch) > 1 {
lineNum, scanErr := strconv.Atoi(lineMatch[1])
if scanErr != nil {
logger.Warn("Failed to parse line number: %v", scanErr)
continue
}
return lineNum
}
}
return 0
}
// extractPositionFromError extracts position from a database error message
//
//nolint:unused // Used in future implementation
func extractPositionFromError(errMsg string) int {
// Check for position patterns
posPatterns := []string{
"at character ([0-9]+)",
"position ([0-9]+)",
"at or near \"([^\"]+)\"",
}
for _, pattern := range posPatterns {
posMatch := regexp.MustCompile(pattern).FindStringSubmatch(errMsg)
if len(posMatch) > 1 {
// For "at or near X" patterns, need to find X in the query
if strings.Contains(pattern, "at or near") {
return 0 // Just return 0 for now
}
// For numeric positions
pos, scanErr := strconv.Atoi(posMatch[1])
if scanErr != nil {
logger.Warn("Failed to parse position: %v", scanErr)
continue
}
return pos
}
}
return 0
}
// Mock functions for use when database is not available
// mockValidateQuery provides mock validation of SQL queries
func mockValidateQuery(query string) (interface{}, error) {
query = strings.TrimSpace(query)
// Basic syntax checks for demonstration purposes
if !strings.HasPrefix(strings.ToUpper(query), "SELECT") {
return map[string]interface{}{
"valid": false,
"query": query,
"error": "Query must start with SELECT",
"suggestion": "Begin your query with the SELECT keyword",
"errorLine": 1,
"errorColumn": 1,
}, nil
}
if !strings.Contains(strings.ToUpper(query), " FROM ") {
return map[string]interface{}{
"valid": false,
"query": query,
"error": "Missing FROM clause",
"suggestion": "Add a FROM clause to specify the table or view to query",
"errorLine": 1,
"errorColumn": len("SELECT"),
}, nil
}
// Check for unbalanced parentheses
if strings.Count(query, "(") != strings.Count(query, ")") {
return map[string]interface{}{
"valid": false,
"query": query,
"error": "Unbalanced parentheses",
"suggestion": "Ensure all opening parentheses have matching closing parentheses",
"errorLine": 1,
"errorColumn": 0,
}, nil
}
// Check for unclosed quotes
if strings.Count(query, "'")%2 != 0 {
return map[string]interface{}{
"valid": false,
"query": query,
"error": "Unclosed string literal",
"suggestion": "Ensure all string literals are properly closed with matching quotes",
"errorLine": 1,
"errorColumn": 0,
}, nil
}
// Query appears valid
return map[string]interface{}{
"valid": true,
"query": query,
}, nil
}
// mockAnalyzeQuery provides mock analysis of SQL queries
func mockAnalyzeQuery(query string) (interface{}, error) {
query = strings.ToUpper(query)
// Mock analysis results
var issues []string
var suggestions []string
// Check for potential performance issues
if !strings.Contains(query, " WHERE ") {
issues = append(issues, "Query has no WHERE clause")
suggestions = append(suggestions, "Add a WHERE clause to filter results and improve performance")
}
// Check for multiple joins
joinCount := strings.Count(query, " JOIN ")
if joinCount > 1 {
issues = append(issues, "Query contains multiple joins")
suggestions = append(suggestions, "Multiple joins can impact performance. Consider denormalizing or using indexed columns")
}
if strings.Contains(query, " LIKE '%") || strings.Contains(query, "% LIKE") {
issues = append(issues, "Query uses LIKE with leading wildcard")
suggestions = append(suggestions, "Leading wildcards in LIKE conditions cannot use indexes. Consider alternative approaches")
}
if strings.Contains(query, " ORDER BY ") && !strings.Contains(query, " LIMIT ") {
issues = append(issues, "ORDER BY without LIMIT")
suggestions = append(suggestions, "Consider adding a LIMIT clause to prevent sorting large result sets")
}
// Create a mock explain plan
mockExplainPlan := []map[string]interface{}{
{
"id": 1,
"select_type": "SIMPLE",
"table": getTableFromQuery(query),
"type": "ALL",
"possible_keys": nil,
"key": nil,
"key_len": nil,
"ref": nil,
"rows": 1000,
"Extra": "",
},
}
// If the query has a WHERE clause, assume it might use an index
if strings.Contains(query, " WHERE ") {
mockExplainPlan[0]["type"] = "range"
mockExplainPlan[0]["possible_keys"] = "PRIMARY"
mockExplainPlan[0]["key"] = "PRIMARY"
mockExplainPlan[0]["key_len"] = 4
mockExplainPlan[0]["rows"] = 100
}
return map[string]interface{}{
"query": query,
"explainPlan": mockExplainPlan,
"issues": issues,
"suggestions": suggestions,
"complexity": calculateQueryComplexity(query),
"is_mock": true,
}, nil
}
// Helper function to extract table name from a query
func getTableFromQuery(query string) string {
queryUpper := strings.ToUpper(query)
// Try to find the table name after FROM
fromIndex := strings.Index(queryUpper, " FROM ")
if fromIndex == -1 {
return "unknown_table"
}
// Get the text after FROM
afterFrom := query[fromIndex+6:]
afterFromUpper := queryUpper[fromIndex+6:]
// Find the end of the table name (next space, comma, or parenthesis)
endIndex := len(afterFrom)
for i, char := range afterFromUpper {
if char == ' ' || char == ',' || char == '(' || char == ')' {
endIndex = i
break
}
}
tableName := strings.TrimSpace(afterFrom[:endIndex])
// If there's an alias, remove it
tableNameParts := strings.Split(tableName, " AS ")
if len(tableNameParts) > 1 {
return tableNameParts[0]
}
return tableName
}
```