# Directory Structure
```
├── .gitignore
├── .python-version
├── CLAUDE.md
├── glama.json
├── LICENSE
├── mcpserverdemo.jpg
├── pyproject.toml
├── pytest.ini
├── README.md
├── RELEASE_NOTES.md
├── src
│ └── mcp_server_starrocks
│ ├── __init__.py
│ ├── connection_health_checker.py
│ ├── db_client.py
│ ├── db_summary_manager.py
│ └── server.py
└── tests
├── __init__.py
├── README.md
└── test_db_client.py
```
# Files
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
```
3.12
```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
uv.lock
# IDE files
.idea/
.vscode/
# Exclude Mac generated files
.DS_Store
```
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
```markdown
# Tests for mcp-server-starrocks
## Prerequisites
1. **StarRocks cluster running on localhost** with default configuration:
- Host: localhost
- Port: 9030 (MySQL protocol)
- User: root
- Password: (empty)
- At least one BE node available
2. **Optional: Arrow Flight SQL enabled** (for Arrow Flight tests):
- Port: 9408 (or custom port)
- Add `arrow_flight_sql_port = 9408` to `fe.conf`
- Restart FE service
- Verify with: `python test_arrow_flight.py`
3. **Test dependencies installed**:
```bash
uv add --optional test pytest pytest-cov
```
## Running Tests
### Quick Connection Test
First, verify your StarRocks connection:
```bash
# Test MySQL connection and basic operations
python test_connection.py
# Test Arrow Flight SQL connectivity (if enabled)
python test_arrow_flight.py
```
The MySQL test will verify basic connectivity and table operations. The Arrow Flight test will diagnose Arrow Flight SQL availability and performance.
### Full Test Suite
Run the complete db_client test suite:
```bash
# Run all tests (MySQL only)
uv run pytest tests/test_db_client.py::TestDBClient -v
# Run Arrow Flight SQL tests (if enabled)
STARROCKS_FE_ARROW_FLIGHT_SQL_PORT=9408 uv run pytest tests/test_db_client.py::TestDBClientWithArrowFlight -v
# Run all tests (both MySQL and Arrow Flight if available)
uv run pytest tests/test_db_client.py -v
# Run specific test
uv run pytest tests/test_db_client.py::TestDBClient::test_execute_show_databases -v
```
### Test Coverage
The test suite covers:
- **Connection Management**: MySQL pooled connections and ADBC Arrow Flight SQL
- **Query Execution**: SELECT, DDL, DML operations with both success and error cases
- **Result Formats**: Raw ResultSet and pandas DataFrame outputs
- **Database Context**: Switching databases for queries
- **Error Handling**: Connection failures, invalid queries, malformed SQL
- **Resource Management**: Connection pooling, cursor cleanup, connection reset
- **Edge Cases**: Empty results, type conversion, schema operations
### Test Configuration
- **Single-node setup**: Tests create tables with `PROPERTIES ("replication_num" = "1")`
- **Temporary databases**: Tests create and clean up test databases automatically
- **Arrow Flight SQL**: Tests are skipped if `STARROCKS_FE_ARROW_FLIGHT_SQL_PORT` is not set
- **Isolation**: Each test uses a fresh DBClient instance with reset connections
## Test Results
When all tests pass, you should see:
```
======================== 16 passed, 2 skipped in 1.30s =========================
```
The 2 skipped tests are Arrow Flight SQL tests that only run when the environment variable is configured.
## Troubleshooting
**Connection issues**:
- Ensure StarRocks FE is running on localhost:9030
- Check that the `root` user has no password set
- Verify at least one BE node is available
**Table creation failures**:
- Single-node clusters need `replication_num=1`
- Check StarRocks logs for detailed error messages
**Import errors**:
- Ensure you're running from the project root directory
- Check that `src/mcp_server_starrocks` is in your Python path
```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
[](https://mseep.ai/app/starrocks-mcp-server-starrocks)
# StarRocks Official MCP Server
The StarRocks MCP Server acts as a bridge between AI assistants and StarRocks databases. It allows for direct SQL execution, database exploration, data visualization via charts, and retrieving detailed schema/data overviews without requiring complex client-side setup.
<a href="https://glama.ai/mcp/servers/@StarRocks/mcp-server-starrocks">
<img width="380" height="200" src="https://glama.ai/mcp/servers/@StarRocks/mcp-server-starrocks/badge" alt="StarRocks Server MCP server" />
</a>
## Features
- **Direct SQL Execution:** Run `SELECT` queries (`read_query`) and DDL/DML commands (`write_query`).
- **Database Exploration:** List databases and tables, retrieve table schemas (`starrocks://` resources).
- **System Information:** Access internal StarRocks metrics and states via the `proc://` resource path.
- **Detailed Overviews:** Get comprehensive summaries of tables (`table_overview`) or entire databases (`db_overview`), including column definitions, row counts, and sample data.
- **Data Visualization:** Execute a query and generate a Plotly chart directly from the results (`query_and_plotly_chart`).
- **Intelligent Caching:** Table and database overviews are cached in memory to speed up repeated requests. Cache can be bypassed when needed.
- **Flexible Configuration:** Set connection details and behavior via environment variables.
## Configuration
The MCP server is typically run via an MCP host. Configuration is passed to the host, specifying how to launch the StarRocks MCP server process.
**Using Streamable HTTP (recommended):**
To start the server in Streamable HTTP mode:
First test connect is ok:
```
$ STARROCKS_URL=root:@localhost:8000 uv run mcp-server-starrocks --test
```
Start the server:
```
uv run mcp-server-starrocks --mode streamable-http --port 8000
```
Then config the MCP like this:
```json
{
"mcpServers": {
"mcp-server-starrocks": {
"url": "http://localhost:8000/mcp"
}
}
}
```
**Using `uv` with installed package (individual environment variables):**
```json
{
"mcpServers": {
"mcp-server-starrocks": {
"command": "uv",
"args": ["run", "--with", "mcp-server-starrocks", "mcp-server-starrocks"],
"env": {
"STARROCKS_HOST": "default localhost",
"STARROCKS_PORT": "default 9030",
"STARROCKS_USER": "default root",
"STARROCKS_PASSWORD": "default empty",
"STARROCKS_DB": "default empty"
}
}
}
}
```
**Using `uv` with installed package (connection URL):**
```json
{
"mcpServers": {
"mcp-server-starrocks": {
"command": "uv",
"args": ["run", "--with", "mcp-server-starrocks", "mcp-server-starrocks"],
"env": {
"STARROCKS_URL": "root:password@localhost:9030/my_database"
}
}
}
}
```
**Using `uv` with local directory (for development):**
```json
{
"mcpServers": {
"mcp-server-starrocks": {
"command": "uv",
"args": [
"--directory",
"path/to/mcp-server-starrocks", // <-- Update this path
"run",
"mcp-server-starrocks"
],
"env": {
"STARROCKS_HOST": "default localhost",
"STARROCKS_PORT": "default 9030",
"STARROCKS_USER": "default root",
"STARROCKS_PASSWORD": "default empty",
"STARROCKS_DB": "default empty"
}
}
}
}
```
**Using `uv` with local directory and connection URL:**
```json
{
"mcpServers": {
"mcp-server-starrocks": {
"command": "uv",
"args": [
"--directory",
"path/to/mcp-server-starrocks", // <-- Update this path
"run",
"mcp-server-starrocks"
],
"env": {
"STARROCKS_URL": "root:password@localhost:9030/my_database"
}
}
}
}
```
**Command-line Arguments:**
The server supports the following command-line arguments:
```bash
uv run mcp-server-starrocks --help
```
- `--mode {stdio,sse,http,streamable-http}`: Transport mode (default: stdio or MCP_TRANSPORT_MODE env var)
- `--host HOST`: Server host for HTTP modes (default: localhost)
- `--port PORT`: Server port for HTTP modes
- `--test`: Run in test mode to verify functionality
Examples:
```bash
# Start in streamable HTTP mode on custom host/port
uv run mcp-server-starrocks --mode streamable-http --host 0.0.0.0 --port 8080
# Start in stdio mode (default)
uv run mcp-server-starrocks --mode stdio
# Run test mode
uv run mcp-server-starrocks --test
```
- The `url` field should point to the Streamable HTTP endpoint of your MCP server (adjust host/port as needed).
- With this configuration, clients can interact with the server using standard JSON over HTTP POST requests. No special SDK is required.
- All tool APIs accept and return standard JSON as described above.
> **Note:**
> The `sse` (Server-Sent Events) mode is deprecated and no longer maintained. Please use Streamable HTTP mode for all new integrations.
**Environment Variables:**
### Connection Configuration
You can configure StarRocks connection using either individual environment variables or a single connection URL:
**Option 1: Individual Environment Variables**
- `STARROCKS_HOST`: (Optional) Hostname or IP address of the StarRocks FE service. Defaults to `localhost`.
- `STARROCKS_PORT`: (Optional) MySQL protocol port of the StarRocks FE service. Defaults to `9030`.
- `STARROCKS_USER`: (Optional) StarRocks username. Defaults to `root`.
- `STARROCKS_PASSWORD`: (Optional) StarRocks password. Defaults to empty string.
- `STARROCKS_DB`: (Optional) Default database to use if not specified in tool arguments or resource URIs. If set, the connection will attempt to `USE` this database. Tools like `table_overview` and `db_overview` will use this if the database part is omitted in their arguments. Defaults to empty (no default database).
**Option 2: Connection URL (takes precedence over individual variables)**
- `STARROCKS_URL`: (Optional) A connection URL string that contains all connection parameters in a single variable. Format: `[<schema>://]user:password@host:port/database`. The schema part is optional. When this variable is set, it takes precedence over the individual `STARROCKS_HOST`, `STARROCKS_PORT`, `STARROCKS_USER`, `STARROCKS_PASSWORD`, and `STARROCKS_DB` variables.
Examples:
- `root:mypass@localhost:9030/test_db`
- `mysql://admin:[email protected]:9030/production`
- `starrocks://user:[email protected]:9030/analytics`
### Additional Configuration
- `STARROCKS_OVERVIEW_LIMIT`: (Optional) An _approximate_ character limit for the _total_ text generated by overview tools (`table_overview`, `db_overview`) when fetching data to populate the cache. This helps prevent excessive memory usage for very large schemas or numerous tables. Defaults to `20000`.
- `STARROCKS_MYSQL_AUTH_PLUGIN`: (Optional) Specifies the authentication plugin to use when connecting to the StarRocks FE service. For example, set to `mysql_clear_password` if your StarRocks deployment requires clear text password authentication (such as when using certain LDAP or external authentication setups). Only set this if your environment specifically requires it; otherwise, the default auth_plugin is used.
- `MCP_TRANSPORT_MODE`: (Optional) Communication mode that specifies how the MCP Server exposes its services. Available options:
- `stdio` (default): Communicates through standard input/output, suitable for MCP Host hosting.
- `streamable-http` (Streamable HTTP): Starts as a Streamable HTTP Server, supporting RESTful API calls.
- `sse`: **(Deprecated, not recommended)** Starts in Server-Sent Events (SSE) streaming mode, suitable for scenarios requiring streaming responses. **Note: SSE mode is no longer maintained, it is recommended to use Streamable HTTP mode uniformly.**
## Components
### Tools
- `read_query`
- **Description:** Execute a SELECT query or other commands that return a ResultSet (e.g., `SHOW`, `DESCRIBE`).
- **Input:**
```json
{
"query": "SQL query string",
"db": "database name (optional, uses default database if not specified)"
}
```
- **Output:** Text content containing the query results in a CSV-like format, including a header row and a row count summary. Returns an error message on failure.
- `write_query`
- **Description:** Execute a DDL (`CREATE`, `ALTER`, `DROP`), DML (`INSERT`, `UPDATE`, `DELETE`), or other StarRocks command that does not return a ResultSet.
- **Input:**
```json
{
"query": "SQL command string",
"db": "database name (optional, uses default database if not specified)"
}
```
- **Output:** Text content confirming success (e.g., "Query OK, X rows affected") or reporting an error. Changes are committed automatically on success.
- `analyze_query`
- **Description:** Analyze a query and get analyze result using query profile or explain analyze.
- **Input:**
```json
{
"uuid": "Query ID, a string composed of 32 hexadecimal digits formatted as 8-4-4-4-12",
"sql": "Query SQL to analyze",
"db": "database name (optional, uses default database if not specified)"
}
```
- **Output:** Text content containing the query analysis results. Uses `ANALYZE PROFILE FROM` if uuid is provided, otherwise uses `EXPLAIN ANALYZE` if sql is provided.
- `query_and_plotly_chart`
- **Description:** Executes a SQL query, loads the results into a Pandas DataFrame, and generates a Plotly chart using a provided Python expression. Designed for visualization in supporting UIs.
- **Input:**
```json
{
"query": "SQL query to fetch data",
"plotly_expr": "Python expression string using 'px' (Plotly Express) and 'df' (DataFrame). Example: 'px.scatter(df, x=\"col1\", y=\"col2\")'",
"db": "database name (optional, uses default database if not specified)"
}
```
- **Output:** A list containing:
1. `TextContent`: A text representation of the DataFrame and a note that the chart is for UI display.
2. `ImageContent`: The generated Plotly chart encoded as a base64 PNG image (`image/png`). Returns text error message on failure or if the query yields no data.
- `table_overview`
- **Description:** Get an overview of a specific table: columns (from `DESCRIBE`), total row count, and sample rows (`LIMIT 3`). Uses an in-memory cache unless `refresh` is true.
- **Input:**
```json
{
"table": "Table name, optionally prefixed with database name (e.g., 'db_name.table_name' or 'table_name'). If database is omitted, uses STARROCKS_DB environment variable if set.",
"refresh": false // Optional, boolean. Set to true to bypass the cache. Defaults to false.
}
```
- **Output:** Text content containing the formatted overview (columns, row count, sample data) or an error message. Cached results include previous errors if applicable.
- `db_overview`
- **Description:** Get an overview (columns, row count, sample rows) for _all_ tables within a specified database. Uses the table-level cache for each table unless `refresh` is true.
- **Input:**
```json
{
"db": "database_name", // Optional if default database is set.
"refresh": false // Optional, boolean. Set to true to bypass the cache for all tables in the DB. Defaults to false.
}
```
- **Output:** Text content containing concatenated overviews for all tables found in the database, separated by headers. Returns an error message if the database cannot be accessed or contains no tables.
### Resources
#### Direct Resources
- `starrocks:///databases`
- **Description:** Lists all databases accessible to the configured user.
- **Equivalent Query:** `SHOW DATABASES`
- **MIME Type:** `text/plain`
#### Resource Templates
- `starrocks:///{db}/{table}/schema`
- **Description:** Gets the schema definition of a specific table.
- **Equivalent Query:** `SHOW CREATE TABLE {db}.{table}`
- **MIME Type:** `text/plain`
- `starrocks:///{db}/tables`
- **Description:** Lists all tables within a specific database.
- **Equivalent Query:** `SHOW TABLES FROM {db}`
- **MIME Type:** `text/plain`
- `proc:///{+path}`
- **Description:** Accesses StarRocks internal system information, similar to Linux `/proc`. The `path` parameter specifies the desired information node.
- **Equivalent Query:** `SHOW PROC '/{path}'`
- **MIME Type:** `text/plain`
- **Common Paths:**
- `/frontends` - Information about FE nodes.
- `/backends` - Information about BE nodes (for non-cloud native deployments).
- `/compute_nodes` - Information about CN nodes (for cloud native deployments).
- `/dbs` - Information about databases.
- `/dbs/<DB_ID>` - Information about a specific database by ID.
- `/dbs/<DB_ID>/<TABLE_ID>` - Information about a specific table by ID.
- `/dbs/<DB_ID>/<TABLE_ID>/partitions` - Partition information for a table.
- `/transactions` - Transaction information grouped by database.
- `/transactions/<DB_ID>` - Transaction information for a specific database ID.
- `/transactions/<DB_ID>/running` - Running transactions for a database ID.
- `/transactions/<DB_ID>/finished` - Finished transactions for a database ID.
- `/jobs` - Information about asynchronous jobs (Schema Change, Rollup, etc.).
- `/statistic` - Statistics for each database.
- `/tasks` - Information about agent tasks.
- `/cluster_balance` - Load balance status information.
- `/routine_loads` - Information about Routine Load jobs.
- `/colocation_group` - Information about Colocation Join groups.
- `/catalog` - Information about configured catalogs (e.g., Hive, Iceberg).
### Prompts
None defined by this server.
## Caching Behavior
- The `table_overview` and `db_overview` tools utilize an in-memory cache to store the generated overview text.
- The cache key is a tuple of `(database_name, table_name)`.
- When `table_overview` is called, it checks the cache first. If a result exists and the `refresh` parameter is `false` (default), the cached result is returned immediately. Otherwise, it fetches the data from StarRocks, stores it in the cache, and then returns it.
- When `db_overview` is called, it lists all tables in the database and then attempts to retrieve the overview for _each table_ using the same caching logic as `table_overview` (checking cache first, fetching if needed and `refresh` is `false` or cache miss). If `refresh` is `true` for `db_overview`, it forces a refresh for _all_ tables in that database.
- The `STARROCKS_OVERVIEW_LIMIT` environment variable provides a _soft target_ for the maximum length of the overview string generated _per table_ when populating the cache, helping to manage memory usage.
- Cached results, including any error messages encountered during the original fetch, are stored and returned on subsequent cache hits.
## Debug
After starting mcp server, you can use inspector to debug:
```
npx @modelcontextprotocol/inspector
```
## Demo

```
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
```markdown
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
StarRocks Official MCP Server - A bridge between AI assistants and StarRocks databases, built using FastMCP framework. Enables direct SQL execution, database exploration, data visualization, and schema introspection through the Model Context Protocol (MCP).
## Development Commands
**Local Development:**
```bash
# Run the server directly for testing
uv run mcp-server-starrocks
# Run with test mode to verify table overview functionality
uv run mcp-server-starrocks --test
# Run in Streamable HTTP mode (recommended for integration)
export MCP_TRANSPORT_MODE=streamable-http
uv run mcp-server-starrocks
```
**Package Management:**
```bash
# Install dependencies (handled by uv automatically)
uv sync
# Build package
uv build
```
## Architecture Overview
### Core Components
- **`src/mcp_server_starrocks/server.py`**: Main server implementation containing all MCP tools, resources, and database connection logic
- **`src/mcp_server_starrocks/__init__.py`**: Entry point that starts the async server
### Connection Architecture
The server supports two connection modes:
- **Standard MySQL Protocol**: Default connection using `mysql.connector`
- **Arrow Flight SQL**: High-performance connection using ADBC drivers (enabled when `STARROCKS_FE_ARROW_FLIGHT_SQL_PORT` is set)
Connection management uses a global singleton pattern with automatic reconnection handling.
### Tool Categories
1. **Query Execution Tools**:
- `read_query`: Execute SELECT and other result-returning queries
- `write_query`: Execute DDL/DML commands
- `analyze_query`: Query performance analysis via EXPLAIN ANALYZE
2. **Overview Tools with Caching**:
- `table_overview`: Get table schema, row count, and sample data (cached)
- `db_overview`: Get overview of all tables in a database (uses table cache)
3. **Visualization Tool**:
- `query_and_plotly_chart`: Execute query and generate Plotly charts from results
### Resource Endpoints
- `starrocks:///databases`: List all databases
- `starrocks:///{db}/tables`: List tables in a database
- `starrocks:///{db}/{table}/schema`: Get table CREATE statement
- `proc:///{path}`: Access StarRocks internal system information (similar to Linux /proc)
### Caching System
In-memory cache for table overviews using `(database_name, table_name)` as cache keys. Cache includes both successful results and error messages. Controlled by `STARROCKS_OVERVIEW_LIMIT` environment variable (default: 20000 characters).
## Configuration
Environment variables for database connection:
- `STARROCKS_HOST`: Database host (default: localhost)
- `STARROCKS_PORT`: MySQL port (default: 9030)
- `STARROCKS_USER`: Username (default: root)
- `STARROCKS_PASSWORD`: Password (default: empty)
- `STARROCKS_DB`: Default database for session
- `STARROCKS_MYSQL_AUTH_PLUGIN`: Auth plugin (e.g., mysql_clear_password)
- `STARROCKS_FE_ARROW_FLIGHT_SQL_PORT`: Enables Arrow Flight SQL mode
- `MCP_TRANSPORT_MODE`: Communication mode (stdio/streamable-http/sse)
## Code Patterns
### Error Handling
- Database errors trigger connection reset via `reset_connection()`
- All tools return string error messages rather than raising exceptions
- Cursors are always closed in finally blocks
### Security
- SQL injection prevention through parameterized queries and backtick escaping
- Plotly expressions are validated using AST parsing to prevent code injection
- Limited `eval()` usage with restricted scope for chart generation
### Async Patterns
- Tools are defined as async functions even though database operations are synchronous
- Main server runs in async context using `FastMCP.run_async()`
## Package Structure
This is a simple Python package built with hatchling:
- Single module in `src/mcp_server_starrocks/`
- Entry point defined in pyproject.toml as `mcp-server-starrocks` command
- Dependencies managed through pyproject.toml, no requirements.txt files
```
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
```python
# Tests for mcp-server-starrocks
```
--------------------------------------------------------------------------------
/glama.json:
--------------------------------------------------------------------------------
```json
{
"$schema": "https://glama.ai/mcp/schemas/server.json",
"maintainers": [
"decster"
]
}
```
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
```
[tool:pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short
filterwarnings =
ignore::DeprecationWarning
ignore::PendingDeprecationWarning
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/__init__.py:
--------------------------------------------------------------------------------
```python
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import server
import asyncio
def main():
asyncio.run(server.main())
__all__ = ['main', 'server']
```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
```toml
[project]
name = "mcp-server-starrocks"
version = "0.2.0"
description = "official MCP server for StarRocks"
readme = "README.md"
license = {text = "Apache-2.0"}
requires-python = ">=3.10"
dependencies = [
"loguru>=0.7.3",
"fastmcp>=2.12.0,<2.13.0",
"mysql-connector-python>=9.2.0",
"pandas>=2.2.3",
"plotly>=6.0.1",
"kaleido==0.2.1",
"adbc-driver-manager>=0.8.0",
"adbc-driver-flightsql>=0.8.0",
"pyarrow>=14.0.0",
]
[project.optional-dependencies]
test = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
]
[[project.authors]]
name = "changbinglin"
email = "[email protected]"
[build-system]
requires = [ "hatchling",]
build-backend = "hatchling.build"
[project.scripts]
mcp-server-starrocks = "mcp_server_starrocks:main"
[project.urls]
Home = "https://github.com/starrocks/mcp-server-starrocks"
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/connection_health_checker.py:
--------------------------------------------------------------------------------
```python
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from loguru import logger
class ConnectionHealthChecker:
"""
A singleton class that manages database connection health monitoring.
"""
def __init__(self, db_client, check_interval=30):
"""
Initialize the connection health checker.
Args:
db_client: Database client instance for health checks
check_interval: Health check interval in seconds (default: 30)
"""
self.db_client = db_client
self.check_interval = check_interval
self._health_check_thread = None
self._health_check_stop_event = threading.Event()
self._last_connection_status = None
self._last_healthy_log = None
def check_connection_health(self):
"""
Check database connection health by executing a simple query.
Returns tuple of (is_healthy: bool, error_message: str or None)
"""
try:
result = self.db_client.execute("show databases")
if result.success:
return True, None
else:
return False, result.error_message
except Exception as e:
return False, str(e)
def _connection_health_checker_loop(self):
"""
Background thread function that periodically checks connection health.
"""
logger.info(f"Starting connection health checker (interval: {self.check_interval}s)")
while True:
is_healthy, error_msg = self.check_connection_health()
# Log status changes or periodic status updates
if self._last_connection_status != is_healthy:
if is_healthy:
logger.info("Database connection is healthy")
else:
logger.warning(f"Database connection is unhealthy: {error_msg}")
else:
# Log periodic status (every 5 minutes when healthy, every check when unhealthy)
current_time = time.time()
if is_healthy:
if self._last_healthy_log is None:
self._last_healthy_log = current_time
elif current_time - self._last_healthy_log >= 300: # 5 minutes
logger.info("Database connection remains healthy")
self._last_healthy_log = current_time
else:
logger.warning(f"Database connection remains unhealthy: {error_msg}")
self._last_connection_status = is_healthy
# Wait for interval or stop event
if self._health_check_stop_event.wait(self.check_interval):
break
logger.info("Connection health checker stopped")
def start(self):
"""
Start the connection health checker thread.
"""
if self._health_check_thread is None or not self._health_check_thread.is_alive():
self._health_check_stop_event.clear()
self._health_check_thread = threading.Thread(
target=self._connection_health_checker_loop,
name="ConnectionHealthChecker",
daemon=True
)
self._health_check_thread.start()
logger.info("Connection health checker thread started")
def stop(self):
"""
Stop the connection health checker thread.
"""
if self._health_check_thread is not None:
self._health_check_stop_event.set()
self._health_check_thread.join(timeout=5)
if self._health_check_thread.is_alive():
logger.warning("Connection health checker thread did not stop gracefully")
else:
logger.info("Connection health checker thread stopped")
self._health_check_thread = None
# Global instance - will be initialized in server.py
_health_checker_instance = None
def initialize_health_checker(db_client, check_interval=30):
"""
Initialize the global connection health checker instance.
Args:
db_client: Database client instance
check_interval: Health check interval in seconds
"""
global _health_checker_instance
_health_checker_instance = ConnectionHealthChecker(db_client, check_interval)
return _health_checker_instance
def start_connection_health_checker():
"""
Start the connection health checker thread.
"""
if _health_checker_instance is None:
raise RuntimeError("Health checker not initialized. Call initialize_health_checker() first.")
_health_checker_instance.start()
def stop_connection_health_checker():
"""
Stop the connection health checker thread.
"""
if _health_checker_instance is not None:
_health_checker_instance.stop()
def check_connection_health():
"""
Check database connection health by executing a simple query.
Returns tuple of (is_healthy: bool, error_message: str or None)
"""
if _health_checker_instance is None:
raise RuntimeError("Health checker not initialized. Call initialize_health_checker() first.")
return _health_checker_instance.check_connection_health()
```
--------------------------------------------------------------------------------
/RELEASE_NOTES.md:
--------------------------------------------------------------------------------
```markdown
# StarRocks MCP Server Release Notes
## Version 0.2.0
### Major Features and Enhancements
1. **Enhanced STARROCKS_URL Parsing** (commit 80ac0ba)
- Support for flexible connection URL formats including empty passwords
- Handle patterns like "root:@localhost:9030" and "root@localhost:9030"
- Support missing ports with default 9030: "root:password@localhost"
- Support minimal format: "user@host" with empty password and default port
- Maintain backward compatibility with existing valid URLs
- Comprehensive test coverage for edge cases
- Fixed DBClient to properly convert string port to integer
2. **Connection Health Monitoring** (commit b8a80c6)
- Added new connection_health_checker.py module
- Implemented health checking functionality for database connections
- Enhanced connection reliability and monitoring capabilities
- Proactive connection health management
3. **Visualization Enhancements** (commit b6f26ec)
- Added format parameter to query_and_plotly_chart tool
- Enhanced chart generation capabilities with configurable output formats
- Improved flexibility for data visualization workflows
### Testing and Infrastructure
- Added comprehensive test coverage for STARROCKS_URL parsing edge cases
- Enhanced test suite with new test cases for database client functionality
- Improved error handling and validation for connection scenarios
### Breaking Changes
None - this release maintains full backward compatibility with version 0.1.5.
## Version 0.1.5
Major Features and Enhancements
1. Connection Pooling and Architecture Refactor (commit 0fc372d)
- Major refactor introducing connection pooling for improved performance
- Extracted database client logic into separate db_client.py module
- Enhanced connection management and reliability
2. Enhanced Arrow Flight SQL Support (commit 877338f)
- Improved Arrow Flight SQL connection handling
- Better result processing for high-performance queries
- Enhanced error handling for Arrow Flight connections
3. New Query Analysis Tools (commit 60ca975)
- Added collect_query_dump_and_profile functionality
- Enhanced query performance analysis capabilities
4. Database Summary Management (commits d269ebe, 5b2ca59)
- Added new db_summary_manager.py module
- Implemented database summary functionality for better overview capabilities
- Enhanced database exploration features
5. Configuration Enhancements (commit fb09271)
- Added STARROCKS_URL configuration option
- Improved connection configuration flexibility
Testing and Infrastructure
- Updated test suite with new test cases for database client functionality
- Added comprehensive testing for Arrow Flight SQL features
- Improved test infrastructure with new README documentation
Breaking Changes
- Major refactor may require configuration updates for some deployment scenarios
- Connection handling has been restructured (though backwards compatibility is maintained)
## Version 0.1.4
## Version 0.1.3
1. refactor using fastmcp
2. add new config STARROCKS_MYSQL_AUTH_PLUGIN
## Version 0.1.2
Fix accidental extra import of sqlalalchemy
## Version 0.1.1
1. add new tool query_and_plotly_chart
2. add new tool table_overview & db_overview
3. add env config STARROCKS_DB and STARROCKS_OVERVIEW_LIMIT, both optional
## Version 0.1.0 (Initial Release)
We are excited to announce the first release of the StarRocks MCP (Model Context Protocol) Server. This server enables AI assistants to interact directly with StarRocks databases, providing a seamless interface for executing queries and retrieving database information.
### Description
The StarRocks MCP Server acts as a bridge between AI assistants and StarRocks databases, allowing for direct SQL execution and database exploration without requiring complex setup or configuration. This initial release provides essential functionality for database interaction while maintaining security and performance.
### Features
- **SQL Query Execution**
- `read_query` tool for executing SELECT queries and commands that return result sets
- `write_query` tool for executing DDL/DML statements and other StarRocks commands
- Proper error handling and connection management
- **Database Exploration**
- List all databases in a StarRocks instance
- View table schemas using SHOW CREATE TABLE
- List all tables within a specific database
- **System Information Access**
- Access to StarRocks internal system information via proc-like interface
- Visibility into FE nodes, BE nodes, CN nodes, databases, tables, partitions, transactions, jobs, and more
- **Flexible Configuration**
- Configurable connection parameters (host, port, user, password)
- Support for both package installation and local directory execution
### Requirements
- Python 3.10 or higher
- Dependencies:
- mcp >= 1.0.0
- mysql-connector-python >= 9.2.0
### Configuration
The server can be configured through environment variables:
- `STARROCKS_HOST` (default: localhost)
- `STARROCKS_PORT` (default: 9030)
- `STARROCKS_USER` (default: root)
- `STARROCKS_PASSWORD` (default: empty)
- `STARROCKS_MYSQL_AUTH_PLUGIN` (default: mysql_native_password) user can also pass different auth plugins like `mysql_clear_password`
### Installation
The server can be installed as a Python package:
```bash
pip install mcp-server-starrocks
```
Or run directly from the source:
```bash
uv --directory path/to/mcp-server-starrocks run mcp-server-starrocks
```
### MCP Integration
Add the following configuration to your MCP settings file:
```json
{
"mcpServers": {
"mcp-server-starrocks": {
"command": "uv",
"args": [
"run",
"--with",
"mcp-server-starrocks",
"mcp-server-starrocks"
],
"env": {
"STARROCKS_HOST": "localhost",
"STARROCKS_PORT": "9030",
"STARROCKS_USER": "root",
"STARROCKS_PASSWORD": "",
"STARROCKS_MYSQL_AUTH_PLUGIN":"mysql_clear_password"
}
}
}
}
```
---
We welcome feedback and contributions to improve the StarRocks MCP Server. Please report any issues or suggestions through our GitHub repository.
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/db_summary_manager.py:
--------------------------------------------------------------------------------
```python
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from loguru import logger
@dataclass
class ColumnInfo:
name: str
column_type: str
ordinal_position: int
@dataclass
class TableInfo:
name: str
database: str
size_bytes: int = 0
size_str: str = ""
replica_count: int = 0
columns: List[ColumnInfo] = field(default_factory=list)
create_statement: Optional[str] = None
last_updated: float = 0
error_message: Optional[str] = None
def __post_init__(self):
if not self.last_updated:
self.last_updated = time.time()
@staticmethod
def parse_size_string(size_str: str) -> int:
"""Parse size strings like '1.285 GB', '714.433 MB', '2.269 KB' to bytes"""
if not size_str or size_str == "0" or size_str.lower() == "total":
return 0
# Handle special cases
if size_str.lower() in ["quota", "left"]:
return 0
# Match pattern like "1.285 GB"
match = re.match(r'([\d.]+)\s*([KMGT]?B)', size_str.strip(), re.IGNORECASE)
if not match:
return 0
value, unit = match.groups()
try:
num_value = float(value)
except ValueError:
return 0
multipliers = {
'B': 1,
'KB': 1024,
'MB': 1024 ** 2,
'GB': 1024 ** 3,
'TB': 1024 ** 4
}
multiplier = multipliers.get(unit.upper(), 1)
return int(num_value * multiplier)
def is_large_table(self) -> bool:
"""Determine if table is considered large (replica_count > 64 OR size > 2GB)"""
return self.replica_count > 64 or self.size_bytes > (2 * 1024 ** 3)
def priority_score(self) -> float:
"""Calculate priority score combining size and replica count for sorting"""
# Normalize size to GB and combine with replica count
size_gb = self.size_bytes / (1024 ** 3)
return size_gb + (self.replica_count * 0.1) # Weight replica count less than size
def is_expired(self, expire_seconds: int = 120) -> bool:
"""Check if cache entry is expired (default 2 minutes)"""
return time.time() - self.last_updated > expire_seconds
class DatabaseSummaryManager:
def __init__(self, db_client):
self.db_client = db_client
# Cache: {(database, table_name): TableInfo}
self.table_cache: Dict[Tuple[str, str], TableInfo] = {}
# Database last sync time: {database: timestamp}
self.db_last_sync: Dict[str, float] = {}
def _sync_table_list(self, database: str, force: bool = False) -> bool:
"""Sync table list using SHOW DATA, detect new/dropped tables"""
current_time = time.time()
# Check if sync is needed (2min expiration or force)
if not force and database in self.db_last_sync:
if current_time - self.db_last_sync[database] < 120:
return True
logger.debug(f"Syncing table list for database {database}")
try:
# Execute SHOW DATA to get current table list with sizes
result = self.db_client.execute("SHOW DATA", db=database)
if not result.success:
logger.error(f"Failed to sync table list for {database}: {result.error_message}")
return False
if not result.rows:
logger.info(f"No tables found in database {database}")
# Clear cache for this database
keys_to_remove = [key for key in self.table_cache.keys() if key[0] == database]
for key in keys_to_remove:
del self.table_cache[key]
self.db_last_sync[database] = current_time
return True
# Parse current tables from SHOW DATA
current_tables = {}
for row in result.rows:
table_name = row[0]
# Skip summary rows (Total, Quota, Left)
if table_name.lower() in ['total', 'quota', 'left']:
continue
size_str = row[1] if len(row) > 1 else ""
replica_count = int(row[2]) if len(row) > 2 and str(row[2]).isdigit() else 0
size_bytes = TableInfo.parse_size_string(size_str)
current_tables[table_name] = {
'size_str': size_str,
'size_bytes': size_bytes,
'replica_count': replica_count
}
# Update cache: add new tables, update existing, remove dropped
cache_keys_for_db = {key[1]: key for key in self.table_cache.keys() if key[0] == database}
# Add or update existing tables
for table_name, table_data in current_tables.items():
cache_key = (database, table_name)
if cache_key in self.table_cache:
# Update existing table info
table_info = self.table_cache[cache_key]
table_info.size_str = table_data['size_str']
table_info.size_bytes = table_data['size_bytes']
table_info.replica_count = table_data['replica_count']
table_info.last_updated = current_time
else:
# Create new table info
self.table_cache[cache_key] = TableInfo(
name=table_name,
database=database,
size_str=table_data['size_str'],
size_bytes=table_data['size_bytes'],
replica_count=table_data['replica_count'],
last_updated=current_time
)
# Remove dropped tables
for table_name in cache_keys_for_db:
if table_name not in current_tables:
cache_key = cache_keys_for_db[table_name]
del self.table_cache[cache_key]
logger.debug(f"Removed dropped table {database}.{table_name} from cache")
self.db_last_sync[database] = current_time
logger.debug(f"Synced {len(current_tables)} tables for database {database}")
return True
except Exception as e:
logger.error(f"Error syncing table list for {database}: {e}")
return False
def _fetch_column_info(self, database: str, tables: List[str]) -> Dict[str, List[ColumnInfo]]:
"""Fetch column information for all tables using information_schema.columns"""
if not tables:
return {}
logger.debug(f"Fetching column info for {len(tables)} tables in {database}")
try:
# Build query to get column information for all tables
table_names_quoted = "', '".join(tables)
query = f"""
SELECT table_name, column_name, ordinal_position, column_type
FROM information_schema.columns
WHERE table_schema = '{database}'
AND table_name IN ('{table_names_quoted}')
ORDER BY table_name, ordinal_position
"""
result = self.db_client.execute(query)
if not result.success:
logger.error(f"Failed to fetch column info: {result.error_message}")
return {}
# Group columns by table
table_columns = {}
for row in result.rows:
table_name = row[0]
column_name = row[1]
ordinal_position = int(row[2]) if row[2] else 0
column_type = 'string' if row[3] == "varchar(65533)" else row[3]
if table_name not in table_columns:
table_columns[table_name] = []
table_columns[table_name].append(ColumnInfo(
name=column_name,
column_type=column_type,
ordinal_position=ordinal_position
))
logger.debug(f"Fetched column info for {len(table_columns)} tables")
return table_columns
except Exception as e:
logger.error(f"Error fetching column information: {e}")
return {}
def _fetch_create_statement(self, database: str, table: str) -> Optional[str]:
"""Fetch CREATE TABLE statement for large tables"""
try:
result = self.db_client.execute(f"SHOW CREATE TABLE `{database}`.`{table}`")
if result.success and result.rows and len(result.rows[0]) > 1:
return result.rows[0][1] # Second column contains CREATE statement
except Exception as e:
logger.error(f"Error fetching CREATE statement for {database}.{table}: {e}")
return None
def get_database_summary(self, database: str, limit: int = 10000, refresh: bool = False) -> str:
"""Generate comprehensive database summary with intelligent prioritization"""
if not database:
return "Error: Database name is required"
logger.info(f"Generating database summary for {database}, limit={limit}, refresh={refresh}")
# Sync table list
if refresh or not self._sync_table_list(database):
return f"Error: Failed to sync table information for database '{database}'"
# Get all tables for this database from cache
tables_info = []
for (db, table_name), table_info in self.table_cache.items():
if db == database:
tables_info.append(table_info)
if not tables_info:
return f"No tables found in database '{database}'"
# Sort tables by priority (large tables first)
tables_info.sort(key=lambda t: t.priority_score(), reverse=True)
# Check if any table needs column information refresh
need_column_refresh = refresh or any(not table_info.columns or table_info.is_expired() for table_info in tables_info)
# If any table needs refresh, fetch ALL tables' columns in one query (more efficient)
if need_column_refresh:
all_table_names = [table_info.name for table_info in tables_info]
table_columns = self._fetch_column_info(database, all_table_names)
# Update cache with column information for all tables
current_time = time.time()
for table_info in tables_info:
if table_info.name in table_columns:
table_info.columns = table_columns[table_info.name]
table_info.last_updated = current_time
# Identify large tables that need CREATE statements
large_tables = [t for t in tables_info if t.is_large_table()][:10] # Top 10 large tables
for table_info in large_tables:
if refresh or not table_info.create_statement:
table_info.create_statement = self._fetch_create_statement(database, table_info.name)
table_info.last_updated = time.time()
# Generate summary output
return self._format_database_summary(database, tables_info, limit)
def _format_database_summary(self, database: str, tables_info: List[TableInfo], limit: int) -> str:
"""Format database summary with intelligent truncation"""
lines = []
lines.append(f"=== Database Summary: '{database}' ===")
lines.append(f"Total tables: {len(tables_info)}")
# Calculate totals
total_size = sum(t.size_bytes for t in tables_info)
total_replicas = sum(t.replica_count for t in tables_info)
large_tables = [t for t in tables_info if t.is_large_table()]
lines.append(f"Total size: {self._format_bytes(total_size)}")
current_length = len("\n".join(lines))
table_limit = min(len(tables_info), 50) # Show max 50 tables
# Show large tables first with full details
if large_tables:
for i, table_info in enumerate(large_tables):
if current_length > limit * 0.8: # Reserve 20% for smaller tables
lines.append(f"... and {len(large_tables) - i} more large tables")
break
table_summary = self._format_table_info(table_info, detailed=True)
lines.append(table_summary)
lines.append("")
current_length = len("\n".join(lines))
# Show remaining tables with basic info
remaining_tables = [t for t in tables_info if not t.is_large_table()]
if remaining_tables and current_length < limit:
lines.append("--- Other Tables ---")
for i, table_info in enumerate(remaining_tables):
if current_length > limit:
lines.append(f"... and {len(remaining_tables) - i} more tables (use higher limit to see all)")
break
table_summary = self._format_table_info(table_info, detailed=False)
lines.append(table_summary)
current_length = len("\n".join(lines))
return "\n".join(lines)
def _format_table_info(self, table_info: TableInfo, detailed: bool = True) -> str:
"""Format individual table information"""
lines = []
# Basic info line
size_info = f"{table_info.size_str} ({table_info.replica_count} replicas)"
lines.append(f"Table: {table_info.name} - {size_info}")
if table_info.error_message:
lines.append(f" Error: {table_info.error_message}")
return "\n".join(lines)
# Show CREATE statement if available, otherwise show column list
if table_info.create_statement:
lines.append(table_info.create_statement)
elif table_info.columns:
# Sort columns by ordinal position and show as list
sorted_columns = sorted(table_info.columns, key=lambda c: c.ordinal_position)
if detailed or len(sorted_columns) <= 20:
for col in sorted_columns:
lines.append(f" {col.name} {col.column_type}")
else:
lines.append(f" Columns ({len(sorted_columns)}): {', '.join(col.name for col in sorted_columns[:100])}...")
return "\n".join(lines)
@staticmethod
def _format_bytes(bytes_count: int) -> str:
"""Format bytes to human readable string"""
if bytes_count == 0:
return "0 B"
units = ['B', 'KB', 'MB', 'GB', 'TB']
unit_index = 0
size = float(bytes_count)
while size >= 1024 and unit_index < len(units) - 1:
size /= 1024
unit_index += 1
if unit_index == 0:
return f"{int(size)} {units[unit_index]}"
else:
return f"{size:.2f} {units[unit_index]}"
def clear_cache(self, database: Optional[str] = None):
"""Clear cache for specific database or all databases"""
if database:
keys_to_remove = [key for key in self.table_cache.keys() if key[0] == database]
for key in keys_to_remove:
del self.table_cache[key]
if database in self.db_last_sync:
del self.db_last_sync[database]
logger.info(f"Cleared cache for database {database}")
else:
self.table_cache.clear()
self.db_last_sync.clear()
logger.info("Cleared all cache")
# Global instance (will be initialized in server.py)
_db_summary_manager: Optional[DatabaseSummaryManager] = None
def get_db_summary_manager(db_client) -> DatabaseSummaryManager:
"""Get or create global database summary manager instance"""
global _db_summary_manager
if _db_summary_manager is None:
_db_summary_manager = DatabaseSummaryManager(db_client)
return _db_summary_manager
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/db_client.py:
--------------------------------------------------------------------------------
```python
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import time
import re
import json
from typing import Optional, List, Any, Union, Literal, TypedDict, NotRequired
from dataclasses import dataclass
import mysql.connector
from mysql.connector import Error as MySQLError
import adbc_driver_manager
import adbc_driver_flightsql.dbapi as flight_sql
from adbc_driver_manager import Error as adbcError
import pandas as pd
@dataclass
class ResultSet:
"""Database query result set."""
success: bool
column_names: Optional[List[str]] = None
rows: Optional[List[List[Any]]] = None
rows_affected: Optional[int] = None
execution_time: Optional[float] = None
error_message: Optional[str] = None
pandas: Optional[pd.DataFrame] = None
def to_pandas(self) -> pd.DataFrame:
"""Convert ResultSet to pandas DataFrame."""
if self.pandas is not None:
return self.pandas
if not self.success:
raise ValueError(f"Cannot convert failed result to DataFrame: {self.error_message}")
if self.column_names is None or self.rows is None:
raise ValueError("No data available to convert to DataFrame")
return pd.DataFrame(self.rows, columns=self.column_names)
def to_string(self, limit: Optional[int] = None) -> str:
"""Format rows as CSV-like string with column names as first row."""
if not self.success:
return f"Error: {self.error_message}"
if self.column_names is None or self.rows is None:
return "No data"
def to_csv_line(row):
return ",".join(
str(item).replace("\"", "\"\"") if isinstance(item, str) else str(item) for item in row)
output = io.StringIO()
output.write(to_csv_line(self.column_names) + "\n")
for row in self.rows:
line = to_csv_line(row) + "\n"
if limit is not None and output.tell() + len(line) > limit:
output.write("...\n")
break
output.write(line)
output.write(f"Total rows: {len(self.rows)}\n")
output.write(f"Execution time: {self.execution_time:.3f}s\n");
return output.getvalue()
def to_dict(self) -> dict:
ret = {
"success": self.success,
"execution_time": self.execution_time,
}
if self.column_names is not None:
ret["column_names"] = self.column_names
ret["rows"] = self.rows
if self.rows_affected is not None:
ret["rows_affected"] = self.rows_affected
if self.error_message:
ret["error_message"] = self.error_message
return ret
class PerfAnalysisInput(TypedDict):
error_message: NotRequired[Optional[str]]
query_id: NotRequired[Optional[str]]
rows_returned: NotRequired[Optional[int]]
duration: NotRequired[Optional[float]]
query_dump: NotRequired[Optional[dict]]
profile: NotRequired[Optional[str]]
analyze_profile: NotRequired[Optional[str]]
def parse_connection_url(connection_url: str) -> dict:
"""
Parse connection URL into dict with user, password, host, port, database.
Supports flexible formats:
- [<schema>://]<user>[:<password>]@<host>[:<port>][/<database>]
- Empty passwords: user:@host:port or user@host:port
- Missing ports (uses default 9030): user:pass@host
- All components are optional except user and host
"""
# More flexible regex pattern that handles optional password and port
pattern = re.compile(
r'^(?:(?P<schema>[\w+]+)://)?' # Optional schema://
r'(?P<user>[^:@]+)' # Required username (no : or @)
r'(?::(?P<password>[^@]*))?' # Optional :password (can be empty)
r'@(?P<host>[^:/]+)' # Required @host
r'(?::(?P<port>\d+))?' # Optional :port
r'(?:/(?P<database>[\w-]+))?$' # Optional /database
)
match = pattern.match(connection_url)
if not match:
raise ValueError(f"Invalid connection URL: {connection_url}")
result = match.groupdict()
# Only keep connection parameters that mysql.connector supports
# Filter out None values and schema (which is not a mysql.connector parameter)
filtered_result = {}
# Always include user and host as they are required
filtered_result['user'] = result['user']
filtered_result['host'] = result['host']
# Include password (default to empty string if None)
filtered_result['password'] = result['password'] if result['password'] is not None else ''
# Include port (default to 9030 if None)
filtered_result['port'] = result['port'] if result['port'] is not None else '9030'
# Always include database (None if not provided in URL)
filtered_result['database'] = result['database']
# Note: schema is intentionally excluded as it's not supported by mysql.connector
return filtered_result
ANSI_ESCAPE_PATTERN = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
def remove_ansi_codes(text):
return ANSI_ESCAPE_PATTERN.sub('', text)
class DBClient:
"""Simplified database client for StarRocks connection and query execution."""
def __init__(self):
self.enable_dummy_test = bool(os.getenv('STARROCKS_DUMMY_TEST'))
self.enable_arrow_flight_sql = bool(os.getenv('STARROCKS_FE_ARROW_FLIGHT_SQL_PORT'))
if os.getenv('STARROCKS_URL'):
self.connection_params = parse_connection_url(os.getenv('STARROCKS_URL'))
# Convert port to integer for mysql.connector
self.connection_params['port'] = int(self.connection_params['port'])
else:
self.connection_params = {
'host': os.getenv('STARROCKS_HOST', 'localhost'),
'port': int(os.getenv('STARROCKS_PORT', '9030')),
'user': os.getenv('STARROCKS_USER', 'root'),
'password': os.getenv('STARROCKS_PASSWORD', ''),
'database': os.getenv('STARROCKS_DB', None),
}
self.connection_params.update(**{
'auth_plugin': os.getenv('STARROCKS_MYSQL_AUTH_PLUGIN', 'mysql_native_password'),
'pool_size': int(os.getenv('STARROCKS_POOL_SIZE', '10')),
'pool_name': 'mcp_starrocks_pool',
'pool_reset_session': True,
'autocommit': True,
'connection_timeout': int(os.getenv('STARROCKS_CONNECTION_TIMEOUT', '10')),
'connect_timeout': int(os.getenv('STARROCKS_CONNECTION_TIMEOUT', '10')),
})
self.default_database = self.connection_params.get('database')
# MySQL connection pool
self._connection_pool = None
# ADBC connection (singleton)
self._adbc_connection = None
def _get_connection_pool(self):
"""Get or create a connection pool for MySQL connections."""
if self._connection_pool is None:
try:
self._connection_pool = mysql.connector.pooling.MySQLConnectionPool(**self.connection_params)
except MySQLError as conn_err:
raise conn_err
return self._connection_pool
def _validate_connection(self, conn):
"""Validate that a MySQL connection is still alive and working."""
try:
conn.ping(reconnect=True, attempts=1, delay=0)
return True
except MySQLError:
return False
def _get_pooled_connection(self):
"""Get a MySQL connection from the pool with timeout and retry logic."""
pool = self._get_connection_pool()
try:
conn = pool.get_connection()
if not self._validate_connection(conn):
conn.close()
conn = pool.get_connection()
return conn
except mysql.connector.errors.PoolError as pool_err:
if "Pool is exhausted" in str(pool_err):
time.sleep(0.1)
try:
return pool.get_connection()
except mysql.connector.errors.PoolError:
self._connection_pool = None
new_pool = self._get_connection_pool()
return new_pool.get_connection()
raise pool_err
def _create_adbc_connection(self):
"""Create a new ADBC connection."""
fe_host = self.connection_params['host']
fe_port = os.getenv('STARROCKS_FE_ARROW_FLIGHT_SQL_PORT', '')
user = self.connection_params['user']
password = self.connection_params['password']
try:
connection = flight_sql.connect(
uri=f"grpc://{fe_host}:{fe_port}",
db_kwargs={
adbc_driver_manager.DatabaseOptions.USERNAME.value: user,
adbc_driver_manager.DatabaseOptions.PASSWORD.value: password,
}
)
# Switch to default database if set
if self.default_database:
try:
cursor = connection.cursor()
cursor.execute(f"USE {self.default_database}")
cursor.close()
except adbcError as db_err:
print(f"Warning: Could not switch to default database '{self.default_database}': {db_err}")
return connection
except adbcError:
print(f"Error creating ADBC connection: {adbcError}")
raise
def _get_adbc_connection(self):
"""Get or create an ADBC connection with health check."""
if self._adbc_connection is None:
self._adbc_connection = self._create_adbc_connection()
# Health check for ADBC connection
if self._adbc_connection is not None:
try:
self._adbc_connection.adbc_get_info()
except adbcError as check_err:
print(f"Connection check failed: {check_err}, creating new ADBC connection.")
self._reset_adbc_connection()
self._adbc_connection = self._create_adbc_connection()
return self._adbc_connection
def _get_connection(self):
"""Get appropriate connection based on configuration."""
if self.enable_arrow_flight_sql:
return self._get_adbc_connection()
else:
return self._get_pooled_connection()
def _reset_adbc_connection(self):
"""Reset ADBC connection."""
if self._adbc_connection is not None:
try:
self._adbc_connection.close()
except Exception as e:
print(f"Error closing ADBC connection: {e}")
finally:
self._adbc_connection = None
def _reset_connection(self):
"""Reset connections based on configuration."""
if self.enable_arrow_flight_sql:
self._reset_adbc_connection()
else:
self._connection_pool = None
def _handle_db_error(self, error):
"""Handle database errors and reset connections as needed."""
if not self.enable_arrow_flight_sql and ("MySQL Connection not available" in str(error) or "Lost connection" in str(error)):
self._connection_pool = None
elif self.enable_arrow_flight_sql:
self._reset_adbc_connection()
def _execute(self, conn, statement: str, params=None, return_format:str="raw") -> ResultSet:
cursor = None
start_time = time.time()
try:
cursor = conn.cursor()
cursor.execute(statement, params)
# Initialize variables to track the last result set
last_result = None
last_affected_rows = None
# Process first result set
if cursor.description:
column_names = [desc[0] for desc in cursor.description]
if self.enable_arrow_flight_sql:
arrow_result = cursor.fetchallarrow()
pandas_df = arrow_result.to_pandas() if return_format == "pandas" else None
rows = arrow_result.to_pandas().values.tolist()
# Check if this is a status result for DML operations (INSERT/UPDATE/DELETE)
# Arrow Flight SQL returns status results as a single column 'StatusResult'
# Note: StarRocks Arrow Flight SQL seems to always return '0' in StatusResult,
# so we use cursor.rowcount when available as a fallback
if (len(column_names) == 1 and column_names[0] == 'StatusResult' and
len(rows) == 1 and len(rows[0]) == 1):
try:
status_value = int(rows[0][0])
# If status_value is 0 but we have cursor.rowcount, prefer that
if status_value == 0 and hasattr(cursor, 'rowcount') and cursor.rowcount > 0:
last_affected_rows = cursor.rowcount
else:
last_affected_rows = status_value
last_result = None # Don't treat this as a regular result set
except (ValueError, TypeError):
# If we can't parse the status result as an integer, treat it as a regular result
last_result = ResultSet(
success=True,
column_names=column_names,
rows=rows,
execution_time=0, # Will be set at the end
pandas=pandas_df
)
else:
last_result = ResultSet(
success=True,
column_names=column_names,
rows=rows,
execution_time=0, # Will be set at the end
pandas=pandas_df
)
else:
rows = cursor.fetchall()
pandas_df = pd.DataFrame(rows, columns=column_names) if return_format == "pandas" else None
last_result = ResultSet(
success=True,
column_names=column_names,
rows=rows,
execution_time=0, # Will be set at the end
pandas=pandas_df
)
else:
last_affected_rows = cursor.rowcount if cursor.rowcount >= 0 else None
# Process additional result sets (for multi-statement queries)
# Note: Arrow Flight SQL may not support nextset(), so we check for it
if not self.enable_arrow_flight_sql and hasattr(cursor, 'nextset'):
while cursor.nextset():
if cursor.description:
column_names = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
pandas_df = pd.DataFrame(rows, columns=column_names) if return_format == "pandas" else None
last_result = ResultSet(
success=True,
column_names=column_names,
rows=rows,
execution_time=0, # Will be set at the end
pandas=pandas_df
)
else:
last_affected_rows = cursor.rowcount if cursor.rowcount >= 0 else None
last_result = None
# Return the last result set found
if last_result is not None:
last_result.execution_time = time.time() - start_time
return last_result
else:
return ResultSet(
success=True,
rows_affected=last_affected_rows,
execution_time=time.time() - start_time
)
except (MySQLError, adbcError) as e:
self._handle_db_error(e)
return ResultSet(
success=False,
error_message=f"Error executing statement '{statement}': {str(e)}",
execution_time=time.time() - start_time
)
except Exception as e:
return ResultSet(
success=False,
error_message=f"Unexpected error executing statement '{statement}': {str(e)}",
execution_time=time.time() - start_time
)
finally:
if cursor:
try:
cursor.close()
except:
pass
def execute(
self,
statement: str,
db: Optional[str] = None,
return_format: Literal["raw", "pandas"] = "raw"
) -> ResultSet:
"""
Execute a SQL statement and return results.
Args:
statement: SQL statement to execute
db: Optional database to use (overrides default)
return_format: "raw" returns ResultSet with rows, "pandas" also populates pandas field
Returns:
ResultSet with column_names and rows, optionally with pandas DataFrame
"""
# If dummy test mode is enabled, return dummy data without connecting to database
if self.enable_dummy_test:
column_names = ['name']
rows = [['aaa'], ['bbb'], ['ccc']]
pandas_df = None
if return_format == "pandas":
pandas_df = pd.DataFrame(rows, columns=column_names)
return ResultSet(
success=True,
column_names=column_names,
rows=rows,
execution_time=0.1,
pandas=pandas_df
)
conn = None
try:
conn = self._get_connection()
# Switch database if specified
if db and db != self.default_database:
cursor_temp = conn.cursor()
try:
cursor_temp.execute(f"USE `{db}`")
except (MySQLError, adbcError) as db_err:
cursor_temp.close()
return ResultSet(
success=False,
error_message=f"Error switching to database '{db}': {str(db_err)}",
execution_time=0
)
cursor_temp.close()
return self._execute(conn, statement, None, return_format)
except (MySQLError, adbcError) as e:
self._handle_db_error(e)
return ResultSet(
success=False,
error_message=f"Error executing statement '{statement}': {str(e)}",
)
except Exception as e:
return ResultSet(
success=False,
error_message=f"Unexpected error executing statement '{statement}': {str(e)}",
)
finally:
if conn and not self.enable_arrow_flight_sql:
try:
conn.close()
except:
pass
def collect_perf_analysis_input(self, query: str, db:Optional[str]=None) -> PerfAnalysisInput:
conn = None
try:
conn = self._get_connection()
# Switch database if specified
if db and db != self.default_database:
cursor_temp = conn.cursor()
try:
cursor_temp.execute(f"USE `{db}`")
except (MySQLError, adbcError) as db_err:
return {"error_message":str(db_err)}
finally:
cursor_temp.close()
query_dump_result = self._execute(conn, "select get_query_dump(%s, %s)", (query, False))
if not query_dump_result.success:
return {"error_message":query_dump_result.error_message}
ret = {
"query_dump": json.loads(query_dump_result.rows[0][0]),
}
start_ts = time.time()
profile_query = "/*+ SET_VAR (enable_profile='true') */ " + query
query_result = self._execute(conn, profile_query)
duration = time.time() - start_ts
ret["duration"] = duration
if not query_result.success:
ret["error_message"] = query_result.error_message
return ret
ret["rows_returned"] = len(query_result.rows) if query_result.rows else 0
# Try to get query id
query_id_result = self._execute(conn, "select last_query_id()")
if not query_id_result.success:
ret["error_message"] = query_id_result.error_message
return ret
ret["query_id"] = query_id_result.rows[0][0]
# Try to get query profile with retries
query_profile = ''
retry_count = 0
while not query_profile and retry_count < 3:
time.sleep(1+retry_count)
query_profile_result = self._execute(conn,"select get_query_profile(%s)", (ret["query_id"],))
if query_profile_result.success:
query_profile = query_profile_result.rows[0][0]
retry_count += 1
if not query_profile:
ret['error_message'] = "Failed to get query profile after 3 retries"
return ret
ret['profile'] = query_profile
analyze_profile_result = self._execute(conn,"ANALYZE PROFILE FROM %s", (ret["query_id"],))
if not analyze_profile_result.success:
ret["error_message"] = analyze_profile_result.error_message
return ret
analyze_text = '\n'.join(row[0] for row in analyze_profile_result.rows)
ret['analyze_profile'] = remove_ansi_codes(analyze_text)
return ret
except (MySQLError, adbcError) as e:
self._handle_db_error(e)
return {"error_message":str(e)}
except Exception as e:
return {"error_message":str(e)}
finally:
if conn and not self.enable_arrow_flight_sql:
try:
conn.close()
except:
pass
def reset_connections(self):
"""Public method to reset all connections."""
self._reset_connection()
# Global singleton instance
_db_client_instance: Optional[DBClient] = None
def get_db_client() -> DBClient:
"""Get or create the global DBClient instance."""
global _db_client_instance
if _db_client_instance is None:
_db_client_instance = DBClient()
return _db_client_instance
def reset_db_connections():
"""Reset all database connections (useful for error recovery)."""
global _db_client_instance
if _db_client_instance is not None:
_db_client_instance.reset_connections()
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/server.py:
--------------------------------------------------------------------------------
```python
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import asyncio
import base64
import json
import math
import sys
import os
import traceback
import threading
import time
from fastmcp import FastMCP
from fastmcp.utilities.types import Image
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent, ImageContent
from fastmcp.exceptions import ToolError
from typing import Annotated
from pydantic import Field
import plotly.express as px
import plotly.graph_objs
from loguru import logger
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware import Middleware
from .db_client import get_db_client, reset_db_connections, ResultSet, PerfAnalysisInput
from .db_summary_manager import get_db_summary_manager
from .connection_health_checker import (
initialize_health_checker,
start_connection_health_checker,
stop_connection_health_checker,
check_connection_health
)
# Configure logging
logger.remove() # Remove default handler
logger.add(sys.stderr, level=os.getenv("LOG_LEVEL", "INFO"),
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}")
mcp = FastMCP('mcp-server-starrocks')
# a hint for soft limit, not enforced
overview_length_limit = int(os.getenv('STARROCKS_OVERVIEW_LIMIT', str(20000)))
# Global cache for table overviews: {(db_name, table_name): overview_string}
global_table_overview_cache = {}
# Get database client instance
db_client = get_db_client()
# Get database summary manager instance
db_summary_manager = get_db_summary_manager(db_client)
# Description suffix for tools, if default db is set
description_suffix = f". db session already in default db `{db_client.default_database}`" if db_client.default_database else ""
# Initialize connection health checker
_health_checker = initialize_health_checker(db_client)
SR_PROC_DESC = '''
Internal information exposed by StarRocks similar to linux /proc, following are some common paths:
'/frontends' Shows the information of FE nodes.
'/backends' Shows the information of BE nodes if this SR is non cloud native deployment.
'/compute_nodes' Shows the information of CN nodes if this SR is cloud native deployment.
'/dbs' Shows the information of databases.
'/dbs/<DB_ID>' Shows the information of a database by database ID.
'/dbs/<DB_ID>/<TABLE_ID>' Shows the information of tables by database ID.
'/dbs/<DB_ID>/<TABLE_ID>/partitions' Shows the information of partitions by database ID and table ID.
'/transactions' Shows the information of transactions by database.
'/transactions/<DB_ID>' Show the information of transactions by database ID.
'/transactions/<DB_ID>/running' Show the information of running transactions by database ID.
'/transactions/<DB_ID>/finished' Show the information of finished transactions by database ID.
'/jobs' Shows the information of jobs.
'/statistic' Shows the statistics of each database.
'/tasks' Shows the total number of all generic tasks and the failed tasks.
'/cluster_balance' Shows the load balance information.
'/routine_loads' Shows the information of Routine Load.
'/colocation_group' Shows the information of Colocate Join groups.
'/catalog' Shows the information of catalogs.
'''
@mcp.resource(uri="starrocks:///databases", name="All Databases", description="List all databases in StarRocks",
mime_type="text/plain")
def get_all_databases() -> str:
logger.debug("Fetching all databases")
result = db_client.execute("SHOW DATABASES")
logger.debug(f"Found {len(result.rows) if result.success and result.rows else 0} databases")
return result.to_string()
@mcp.resource(uri="starrocks:///{db}/{table}/schema", name="Table Schema",
description="Get the schema of a table using SHOW CREATE TABLE", mime_type="text/plain")
def get_table_schema(db: str, table: str) -> str:
logger.debug(f"Fetching schema for table {db}.{table}")
return db_client.execute(f"SHOW CREATE TABLE {db}.{table}").to_string()
@mcp.resource(uri="starrocks:///{db}/tables", name="Database Tables",
description="List all tables in a specific database", mime_type="text/plain")
def get_database_tables(db: str) -> str:
logger.debug(f"Fetching tables from database {db}")
result = db_client.execute(f"SHOW TABLES FROM {db}")
logger.debug(f"Found {len(result.rows) if result.success and result.rows else 0} tables in {db}")
return result.to_string()
@mcp.resource(uri="proc:///{path*}", name="System internal information", description=SR_PROC_DESC,
mime_type="text/plain")
def get_system_internal_information(path: str) -> str:
logger.debug(f"Fetching system information for proc path: {path}")
return db_client.execute(f"show proc '{path}'").to_string(limit=overview_length_limit)
def _get_table_details(db_name, table_name, limit=None):
"""
Helper function to get description, sample rows, and count for a table.
Returns a formatted string. Handles DB errors internally and returns error messages.
"""
global global_table_overview_cache
logger.debug(f"Fetching table details for {db_name}.{table_name}")
output_lines = []
full_table_name = f"`{table_name}`"
if db_name:
full_table_name = f"`{db_name}`.`{table_name}`"
else:
output_lines.append(
f"Warning: Database name missing for table '{table_name}'. Using potentially incorrect context.")
logger.warning(f"Database name missing for table '{table_name}'")
count = 0
output_lines.append(f"--- Overview for {full_table_name} ---")
# 1. Get Row Count
query = f"SELECT COUNT(*) FROM {full_table_name}"
count_result = db_client.execute(query, db=db_name)
if count_result.success and count_result.rows:
count = count_result.rows[0][0]
output_lines.append(f"\nTotal rows: {count}")
logger.debug(f"Table {full_table_name} has {count} rows")
else:
output_lines.append(f"\nCould not determine total row count.")
if not count_result.success:
output_lines.append(f"Error: {count_result.error_message}")
logger.error(f"Failed to get row count for {full_table_name}: {count_result.error_message}")
# 2. Get Columns (DESCRIBE)
if count > 0:
query = f"DESCRIBE {full_table_name}"
desc_result = db_client.execute(query, db=db_name)
if desc_result.success and desc_result.column_names and desc_result.rows:
output_lines.append(f"\nColumns:")
output_lines.append(desc_result.to_string(limit=limit))
else:
output_lines.append("(Could not retrieve column information or table has no columns).")
if not desc_result.success:
output_lines.append(f"Error getting columns for {full_table_name}: {desc_result.error_message}")
return "\n".join(output_lines)
# 3. Get Sample Rows (LIMIT 3)
query = f"SELECT * FROM {full_table_name} LIMIT 3"
sample_result = db_client.execute(query, db=db_name)
if sample_result.success and sample_result.column_names and sample_result.rows:
output_lines.append(f"\nSample rows (limit 3):")
output_lines.append(sample_result.to_string(limit=limit))
else:
output_lines.append(f"(No rows found in {full_table_name}).")
if not sample_result.success:
output_lines.append(f"Error getting sample rows for {full_table_name}: {sample_result.error_message}")
overview_string = "\n".join(output_lines)
# Update cache even if there were partial errors, so we cache the error message too
cache_key = (db_name, table_name)
global_table_overview_cache[cache_key] = overview_string
return overview_string
# tools
@mcp.tool(description="Execute a SELECT query or commands that return a ResultSet" + description_suffix)
def read_query(query: Annotated[str, Field(description="SQL query to execute")],
db: Annotated[str|None, Field(description="database")] = None) -> ToolResult:
# return csv like result set, with column names as first row
logger.info(f"Executing read query: {query[:100]}{'...' if len(query) > 100 else ''}")
result = db_client.execute(query, db=db)
if result.success:
logger.info(f"Query executed successfully, returned {len(result.rows) if result.rows else 0} rows")
else:
logger.error(f"Query failed: {result.error_message}")
return ToolResult(content=[TextContent(type='text', text=result.to_string(limit=10000))],
structured_content=result.to_dict())
@mcp.tool(description="Execute a DDL/DML or other StarRocks command that do not have a ResultSet" + description_suffix)
def write_query(query: Annotated[str, Field(description="SQL to execute")],
db: Annotated[str|None, Field(description="database")] = None) -> ToolResult:
logger.info(f"Executing write query: {query[:100]}{'...' if len(query) > 100 else ''}")
result = db_client.execute(query, db=db)
if not result.success:
logger.error(f"Write query failed: {result.error_message}")
elif result.rows_affected is not None and result.rows_affected >= 0:
logger.info(f"Write query executed successfully, {result.rows_affected} rows affected in {result.execution_time:.2f}s")
else:
logger.info(f"Write query executed successfully in {result.execution_time:.2f}s")
return ToolResult(content=[TextContent(type='text', text=result.to_string(limit=2000))],
structured_content=result.to_dict())
@mcp.tool(description="Analyze a query and get analyze result using query profile" + description_suffix)
def analyze_query(
uuid: Annotated[
str|None, Field(description="Query ID, a string composed of 32 hexadecimal digits formatted as 8-4-4-4-12")]=None,
sql: Annotated[str|None, Field(description="Query SQL")]=None,
db: Annotated[str|None, Field(description="database")] = None
) -> str:
if uuid:
logger.info(f"Analyzing query profile for UUID: {uuid}")
return db_client.execute(f"ANALYZE PROFILE FROM '{uuid}'", db=db).to_string()
elif sql:
logger.info(f"Analyzing query: {sql[:100]}{'...' if len(sql) > 100 else ''}")
return db_client.execute(f"EXPLAIN ANALYZE {sql}", db=db).to_string()
else:
logger.warning("Analyze query called without valid UUID or SQL")
return f"Failed to analyze query, the reasons maybe: 1.query id is not standard uuid format; 2.the SQL statement have spelling error."
@mcp.tool(description="Run a query to get it's query dump and profile, output very large, need special tools to do further processing")
def collect_query_dump_and_profile(
query: Annotated[str, Field(description="query to execute")],
db: Annotated[str|None, Field(description="database")] = None
) -> ToolResult:
logger.info(f"Collecting query dump and profile for query: {query[:100]}{'...' if len(query) > 100 else ''}")
result : PerfAnalysisInput = db_client.collect_perf_analysis_input(query, db=db)
if result.get('error_message'):
status = f"collecting query dump and profile failed, query_id={result.get('query_id')} error_message={result.get('error_message')}"
logger.warning(status)
else:
status = f"collecting query dump and profile succeeded, but it's only for user/tool, not for AI, query_id={result.get('query_id')}"
logger.info(status)
return ToolResult(
content=[TextContent(type='text', text=status)],
structured_content=result,
)
def validate_plotly_expr(expr: str):
"""
Validates a string to ensure it represents a single call to a method
of the 'px' object, without containing other statements or imports,
and ensures its arguments do not contain nested function calls.
Args:
expr: The string expression to validate.
Raises:
ValueError: If the expression does not meet the security criteria.
SyntaxError: If the expression is not valid Python syntax.
"""
# 1. Check for valid Python syntax
try:
tree = ast.parse(expr)
except SyntaxError as e:
raise SyntaxError(f"Invalid Python syntax in expression: {e}") from e
# 2. Check that the tree contains exactly one top-level node (statement/expression)
if len(tree.body) != 1:
raise ValueError("Expression must be a single statement or expression.")
node = tree.body[0]
# 3. Check that the single node is an expression
if not isinstance(node, ast.Expr):
raise ValueError(
"Expression must be a single expression, not a statement (like assignment, function definition, import, etc.).")
# 4. Get the actual value of the expression and check it's a function call
expr_value = node.value
if not isinstance(expr_value, ast.Call):
raise ValueError("Expression must be a function call.")
# 5. Check that the function being called is an attribute lookup (like px.scatter)
if not isinstance(expr_value.func, ast.Attribute):
raise ValueError("Function call must be on an object attribute (e.g., px.scatter).")
# 6. Check that the attribute is being accessed on a simple variable name
if not isinstance(expr_value.func.value, ast.Name):
raise ValueError("Function call must be on a simple variable name (e.g., px.scatter, not obj.px.scatter).")
# 7. Check that the simple variable name is 'px'
if expr_value.func.value.id != 'px':
raise ValueError("Function call must be on the 'px' object.")
# Check positional arguments
for i, arg_node in enumerate(expr_value.args):
for sub_node in ast.walk(arg_node):
if isinstance(sub_node, ast.Call):
raise ValueError(f"Positional argument at index {i} contains a disallowed nested function call.")
# Check keyword arguments
for kw in expr_value.keywords:
for sub_node in ast.walk(kw.value):
if isinstance(sub_node, ast.Call):
keyword_name = kw.arg if kw.arg else '<unknown>'
raise ValueError(f"Keyword argument '{keyword_name}' contains a disallowed nested function call.")
def one_line_summary(text: str, limit:int=100) -> str:
"""Generate a one-line summary of the given text, truncated to the specified limit."""
single_line = ' '.join(text.split())
if len(single_line) > limit:
return single_line[:limit-3] + '...'
return single_line
@mcp.tool(description="using sql `query` to extract data from database, then using python `plotly_expr` to generate a chart for UI to display" + description_suffix)
def query_and_plotly_chart(
query: Annotated[str, Field(description="SQL query to execute")],
plotly_expr: Annotated[
str, Field(description="a one function call expression, with 2 vars binded: `px` as `import plotly.express as px`, and `df` as dataframe generated by query `plotly_expr` example: `px.scatter(df, x=\"sepal_width\", y=\"sepal_length\", color=\"species\", marginal_y=\"violin\", marginal_x=\"box\", trendline=\"ols\", template=\"simple_white\")`")],
format: Annotated[str, Field(description="chart output format, json|png|jpeg")] = "jpeg",
db: Annotated[str|None, Field(description="database")] = None
) -> ToolResult:
"""
Executes an SQL query, creates a Pandas DataFrame, generates a Plotly chart
using the provided expression, encodes the chart as a base64 PNG image,
and returns it along with optional text.
Args:
query: The SQL query string to execute.
plotly_expr: A Python string expression using 'px' (plotly.express)
and 'df' (the DataFrame from the query) to generate a figure.
Example: "px.scatter(df, x='col1', y='col2')"
format: chat output format, json|png|jpeg, default is jpeg
db: Optional database name to execute the query in.
Returns:
A list containing types.TextContent and types.ImageContent,
or just types.TextContent in case of an error or no data.
"""
try:
logger.info(f'query_and_plotly_chart query:{one_line_summary(query)}, plotly:{one_line_summary(plotly_expr)} format:{format}, db:{db}')
result = db_client.execute(query, db=db, return_format="pandas")
errmsg = None
if not result.success:
errmsg = result.error_message
elif result.pandas is None:
errmsg = 'Query did not return data suitable for plotting.'
else:
df = result.pandas
if df.empty:
errmsg = 'Query returned no data to plot.'
if errmsg:
logger.warning(f"Query or data issue: {errmsg}")
return ToolResult(
content=[TextContent(type='text', text=f'Error: {errmsg}')],
structured_content={'success': False, 'error_message': errmsg},
)
# Validate and evaluate the plotly expression using px and df
local_vars = {'df': df}
validate_plotly_expr(plotly_expr)
fig : plotly.graph_objs.Figure = eval(plotly_expr, {"px": px}, local_vars)
if format == 'json':
# return json representation of the figure for front-end rendering
plot_json = json.loads(fig.to_json())
structured_content = result.to_dict()
structured_content['data'] = plot_json['data']
structured_content['layout'] = plot_json['layout']
summary = result.to_string()
return ToolResult(
content=[
TextContent(type='text', text=f'{summary}\nChart Generated for UI rendering'),
],
structured_content=structured_content,
)
else:
if not hasattr(fig, 'to_image'):
raise ToolError(f"The evaluated expression did not return a Plotly figure object. Result type: {type(fig)}")
if format == 'jpg':
format = 'jpeg'
img_bytes = fig.to_image(format=format, width=960, height=720)
structured_content = result.to_dict()
structured_content['img_bytes_base64'] = base64.b64encode(img_bytes)
return ToolResult(
content=[
TextContent(type='text', text=f'dataframe data:\n{df}\nChart generated but for UI only'),
Image(data=img_bytes, format="jpeg").to_image_content()
],
structured_content=structured_content
)
except Exception as err:
return ToolResult(
content=[TextContent(type='text', text=f'Error: {err}')],
structured_content={'success': False, 'error_message': str(err)},
)
@mcp.tool(description="Get an overview of a specific table: columns, sample rows (up to 5), and total row count. Uses cache unless refresh=true" + description_suffix)
def table_overview(
table: Annotated[str, Field(
description="Table name, optionally prefixed with database name (e.g., 'db_name.table_name'). If database is omitted, uses the default database.")],
refresh: Annotated[
bool, Field(description="Set to true to force refresh, ignoring cache. Defaults to false.")] = False
) -> str:
try:
logger.info(f"Getting table overview for: {table}, refresh={refresh}")
if not table:
logger.error("Table overview called without table name")
return "Error: Missing 'table' argument."
# Parse table argument: [db.]<table>
parts = table.split('.', 1)
db_name = None
table_name = None
if len(parts) == 2:
db_name, table_name = parts[0], parts[1]
elif len(parts) == 1:
table_name = parts[0]
db_name = db_client.default_database # Use default if only table name is given
if not table_name: # Should not happen if table_arg exists, but check
logger.error(f"Invalid table name format: {table}")
return f"Error: Invalid table name format '{table}'."
if not db_name:
logger.error(f"No database specified for table {table_name}")
return f"Error: Database name not specified for table '{table_name}' and no default database is set."
cache_key = (db_name, table_name)
# Check cache
if not refresh and cache_key in global_table_overview_cache:
logger.debug(f"Using cached overview for {db_name}.{table_name}")
return global_table_overview_cache[cache_key]
logger.debug(f"Fetching fresh overview for {db_name}.{table_name}")
# Fetch details (will also update cache)
overview_text = _get_table_details(db_name, table_name, limit=overview_length_limit)
return overview_text
except Exception as e:
# Reset connections on unexpected errors
logger.exception(f"Unexpected error in table_overview for {table}")
reset_db_connections()
stack_trace = traceback.format_exc()
return f"Unexpected Error executing tool 'table_overview': {type(e).__name__}: {e}\nStack Trace:\n{stack_trace}"
# comment out to prefer db_summary tool
#@mcp.tool(description="Get an overview (columns, sample rows, row count) for ALL tables in a database. Uses cache unless refresh=True" + description_suffix)
def db_overview(
db: Annotated[str, Field(
description="Database name. Optional: uses the default database if not provided.")] = None,
refresh: Annotated[
bool, Field(description="Set to true to force refresh, ignoring cache. Defaults to false.")] = False
) -> str:
try:
db_name = db if db else db_client.default_database
logger.info(f"Getting database overview for: {db_name}, refresh={refresh}")
if not db_name:
logger.error("Database overview called without database name")
return "Error: Database name not provided and no default database is set."
# List tables in the database
query = f"SHOW TABLES FROM `{db_name}`"
result = db_client.execute(query, db=db_name)
if not result.success:
logger.error(f"Failed to list tables in database {db_name}: {result.error_message}")
return f"Database Error listing tables in '{db_name}': {result.error_message}"
if not result.rows:
logger.info(f"No tables found in database {db_name}")
return f"No tables found in database '{db_name}'."
tables = [row[0] for row in result.rows]
logger.info(f"Found {len(tables)} tables in database {db_name}")
all_overviews = [f"--- Overview for Database: `{db_name}` ({len(tables)} tables) ---"]
total_length = 0
limit_per_table = overview_length_limit * (math.log10(len(tables)) + 1) // len(tables) # Limit per table
for table_name in tables:
cache_key = (db_name, table_name)
overview_text = None
# Check cache first
if not refresh and cache_key in global_table_overview_cache:
logger.debug(f"Using cached overview for {db_name}.{table_name}")
overview_text = global_table_overview_cache[cache_key]
else:
logger.debug(f"Fetching fresh overview for {db_name}.{table_name}")
# Fetch details for this table (will update cache via _get_table_details)
overview_text = _get_table_details(db_name, table_name, limit=limit_per_table)
all_overviews.append(overview_text)
all_overviews.append("\n") # Add separator
total_length += len(overview_text) + 1
logger.info(f"Database overview completed for {db_name}, total length: {total_length}")
return "\n".join(all_overviews)
except Exception as e:
# Catch any other unexpected errors during tool execution
logger.exception(f"Unexpected error in db_overview for database {db}")
reset_db_connections()
stack_trace = traceback.format_exc()
return f"Unexpected Error executing tool 'db_overview': {type(e).__name__}: {e}\nStack Trace:\n{stack_trace}"
@mcp.tool(description="Quickly get summary of a database with tables' schema and size information" + description_suffix)
def db_summary(
db: Annotated[str|None, Field(
description="Database name. Optional: uses current database by default.")] = None,
limit: Annotated[int, Field(
description="Output length limit in characters. Defaults to 10000. Higher values show more tables and details.")] = 10000,
refresh: Annotated[bool, Field(
description="Set to true to force refresh, ignoring cache. Defaults to false.")] = False
) -> str:
try:
db_name = db if db else db_client.default_database
logger.info(f"Getting database summary for: {db_name}, limit={limit}, refresh={refresh}")
if not db_name:
logger.error("Database summary called without database name")
return "Error: Database name not provided and no default database is set."
# Use the database summary manager
summary = db_summary_manager.get_database_summary(db_name, limit=limit, refresh=refresh)
logger.info(f"Database summary completed for {db_name}")
return summary
except Exception as e:
# Reset connections on unexpected errors
logger.exception(f"Unexpected error in db_summary for database {db}")
reset_db_connections()
stack_trace = traceback.format_exc()
return f"Unexpected Error executing tool 'db_summary': {type(e).__name__}: {e}\nStack Trace:\n{stack_trace}"
async def main():
parser = argparse.ArgumentParser(description='StarRocks MCP Server')
parser.add_argument('--mode', choices=['stdio', 'sse', 'http', 'streamable-http'],
default=os.getenv('MCP_TRANSPORT_MODE', 'stdio'),
help='Transport mode (default: stdio)')
parser.add_argument('--host', default='localhost',
help='Server host (default: localhost)')
parser.add_argument('--port', type=int, default=3000,
help='Server port (default: 3000)')
parser.add_argument('--test', action='store_true',
help='Run in test mode')
args = parser.parse_args()
logger.info(f"Starting StarRocks MCP Server with mode={args.mode}, host={args.host}, port={args.port} default_db={db_client.default_database or 'None'}")
if args.test:
try:
logger.info("Starting tool test")
# Use the test version without tool wrapper
result = db_client.execute("show databases").to_string()
logger.info("Result:")
logger.info(result)
logger.info("Tool test completed")
finally:
stop_connection_health_checker()
reset_db_connections()
return
# Start connection health checker
start_connection_health_checker()
try:
# Add CORS middleware for HTTP transports to allow web frontend access
if args.mode in ['http', 'streamable-http', 'sse']:
cors_middleware = [
Middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for development. In production, specify exact origins
allow_credentials=True,
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)
]
logger.info(f"CORS enabled for {args.mode} transport - allowing all origins")
await mcp.run_async(
transport=args.mode,
host=args.host,
port=args.port,
middleware=cors_middleware
)
else:
await mcp.run_async(transport=args.mode)
except Exception as e:
logger.exception("Failed to start MCP server")
raise
finally:
# Stop connection health checker when server shuts down
stop_connection_health_checker()
if __name__ == "__main__":
asyncio.run(main())
```
--------------------------------------------------------------------------------
/tests/test_db_client.py:
--------------------------------------------------------------------------------
```python
"""
Tests for db_client module.
These tests assume a StarRocks cluster is running on localhost with default configurations:
- Host: localhost
- Port: 9030 (MySQL protocol)
- User: root
- Password: (empty)
- No default database set
Run tests with: pytest tests/test_db_client.py -v
"""
import os
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
# Set up test environment variables
os.environ.pop('STARROCKS_FE_ARROW_FLIGHT_SQL_PORT', None) # Force MySQL mode for tests
os.environ.pop('STARROCKS_DB', None) # No default database
from src.mcp_server_starrocks.db_client import (
DBClient,
ResultSet,
get_db_client,
reset_db_connections,
parse_connection_url
)
class TestDBClient:
"""Test cases for DBClient class."""
@pytest.fixture
def db_client(self):
"""Create a fresh DBClient instance for each test."""
# Reset global state
reset_db_connections()
return DBClient()
def test_client_initialization(self, db_client):
"""Test DBClient initialization with default settings."""
assert db_client.enable_arrow_flight_sql is False
assert db_client.default_database is None
assert db_client._connection_pool is None
assert db_client._adbc_connection is None
def test_singleton_pattern(self):
"""Test that get_db_client returns the same instance."""
client1 = get_db_client()
client2 = get_db_client()
assert client1 is client2
def test_execute_show_databases(self, db_client):
"""Test executing SHOW DATABASES query."""
result = db_client.execute("SHOW DATABASES")
assert isinstance(result, ResultSet)
assert result.success is True
assert result.column_names is not None
assert len(result.column_names) == 1
assert result.rows is not None
assert len(result.rows) > 0
assert result.execution_time is not None
assert result.execution_time > 0
# Check that information_schema is present (standard in StarRocks)
database_names = [row[0] for row in result.rows]
assert 'information_schema' in database_names
def test_execute_show_databases_pandas(self, db_client):
"""Test executing SHOW DATABASES with pandas return format."""
result = db_client.execute("SHOW DATABASES", return_format="pandas")
assert isinstance(result, ResultSet)
assert result.success is True
assert result.pandas is not None
assert isinstance(result.pandas, pd.DataFrame)
assert len(result.pandas.columns) == 1
assert len(result.pandas) > 0
# Test that to_pandas() returns the same DataFrame
df = result.to_pandas()
assert df is result.pandas
def test_execute_invalid_query(self, db_client):
"""Test executing an invalid SQL query."""
result = db_client.execute("SELECT * FROM nonexistent_table_12345")
assert isinstance(result, ResultSet)
assert result.success is False
assert result.error_message is not None
assert "nonexistent_table_12345" in result.error_message or "doesn't exist" in result.error_message.lower()
assert result.execution_time is not None
def test_execute_create_and_drop_database(self, db_client):
"""Test creating and dropping a test database."""
test_db_name = "test_mcp_db_client"
# Clean up first (in case previous test failed)
db_client.execute(f"DROP DATABASE IF EXISTS {test_db_name}")
# Create database
create_result = db_client.execute(f"CREATE DATABASE {test_db_name}")
assert create_result.success is True
assert create_result.rows_affected is not None # DDL returns row count (usually 0)
# Verify database exists
show_result = db_client.execute("SHOW DATABASES")
database_names = [row[0] for row in show_result.rows]
assert test_db_name in database_names
# Drop database
drop_result = db_client.execute(f"DROP DATABASE {test_db_name}")
assert drop_result.success is True
# Verify database is gone
show_result = db_client.execute("SHOW DATABASES")
database_names = [row[0] for row in show_result.rows]
assert test_db_name not in database_names
def test_execute_with_specific_database(self, db_client):
"""Test executing query with specific database context."""
# Use information_schema which should always be available
result = db_client.execute("SHOW TABLES", db="information_schema")
assert result.success is True
assert result.column_names is not None
assert result.rows is not None
assert len(result.rows) > 0 # information_schema should have tables
# Check for expected information_schema tables
table_names = [row[0] for row in result.rows]
expected_tables = ['tables', 'columns', 'schemata']
found_expected = any(table in table_names for table in expected_tables)
assert found_expected, f"Expected at least one of {expected_tables} in {table_names}"
def test_execute_with_invalid_database(self, db_client):
"""Test executing query with non-existent database."""
result = db_client.execute("SHOW TABLES", db="nonexistent_db_12345")
assert result.success is False
assert result.error_message is not None
assert "nonexistent_db_12345" in result.error_message
def test_execute_table_operations(self, db_client):
"""Test creating, inserting, querying, and dropping a table."""
test_db = "test_mcp_table_ops"
test_table = "test_table"
try:
# Create database
create_db_result = db_client.execute(f"CREATE DATABASE IF NOT EXISTS {test_db}")
assert create_db_result.success is True
# Create table (with replication_num=1 for single-node setup)
create_table_sql = f"""
CREATE TABLE {test_db}.{test_table} (
id INT,
name STRING,
value DOUBLE
)
PROPERTIES ("replication_num" = "1")
"""
create_result = db_client.execute(create_table_sql)
assert create_result.success is True
# Insert data
insert_sql = f"""
INSERT INTO {test_db}.{test_table} VALUES
(1, 'test1', 1.5),
(2, 'test2', 2.5),
(3, 'test3', 3.5)
"""
insert_result = db_client.execute(insert_sql)
assert insert_result.success is True
assert insert_result.rows_affected == 3
# Query data
select_result = db_client.execute(f"SELECT * FROM {test_db}.{test_table} ORDER BY id")
assert select_result.success is True
assert len(select_result.column_names) == 3
assert select_result.column_names == ['id', 'name', 'value']
assert len(select_result.rows) == 3
# MySQL connector returns tuples, convert to lists for comparison
assert list(select_result.rows[0]) == [1, 'test1', 1.5]
assert list(select_result.rows[1]) == [2, 'test2', 2.5]
assert list(select_result.rows[2]) == [3, 'test3', 3.5]
# Test COUNT query
count_result = db_client.execute(f"SELECT COUNT(*) as cnt FROM {test_db}.{test_table}")
assert count_result.success is True
assert count_result.rows[0][0] == 3
# Test with specific database context
ctx_result = db_client.execute(f"SELECT * FROM {test_table}", db=test_db)
assert ctx_result.success is True
assert len(ctx_result.rows) == 3
finally:
# Clean up
db_client.execute(f"DROP DATABASE IF EXISTS {test_db}")
def test_execute_pandas_format_with_data(self, db_client):
"""Test pandas format with actual data."""
test_db = "test_mcp_pandas"
try:
# Setup test data
db_client.execute(f"CREATE DATABASE IF NOT EXISTS {test_db}")
db_client.execute(f"""
CREATE TABLE {test_db}.pandas_test (
id INT,
category STRING,
amount DECIMAL(10,2)
)
PROPERTIES ("replication_num" = "1")
""")
db_client.execute(f"""
INSERT INTO {test_db}.pandas_test VALUES
(1, 'A', 100.50),
(2, 'B', 200.75),
(3, 'A', 150.25)
""")
# Test executing query with pandas format
result = db_client.execute(f"SELECT * FROM {test_db}.pandas_test ORDER BY id", return_format="pandas")
assert isinstance(result, ResultSet)
assert result.success is True
assert result.pandas is not None
assert isinstance(result.pandas, pd.DataFrame)
assert len(result.pandas) == 3
assert list(result.pandas.columns) == ['id', 'category', 'amount']
assert result.pandas.iloc[0]['id'] == 1
assert result.pandas.iloc[0]['category'] == 'A'
assert float(result.pandas.iloc[0]['amount']) == 100.50
# Test that to_pandas() returns the same DataFrame
df = result.to_pandas()
assert df is result.pandas
finally:
db_client.execute(f"DROP DATABASE IF EXISTS {test_db}")
def test_connection_error_handling(self, db_client):
"""Test error handling when connection fails."""
# Mock a connection failure
with patch.object(db_client, '_get_connection', side_effect=Exception("Connection failed")):
result = db_client.execute("SHOW DATABASES")
assert result.success is False
assert "Connection failed" in result.error_message
assert result.execution_time is not None
def test_reset_connections(self, db_client):
"""Test connection reset functionality."""
# First execute a query to establish connection
result1 = db_client.execute("SHOW DATABASES")
assert result1.success is True
# Reset connections
db_client.reset_connections()
# Should still work after reset
result2 = db_client.execute("SHOW DATABASES")
assert result2.success is True
def test_describe_table(self, db_client):
"""Test DESCRIBE table functionality."""
test_db = "test_mcp_describe"
test_table = "describe_test"
try:
# Create test table
db_result = db_client.execute(f"CREATE DATABASE IF NOT EXISTS {test_db}")
assert db_result.success, f"Failed to create database: {db_result.error_message}"
table_result = db_client.execute(f"""
CREATE TABLE {test_db}.{test_table} (
id BIGINT NOT NULL COMMENT 'Primary key',
name VARCHAR(100) COMMENT 'Name field',
created_at DATETIME,
is_active BOOLEAN
)
PROPERTIES ("replication_num" = "1")
""")
assert table_result.success, f"Failed to create table: {table_result.error_message}"
# Verify table exists first
show_result = db_client.execute(f"SHOW TABLES", db=test_db)
assert show_result.success, f"Failed to show tables: {show_result.error_message}"
table_names = [row[0] for row in show_result.rows]
assert test_table in table_names, f"Table {test_table} not found in {table_names}"
# Describe table (use full table name for clarity)
result = db_client.execute(f"DESCRIBE {test_db}.{test_table}")
assert result.success is True
assert result.column_names is not None
assert len(result.rows) == 4 # 4 columns
# Check column names in result (should include Field, Type, etc.)
expected_columns = ['Field', 'Type', 'Null', 'Key', 'Default', 'Extra']
for expected_col in expected_columns[:len(result.column_names)]:
assert expected_col in result.column_names
# Check that our table columns are present
field_names = [row[0] for row in result.rows]
assert 'id' in field_names
assert 'name' in field_names
assert 'created_at' in field_names
assert 'is_active' in field_names
finally:
db_client.execute(f"DROP DATABASE IF EXISTS {test_db}")
class TestDBClientWithArrowFlight:
"""Test cases for DBClient with Arrow Flight SQL (if configured)."""
@pytest.fixture
def arrow_client(self):
"""Create DBClient with Arrow Flight SQL if available."""
# Check if Arrow Flight SQL port is configured (either from env or default test port)
arrow_port = os.getenv('STARROCKS_FE_ARROW_FLIGHT_SQL_PORT', '9408')
# Test if Arrow Flight SQL is actually available by trying to connect
try:
with patch.dict(os.environ, {'STARROCKS_FE_ARROW_FLIGHT_SQL_PORT': arrow_port}):
reset_db_connections()
client = DBClient()
assert client.enable_arrow_flight_sql is True
# Test basic connectivity
result = client.execute("SHOW DATABASES")
if not result.success:
pytest.skip(f"Arrow Flight SQL not available on port {arrow_port}: {result.error_message}")
return client
except Exception as e:
pytest.skip(f"Arrow Flight SQL not available: {e}")
def test_arrow_flight_basic_query(self, arrow_client):
"""Test basic query with Arrow Flight SQL."""
result = arrow_client.execute("SHOW DATABASES")
assert isinstance(result, ResultSet)
assert result.success is True
assert result.column_names is not None
assert result.rows is not None
assert len(result.rows) > 0
# Verify we're actually using Arrow Flight SQL
assert arrow_client.enable_arrow_flight_sql is True
def test_arrow_flight_pandas_format(self, arrow_client):
"""Test pandas format with Arrow Flight SQL."""
result = arrow_client.execute("SHOW DATABASES", return_format="pandas")
assert isinstance(result, ResultSet)
assert result.success is True
assert result.pandas is not None
assert isinstance(result.pandas, pd.DataFrame)
assert len(result.pandas) > 0
assert len(result.pandas.columns) == 1
# Test that to_pandas() returns the same DataFrame
df = result.to_pandas()
assert df is result.pandas
# Verify we're actually using Arrow Flight SQL
assert arrow_client.enable_arrow_flight_sql is True
def test_arrow_flight_table_operations(self, arrow_client):
"""Test table operations with Arrow Flight SQL."""
test_db = "test_arrow_flight"
test_table = "arrow_test"
try:
# Create database
create_db_result = arrow_client.execute(f"CREATE DATABASE IF NOT EXISTS {test_db}")
assert create_db_result.success is True
# Create table
create_table_sql = f"""
CREATE TABLE {test_db}.{test_table} (
id INT,
name STRING,
value DOUBLE
)
PROPERTIES ("replication_num" = "1")
"""
create_result = arrow_client.execute(create_table_sql)
assert create_result.success is True
# Insert data
insert_sql = f"""
INSERT INTO {test_db}.{test_table} VALUES
(1, 'arrow1', 1.1),
(2, 'arrow2', 2.2)
"""
insert_result = arrow_client.execute(insert_sql)
assert insert_result.success is True
# Note: StarRocks Arrow Flight SQL always returns 0 for rows_affected due to implementation limitations
assert insert_result.rows_affected == 0
# Query data with pandas format
select_result = arrow_client.execute(f"SELECT * FROM {test_db}.{test_table} ORDER BY id", return_format="pandas")
assert isinstance(select_result, ResultSet)
assert select_result.success is True
assert select_result.pandas is not None
assert isinstance(select_result.pandas, pd.DataFrame)
assert len(select_result.pandas) == 2
# Note: StarRocks Arrow Flight SQL loses column names in SELECT results (known limitation)
# The columns come back as empty strings, but the data is correct
assert len(select_result.pandas.columns) == 3
# Since column names are empty, access by position instead
assert select_result.pandas.iloc[0, 0] == 1 # id column
assert select_result.pandas.iloc[0, 1] == 'arrow1' # name column
assert select_result.pandas.iloc[0, 2] == 1.1 # value column
# Test that to_pandas() returns the same DataFrame
df = select_result.to_pandas()
assert df is select_result.pandas
# Query data with raw format
raw_result = arrow_client.execute(f"SELECT * FROM {test_db}.{test_table} ORDER BY id")
assert raw_result.success is True
assert len(raw_result.rows) == 2
# Note: Column names are empty due to StarRocks Arrow Flight SQL limitation
assert raw_result.column_names == ['', '', '']
# But the data is correct
assert raw_result.rows[0] == [1, 'arrow1', 1.1]
assert raw_result.rows[1] == [2, 'arrow2', 2.2]
finally:
# Clean up
arrow_client.execute(f"DROP DATABASE IF EXISTS {test_db}")
def test_arrow_flight_error_handling(self, arrow_client):
"""Test error handling with Arrow Flight SQL."""
# Test invalid query
result = arrow_client.execute("SELECT * FROM nonexistent_arrow_table")
assert result.success is False
assert result.error_message is not None
# Test invalid database - Note: Arrow Flight SQL may fail with connection errors
# before database validation, so we just check that it fails
result = arrow_client.execute("SHOW TABLES", db="nonexistent_arrow_db")
assert result.success is False
assert result.error_message is not None
class TestResultSet:
"""Test cases for ResultSet dataclass."""
def test_result_set_creation(self):
"""Test ResultSet creation with various parameters."""
# Success case
result = ResultSet(
success=True,
column_names=['id', 'name'],
rows=[[1, 'test'], [2, 'test2']],
execution_time=0.5
)
assert result.success is True
assert result.column_names == ['id', 'name']
assert result.rows == [[1, 'test'], [2, 'test2']]
assert result.execution_time == 0.5
assert result.rows_affected is None
assert result.error_message is None
def test_result_set_to_pandas_from_rows(self):
"""Test ResultSet to_pandas conversion from rows."""
result = ResultSet(
success=True,
column_names=['id', 'name', 'value'],
rows=[[1, 'test1', 10.5], [2, 'test2', 20.5]],
execution_time=0.1
)
df = result.to_pandas()
assert isinstance(df, pd.DataFrame)
assert len(df) == 2
assert list(df.columns) == ['id', 'name', 'value']
assert df.iloc[0]['id'] == 1
assert df.iloc[0]['name'] == 'test1'
assert df.iloc[0]['value'] == 10.5
assert df.iloc[1]['id'] == 2
assert df.iloc[1]['name'] == 'test2'
assert df.iloc[1]['value'] == 20.5
def test_result_set_to_pandas_from_pandas_field(self):
"""Test ResultSet to_pandas returns existing pandas field if available."""
original_df = pd.DataFrame({
'id': [1, 2],
'name': ['test1', 'test2'],
'value': [10.5, 20.5]
})
result = ResultSet(
success=True,
column_names=['id', 'name', 'value'],
rows=[[1, 'test1', 10.5], [2, 'test2', 20.5]],
pandas=original_df,
execution_time=0.1
)
df = result.to_pandas()
assert df is original_df # Should return the same object
def test_result_set_to_string(self):
"""Test ResultSet to_string conversion."""
result = ResultSet(
success=True,
column_names=['id', 'name', 'value'],
rows=[[1, 'test1', 10.5], [2, 'test2', 20.5]],
execution_time=0.1
)
string_output = result.to_string()
expected_lines = [
'id,name,value',
'1,test1,10.5',
'2,test2,20.5',
''
]
assert string_output == '\n'.join(expected_lines)
def test_result_set_to_string_with_limit(self):
"""Test ResultSet to_string with limit."""
result = ResultSet(
success=True,
column_names=['id', 'name'],
rows=[[1, 'very_long_test_string'], [2, 'another_long_string']],
execution_time=0.1
)
# Test with very small limit
string_output = result.to_string(limit=20)
lines = string_output.split('\n')
assert lines[0] == 'id,name' # Header should always be included
# Should stop before all rows due to limit
assert len(lines) < 4 # Should be less than header + 2 rows + empty line
def test_result_set_to_string_error_cases(self):
"""Test ResultSet to_string error handling."""
# Test with failed result
failed_result = ResultSet(
success=False,
error_message="Test error"
)
string_output = failed_result.to_string()
assert string_output == "Error: Test error"
# Test with no data
no_data_result = ResultSet(
success=True,
column_names=None,
rows=None
)
string_output = no_data_result.to_string()
assert string_output == "No data"
def test_result_set_to_pandas_error_cases(self):
"""Test ResultSet to_pandas error handling."""
# Test with failed result
failed_result = ResultSet(
success=False,
error_message="Test error"
)
with pytest.raises(ValueError, match="Cannot convert failed result to DataFrame"):
failed_result.to_pandas()
# Test with no data
no_data_result = ResultSet(
success=True,
column_names=None,
rows=None
)
with pytest.raises(ValueError, match="No data available to convert to DataFrame"):
no_data_result.to_pandas()
def test_result_set_error_case(self):
"""Test ResultSet for error cases."""
result = ResultSet(
success=False,
error_message="Test error",
execution_time=0.1
)
assert result.success is False
assert result.error_message == "Test error"
assert result.execution_time == 0.1
assert result.column_names is None
assert result.rows is None
assert result.rows_affected is None
def test_result_set_write_operation(self):
"""Test ResultSet for write operations."""
result = ResultSet(
success=True,
rows_affected=5,
execution_time=0.2
)
assert result.success is True
assert result.rows_affected == 5
assert result.execution_time == 0.2
assert result.column_names is None
assert result.rows is None
assert result.error_message is None
class TestParseConnectionUrl:
"""Test cases for parse_connection_url function."""
def test_parse_basic_url(self):
"""Test parsing basic connection URL without schema."""
url = "root:password123@localhost:9030/test_db"
result = parse_connection_url(url)
expected = {
'user': 'root',
'password': 'password123',
'host': 'localhost',
'port': '9030',
'database': 'test_db'
}
assert result == expected
def test_parse_url_with_schema(self):
"""Test parsing connection URL with schema."""
url = "mysql://admin:[email protected]:3306/production"
result = parse_connection_url(url)
expected = {
'user': 'admin',
'password': 'secret',
'host': 'db.example.com',
'port': '3306',
'database': 'production'
}
assert result == expected
def test_parse_url_with_different_schemas(self):
"""Test parsing URLs with various schema types."""
test_cases = [
("starrocks://user:pass@host:9030/db", "starrocks"),
("jdbc+mysql://user:pass@host:3306/db", "jdbc+mysql"),
("postgresql://user:pass@host:5432/db", "postgresql"),
]
for url, expected_schema in test_cases:
result = parse_connection_url(url)
# Schema is no longer returned in the result
assert result['user'] == 'user'
assert result['password'] == 'pass'
assert result['host'] == 'host'
assert result['database'] == 'db'
def test_parse_url_empty_password_succeeds(self):
"""Test that URL with empty password now works."""
url = "root:@localhost:9030/test_db"
result = parse_connection_url(url)
expected = {
'user': 'root',
'password': '', # Empty password
'host': 'localhost',
'port': '9030',
'database': 'test_db'
}
assert result == expected
def test_parse_url_no_password_colon(self):
"""Test URL without password colon (e.g., root@localhost:9030)."""
url = "root@localhost:9030"
result = parse_connection_url(url)
expected = {
'user': 'root',
'password': '', # Default empty password
'host': 'localhost',
'port': '9030',
'database': None
}
assert result == expected
def test_parse_url_missing_port_uses_default(self):
"""Test URL without port uses default 9030."""
url = "root:password@localhost/mydb"
result = parse_connection_url(url)
expected = {
'user': 'root',
'password': 'password',
'host': 'localhost',
'port': '9030', # Default port
'database': 'mydb'
}
assert result == expected
def test_parse_url_minimal_format(self):
"""Test minimal URL format (just user@host)."""
url = "user@host"
result = parse_connection_url(url)
expected = {
'user': 'user',
'password': '', # Default empty password
'host': 'host',
'port': '9030', # Default port
'database': None
}
assert result == expected
def test_parse_url_empty_string_password(self):
"""Test URL with explicit empty password using double colon."""
url = "user::@host:9030/db"
result = parse_connection_url(url)
expected = {
'user': 'user',
'password': ':', # Literal colon as password
'host': 'host',
'port': '9030',
'database': 'db'
}
assert result == expected
def test_parse_url_complex_password_limitation(self):
"""Test that password with @ symbol has regex limitation (parses incorrectly)."""
url = "user:p@ssw0rd!@server:9030/mydb"
result = parse_connection_url(url)
# Due to regex limitation, @ in password causes incorrect parsing
assert result['user'] == 'user'
assert result['password'] == 'p' # Only gets characters before first @
assert result['host'] == 'ssw0rd!@server' # Rest becomes host
assert result['port'] == '9030'
assert result['database'] == 'mydb'
def test_parse_url_password_without_at_symbol(self):
"""Test parsing URL with complex password without @ symbol."""
url = "user:p#ssw0rd!$%^&*()@server:9030/mydb"
result = parse_connection_url(url)
assert result['user'] == 'user'
assert result['password'] == 'p#ssw0rd!$%^&*()'
assert result['host'] == 'server'
assert result['port'] == '9030'
assert result['database'] == 'mydb'
def test_parse_url_complex_username_with_at_symbol_limitation(self):
"""Test that username with @ symbol fails (regex limitation)."""
url = "user.name+tag@domain:password123@host:9030/db"
# This should fail because our regex cannot distinguish between
# the @ in username vs the @ separator for host
with pytest.raises(ValueError, match="Invalid connection URL"):
parse_connection_url(url)
def test_parse_url_complex_username_without_at(self):
"""Test parsing URL with complex username without @ symbol."""
url = "user.name+tag_domain:password123@host:9030/db"
result = parse_connection_url(url)
assert result['user'] == 'user.name+tag_domain'
assert result['password'] == 'password123'
assert result['host'] == 'host'
assert result['port'] == '9030'
assert result['database'] == 'db'
def test_parse_url_numeric_database(self):
"""Test parsing URL with numeric database name."""
url = "root:pass@localhost:9030/db123"
result = parse_connection_url(url)
assert result['database'] == 'db123'
def test_parse_url_database_with_hyphens(self):
"""Test parsing URL with database name containing hyphens."""
url = "root:pass@localhost:9030/test-db-name"
result = parse_connection_url(url)
assert result['database'] == 'test-db-name'
def test_parse_url_ip_address_host(self):
"""Test parsing URL with IP address as host."""
url = "root:[email protected]:9030/testdb"
result = parse_connection_url(url)
assert result['host'] == '192.168.1.100'
assert result['port'] == '9030'
assert result['database'] == 'testdb'
def test_parse_url_different_ports(self):
"""Test parsing URLs with different port numbers."""
test_cases = [
("user:pass@host:3306/db", "3306"),
("user:pass@host:5432/db", "5432"),
("user:pass@host:27017/db", "27017"),
("user:pass@host:1/db", "1"),
("user:pass@host:65535/db", "65535"),
]
for url, expected_port in test_cases:
result = parse_connection_url(url)
assert result['port'] == expected_port
def test_parse_invalid_urls(self):
"""Test that invalid URLs raise ValueError."""
invalid_urls = [
# Missing required parts
"@host:9030/db", # Missing user
"user:pass@:9030/db", # Missing host
# Malformed URLs
"user:pass@host:port/db", # Non-numeric port
"user:pass@host:9030/", # Empty database
"user:pass@host:9030/db/extra", # Extra path component
"", # Empty string
"random-string-not-url", # Not a URL format
# Special cases
"://user:pass@host:9030/db", # Empty schema
"user:pass@host:-1/db", # Negative port
]
for invalid_url in invalid_urls:
with pytest.raises(ValueError, match="Invalid connection URL"):
parse_connection_url(invalid_url)
def test_parse_url_colon_in_password_works(self):
"""Test that colon in password actually works (unlike @ symbol)."""
url = "user:pass:extra@host:9030/db"
result = parse_connection_url(url)
assert result['user'] == 'user'
assert result['password'] == 'pass:extra' # Colons in password are fine
assert result['host'] == 'host'
assert result['port'] == '9030'
assert result['database'] == 'db'
def test_parse_url_without_database(self):
"""Test parsing URL without database (database is optional)."""
url = "user:password@host:9030"
result = parse_connection_url(url)
assert result['user'] == 'user'
assert result['password'] == 'password'
assert result['host'] == 'host'
assert result['port'] == '9030'
assert result['database'] == None # Database should be None when omitted
def test_parse_url_with_schema_without_database(self):
"""Test parsing URL with schema but without database."""
url = "mysql://admin:[email protected]:3306"
result = parse_connection_url(url)
assert result['user'] == 'admin'
assert result['password'] == 'secret'
assert result['host'] == 'db.example.com'
assert result['port'] == '3306'
assert result['database'] == None
def test_parse_url_various_schemas_without_database(self):
"""Test parsing URLs with various schemas but no database."""
test_cases = [
("starrocks://user:pass@host:9030", "starrocks"),
("jdbc+mysql://user:pass@host:3306", "jdbc+mysql"),
("postgresql://user:pass@host:5432", "postgresql"),
]
for url, expected_schema in test_cases:
result = parse_connection_url(url)
# Schema is no longer returned in the result
assert result['user'] == 'user'
assert result['password'] == 'pass'
assert result['host'] == 'host'
assert result['database'] == None
def test_parse_url_edge_cases(self):
"""Test edge cases that should work."""
# Single character components
url = "a:b@c:1/d"
result = parse_connection_url(url)
assert result['user'] == 'a'
assert result['password'] == 'b'
assert result['host'] == 'c'
assert result['port'] == '1'
assert result['database'] == 'd'
# Long components
long_user = "a" * 100
long_pass = "b" * 100
long_host = "c" * 50
long_db = "d" * 50
url = f"{long_user}:{long_pass}@{long_host}:9030/{long_db}"
result = parse_connection_url(url)
assert result['user'] == long_user
assert result['password'] == long_pass
assert result['host'] == long_host
assert result['database'] == long_db
def test_parse_url_returns_dict_with_all_keys(self):
"""Test that parse_connection_url always returns dict with all expected keys."""
test_cases = [
"root:pass@localhost:9030/db",
"mysql://root:pass@localhost:3306/db",
]
expected_keys = {'user', 'password', 'host', 'port', 'database'}
for url in test_cases:
result = parse_connection_url(url)
assert isinstance(result, dict)
assert set(result.keys()) == expected_keys
def test_parse_url_regex_pattern_comprehensive(self):
"""Test comprehensive regex pattern matching."""
# Test that the regex correctly captures each group
url = "custom+schema://test_user:[email protected]:12345/my_db-name"
result = parse_connection_url(url)
# Schema is no longer returned in the result
assert result['user'] == 'test_user'
assert result['password'] == 'complex!pass'
assert result['host'] == 'sub.domain.com'
assert result['port'] == '12345'
assert result['database'] == 'my_db-name'
class TestDummyMode:
"""Test cases for STARROCKS_DUMMY_TEST environment variable."""
def test_dummy_mode_enabled(self):
"""Test that dummy mode returns expected dummy data."""
# Set dummy test environment variable
with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
client = DBClient()
assert client.enable_dummy_test is True
# Test basic query
result = client.execute("SELECT * FROM any_table")
assert result.success is True
assert result.column_names == ['name']
assert result.rows == [['aaa'], ['bbb'], ['ccc']]
assert result.execution_time is not None
assert result.execution_time > 0
assert result.pandas is None # pandas should be None for raw format
def test_dummy_mode_with_pandas_format(self):
"""Test dummy mode with pandas return format."""
with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
client = DBClient()
result = client.execute("SELECT * FROM any_table", return_format="pandas")
assert result.success is True
assert result.column_names == ['name']
assert result.rows == [['aaa'], ['bbb'], ['ccc']]
assert result.pandas is not None
assert isinstance(result.pandas, pd.DataFrame)
assert len(result.pandas) == 3
assert list(result.pandas.columns) == ['name']
assert result.pandas.iloc[0]['name'] == 'aaa'
assert result.pandas.iloc[1]['name'] == 'bbb'
assert result.pandas.iloc[2]['name'] == 'ccc'
def test_dummy_mode_ignores_statement_and_db(self):
"""Test that dummy mode returns same data regardless of SQL statement or database."""
with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
client = DBClient()
# Test different statements
result1 = client.execute("SHOW DATABASES")
result2 = client.execute("CREATE TABLE test (id INT)")
result3 = client.execute("SELECT COUNT(*) FROM users", db="production")
# All should return the same dummy data
for result in [result1, result2, result3]:
assert result.success is True
assert result.column_names == ['name']
assert result.rows == [['aaa'], ['bbb'], ['ccc']]
def test_dummy_mode_disabled_by_default(self):
"""Test that dummy mode is disabled when environment variable is not set."""
# Ensure STARROCKS_DUMMY_TEST is not set
with patch.dict(os.environ, {}, clear=True):
os.environ.pop('STARROCKS_DUMMY_TEST', None) # Remove if exists
client = DBClient()
assert client.enable_dummy_test is False
def test_dummy_mode_with_empty_string(self):
"""Test that empty string for STARROCKS_DUMMY_TEST disables dummy mode."""
with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': ''}):
client = DBClient()
assert client.enable_dummy_test is False
def test_dummy_mode_with_various_truthy_values(self):
"""Test that various truthy values enable dummy mode."""
truthy_values = ['1', 'true', 'True', 'yes', 'on', 'any_non_empty_string']
for value in truthy_values:
with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': value}):
client = DBClient()
assert client.enable_dummy_test is True, f"Failed for value: {value}"
def test_dummy_mode_to_pandas_conversion(self):
"""Test to_pandas() method works with dummy data."""
with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
client = DBClient()
# Test raw format conversion
result = client.execute("SELECT * FROM test")
df = result.to_pandas()
assert isinstance(df, pd.DataFrame)
assert len(df) == 3
assert list(df.columns) == ['name']
assert df.iloc[0]['name'] == 'aaa'
# Test pandas format (should return same DataFrame)
result_pandas = client.execute("SELECT * FROM test", return_format="pandas")
df_pandas = result_pandas.to_pandas()
assert df_pandas is result_pandas.pandas
def test_dummy_mode_to_string_conversion(self):
"""Test to_string() method works with dummy data."""
with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
client = DBClient()
result = client.execute("SELECT * FROM test")
string_output = result.to_string()
expected_lines = [
'name',
'aaa',
'bbb',
'ccc',
''
]
assert string_output == '\n'.join(expected_lines)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
```