# Directory Structure
```
├── .github
│ └── workflows
│ ├── docker-image.yml
│ └── go.yml
├── Dockerfile
├── go-mcp-postgres.png
├── go.mod
├── go.sum
├── LICENSE
├── locales
│ ├── en
│ │ └── active.en.toml
│ └── zh-CN
│ └── active.zh-CN.toml
├── main_test.go
├── main.go
└── README.md
```
# Files
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
# go-mcp-postgres
## Overview
Copy code from https://github.com/Zhwt/go-mcp-mysql/ and with AI help, I change db from mysql to postgres.
Zero burden, ready-to-use Model Context Protocol (MCP) server for interacting with Postgres and automation. No Node.js or Python environment needed. This server provides tools to do CRUD operations on MySQL databases and tables, and a read-only mode to prevent surprise write operations. You can also make the MCP server check the query plan by using a `EXPLAIN` statement before executing the query by adding a `--with-explain-check` flag.
Please note that this is a work in progress and may not yet be ready for production use.
## Installation
1. Get the latest [release](https://github.com/guoling2008/go-mcp-postgres/releases) and put it in your `$PATH` or somewhere you can easily access.
2. Or if you have Go installed, you can build it from source:
```sh
go install -v github.com/guoling2008/go-mcp-postgres@latest
```
## Usage
### Method A: Using Command Line Arguments for stdio mode
```json
{
"mcpServers": {
"postgres": {
"command": "go-mcp-postgres",
"args": [
"--dsn",
"postgresql://user:pass@host:port/db"
]
}
}
}
```
Note: For those who put the binary outside of your `$PATH`, you need to replace `go-mcp-postgres` with the full path to the binary: e.g.: if you put the binary in the **Downloads** folder, you may use the following path:
```json
{
"mcpServers": {
"postgres": {
"command": "C:\\Users\\<username>\\Downloads\\go-mcp-postgres.exe",
"args": [
...
]
}
}
}
```
### Method B: Using Command Line Arguments for sse mode
./go-mcp-postgres --t sse --ip x.x.x.x --port nnnn --dsn postgresql://user:pass@host:port/db --lang en
### Optional Flags
- `--lang`: Set language option (en/zh-CN), defaults to system language
- Add a `--read-only` flag to enable read-only mode. In this mode, only tools beginning with `list`, `read_` and `desc_` are available. Make sure to refresh/restart the MCP server after adding this flag.
- By default, CRUD queries will be first executed with a `EXPLAIN ?` statement to check whether the generated query plan matches the expected pattern. Add a `--with-explain-check` flag to disable this behavior.
## Tools
_Multi-language support: All tool descriptions will automatically localize based on lang parameter_
If you want to add your own language support, please refer to the [locales](for i18n) folder.
The new locales/xxx/active-xx.toml file should be created if you want to use it in command line.
### Schema Tools
1. `list_database`
- ${mcp.tool.list_database.desc}
- Parameters: None
- Returns: A list of matching database names.
2. `list_table`
- ${mcp.tool.list_table.desc}
- Parameters:
- `name`: If provided, list tables with the specified name, Otherwise, list all tables.
- Returns: A list of matching table names.
3. `create_table`
- ${mcp.tool.create_table.desc}
- Parameters:
- `query`: The SQL query to create the table.
- Returns: x rows affected.
4. `alter_table`
- Alter an existing table in the Postgres server. The LLM is informed not to drop an existing table or column.
- Parameters:
- `query`: The SQL query to alter the table.
- Returns: x rows affected.
5. `desc_table`
- Describe the structure of a table.
- Parameters:
- `name`: The name of the table to describe.
- Returns: The structure of the table.
### Data Tools
1. `read_query`
- Execute a read-only SQL query.
- Parameters:
- `query`: The SQL query to execute.
- Returns: The result of the query.
2. `write_query`
- Execute a write SQL query.
- Parameters:
- `query`: The SQL query to execute.
- Returns: x rows affected, last insert id: <last_insert_id>.
3. `update_query`
- Execute an update SQL query.
- Parameters:
- `query`: The SQL query to execute.
- Returns: x rows affected.
4. `delete_query`
- Execute a delete SQL query.
- Parameters:
- `query`: The SQL query to execute.
- Returns: x rows affected.
5. `count_query`
- Query the number of rows in a certain table..
- Parameters:
- `name`: The name of the table to count.
- Returns: The row number of the table.
Big thanks to https://github.com/Zhwt/go-mcp-mysql/ again.
## License
MIT
```
--------------------------------------------------------------------------------
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
```yaml
# This workflow will build a golang project
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
name: Go
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
- name: Build
run: go mod tidy && go build -v ./...
- name: Test
run: go mod tidy && go test -v ./...
```
--------------------------------------------------------------------------------
/.github/workflows/docker-image.yml:
--------------------------------------------------------------------------------
```yaml
name: Docker Image CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build the Docker image
run: docker build . --file Dockerfile --tag go-mcp-postgres:$(date +%s)
-
name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Build and push
id: docker_build
uses: docker/build-push-action@v6
with:
push: true
tags: guoling21cn/go-mcp-postgres:${{ github.RUN_ID }}
```
--------------------------------------------------------------------------------
/locales/en/active.en.toml:
--------------------------------------------------------------------------------
```toml
[command]
dsn = "POSTGRES DSN"
read_only = "Enable read-only mode"
explain_check = "Check query plan with `EXPLAIN` before executing"
transport = "Transport type (stdio or sse)"
port = "sse port"
ip_address = "server ip address"
[gomcp]
list_database = "List all databases in the POSTGRES server"
list_table = "List all tables in the POSTGRES server"
create_table = "Create a new table in the POSTGRES server. Make sure you have added proper comments for each column and the table itself"
alter_table = "Alter an existing table in the POSTGRES server. Make sure you have updated comments for each modified column. DO NOT drop table or existing columns!"
alter_table_query = "The SQL query to alter the table"
create_table_query_description = "The SQL query to create the table"
read_query = "Execute a read-only SQL query. Make sure you have knowledge of the table structure before writing WHERE conditions. Call `desc_table` first if necessary"
count_query = "Query the number of rows in a certain table."
write_query = "Execute a write SQL query. Make sure you have knowledge of the table structure before executing the query. Make sure the data types match the columns' definitions"
update_query = "Execute an update SQL query. Make sure you have knowledge of the table structure before executing the query. Make sure there is always a WHERE condition. Call `desc_table` first if necessary"
count_query_name = "Name of the table to count"
desc_table = "Describe table structure"
desc_table_name = "Name of the table to describe"
delete_query = "Execute a delete SQL query. Make sure you have knowledge of the table structure before executing the query. Make sure there is always a WHERE condition. Call `desc_table` first if necessary"
query_execute_description = "Execute the SQL query and return the result"
```
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
```go
package main
import (
"context"
"embed"
"encoding/csv"
"flag"
"fmt"
"log"
"strings"
_ "github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/stdlib"
"github.com/jmoiron/sqlx"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/nicksnyder/go-i18n/v2/i18n"
"github.com/pelletier/go-toml/v2"
"golang.org/x/text/language"
)
//go:embed locales/*
var localeFS embed.FS
const (
StatementTypeNoExplainCheck = ""
StatementTypeSelect = "SELECT"
StatementTypeInsert = "INSERT"
StatementTypeUpdate = "UPDATE"
StatementTypeDelete = "DELETE"
)
var (
DSN string
ReadOnly bool
WithExplainCheck bool
DB *sqlx.DB
Transport string
IPaddress string
Port int
Lang string
)
type ExplainResult struct {
Id *string `db:"id"`
SelectType *string `db:"select_type"`
Table *string `db:"table"`
Partitions *string `db:"partitions"`
Type *string `db:"type"`
PossibleKeys *string `db:"possible_keys"`
Key *string `db:"key"`
KeyLen *string `db:"key_len"`
Ref *string `db:"ref"`
Rows *string `db:"rows"`
Filtered *string `db:"filtered"`
Extra *string `db:"Extra"`
}
type ShowCreateTableResult struct {
Table string `db:"Table"`
CreateTable string `db:"Create Table"`
}
func main() {
// 初始化i18n
bundle := i18n.NewBundle(language.English)
bundle.RegisterUnmarshalFunc("toml", toml.Unmarshal)
flag.StringVar(&DSN, "dsn", "", "POSTGRES DSN")
flag.BoolVar(&ReadOnly, "read-only", false, "Enable read-only mode")
flag.BoolVar(&WithExplainCheck, "with-explain-check", false, "Check query plan with `EXPLAIN` before executing")
flag.StringVar(&Transport, "t", "stdio", "Transport type (stdio or sse)")
flag.IntVar(&Port, "port", 8080, "sse server port")
flag.StringVar(&IPaddress, "ip", "localhost", "server ip address")
flag.StringVar(&Lang, "lang", language.English.String(), "Language code (en/zh-CN/...)")
flag.Parse()
langTag, err := language.Parse(Lang)
if err != nil {
langTag = language.English
}
langFile := fmt.Sprintf("locales/%s/active.%s.toml", langTag.String(), langTag.String())
if data, err := localeFS.ReadFile(langFile); err == nil {
bundle.ParseMessageFileBytes(data, langFile)
} else {
if enData, err := localeFS.ReadFile("locales/en/active.en.toml"); err == nil {
bundle.ParseMessageFileBytes(enData, "locales/en/active.en.toml")
}
}
localizer := i18n.NewLocalizer(bundle, langTag.String())
T := func(key string) string {
return localizer.MustLocalize(&i18n.LocalizeConfig{MessageID: key})
}
s := server.NewMCPServer(
"go-mcp-postgres",
"0.2.1",
server.WithResourceCapabilities(true, true),
server.WithPromptCapabilities(true),
server.WithLogging(),
)
// Schema Tools
listDatabaseTool := mcp.NewTool(
"list_database",
mcp.WithDescription(T("gomcp.list_database")),
)
listTableTool := mcp.NewTool(
"list_table",
mcp.WithDescription(T("gomcp.list_table")),
)
createTableTool := mcp.NewTool(
"create_table",
mcp.WithDescription(T("gomcp.create_table")),
mcp.WithString("query",
mcp.Required(),
mcp.Description(T("gomcp.create_table_query_description")),
),
)
alterTableTool := mcp.NewTool(
"alter_table",
mcp.WithDescription(T("gomcp.alter_table")),
mcp.WithString("query",
mcp.Required(),
mcp.Description(T("gomcp.alter_table_query")),
),
)
descTableTool := mcp.NewTool(
"desc_table",
mcp.WithDescription(T("gomcp.desc_table")),
mcp.WithString("name",
mcp.Required(),
mcp.Description(T("gomcp.desc_table_name")),
),
)
// Data Tools
readQueryTool := mcp.NewTool(
"read_query",
mcp.WithDescription(T("gomcp.read_query")),
mcp.WithString("query",
mcp.Required(),
mcp.Description(T("gomcp.query_execute_description")),
),
)
countQueryTool := mcp.NewTool(
"count_query",
mcp.WithDescription(T("gomcp.count_query")),
mcp.WithString("name",
mcp.Required(),
mcp.Description(T("gomcp.count_query_name")),
),
)
writeQueryTool := mcp.NewTool(
"write_query",
mcp.WithDescription(T("gomcp.write_query")),
mcp.WithString("query",
mcp.Required(),
mcp.Description(T("gomcp.query_execute_description")),
),
)
updateQueryTool := mcp.NewTool(
"update_query",
mcp.WithDescription(T("gomcp.update_query")),
mcp.WithString("query",
mcp.Required(),
mcp.Description(T("gomcp.query_execute_description")),
),
)
deleteQueryTool := mcp.NewTool(
"delete_query",
mcp.WithDescription(T("gomcp.delete_query")),
mcp.WithString("query",
mcp.Required(),
mcp.Description(T("gomcp.query_execute_description")),
),
)
s.AddTool(listDatabaseTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleQuery("SELECT datname FROM pg_database WHERE datistemplate = false;", StatementTypeNoExplainCheck)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
s.AddTool(listTableTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleQuery("SELECT table_schema,table_name FROM information_schema.tables ORDER BY table_schema,table_name;", StatementTypeNoExplainCheck)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
if !ReadOnly {
s.AddTool(createTableTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleExec(request.Params.Arguments["query"].(string), StatementTypeNoExplainCheck)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
}
if !ReadOnly {
s.AddTool(alterTableTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleExec(request.Params.Arguments["query"].(string), StatementTypeNoExplainCheck)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
}
s.AddTool(listTableTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleQuery("SELECT table_schema,table_name FROM information_schema.tables ORDER BY table_schema,table_name;", StatementTypeNoExplainCheck)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
s.AddTool(descTableTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
descsql :=
`SELECT
'CREATE TABLE ' || t.table_name || ' (' ||
string_agg(
c.column_name || ' ' || c.data_type ||
CASE
WHEN c.character_maximum_length IS NOT NULL THEN '(' || c.character_maximum_length || ')'
ELSE ''
END ||
CASE
WHEN c.is_nullable = 'NO' THEN ' NOT NULL'
ELSE ''
END, ', '
) ||
', PRIMARY KEY (' || (
SELECT string_agg(kcu.column_name, ', ')
FROM information_schema.key_column_usage kcu
WHERE kcu.table_name = t.table_name AND kcu.constraint_name LIKE '%_pkey'
) || ')' ||
');' AS create_table_sql
FROM
information_schema.tables t
JOIN
information_schema.columns c ON t.table_name = c.table_name
WHERE
t.table_name = '` + request.Params.Arguments["name"].(string) + `'
GROUP BY
t.table_name;`
result, err := HandleQuery(descsql, StatementTypeNoExplainCheck)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
s.AddTool(readQueryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleQuery(request.Params.Arguments["query"].(string), StatementTypeSelect)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
s.AddTool(countQueryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleQuery("SELECT count(1) from "+request.Params.Arguments["name"].(string)+";", StatementTypeNoExplainCheck)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
if !ReadOnly {
s.AddTool(writeQueryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleExec(request.Params.Arguments["query"].(string), StatementTypeInsert)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
}
if !ReadOnly {
s.AddTool(updateQueryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleExec(request.Params.Arguments["query"].(string), StatementTypeUpdate)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
}
if !ReadOnly {
s.AddTool(deleteQueryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
result, err := HandleExec(request.Params.Arguments["query"].(string), StatementTypeDelete)
if err != nil {
return nil, nil
}
return mcp.NewToolResultText(result), nil
})
}
// Only check for "sse" since stdio is the default
if Transport == "sse" {
sseServer := server.NewSSEServer(s, server.WithBaseURL(fmt.Sprintf("http://%s:%d", IPaddress, Port)))
//log.Printf("SSE server listening on : %d", Port)
if err := sseServer.Start(fmt.Sprintf("%s:%d", IPaddress, Port)); err != nil {
log.Fatalf("Server error: %v", err)
}
} else {
if err := server.ServeStdio(s); err != nil {
log.Fatalf("Server error: %v", err)
}
}
}
func GetDB() (*sqlx.DB, error) {
if DB != nil {
return DB, nil
}
db, err := sqlx.Connect("pgx", DSN)
if err != nil {
return nil, fmt.Errorf("failed to establish database connection: %v", err)
}
DB = db
return DB, nil
}
func HandleQuery(query, expect string) (string, error) {
result, headers, err := DoQuery(query, expect)
if err != nil {
return "", err
}
s, err := MapToCSV(result, headers)
if err != nil {
return "", err
}
return s, nil
}
func DoQuery(query, expect string) ([]map[string]interface{}, []string, error) {
db, err := GetDB()
if err != nil {
return nil, nil, err
}
if len(expect) > 0 {
if err := HandleExplain(query, expect); err != nil {
return nil, nil, err
}
}
rows, err := db.Queryx(query)
if err != nil {
return nil, nil, err
}
cols, err := rows.Columns()
if err != nil {
return nil, nil, err
}
result := []map[string]interface{}{}
for rows.Next() {
row, err := rows.SliceScan()
if err != nil {
return nil, nil, err
}
resultRow := map[string]interface{}{}
for i, col := range cols {
switch v := row[i].(type) {
case []byte:
resultRow[col] = string(v)
default:
resultRow[col] = v
}
}
result = append(result, resultRow)
}
return result, cols, nil
}
func HandleExec(query, expect string) (string, error) {
db, err := GetDB()
if err != nil {
return "", err
}
if len(expect) > 0 {
if err := HandleExplain(query, expect); err != nil {
return "", err
}
}
result, err := db.Exec(query)
if err != nil {
return "", err
}
ra, err := result.RowsAffected()
if err != nil {
return "", err
}
switch expect {
case StatementTypeInsert:
li, err := result.LastInsertId()
if err != nil {
return "", err
}
return fmt.Sprintf("%d rows affected, last insert id: %d", ra, li), nil
default:
return fmt.Sprintf("%d rows affected", ra), nil
}
}
func HandleExplain(query, expect string) error {
if !WithExplainCheck {
return nil
}
db, err := GetDB()
if err != nil {
return err
}
rows, err := db.Queryx(fmt.Sprintf("EXPLAIN %s", query))
if err != nil {
return err
}
result := []ExplainResult{}
for rows.Next() {
var row ExplainResult
if err := rows.StructScan(&row); err != nil {
return err
}
result = append(result, row)
}
if len(result) != 1 {
return fmt.Errorf("unable to check query plan, denied")
}
match := false
switch expect {
case StatementTypeInsert:
fallthrough
case StatementTypeUpdate:
fallthrough
case StatementTypeDelete:
if *result[0].SelectType == expect {
match = true
}
default:
// for SELECT type query, the select_type will be multiple values
// here we check if it's not INSERT, UPDATE or DELETE
match = true
for _, typ := range []string{StatementTypeInsert, StatementTypeUpdate, StatementTypeDelete} {
if *result[0].SelectType == typ {
match = false
break
}
}
}
if !match {
return fmt.Errorf("query plan does not match expected pattern, denied")
}
return nil
}
/*
func HandleDescTable(name string) (string, error) {
db, err := GetDB()
if err != nil {
return "", err
}
rows, err := db.Queryx(descsql)
if err != nil {
return "", err
}
result := []ShowCreateTableResult{}
for rows.Next() {
var row ShowCreateTableResult
if err := rows.StructScan(&row); err != nil {
return "", err
}
result = append(result, row)
}
if len(result) == 0 {
return "", fmt.Errorf("table %s does not exist", name)
}
return result[0].CreateTable, nil
}*/
func MapToCSV(m []map[string]interface{}, headers []string) (string, error) {
var csvBuf strings.Builder
writer := csv.NewWriter(&csvBuf)
if err := writer.Write(headers); err != nil {
return "", fmt.Errorf("failed to write headers: %v", err)
}
for _, item := range m {
row := make([]string, len(headers))
for i, header := range headers {
value, exists := item[header]
if !exists {
return "", fmt.Errorf("key '%s' not found in map", header)
}
row[i] = fmt.Sprintf("%v", value)
}
if err := writer.Write(row); err != nil {
return "", fmt.Errorf("failed to write row: %v", err)
}
}
writer.Flush()
if err := writer.Error(); err != nil {
return "", fmt.Errorf("error flushing CSV writer: %v", err)
}
return csvBuf.String(), nil
}
```
--------------------------------------------------------------------------------
/main_test.go:
--------------------------------------------------------------------------------
```go
package main
import (
"database/sql"
"fmt"
"strings"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
)
func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("Failed to create mock DB: %v", err)
}
// Save the original DB
originalDB := DB
// Replace with our mock
DB = sqlx.NewDb(db, "sqlmock")
// Return a cleanup function
cleanup := func() {
db.Close()
DB = originalDB
}
return db, mock, cleanup
}
func TestGetDB(t *testing.T) {
// Save the original DB
originalDB := DB
defer func() { DB = originalDB }()
t.Run("returns existing DB if already set", func(t *testing.T) {
// Set a mock DB
mockDB := &sqlx.DB{}
DB = mockDB
// Call GetDB
db, err := GetDB()
// Verify results
assert.NoError(t, err)
assert.Equal(t, mockDB, db)
})
t.Run("creates new DB connection if not set", func(t *testing.T) {
// Reset DB to nil
DB = nil
// Set DSN to a value that will work with sqlmock
originalDSN := DSN
DSN = "sqlmock"
defer func() { DSN = originalDSN }()
// This test is more of an integration test and would require a real DB
// For unit testing, we'll just verify that it returns an error with an invalid DSN
_, err := GetDB()
assert.Error(t, err)
})
}
func TestHandleQuery(t *testing.T) {
_, mock, cleanup := setupMockDB(t)
defer cleanup()
t.Run("successful query", func(t *testing.T) {
// Setup mock expectations
rows := sqlmock.NewRows([]string{"id", "name"}).
AddRow(1, "test1").
AddRow(2, "test2")
mock.ExpectQuery("SELECT").WillReturnRows(rows)
// Call HandleQuery
result, err := HandleQuery("SELECT id, name FROM users", StatementTypeNoExplainCheck)
// Verify results
assert.NoError(t, err)
assert.Contains(t, result, "id,name")
assert.Contains(t, result, "1,test1")
assert.Contains(t, result, "2,test2")
})
t.Run("query error", func(t *testing.T) {
// Setup mock expectations
mock.ExpectQuery("SELECT").WillReturnError(fmt.Errorf("query error"))
// Call HandleQuery
_, err := HandleQuery("SELECT id, name FROM users", StatementTypeNoExplainCheck)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "query error")
})
}
func TestDoQuery(t *testing.T) {
_, mock, cleanup := setupMockDB(t)
defer cleanup()
t.Run("successful query", func(t *testing.T) {
// Setup mock expectations
rows := sqlmock.NewRows([]string{"id", "name"}).
AddRow(1, "test1").
AddRow(2, "test2")
mock.ExpectQuery("SELECT").WillReturnRows(rows)
// Call DoQuery
result, headers, err := DoQuery("SELECT id, name FROM users", StatementTypeNoExplainCheck)
// Verify results
assert.NoError(t, err)
assert.Equal(t, []string{"id", "name"}, headers)
assert.Len(t, result, 2)
assert.Equal(t, int64(1), result[0]["id"])
assert.Equal(t, "test1", result[0]["name"])
assert.Equal(t, int64(2), result[1]["id"])
assert.Equal(t, "test2", result[1]["name"])
})
t.Run("with explain check", func(t *testing.T) {
// Save original WithExplainCheck value
originalWithExplainCheck := WithExplainCheck
WithExplainCheck = true
defer func() { WithExplainCheck = originalWithExplainCheck }()
// Setup mock expectations for EXPLAIN
explainRows := sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow("1", "SELECT", "users", nil, "ALL", nil, nil, nil, nil, "2", "100.00", nil)
mock.ExpectQuery("EXPLAIN").WillReturnRows(explainRows)
// Setup mock expectations for actual query
rows := sqlmock.NewRows([]string{"id", "name"}).
AddRow(1, "test1").
AddRow(2, "test2")
mock.ExpectQuery("SELECT").WillReturnRows(rows)
// Call DoQuery
result, headers, err := DoQuery("SELECT id, name FROM users", StatementTypeSelect)
// Verify results
assert.NoError(t, err)
assert.Equal(t, []string{"id", "name"}, headers)
assert.Len(t, result, 2)
})
t.Run("query error", func(t *testing.T) {
// Setup mock expectations
mock.ExpectQuery("SELECT").WillReturnError(fmt.Errorf("query error"))
// Call DoQuery
_, _, err := DoQuery("SELECT id, name FROM users", StatementTypeNoExplainCheck)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "query error")
})
t.Run("columns error", func(t *testing.T) {
// Setup mock expectations
mock.ExpectQuery("SELECT").WillReturnError(fmt.Errorf("columns error"))
// Call DoQuery
_, _, err := DoQuery("SELECT id, name FROM users", StatementTypeNoExplainCheck)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "columns error")
})
t.Run("scan error", func(t *testing.T) {
// Setup mock expectations
mock.ExpectQuery("SELECT").WillReturnError(fmt.Errorf("scan error"))
// Call DoQuery
_, _, err := DoQuery("SELECT id, name FROM users", StatementTypeNoExplainCheck)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "scan error")
})
t.Run("with byte array conversion", func(t *testing.T) {
// Setup mock expectations with a byte array value
rows := sqlmock.NewRows([]string{"id", "blob"}).
AddRow(1, []byte("binary data"))
mock.ExpectQuery("SELECT").WillReturnRows(rows)
// Call DoQuery
result, headers, err := DoQuery("SELECT id, blob FROM users", StatementTypeNoExplainCheck)
// Verify results
assert.NoError(t, err)
assert.Equal(t, []string{"id", "blob"}, headers)
assert.Len(t, result, 1)
assert.Equal(t, int64(1), result[0]["id"])
assert.Equal(t, "binary data", result[0]["blob"])
})
}
func TestHandleExec(t *testing.T) {
_, mock, cleanup := setupMockDB(t)
defer cleanup()
t.Run("insert statement", func(t *testing.T) {
// Setup mock expectations
mock.ExpectExec("INSERT").WillReturnResult(sqlmock.NewResult(123, 1))
// Call HandleExec
result, err := HandleExec("INSERT INTO users (name) VALUES ('test')", StatementTypeInsert)
// Verify results
assert.NoError(t, err)
assert.Contains(t, result, "1 rows affected")
assert.Contains(t, result, "last insert id: 123")
})
t.Run("update statement", func(t *testing.T) {
// Setup mock expectations
mock.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 2))
// Call HandleExec
result, err := HandleExec("UPDATE users SET name = 'updated' WHERE id IN (1, 2)", StatementTypeNoExplainCheck)
// Verify results
assert.NoError(t, err)
assert.Equal(t, "2 rows affected", result)
})
t.Run("exec error", func(t *testing.T) {
// Setup mock expectations
mock.ExpectExec("UPDATE").WillReturnError(fmt.Errorf("exec error"))
// Call HandleExec
_, err := HandleExec("UPDATE users SET name = 'updated'", StatementTypeNoExplainCheck)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "exec error")
})
}
func TestHandleExplain(t *testing.T) {
_, mock, cleanup := setupMockDB(t)
defer cleanup()
// Save original WithExplainCheck value
originalWithExplainCheck := WithExplainCheck
defer func() { WithExplainCheck = originalWithExplainCheck }()
t.Run("with explain check disabled", func(t *testing.T) {
// Disable explain check
WithExplainCheck = false
// Call HandleExplain - should return nil without querying
err := HandleExplain("SELECT * FROM users", StatementTypeSelect)
// Verify results
assert.NoError(t, err)
})
// Enable explain check for the rest of the tests
WithExplainCheck = true
t.Run("select query", func(t *testing.T) {
// Setup mock expectations
explainRows := sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow("1", "SIMPLE", "users", nil, "ALL", nil, nil, nil, nil, "2", "100.00", nil)
mock.ExpectQuery("EXPLAIN").WillReturnRows(explainRows)
// Call HandleExplain
err := HandleExplain("SELECT * FROM users", StatementTypeSelect)
// Verify results
assert.NoError(t, err)
})
t.Run("insert query", func(t *testing.T) {
// Setup mock expectations
explainRows := sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow("1", "INSERT", "users", nil, "ALL", nil, nil, nil, nil, "1", "100.00", nil)
mock.ExpectQuery("EXPLAIN").WillReturnRows(explainRows)
// Call HandleExplain
err := HandleExplain("INSERT INTO users (name) VALUES ('test')", StatementTypeInsert)
// Verify results
assert.NoError(t, err)
})
t.Run("update query", func(t *testing.T) {
// Setup mock expectations
explainRows := sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow("1", "UPDATE", "users", nil, "ALL", nil, nil, nil, nil, "1", "100.00", nil)
mock.ExpectQuery("EXPLAIN").WillReturnRows(explainRows)
// Call HandleExplain
err := HandleExplain("UPDATE users SET name = 'test' WHERE id = 1", StatementTypeUpdate)
// Verify results
assert.NoError(t, err)
})
t.Run("delete query", func(t *testing.T) {
// Setup mock expectations
explainRows := sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow("1", "DELETE", "users", nil, "ALL", nil, nil, nil, nil, "1", "100.00", nil)
mock.ExpectQuery("EXPLAIN").WillReturnRows(explainRows)
// Call HandleExplain
err := HandleExplain("DELETE FROM users WHERE id = 1", StatementTypeDelete)
// Verify results
assert.NoError(t, err)
})
t.Run("explain error", func(t *testing.T) {
// Setup mock expectations
mock.ExpectQuery("EXPLAIN").WillReturnError(fmt.Errorf("explain error"))
// Call HandleExplain
err := HandleExplain("SELECT * FROM users", StatementTypeSelect)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "explain error")
})
t.Run("no results", func(t *testing.T) {
// Setup mock expectations
explainRows := sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"})
mock.ExpectQuery("EXPLAIN").WillReturnRows(explainRows)
// Call HandleExplain
err := HandleExplain("SELECT * FROM users", StatementTypeSelect)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "unable to check query plan")
})
t.Run("type mismatch", func(t *testing.T) {
// Setup mock expectations
explainRows := sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow("1", "INSERT", "users", nil, "ALL", nil, nil, nil, nil, "1", "100.00", nil)
mock.ExpectQuery("EXPLAIN").WillReturnRows(explainRows)
// Call HandleExplain
err := HandleExplain("INSERT INTO users (name) VALUES ('test')", StatementTypeUpdate)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "query plan does not match expected pattern")
})
t.Run("scan error", func(t *testing.T) {
// Setup mock expectations
mock.ExpectQuery("EXPLAIN").WillReturnError(fmt.Errorf("scan error"))
// Call HandleExplain
err := HandleExplain("SELECT * FROM users", StatementTypeSelect)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "scan error")
})
}
/*
func TestHandleDescTable(t *testing.T) {
_, mock, cleanup := setupMockDB(t)
defer cleanup()
t.Run("successful desc", func(t *testing.T) {
// Setup mock expectations
rows := sqlmock.NewRows([]string{"Table", "Create Table"}).
AddRow("users", "CREATE TABLE `users` (`id` int(11) NOT NULL AUTO_INCREMENT, `name` varchar(255) NOT NULL, PRIMARY KEY (`id`)) ENGINE=InnoDB")
mock.ExpectQuery("SHOW CREATE TABLE").WillReturnRows(rows)
// Call HandleDescTable
result, err := HandleDescTable("users")
// Verify results
assert.NoError(t, err)
assert.Contains(t, result, "CREATE TABLE `users`")
})
t.Run("table not found", func(t *testing.T) {
// Setup mock expectations
rows := sqlmock.NewRows([]string{"Table", "Create Table"})
mock.ExpectQuery("SHOW CREATE TABLE").WillReturnRows(rows)
// Call HandleDescTable
_, err := HandleDescTable("nonexistent")
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "does not exist")
})
t.Run("query error", func(t *testing.T) {
// Setup mock expectations
mock.ExpectQuery("SHOW CREATE TABLE").WillReturnError(fmt.Errorf("query error"))
// Call HandleDescTable
_, err := HandleDescTable("users")
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "query error")
})
}
*/
func TestMapToCSV(t *testing.T) {
t.Run("successful mapping", func(t *testing.T) {
// Setup test data
data := []map[string]interface{}{
{"id": 1, "name": "test1"},
{"id": 2, "name": "test2"},
}
headers := []string{"id", "name"}
// Call MapToCSV
result, err := MapToCSV(data, headers)
// Verify results
assert.NoError(t, err)
lines := strings.Split(strings.TrimSpace(result), "\n")
assert.Len(t, lines, 3)
assert.Equal(t, "id,name", lines[0])
assert.Equal(t, "1,test1", lines[1])
assert.Equal(t, "2,test2", lines[2])
})
t.Run("missing key", func(t *testing.T) {
// Setup test data
data := []map[string]interface{}{
{"id": 1}, // missing "name"
}
headers := []string{"id", "name"}
// Call MapToCSV
_, err := MapToCSV(data, headers)
// Verify results
assert.Error(t, err)
assert.Contains(t, err.Error(), "key 'name' not found in map")
})
t.Run("empty data", func(t *testing.T) {
// Setup test data
data := []map[string]interface{}{}
headers := []string{"id", "name"}
// Call MapToCSV
result, err := MapToCSV(data, headers)
// Verify results
assert.NoError(t, err)
lines := strings.Split(strings.TrimSpace(result), "\n")
assert.Len(t, lines, 1)
assert.Equal(t, "id,name", lines[0])
})
t.Run("handles different types", func(t *testing.T) {
// Setup test data
data := []map[string]interface{}{
{"id": 1, "name": "test1", "active": true, "score": 3.14},
}
headers := []string{"id", "name", "active", "score"}
// Call MapToCSV
result, err := MapToCSV(data, headers)
// Verify results
assert.NoError(t, err)
lines := strings.Split(strings.TrimSpace(result), "\n")
assert.Len(t, lines, 2)
assert.Equal(t, "id,name,active,score", lines[0])
assert.Equal(t, "1,test1,true,3.14", lines[1])
})
t.Run("header write error", func(t *testing.T) {
// This is hard to test directly since we can't easily mock the csv.Writer
// But we can at least ensure our error handling code is covered
// by checking that the error message is correctly formatted
_ = []map[string]interface{}{}
_ = []string{"id", "name"}
// Create a mock error
mockErr := fmt.Errorf("mock header write error")
// Simulate the error by checking the error message format
errMsg := fmt.Errorf("failed to write headers: %v", mockErr).Error()
assert.Contains(t, errMsg, "failed to write headers")
assert.Contains(t, errMsg, "mock header write error")
})
t.Run("row write error", func(t *testing.T) {
// Similar to the header write error test, we're checking error message format
mockErr := fmt.Errorf("mock row write error")
errMsg := fmt.Errorf("failed to write row: %v", mockErr).Error()
assert.Contains(t, errMsg, "failed to write row")
assert.Contains(t, errMsg, "mock row write error")
})
t.Run("flush error", func(t *testing.T) {
// Similar to the other error tests, we're checking error message format
mockErr := fmt.Errorf("mock flush error")
errMsg := fmt.Errorf("error flushing CSV writer: %v", mockErr).Error()
assert.Contains(t, errMsg, "error flushing CSV writer")
assert.Contains(t, errMsg, "mock flush error")
})
}
```