# Directory Structure
```
├── .gitignore
├── .python-version
├── CLAUDE.md
├── glama.json
├── LICENSE
├── mcpserverdemo.jpg
├── pyproject.toml
├── pytest.ini
├── README.md
├── RELEASE_NOTES.md
├── src
│ └── mcp_server_starrocks
│ ├── __init__.py
│ ├── connection_health_checker.py
│ ├── db_client.py
│ ├── db_summary_manager.py
│ └── server.py
└── tests
├── __init__.py
├── README.md
└── test_db_client.py
```
# Files
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
```
1 | 3.12
2 |
```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
1 | # Python-generated files
2 | __pycache__/
3 | *.py[oc]
4 | build/
5 | dist/
6 | wheels/
7 | *.egg-info
8 |
9 | # Virtual environments
10 | .venv
11 | uv.lock
12 |
13 | # IDE files
14 | .idea/
15 | .vscode/
16 |
17 | # Exclude Mac generated files
18 | .DS_Store
```
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
```markdown
1 | # Tests for mcp-server-starrocks
2 |
3 | ## Prerequisites
4 |
5 | 1. **StarRocks cluster running on localhost** with default configuration:
6 | - Host: localhost
7 | - Port: 9030 (MySQL protocol)
8 | - User: root
9 | - Password: (empty)
10 | - At least one BE node available
11 |
12 | 2. **Optional: Arrow Flight SQL enabled** (for Arrow Flight tests):
13 | - Port: 9408 (or custom port)
14 | - Add `arrow_flight_sql_port = 9408` to `fe.conf`
15 | - Restart FE service
16 | - Verify with: `python test_arrow_flight.py`
17 |
18 | 3. **Test dependencies installed**:
19 | ```bash
20 | uv add --optional test pytest pytest-cov
21 | ```
22 |
23 | ## Running Tests
24 |
25 | ### Quick Connection Test
26 | First, verify your StarRocks connection:
27 | ```bash
28 | # Test MySQL connection and basic operations
29 | python test_connection.py
30 |
31 | # Test Arrow Flight SQL connectivity (if enabled)
32 | python test_arrow_flight.py
33 | ```
34 |
35 | The MySQL test will verify basic connectivity and table operations. The Arrow Flight test will diagnose Arrow Flight SQL availability and performance.
36 |
37 | ### Full Test Suite
38 | Run the complete db_client test suite:
39 | ```bash
40 | # Run all tests (MySQL only)
41 | uv run pytest tests/test_db_client.py::TestDBClient -v
42 |
43 | # Run Arrow Flight SQL tests (if enabled)
44 | STARROCKS_FE_ARROW_FLIGHT_SQL_PORT=9408 uv run pytest tests/test_db_client.py::TestDBClientWithArrowFlight -v
45 |
46 | # Run all tests (both MySQL and Arrow Flight if available)
47 | uv run pytest tests/test_db_client.py -v
48 |
49 | # Run specific test
50 | uv run pytest tests/test_db_client.py::TestDBClient::test_execute_show_databases -v
51 | ```
52 |
53 | ### Test Coverage
54 |
55 | The test suite covers:
56 |
57 | - **Connection Management**: MySQL pooled connections and ADBC Arrow Flight SQL
58 | - **Query Execution**: SELECT, DDL, DML operations with both success and error cases
59 | - **Result Formats**: Raw ResultSet and pandas DataFrame outputs
60 | - **Database Context**: Switching databases for queries
61 | - **Error Handling**: Connection failures, invalid queries, malformed SQL
62 | - **Resource Management**: Connection pooling, cursor cleanup, connection reset
63 | - **Edge Cases**: Empty results, type conversion, schema operations
64 |
65 | ### Test Configuration
66 |
67 | - **Single-node setup**: Tests create tables with `PROPERTIES ("replication_num" = "1")`
68 | - **Temporary databases**: Tests create and clean up test databases automatically
69 | - **Arrow Flight SQL**: Tests are skipped if `STARROCKS_FE_ARROW_FLIGHT_SQL_PORT` is not set
70 | - **Isolation**: Each test uses a fresh DBClient instance with reset connections
71 |
72 | ## Test Results
73 |
74 | When all tests pass, you should see:
75 | ```
76 | ======================== 16 passed, 2 skipped in 1.30s =========================
77 | ```
78 |
79 | The 2 skipped tests are Arrow Flight SQL tests that only run when the environment variable is configured.
80 |
81 | ## Troubleshooting
82 |
83 | **Connection issues**:
84 | - Ensure StarRocks FE is running on localhost:9030
85 | - Check that the `root` user has no password set
86 | - Verify at least one BE node is available
87 |
88 | **Table creation failures**:
89 | - Single-node clusters need `replication_num=1`
90 | - Check StarRocks logs for detailed error messages
91 |
92 | **Import errors**:
93 | - Ensure you're running from the project root directory
94 | - Check that `src/mcp_server_starrocks` is in your Python path
```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
1 | [](https://mseep.ai/app/starrocks-mcp-server-starrocks)
2 |
3 | # StarRocks Official MCP Server
4 |
5 | The StarRocks MCP Server acts as a bridge between AI assistants and StarRocks databases. It allows for direct SQL execution, database exploration, data visualization via charts, and retrieving detailed schema/data overviews without requiring complex client-side setup.
6 |
7 | <a href="https://glama.ai/mcp/servers/@StarRocks/mcp-server-starrocks">
8 | <img width="380" height="200" src="https://glama.ai/mcp/servers/@StarRocks/mcp-server-starrocks/badge" alt="StarRocks Server MCP server" />
9 | </a>
10 |
11 | ## Features
12 |
13 | - **Direct SQL Execution:** Run `SELECT` queries (`read_query`) and DDL/DML commands (`write_query`).
14 | - **Database Exploration:** List databases and tables, retrieve table schemas (`starrocks://` resources).
15 | - **System Information:** Access internal StarRocks metrics and states via the `proc://` resource path.
16 | - **Detailed Overviews:** Get comprehensive summaries of tables (`table_overview`) or entire databases (`db_overview`), including column definitions, row counts, and sample data.
17 | - **Data Visualization:** Execute a query and generate a Plotly chart directly from the results (`query_and_plotly_chart`).
18 | - **Intelligent Caching:** Table and database overviews are cached in memory to speed up repeated requests. Cache can be bypassed when needed.
19 | - **Flexible Configuration:** Set connection details and behavior via environment variables.
20 |
21 | ## Configuration
22 |
23 | The MCP server is typically run via an MCP host. Configuration is passed to the host, specifying how to launch the StarRocks MCP server process.
24 |
25 | **Using Streamable HTTP (recommended):**
26 |
27 | To start the server in Streamable HTTP mode:
28 |
29 | First test connect is ok:
30 | ```
31 | $ STARROCKS_URL=root:@localhost:8000 uv run mcp-server-starrocks --test
32 | ```
33 |
34 | Start the server:
35 |
36 | ```
37 | uv run mcp-server-starrocks --mode streamable-http --port 8000
38 | ```
39 |
40 | Then config the MCP like this:
41 |
42 | ```json
43 | {
44 | "mcpServers": {
45 | "mcp-server-starrocks": {
46 | "url": "http://localhost:8000/mcp"
47 | }
48 | }
49 | }
50 | ```
51 |
52 |
53 | **Using `uv` with installed package (individual environment variables):**
54 |
55 | ```json
56 | {
57 | "mcpServers": {
58 | "mcp-server-starrocks": {
59 | "command": "uv",
60 | "args": ["run", "--with", "mcp-server-starrocks", "mcp-server-starrocks"],
61 | "env": {
62 | "STARROCKS_HOST": "default localhost",
63 | "STARROCKS_PORT": "default 9030",
64 | "STARROCKS_USER": "default root",
65 | "STARROCKS_PASSWORD": "default empty",
66 | "STARROCKS_DB": "default empty"
67 | }
68 | }
69 | }
70 | }
71 | ```
72 |
73 | **Using `uv` with installed package (connection URL):**
74 |
75 | ```json
76 | {
77 | "mcpServers": {
78 | "mcp-server-starrocks": {
79 | "command": "uv",
80 | "args": ["run", "--with", "mcp-server-starrocks", "mcp-server-starrocks"],
81 | "env": {
82 | "STARROCKS_URL": "root:password@localhost:9030/my_database"
83 | }
84 | }
85 | }
86 | }
87 | ```
88 |
89 | **Using `uv` with local directory (for development):**
90 |
91 | ```json
92 | {
93 | "mcpServers": {
94 | "mcp-server-starrocks": {
95 | "command": "uv",
96 | "args": [
97 | "--directory",
98 | "path/to/mcp-server-starrocks", // <-- Update this path
99 | "run",
100 | "mcp-server-starrocks"
101 | ],
102 | "env": {
103 | "STARROCKS_HOST": "default localhost",
104 | "STARROCKS_PORT": "default 9030",
105 | "STARROCKS_USER": "default root",
106 | "STARROCKS_PASSWORD": "default empty",
107 | "STARROCKS_DB": "default empty"
108 | }
109 | }
110 | }
111 | }
112 | ```
113 |
114 | **Using `uv` with local directory and connection URL:**
115 |
116 | ```json
117 | {
118 | "mcpServers": {
119 | "mcp-server-starrocks": {
120 | "command": "uv",
121 | "args": [
122 | "--directory",
123 | "path/to/mcp-server-starrocks", // <-- Update this path
124 | "run",
125 | "mcp-server-starrocks"
126 | ],
127 | "env": {
128 | "STARROCKS_URL": "root:password@localhost:9030/my_database"
129 | }
130 | }
131 | }
132 | }
133 | ```
134 |
135 | **Command-line Arguments:**
136 |
137 | The server supports the following command-line arguments:
138 |
139 | ```bash
140 | uv run mcp-server-starrocks --help
141 | ```
142 |
143 | - `--mode {stdio,sse,http,streamable-http}`: Transport mode (default: stdio or MCP_TRANSPORT_MODE env var)
144 | - `--host HOST`: Server host for HTTP modes (default: localhost)
145 | - `--port PORT`: Server port for HTTP modes
146 | - `--test`: Run in test mode to verify functionality
147 |
148 | Examples:
149 |
150 | ```bash
151 | # Start in streamable HTTP mode on custom host/port
152 | uv run mcp-server-starrocks --mode streamable-http --host 0.0.0.0 --port 8080
153 |
154 | # Start in stdio mode (default)
155 | uv run mcp-server-starrocks --mode stdio
156 |
157 | # Run test mode
158 | uv run mcp-server-starrocks --test
159 | ```
160 |
161 | - The `url` field should point to the Streamable HTTP endpoint of your MCP server (adjust host/port as needed).
162 | - With this configuration, clients can interact with the server using standard JSON over HTTP POST requests. No special SDK is required.
163 | - All tool APIs accept and return standard JSON as described above.
164 |
165 | > **Note:**
166 | > The `sse` (Server-Sent Events) mode is deprecated and no longer maintained. Please use Streamable HTTP mode for all new integrations.
167 |
168 | **Environment Variables:**
169 |
170 | ### Connection Configuration
171 |
172 | You can configure StarRocks connection using either individual environment variables or a single connection URL:
173 |
174 | **Option 1: Individual Environment Variables**
175 |
176 | - `STARROCKS_HOST`: (Optional) Hostname or IP address of the StarRocks FE service. Defaults to `localhost`.
177 | - `STARROCKS_PORT`: (Optional) MySQL protocol port of the StarRocks FE service. Defaults to `9030`.
178 | - `STARROCKS_USER`: (Optional) StarRocks username. Defaults to `root`.
179 | - `STARROCKS_PASSWORD`: (Optional) StarRocks password. Defaults to empty string.
180 | - `STARROCKS_DB`: (Optional) Default database to use if not specified in tool arguments or resource URIs. If set, the connection will attempt to `USE` this database. Tools like `table_overview` and `db_overview` will use this if the database part is omitted in their arguments. Defaults to empty (no default database).
181 |
182 | **Option 2: Connection URL (takes precedence over individual variables)**
183 |
184 | - `STARROCKS_URL`: (Optional) A connection URL string that contains all connection parameters in a single variable. Format: `[<schema>://]user:password@host:port/database`. The schema part is optional. When this variable is set, it takes precedence over the individual `STARROCKS_HOST`, `STARROCKS_PORT`, `STARROCKS_USER`, `STARROCKS_PASSWORD`, and `STARROCKS_DB` variables.
185 |
186 | Examples:
187 | - `root:mypass@localhost:9030/test_db`
188 | - `mysql://admin:[email protected]:9030/production`
189 | - `starrocks://user:[email protected]:9030/analytics`
190 |
191 | ### Additional Configuration
192 |
193 | - `STARROCKS_OVERVIEW_LIMIT`: (Optional) An _approximate_ character limit for the _total_ text generated by overview tools (`table_overview`, `db_overview`) when fetching data to populate the cache. This helps prevent excessive memory usage for very large schemas or numerous tables. Defaults to `20000`.
194 |
195 | - `STARROCKS_MYSQL_AUTH_PLUGIN`: (Optional) Specifies the authentication plugin to use when connecting to the StarRocks FE service. For example, set to `mysql_clear_password` if your StarRocks deployment requires clear text password authentication (such as when using certain LDAP or external authentication setups). Only set this if your environment specifically requires it; otherwise, the default auth_plugin is used.
196 |
197 | - `MCP_TRANSPORT_MODE`: (Optional) Communication mode that specifies how the MCP Server exposes its services. Available options:
198 | - `stdio` (default): Communicates through standard input/output, suitable for MCP Host hosting.
199 | - `streamable-http` (Streamable HTTP): Starts as a Streamable HTTP Server, supporting RESTful API calls.
200 | - `sse`: **(Deprecated, not recommended)** Starts in Server-Sent Events (SSE) streaming mode, suitable for scenarios requiring streaming responses. **Note: SSE mode is no longer maintained, it is recommended to use Streamable HTTP mode uniformly.**
201 |
202 | ## Components
203 |
204 | ### Tools
205 |
206 | - `read_query`
207 |
208 | - **Description:** Execute a SELECT query or other commands that return a ResultSet (e.g., `SHOW`, `DESCRIBE`).
209 | - **Input:**
210 | ```json
211 | {
212 | "query": "SQL query string",
213 | "db": "database name (optional, uses default database if not specified)"
214 | }
215 | ```
216 | - **Output:** Text content containing the query results in a CSV-like format, including a header row and a row count summary. Returns an error message on failure.
217 |
218 | - `write_query`
219 |
220 | - **Description:** Execute a DDL (`CREATE`, `ALTER`, `DROP`), DML (`INSERT`, `UPDATE`, `DELETE`), or other StarRocks command that does not return a ResultSet.
221 | - **Input:**
222 | ```json
223 | {
224 | "query": "SQL command string",
225 | "db": "database name (optional, uses default database if not specified)"
226 | }
227 | ```
228 | - **Output:** Text content confirming success (e.g., "Query OK, X rows affected") or reporting an error. Changes are committed automatically on success.
229 |
230 | - `analyze_query`
231 |
232 | - **Description:** Analyze a query and get analyze result using query profile or explain analyze.
233 | - **Input:**
234 | ```json
235 | {
236 | "uuid": "Query ID, a string composed of 32 hexadecimal digits formatted as 8-4-4-4-12",
237 | "sql": "Query SQL to analyze",
238 | "db": "database name (optional, uses default database if not specified)"
239 | }
240 | ```
241 | - **Output:** Text content containing the query analysis results. Uses `ANALYZE PROFILE FROM` if uuid is provided, otherwise uses `EXPLAIN ANALYZE` if sql is provided.
242 |
243 | - `query_and_plotly_chart`
244 |
245 | - **Description:** Executes a SQL query, loads the results into a Pandas DataFrame, and generates a Plotly chart using a provided Python expression. Designed for visualization in supporting UIs.
246 | - **Input:**
247 | ```json
248 | {
249 | "query": "SQL query to fetch data",
250 | "plotly_expr": "Python expression string using 'px' (Plotly Express) and 'df' (DataFrame). Example: 'px.scatter(df, x=\"col1\", y=\"col2\")'",
251 | "db": "database name (optional, uses default database if not specified)"
252 | }
253 | ```
254 | - **Output:** A list containing:
255 | 1. `TextContent`: A text representation of the DataFrame and a note that the chart is for UI display.
256 | 2. `ImageContent`: The generated Plotly chart encoded as a base64 PNG image (`image/png`). Returns text error message on failure or if the query yields no data.
257 |
258 | - `table_overview`
259 |
260 | - **Description:** Get an overview of a specific table: columns (from `DESCRIBE`), total row count, and sample rows (`LIMIT 3`). Uses an in-memory cache unless `refresh` is true.
261 | - **Input:**
262 | ```json
263 | {
264 | "table": "Table name, optionally prefixed with database name (e.g., 'db_name.table_name' or 'table_name'). If database is omitted, uses STARROCKS_DB environment variable if set.",
265 | "refresh": false // Optional, boolean. Set to true to bypass the cache. Defaults to false.
266 | }
267 | ```
268 | - **Output:** Text content containing the formatted overview (columns, row count, sample data) or an error message. Cached results include previous errors if applicable.
269 |
270 | - `db_overview`
271 | - **Description:** Get an overview (columns, row count, sample rows) for _all_ tables within a specified database. Uses the table-level cache for each table unless `refresh` is true.
272 | - **Input:**
273 | ```json
274 | {
275 | "db": "database_name", // Optional if default database is set.
276 | "refresh": false // Optional, boolean. Set to true to bypass the cache for all tables in the DB. Defaults to false.
277 | }
278 | ```
279 | - **Output:** Text content containing concatenated overviews for all tables found in the database, separated by headers. Returns an error message if the database cannot be accessed or contains no tables.
280 |
281 | ### Resources
282 |
283 | #### Direct Resources
284 |
285 | - `starrocks:///databases`
286 | - **Description:** Lists all databases accessible to the configured user.
287 | - **Equivalent Query:** `SHOW DATABASES`
288 | - **MIME Type:** `text/plain`
289 |
290 | #### Resource Templates
291 |
292 | - `starrocks:///{db}/{table}/schema`
293 |
294 | - **Description:** Gets the schema definition of a specific table.
295 | - **Equivalent Query:** `SHOW CREATE TABLE {db}.{table}`
296 | - **MIME Type:** `text/plain`
297 |
298 | - `starrocks:///{db}/tables`
299 |
300 | - **Description:** Lists all tables within a specific database.
301 | - **Equivalent Query:** `SHOW TABLES FROM {db}`
302 | - **MIME Type:** `text/plain`
303 |
304 | - `proc:///{+path}`
305 | - **Description:** Accesses StarRocks internal system information, similar to Linux `/proc`. The `path` parameter specifies the desired information node.
306 | - **Equivalent Query:** `SHOW PROC '/{path}'`
307 | - **MIME Type:** `text/plain`
308 | - **Common Paths:**
309 | - `/frontends` - Information about FE nodes.
310 | - `/backends` - Information about BE nodes (for non-cloud native deployments).
311 | - `/compute_nodes` - Information about CN nodes (for cloud native deployments).
312 | - `/dbs` - Information about databases.
313 | - `/dbs/<DB_ID>` - Information about a specific database by ID.
314 | - `/dbs/<DB_ID>/<TABLE_ID>` - Information about a specific table by ID.
315 | - `/dbs/<DB_ID>/<TABLE_ID>/partitions` - Partition information for a table.
316 | - `/transactions` - Transaction information grouped by database.
317 | - `/transactions/<DB_ID>` - Transaction information for a specific database ID.
318 | - `/transactions/<DB_ID>/running` - Running transactions for a database ID.
319 | - `/transactions/<DB_ID>/finished` - Finished transactions for a database ID.
320 | - `/jobs` - Information about asynchronous jobs (Schema Change, Rollup, etc.).
321 | - `/statistic` - Statistics for each database.
322 | - `/tasks` - Information about agent tasks.
323 | - `/cluster_balance` - Load balance status information.
324 | - `/routine_loads` - Information about Routine Load jobs.
325 | - `/colocation_group` - Information about Colocation Join groups.
326 | - `/catalog` - Information about configured catalogs (e.g., Hive, Iceberg).
327 |
328 | ### Prompts
329 |
330 | None defined by this server.
331 |
332 | ## Caching Behavior
333 |
334 | - The `table_overview` and `db_overview` tools utilize an in-memory cache to store the generated overview text.
335 | - The cache key is a tuple of `(database_name, table_name)`.
336 | - When `table_overview` is called, it checks the cache first. If a result exists and the `refresh` parameter is `false` (default), the cached result is returned immediately. Otherwise, it fetches the data from StarRocks, stores it in the cache, and then returns it.
337 | - When `db_overview` is called, it lists all tables in the database and then attempts to retrieve the overview for _each table_ using the same caching logic as `table_overview` (checking cache first, fetching if needed and `refresh` is `false` or cache miss). If `refresh` is `true` for `db_overview`, it forces a refresh for _all_ tables in that database.
338 | - The `STARROCKS_OVERVIEW_LIMIT` environment variable provides a _soft target_ for the maximum length of the overview string generated _per table_ when populating the cache, helping to manage memory usage.
339 | - Cached results, including any error messages encountered during the original fetch, are stored and returned on subsequent cache hits.
340 |
341 | ## Debug
342 |
343 | After starting mcp server, you can use inspector to debug:
344 | ```
345 | npx @modelcontextprotocol/inspector
346 | ```
347 |
348 | ## Demo
349 |
350 | 
351 |
```
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
```markdown
1 | # CLAUDE.md
2 |
3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4 |
5 | ## Project Overview
6 |
7 | StarRocks Official MCP Server - A bridge between AI assistants and StarRocks databases, built using FastMCP framework. Enables direct SQL execution, database exploration, data visualization, and schema introspection through the Model Context Protocol (MCP).
8 |
9 | ## Development Commands
10 |
11 | **Local Development:**
12 | ```bash
13 | # Run the server directly for testing
14 | uv run mcp-server-starrocks
15 |
16 | # Run with test mode to verify table overview functionality
17 | uv run mcp-server-starrocks --test
18 |
19 | # Run in Streamable HTTP mode (recommended for integration)
20 | export MCP_TRANSPORT_MODE=streamable-http
21 | uv run mcp-server-starrocks
22 | ```
23 |
24 | **Package Management:**
25 | ```bash
26 | # Install dependencies (handled by uv automatically)
27 | uv sync
28 |
29 | # Build package
30 | uv build
31 | ```
32 |
33 | ## Architecture Overview
34 |
35 | ### Core Components
36 |
37 | - **`src/mcp_server_starrocks/server.py`**: Main server implementation containing all MCP tools, resources, and database connection logic
38 | - **`src/mcp_server_starrocks/__init__.py`**: Entry point that starts the async server
39 |
40 | ### Connection Architecture
41 |
42 | The server supports two connection modes:
43 | - **Standard MySQL Protocol**: Default connection using `mysql.connector`
44 | - **Arrow Flight SQL**: High-performance connection using ADBC drivers (enabled when `STARROCKS_FE_ARROW_FLIGHT_SQL_PORT` is set)
45 |
46 | Connection management uses a global singleton pattern with automatic reconnection handling.
47 |
48 | ### Tool Categories
49 |
50 | 1. **Query Execution Tools**:
51 | - `read_query`: Execute SELECT and other result-returning queries
52 | - `write_query`: Execute DDL/DML commands
53 | - `analyze_query`: Query performance analysis via EXPLAIN ANALYZE
54 |
55 | 2. **Overview Tools with Caching**:
56 | - `table_overview`: Get table schema, row count, and sample data (cached)
57 | - `db_overview`: Get overview of all tables in a database (uses table cache)
58 |
59 | 3. **Visualization Tool**:
60 | - `query_and_plotly_chart`: Execute query and generate Plotly charts from results
61 |
62 | ### Resource Endpoints
63 |
64 | - `starrocks:///databases`: List all databases
65 | - `starrocks:///{db}/tables`: List tables in a database
66 | - `starrocks:///{db}/{table}/schema`: Get table CREATE statement
67 | - `proc:///{path}`: Access StarRocks internal system information (similar to Linux /proc)
68 |
69 | ### Caching System
70 |
71 | In-memory cache for table overviews using `(database_name, table_name)` as cache keys. Cache includes both successful results and error messages. Controlled by `STARROCKS_OVERVIEW_LIMIT` environment variable (default: 20000 characters).
72 |
73 | ## Configuration
74 |
75 | Environment variables for database connection:
76 | - `STARROCKS_HOST`: Database host (default: localhost)
77 | - `STARROCKS_PORT`: MySQL port (default: 9030)
78 | - `STARROCKS_USER`: Username (default: root)
79 | - `STARROCKS_PASSWORD`: Password (default: empty)
80 | - `STARROCKS_DB`: Default database for session
81 | - `STARROCKS_MYSQL_AUTH_PLUGIN`: Auth plugin (e.g., mysql_clear_password)
82 | - `STARROCKS_FE_ARROW_FLIGHT_SQL_PORT`: Enables Arrow Flight SQL mode
83 | - `MCP_TRANSPORT_MODE`: Communication mode (stdio/streamable-http/sse)
84 |
85 | ## Code Patterns
86 |
87 | ### Error Handling
88 | - Database errors trigger connection reset via `reset_connection()`
89 | - All tools return string error messages rather than raising exceptions
90 | - Cursors are always closed in finally blocks
91 |
92 | ### Security
93 | - SQL injection prevention through parameterized queries and backtick escaping
94 | - Plotly expressions are validated using AST parsing to prevent code injection
95 | - Limited `eval()` usage with restricted scope for chart generation
96 |
97 | ### Async Patterns
98 | - Tools are defined as async functions even though database operations are synchronous
99 | - Main server runs in async context using `FastMCP.run_async()`
100 |
101 | ## Package Structure
102 |
103 | This is a simple Python package built with hatchling:
104 | - Single module in `src/mcp_server_starrocks/`
105 | - Entry point defined in pyproject.toml as `mcp-server-starrocks` command
106 | - Dependencies managed through pyproject.toml, no requirements.txt files
```
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
```python
1 | # Tests for mcp-server-starrocks
```
--------------------------------------------------------------------------------
/glama.json:
--------------------------------------------------------------------------------
```json
1 | {
2 | "$schema": "https://glama.ai/mcp/schemas/server.json",
3 | "maintainers": [
4 | "decster"
5 | ]
6 | }
7 |
```
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
```
1 | [tool:pytest]
2 | testpaths = tests
3 | python_files = test_*.py
4 | python_classes = Test*
5 | python_functions = test_*
6 | addopts = -v --tb=short
7 | filterwarnings =
8 | ignore::DeprecationWarning
9 | ignore::PendingDeprecationWarning
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/__init__.py:
--------------------------------------------------------------------------------
```python
1 | # Copyright 2021-present StarRocks, Inc. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https:#www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from . import server
15 | import asyncio
16 |
17 |
18 | def main():
19 | asyncio.run(server.main())
20 |
21 |
22 | __all__ = ['main', 'server']
23 |
```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
```toml
1 | [project]
2 | name = "mcp-server-starrocks"
3 | version = "0.2.0"
4 | description = "official MCP server for StarRocks"
5 | readme = "README.md"
6 | license = {text = "Apache-2.0"}
7 | requires-python = ">=3.10"
8 | dependencies = [
9 | "loguru>=0.7.3",
10 | "fastmcp>=2.12.0,<2.13.0",
11 | "mysql-connector-python>=9.2.0",
12 | "pandas>=2.2.3",
13 | "plotly>=6.0.1",
14 | "kaleido==0.2.1",
15 | "adbc-driver-manager>=0.8.0",
16 | "adbc-driver-flightsql>=0.8.0",
17 | "pyarrow>=14.0.0",
18 | ]
19 |
20 | [project.optional-dependencies]
21 | test = [
22 | "pytest>=7.0.0",
23 | "pytest-cov>=4.0.0",
24 | ]
25 |
26 | [[project.authors]]
27 | name = "changbinglin"
28 | email = "[email protected]"
29 |
30 | [build-system]
31 | requires = [ "hatchling",]
32 | build-backend = "hatchling.build"
33 |
34 | [project.scripts]
35 | mcp-server-starrocks = "mcp_server_starrocks:main"
36 |
37 | [project.urls]
38 | Home = "https://github.com/starrocks/mcp-server-starrocks"
39 |
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/connection_health_checker.py:
--------------------------------------------------------------------------------
```python
1 | # Copyright 2021-present StarRocks, Inc. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import threading
16 | import time
17 | from loguru import logger
18 |
19 |
20 | class ConnectionHealthChecker:
21 | """
22 | A singleton class that manages database connection health monitoring.
23 | """
24 |
25 | def __init__(self, db_client, check_interval=30):
26 | """
27 | Initialize the connection health checker.
28 |
29 | Args:
30 | db_client: Database client instance for health checks
31 | check_interval: Health check interval in seconds (default: 30)
32 | """
33 | self.db_client = db_client
34 | self.check_interval = check_interval
35 | self._health_check_thread = None
36 | self._health_check_stop_event = threading.Event()
37 | self._last_connection_status = None
38 | self._last_healthy_log = None
39 |
40 | def check_connection_health(self):
41 | """
42 | Check database connection health by executing a simple query.
43 | Returns tuple of (is_healthy: bool, error_message: str or None)
44 | """
45 | try:
46 | result = self.db_client.execute("show databases")
47 | if result.success:
48 | return True, None
49 | else:
50 | return False, result.error_message
51 | except Exception as e:
52 | return False, str(e)
53 |
54 | def _connection_health_checker_loop(self):
55 | """
56 | Background thread function that periodically checks connection health.
57 | """
58 | logger.info(f"Starting connection health checker (interval: {self.check_interval}s)")
59 | while True:
60 | is_healthy, error_msg = self.check_connection_health()
61 | # Log status changes or periodic status updates
62 | if self._last_connection_status != is_healthy:
63 | if is_healthy:
64 | logger.info("Database connection is healthy")
65 | else:
66 | logger.warning(f"Database connection is unhealthy: {error_msg}")
67 | else:
68 | # Log periodic status (every 5 minutes when healthy, every check when unhealthy)
69 | current_time = time.time()
70 | if is_healthy:
71 | if self._last_healthy_log is None:
72 | self._last_healthy_log = current_time
73 | elif current_time - self._last_healthy_log >= 300: # 5 minutes
74 | logger.info("Database connection remains healthy")
75 | self._last_healthy_log = current_time
76 | else:
77 | logger.warning(f"Database connection remains unhealthy: {error_msg}")
78 | self._last_connection_status = is_healthy
79 | # Wait for interval or stop event
80 | if self._health_check_stop_event.wait(self.check_interval):
81 | break
82 | logger.info("Connection health checker stopped")
83 |
84 | def start(self):
85 | """
86 | Start the connection health checker thread.
87 | """
88 | if self._health_check_thread is None or not self._health_check_thread.is_alive():
89 | self._health_check_stop_event.clear()
90 | self._health_check_thread = threading.Thread(
91 | target=self._connection_health_checker_loop,
92 | name="ConnectionHealthChecker",
93 | daemon=True
94 | )
95 | self._health_check_thread.start()
96 | logger.info("Connection health checker thread started")
97 |
98 | def stop(self):
99 | """
100 | Stop the connection health checker thread.
101 | """
102 | if self._health_check_thread is not None:
103 | self._health_check_stop_event.set()
104 | self._health_check_thread.join(timeout=5)
105 | if self._health_check_thread.is_alive():
106 | logger.warning("Connection health checker thread did not stop gracefully")
107 | else:
108 | logger.info("Connection health checker thread stopped")
109 | self._health_check_thread = None
110 |
111 |
112 | # Global instance - will be initialized in server.py
113 | _health_checker_instance = None
114 |
115 |
116 | def initialize_health_checker(db_client, check_interval=30):
117 | """
118 | Initialize the global connection health checker instance.
119 |
120 | Args:
121 | db_client: Database client instance
122 | check_interval: Health check interval in seconds
123 | """
124 | global _health_checker_instance
125 | _health_checker_instance = ConnectionHealthChecker(db_client, check_interval)
126 | return _health_checker_instance
127 |
128 |
129 | def start_connection_health_checker():
130 | """
131 | Start the connection health checker thread.
132 | """
133 | if _health_checker_instance is None:
134 | raise RuntimeError("Health checker not initialized. Call initialize_health_checker() first.")
135 | _health_checker_instance.start()
136 |
137 |
138 | def stop_connection_health_checker():
139 | """
140 | Stop the connection health checker thread.
141 | """
142 | if _health_checker_instance is not None:
143 | _health_checker_instance.stop()
144 |
145 |
146 | def check_connection_health():
147 | """
148 | Check database connection health by executing a simple query.
149 | Returns tuple of (is_healthy: bool, error_message: str or None)
150 | """
151 | if _health_checker_instance is None:
152 | raise RuntimeError("Health checker not initialized. Call initialize_health_checker() first.")
153 | return _health_checker_instance.check_connection_health()
```
--------------------------------------------------------------------------------
/RELEASE_NOTES.md:
--------------------------------------------------------------------------------
```markdown
1 | # StarRocks MCP Server Release Notes
2 |
3 | ## Version 0.2.0
4 |
5 | ### Major Features and Enhancements
6 |
7 | 1. **Enhanced STARROCKS_URL Parsing** (commit 80ac0ba)
8 | - Support for flexible connection URL formats including empty passwords
9 | - Handle patterns like "root:@localhost:9030" and "root@localhost:9030"
10 | - Support missing ports with default 9030: "root:password@localhost"
11 | - Support minimal format: "user@host" with empty password and default port
12 | - Maintain backward compatibility with existing valid URLs
13 | - Comprehensive test coverage for edge cases
14 | - Fixed DBClient to properly convert string port to integer
15 |
16 | 2. **Connection Health Monitoring** (commit b8a80c6)
17 | - Added new connection_health_checker.py module
18 | - Implemented health checking functionality for database connections
19 | - Enhanced connection reliability and monitoring capabilities
20 | - Proactive connection health management
21 |
22 | 3. **Visualization Enhancements** (commit b6f26ec)
23 | - Added format parameter to query_and_plotly_chart tool
24 | - Enhanced chart generation capabilities with configurable output formats
25 | - Improved flexibility for data visualization workflows
26 |
27 | ### Testing and Infrastructure
28 |
29 | - Added comprehensive test coverage for STARROCKS_URL parsing edge cases
30 | - Enhanced test suite with new test cases for database client functionality
31 | - Improved error handling and validation for connection scenarios
32 |
33 | ### Breaking Changes
34 |
35 | None - this release maintains full backward compatibility with version 0.1.5.
36 |
37 | ## Version 0.1.5
38 |
39 | Major Features and Enhancements
40 |
41 | 1. Connection Pooling and Architecture Refactor (commit 0fc372d)
42 | - Major refactor introducing connection pooling for improved performance
43 | - Extracted database client logic into separate db_client.py module
44 | - Enhanced connection management and reliability
45 | 2. Enhanced Arrow Flight SQL Support (commit 877338f)
46 | - Improved Arrow Flight SQL connection handling
47 | - Better result processing for high-performance queries
48 | - Enhanced error handling for Arrow Flight connections
49 | 3. New Query Analysis Tools (commit 60ca975)
50 | - Added collect_query_dump_and_profile functionality
51 | - Enhanced query performance analysis capabilities
52 | 4. Database Summary Management (commits d269ebe, 5b2ca59)
53 | - Added new db_summary_manager.py module
54 | - Implemented database summary functionality for better overview capabilities
55 | - Enhanced database exploration features
56 | 5. Configuration Enhancements (commit fb09271)
57 | - Added STARROCKS_URL configuration option
58 | - Improved connection configuration flexibility
59 |
60 | Testing and Infrastructure
61 |
62 | - Updated test suite with new test cases for database client functionality
63 | - Added comprehensive testing for Arrow Flight SQL features
64 | - Improved test infrastructure with new README documentation
65 |
66 | Breaking Changes
67 |
68 | - Major refactor may require configuration updates for some deployment scenarios
69 | - Connection handling has been restructured (though backwards compatibility is maintained)
70 |
71 | ## Version 0.1.4
72 |
73 |
74 | ## Version 0.1.3
75 |
76 | 1. refactor using fastmcp
77 | 2. add new config STARROCKS_MYSQL_AUTH_PLUGIN
78 |
79 | ## Version 0.1.2
80 |
81 | Fix accidental extra import of sqlalalchemy
82 |
83 | ## Version 0.1.1
84 |
85 | 1. add new tool query_and_plotly_chart
86 | 2. add new tool table_overview & db_overview
87 | 3. add env config STARROCKS_DB and STARROCKS_OVERVIEW_LIMIT, both optional
88 |
89 |
90 | ## Version 0.1.0 (Initial Release)
91 |
92 | We are excited to announce the first release of the StarRocks MCP (Model Context Protocol) Server. This server enables AI assistants to interact directly with StarRocks databases, providing a seamless interface for executing queries and retrieving database information.
93 |
94 | ### Description
95 |
96 | The StarRocks MCP Server acts as a bridge between AI assistants and StarRocks databases, allowing for direct SQL execution and database exploration without requiring complex setup or configuration. This initial release provides essential functionality for database interaction while maintaining security and performance.
97 |
98 | ### Features
99 |
100 | - **SQL Query Execution**
101 | - `read_query` tool for executing SELECT queries and commands that return result sets
102 | - `write_query` tool for executing DDL/DML statements and other StarRocks commands
103 | - Proper error handling and connection management
104 |
105 | - **Database Exploration**
106 | - List all databases in a StarRocks instance
107 | - View table schemas using SHOW CREATE TABLE
108 | - List all tables within a specific database
109 |
110 | - **System Information Access**
111 | - Access to StarRocks internal system information via proc-like interface
112 | - Visibility into FE nodes, BE nodes, CN nodes, databases, tables, partitions, transactions, jobs, and more
113 |
114 | - **Flexible Configuration**
115 | - Configurable connection parameters (host, port, user, password)
116 | - Support for both package installation and local directory execution
117 |
118 | ### Requirements
119 |
120 | - Python 3.10 or higher
121 | - Dependencies:
122 | - mcp >= 1.0.0
123 | - mysql-connector-python >= 9.2.0
124 |
125 | ### Configuration
126 |
127 | The server can be configured through environment variables:
128 |
129 | - `STARROCKS_HOST` (default: localhost)
130 | - `STARROCKS_PORT` (default: 9030)
131 | - `STARROCKS_USER` (default: root)
132 | - `STARROCKS_PASSWORD` (default: empty)
133 | - `STARROCKS_MYSQL_AUTH_PLUGIN` (default: mysql_native_password) user can also pass different auth plugins like `mysql_clear_password`
134 |
135 | ### Installation
136 |
137 | The server can be installed as a Python package:
138 |
139 | ```bash
140 | pip install mcp-server-starrocks
141 | ```
142 |
143 | Or run directly from the source:
144 |
145 | ```bash
146 | uv --directory path/to/mcp-server-starrocks run mcp-server-starrocks
147 | ```
148 |
149 | ### MCP Integration
150 |
151 | Add the following configuration to your MCP settings file:
152 |
153 | ```json
154 | {
155 | "mcpServers": {
156 | "mcp-server-starrocks": {
157 | "command": "uv",
158 | "args": [
159 | "run",
160 | "--with",
161 | "mcp-server-starrocks",
162 | "mcp-server-starrocks"
163 | ],
164 | "env": {
165 | "STARROCKS_HOST": "localhost",
166 | "STARROCKS_PORT": "9030",
167 | "STARROCKS_USER": "root",
168 | "STARROCKS_PASSWORD": "",
169 | "STARROCKS_MYSQL_AUTH_PLUGIN":"mysql_clear_password"
170 | }
171 | }
172 | }
173 | }
174 | ```
175 |
176 | ---
177 |
178 | We welcome feedback and contributions to improve the StarRocks MCP Server. Please report any issues or suggestions through our GitHub repository.
179 |
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/db_summary_manager.py:
--------------------------------------------------------------------------------
```python
1 | # Copyright 2021-present StarRocks, Inc. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | import time
17 | from dataclasses import dataclass, field
18 | from typing import Dict, List, Optional, Tuple
19 | from loguru import logger
20 |
21 |
22 | @dataclass
23 | class ColumnInfo:
24 | name: str
25 | column_type: str
26 | ordinal_position: int
27 |
28 |
29 | @dataclass
30 | class TableInfo:
31 | name: str
32 | database: str
33 | size_bytes: int = 0
34 | size_str: str = ""
35 | replica_count: int = 0
36 | columns: List[ColumnInfo] = field(default_factory=list)
37 | create_statement: Optional[str] = None
38 | last_updated: float = 0
39 | error_message: Optional[str] = None
40 |
41 | def __post_init__(self):
42 | if not self.last_updated:
43 | self.last_updated = time.time()
44 |
45 | @staticmethod
46 | def parse_size_string(size_str: str) -> int:
47 | """Parse size strings like '1.285 GB', '714.433 MB', '2.269 KB' to bytes"""
48 | if not size_str or size_str == "0" or size_str.lower() == "total":
49 | return 0
50 |
51 | # Handle special cases
52 | if size_str.lower() in ["quota", "left"]:
53 | return 0
54 |
55 | # Match pattern like "1.285 GB"
56 | match = re.match(r'([\d.]+)\s*([KMGT]?B)', size_str.strip(), re.IGNORECASE)
57 | if not match:
58 | return 0
59 |
60 | value, unit = match.groups()
61 | try:
62 | num_value = float(value)
63 | except ValueError:
64 | return 0
65 |
66 | multipliers = {
67 | 'B': 1,
68 | 'KB': 1024,
69 | 'MB': 1024 ** 2,
70 | 'GB': 1024 ** 3,
71 | 'TB': 1024 ** 4
72 | }
73 |
74 | multiplier = multipliers.get(unit.upper(), 1)
75 | return int(num_value * multiplier)
76 |
77 | def is_large_table(self) -> bool:
78 | """Determine if table is considered large (replica_count > 64 OR size > 2GB)"""
79 | return self.replica_count > 64 or self.size_bytes > (2 * 1024 ** 3)
80 |
81 | def priority_score(self) -> float:
82 | """Calculate priority score combining size and replica count for sorting"""
83 | # Normalize size to GB and combine with replica count
84 | size_gb = self.size_bytes / (1024 ** 3)
85 | return size_gb + (self.replica_count * 0.1) # Weight replica count less than size
86 |
87 | def is_expired(self, expire_seconds: int = 120) -> bool:
88 | """Check if cache entry is expired (default 2 minutes)"""
89 | return time.time() - self.last_updated > expire_seconds
90 |
91 |
92 | class DatabaseSummaryManager:
93 | def __init__(self, db_client):
94 | self.db_client = db_client
95 | # Cache: {(database, table_name): TableInfo}
96 | self.table_cache: Dict[Tuple[str, str], TableInfo] = {}
97 | # Database last sync time: {database: timestamp}
98 | self.db_last_sync: Dict[str, float] = {}
99 |
100 | def _sync_table_list(self, database: str, force: bool = False) -> bool:
101 | """Sync table list using SHOW DATA, detect new/dropped tables"""
102 | current_time = time.time()
103 |
104 | # Check if sync is needed (2min expiration or force)
105 | if not force and database in self.db_last_sync:
106 | if current_time - self.db_last_sync[database] < 120:
107 | return True
108 |
109 | logger.debug(f"Syncing table list for database {database}")
110 |
111 | try:
112 | # Execute SHOW DATA to get current table list with sizes
113 | result = self.db_client.execute("SHOW DATA", db=database)
114 | if not result.success:
115 | logger.error(f"Failed to sync table list for {database}: {result.error_message}")
116 | return False
117 |
118 | if not result.rows:
119 | logger.info(f"No tables found in database {database}")
120 | # Clear cache for this database
121 | keys_to_remove = [key for key in self.table_cache.keys() if key[0] == database]
122 | for key in keys_to_remove:
123 | del self.table_cache[key]
124 | self.db_last_sync[database] = current_time
125 | return True
126 |
127 | # Parse current tables from SHOW DATA
128 | current_tables = {}
129 | for row in result.rows:
130 | table_name = row[0]
131 | # Skip summary rows (Total, Quota, Left)
132 | if table_name.lower() in ['total', 'quota', 'left']:
133 | continue
134 |
135 | size_str = row[1] if len(row) > 1 else ""
136 | replica_count = int(row[2]) if len(row) > 2 and str(row[2]).isdigit() else 0
137 |
138 | size_bytes = TableInfo.parse_size_string(size_str)
139 | current_tables[table_name] = {
140 | 'size_str': size_str,
141 | 'size_bytes': size_bytes,
142 | 'replica_count': replica_count
143 | }
144 |
145 | # Update cache: add new tables, update existing, remove dropped
146 | cache_keys_for_db = {key[1]: key for key in self.table_cache.keys() if key[0] == database}
147 |
148 | # Add or update existing tables
149 | for table_name, table_data in current_tables.items():
150 | cache_key = (database, table_name)
151 |
152 | if cache_key in self.table_cache:
153 | # Update existing table info
154 | table_info = self.table_cache[cache_key]
155 | table_info.size_str = table_data['size_str']
156 | table_info.size_bytes = table_data['size_bytes']
157 | table_info.replica_count = table_data['replica_count']
158 | table_info.last_updated = current_time
159 | else:
160 | # Create new table info
161 | self.table_cache[cache_key] = TableInfo(
162 | name=table_name,
163 | database=database,
164 | size_str=table_data['size_str'],
165 | size_bytes=table_data['size_bytes'],
166 | replica_count=table_data['replica_count'],
167 | last_updated=current_time
168 | )
169 |
170 | # Remove dropped tables
171 | for table_name in cache_keys_for_db:
172 | if table_name not in current_tables:
173 | cache_key = cache_keys_for_db[table_name]
174 | del self.table_cache[cache_key]
175 | logger.debug(f"Removed dropped table {database}.{table_name} from cache")
176 |
177 | self.db_last_sync[database] = current_time
178 | logger.debug(f"Synced {len(current_tables)} tables for database {database}")
179 | return True
180 |
181 | except Exception as e:
182 | logger.error(f"Error syncing table list for {database}: {e}")
183 | return False
184 |
185 | def _fetch_column_info(self, database: str, tables: List[str]) -> Dict[str, List[ColumnInfo]]:
186 | """Fetch column information for all tables using information_schema.columns"""
187 | if not tables:
188 | return {}
189 |
190 | logger.debug(f"Fetching column info for {len(tables)} tables in {database}")
191 |
192 | try:
193 | # Build query to get column information for all tables
194 | table_names_quoted = "', '".join(tables)
195 | query = f"""
196 | SELECT table_name, column_name, ordinal_position, column_type
197 | FROM information_schema.columns
198 | WHERE table_schema = '{database}'
199 | AND table_name IN ('{table_names_quoted}')
200 | ORDER BY table_name, ordinal_position
201 | """
202 |
203 | result = self.db_client.execute(query)
204 | if not result.success:
205 | logger.error(f"Failed to fetch column info: {result.error_message}")
206 | return {}
207 |
208 | # Group columns by table
209 | table_columns = {}
210 | for row in result.rows:
211 | table_name = row[0]
212 | column_name = row[1]
213 | ordinal_position = int(row[2]) if row[2] else 0
214 | column_type = 'string' if row[3] == "varchar(65533)" else row[3]
215 |
216 | if table_name not in table_columns:
217 | table_columns[table_name] = []
218 |
219 | table_columns[table_name].append(ColumnInfo(
220 | name=column_name,
221 | column_type=column_type,
222 | ordinal_position=ordinal_position
223 | ))
224 |
225 | logger.debug(f"Fetched column info for {len(table_columns)} tables")
226 | return table_columns
227 |
228 | except Exception as e:
229 | logger.error(f"Error fetching column information: {e}")
230 | return {}
231 |
232 | def _fetch_create_statement(self, database: str, table: str) -> Optional[str]:
233 | """Fetch CREATE TABLE statement for large tables"""
234 | try:
235 | result = self.db_client.execute(f"SHOW CREATE TABLE `{database}`.`{table}`")
236 | if result.success and result.rows and len(result.rows[0]) > 1:
237 | return result.rows[0][1] # Second column contains CREATE statement
238 | except Exception as e:
239 | logger.error(f"Error fetching CREATE statement for {database}.{table}: {e}")
240 | return None
241 |
242 | def get_database_summary(self, database: str, limit: int = 10000, refresh: bool = False) -> str:
243 | """Generate comprehensive database summary with intelligent prioritization"""
244 | if not database:
245 | return "Error: Database name is required"
246 |
247 | logger.info(f"Generating database summary for {database}, limit={limit}, refresh={refresh}")
248 |
249 | # Sync table list
250 | if refresh or not self._sync_table_list(database):
251 | return f"Error: Failed to sync table information for database '{database}'"
252 |
253 | # Get all tables for this database from cache
254 | tables_info = []
255 | for (db, table_name), table_info in self.table_cache.items():
256 | if db == database:
257 | tables_info.append(table_info)
258 |
259 | if not tables_info:
260 | return f"No tables found in database '{database}'"
261 |
262 | # Sort tables by priority (large tables first)
263 | tables_info.sort(key=lambda t: t.priority_score(), reverse=True)
264 |
265 | # Check if any table needs column information refresh
266 | need_column_refresh = refresh or any(not table_info.columns or table_info.is_expired() for table_info in tables_info)
267 |
268 | # If any table needs refresh, fetch ALL tables' columns in one query (more efficient)
269 | if need_column_refresh:
270 | all_table_names = [table_info.name for table_info in tables_info]
271 | table_columns = self._fetch_column_info(database, all_table_names)
272 |
273 | # Update cache with column information for all tables
274 | current_time = time.time()
275 | for table_info in tables_info:
276 | if table_info.name in table_columns:
277 | table_info.columns = table_columns[table_info.name]
278 | table_info.last_updated = current_time
279 |
280 | # Identify large tables that need CREATE statements
281 | large_tables = [t for t in tables_info if t.is_large_table()][:10] # Top 10 large tables
282 | for table_info in large_tables:
283 | if refresh or not table_info.create_statement:
284 | table_info.create_statement = self._fetch_create_statement(database, table_info.name)
285 | table_info.last_updated = time.time()
286 |
287 | # Generate summary output
288 | return self._format_database_summary(database, tables_info, limit)
289 |
290 | def _format_database_summary(self, database: str, tables_info: List[TableInfo], limit: int) -> str:
291 | """Format database summary with intelligent truncation"""
292 | lines = []
293 | lines.append(f"=== Database Summary: '{database}' ===")
294 | lines.append(f"Total tables: {len(tables_info)}")
295 |
296 | # Calculate totals
297 | total_size = sum(t.size_bytes for t in tables_info)
298 | total_replicas = sum(t.replica_count for t in tables_info)
299 | large_tables = [t for t in tables_info if t.is_large_table()]
300 |
301 | lines.append(f"Total size: {self._format_bytes(total_size)}")
302 |
303 | current_length = len("\n".join(lines))
304 | table_limit = min(len(tables_info), 50) # Show max 50 tables
305 |
306 | # Show large tables first with full details
307 | if large_tables:
308 | for i, table_info in enumerate(large_tables):
309 | if current_length > limit * 0.8: # Reserve 20% for smaller tables
310 | lines.append(f"... and {len(large_tables) - i} more large tables")
311 | break
312 |
313 | table_summary = self._format_table_info(table_info, detailed=True)
314 | lines.append(table_summary)
315 | lines.append("")
316 | current_length = len("\n".join(lines))
317 |
318 | # Show remaining tables with basic info
319 | remaining_tables = [t for t in tables_info if not t.is_large_table()]
320 | if remaining_tables and current_length < limit:
321 | lines.append("--- Other Tables ---")
322 |
323 | for i, table_info in enumerate(remaining_tables):
324 | if current_length > limit:
325 | lines.append(f"... and {len(remaining_tables) - i} more tables (use higher limit to see all)")
326 | break
327 |
328 | table_summary = self._format_table_info(table_info, detailed=False)
329 | lines.append(table_summary)
330 | current_length = len("\n".join(lines))
331 |
332 | return "\n".join(lines)
333 |
334 | def _format_table_info(self, table_info: TableInfo, detailed: bool = True) -> str:
335 | """Format individual table information"""
336 | lines = []
337 |
338 | # Basic info line
339 | size_info = f"{table_info.size_str} ({table_info.replica_count} replicas)"
340 | lines.append(f"Table: {table_info.name} - {size_info}")
341 |
342 | if table_info.error_message:
343 | lines.append(f" Error: {table_info.error_message}")
344 | return "\n".join(lines)
345 |
346 | # Show CREATE statement if available, otherwise show column list
347 | if table_info.create_statement:
348 | lines.append(table_info.create_statement)
349 | elif table_info.columns:
350 | # Sort columns by ordinal position and show as list
351 | sorted_columns = sorted(table_info.columns, key=lambda c: c.ordinal_position)
352 | if detailed or len(sorted_columns) <= 20:
353 | for col in sorted_columns:
354 | lines.append(f" {col.name} {col.column_type}")
355 | else:
356 | lines.append(f" Columns ({len(sorted_columns)}): {', '.join(col.name for col in sorted_columns[:100])}...")
357 |
358 | return "\n".join(lines)
359 |
360 | @staticmethod
361 | def _format_bytes(bytes_count: int) -> str:
362 | """Format bytes to human readable string"""
363 | if bytes_count == 0:
364 | return "0 B"
365 |
366 | units = ['B', 'KB', 'MB', 'GB', 'TB']
367 | unit_index = 0
368 | size = float(bytes_count)
369 |
370 | while size >= 1024 and unit_index < len(units) - 1:
371 | size /= 1024
372 | unit_index += 1
373 |
374 | if unit_index == 0:
375 | return f"{int(size)} {units[unit_index]}"
376 | else:
377 | return f"{size:.2f} {units[unit_index]}"
378 |
379 | def clear_cache(self, database: Optional[str] = None):
380 | """Clear cache for specific database or all databases"""
381 | if database:
382 | keys_to_remove = [key for key in self.table_cache.keys() if key[0] == database]
383 | for key in keys_to_remove:
384 | del self.table_cache[key]
385 | if database in self.db_last_sync:
386 | del self.db_last_sync[database]
387 | logger.info(f"Cleared cache for database {database}")
388 | else:
389 | self.table_cache.clear()
390 | self.db_last_sync.clear()
391 | logger.info("Cleared all cache")
392 |
393 |
394 | # Global instance (will be initialized in server.py)
395 | _db_summary_manager: Optional[DatabaseSummaryManager] = None
396 |
397 |
398 | def get_db_summary_manager(db_client) -> DatabaseSummaryManager:
399 | """Get or create global database summary manager instance"""
400 | global _db_summary_manager
401 | if _db_summary_manager is None:
402 | _db_summary_manager = DatabaseSummaryManager(db_client)
403 | return _db_summary_manager
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/db_client.py:
--------------------------------------------------------------------------------
```python
1 | # Copyright 2021-present StarRocks, Inc. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import io
16 | import os
17 | import time
18 | import re
19 | import json
20 | from typing import Optional, List, Any, Union, Literal, TypedDict, NotRequired
21 | from dataclasses import dataclass
22 | import mysql.connector
23 | from mysql.connector import Error as MySQLError
24 | import adbc_driver_manager
25 | import adbc_driver_flightsql.dbapi as flight_sql
26 | from adbc_driver_manager import Error as adbcError
27 | import pandas as pd
28 |
29 |
30 | @dataclass
31 | class ResultSet:
32 | """Database query result set."""
33 | success: bool
34 | column_names: Optional[List[str]] = None
35 | rows: Optional[List[List[Any]]] = None
36 | rows_affected: Optional[int] = None
37 | execution_time: Optional[float] = None
38 | error_message: Optional[str] = None
39 | pandas: Optional[pd.DataFrame] = None
40 |
41 | def to_pandas(self) -> pd.DataFrame:
42 | """Convert ResultSet to pandas DataFrame."""
43 | if self.pandas is not None:
44 | return self.pandas
45 |
46 | if not self.success:
47 | raise ValueError(f"Cannot convert failed result to DataFrame: {self.error_message}")
48 |
49 | if self.column_names is None or self.rows is None:
50 | raise ValueError("No data available to convert to DataFrame")
51 |
52 | return pd.DataFrame(self.rows, columns=self.column_names)
53 |
54 | def to_string(self, limit: Optional[int] = None) -> str:
55 | """Format rows as CSV-like string with column names as first row."""
56 | if not self.success:
57 | return f"Error: {self.error_message}"
58 | if self.column_names is None or self.rows is None:
59 | return "No data"
60 | def to_csv_line(row):
61 | return ",".join(
62 | str(item).replace("\"", "\"\"") if isinstance(item, str) else str(item) for item in row)
63 | output = io.StringIO()
64 | output.write(to_csv_line(self.column_names) + "\n")
65 | for row in self.rows:
66 | line = to_csv_line(row) + "\n"
67 | if limit is not None and output.tell() + len(line) > limit:
68 | output.write("...\n")
69 | break
70 | output.write(line)
71 | output.write(f"Total rows: {len(self.rows)}\n")
72 | output.write(f"Execution time: {self.execution_time:.3f}s\n");
73 | return output.getvalue()
74 |
75 | def to_dict(self) -> dict:
76 | ret = {
77 | "success": self.success,
78 | "execution_time": self.execution_time,
79 | }
80 | if self.column_names is not None:
81 | ret["column_names"] = self.column_names
82 | ret["rows"] = self.rows
83 | if self.rows_affected is not None:
84 | ret["rows_affected"] = self.rows_affected
85 | if self.error_message:
86 | ret["error_message"] = self.error_message
87 | return ret
88 |
89 |
90 | class PerfAnalysisInput(TypedDict):
91 | error_message: NotRequired[Optional[str]]
92 | query_id: NotRequired[Optional[str]]
93 | rows_returned: NotRequired[Optional[int]]
94 | duration: NotRequired[Optional[float]]
95 | query_dump: NotRequired[Optional[dict]]
96 | profile: NotRequired[Optional[str]]
97 | analyze_profile: NotRequired[Optional[str]]
98 |
99 |
100 | def parse_connection_url(connection_url: str) -> dict:
101 | """
102 | Parse connection URL into dict with user, password, host, port, database.
103 |
104 | Supports flexible formats:
105 | - [<schema>://]<user>[:<password>]@<host>[:<port>][/<database>]
106 | - Empty passwords: user:@host:port or user@host:port
107 | - Missing ports (uses default 9030): user:pass@host
108 | - All components are optional except user and host
109 | """
110 | # More flexible regex pattern that handles optional password and port
111 | pattern = re.compile(
112 | r'^(?:(?P<schema>[\w+]+)://)?' # Optional schema://
113 | r'(?P<user>[^:@]+)' # Required username (no : or @)
114 | r'(?::(?P<password>[^@]*))?' # Optional :password (can be empty)
115 | r'@(?P<host>[^:/]+)' # Required @host
116 | r'(?::(?P<port>\d+))?' # Optional :port
117 | r'(?:/(?P<database>[\w-]+))?$' # Optional /database
118 | )
119 |
120 | match = pattern.match(connection_url)
121 | if not match:
122 | raise ValueError(f"Invalid connection URL: {connection_url}")
123 |
124 | result = match.groupdict()
125 |
126 | # Only keep connection parameters that mysql.connector supports
127 | # Filter out None values and schema (which is not a mysql.connector parameter)
128 | filtered_result = {}
129 |
130 | # Always include user and host as they are required
131 | filtered_result['user'] = result['user']
132 | filtered_result['host'] = result['host']
133 |
134 | # Include password (default to empty string if None)
135 | filtered_result['password'] = result['password'] if result['password'] is not None else ''
136 |
137 | # Include port (default to 9030 if None)
138 | filtered_result['port'] = result['port'] if result['port'] is not None else '9030'
139 |
140 | # Always include database (None if not provided in URL)
141 | filtered_result['database'] = result['database']
142 |
143 | # Note: schema is intentionally excluded as it's not supported by mysql.connector
144 |
145 | return filtered_result
146 |
147 | ANSI_ESCAPE_PATTERN = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
148 |
149 |
150 | def remove_ansi_codes(text):
151 | return ANSI_ESCAPE_PATTERN.sub('', text)
152 |
153 |
154 | class DBClient:
155 | """Simplified database client for StarRocks connection and query execution."""
156 |
157 | def __init__(self):
158 | self.enable_dummy_test = bool(os.getenv('STARROCKS_DUMMY_TEST'))
159 | self.enable_arrow_flight_sql = bool(os.getenv('STARROCKS_FE_ARROW_FLIGHT_SQL_PORT'))
160 | if os.getenv('STARROCKS_URL'):
161 | self.connection_params = parse_connection_url(os.getenv('STARROCKS_URL'))
162 | # Convert port to integer for mysql.connector
163 | self.connection_params['port'] = int(self.connection_params['port'])
164 | else:
165 | self.connection_params = {
166 | 'host': os.getenv('STARROCKS_HOST', 'localhost'),
167 | 'port': int(os.getenv('STARROCKS_PORT', '9030')),
168 | 'user': os.getenv('STARROCKS_USER', 'root'),
169 | 'password': os.getenv('STARROCKS_PASSWORD', ''),
170 | 'database': os.getenv('STARROCKS_DB', None),
171 | }
172 | self.connection_params.update(**{
173 | 'auth_plugin': os.getenv('STARROCKS_MYSQL_AUTH_PLUGIN', 'mysql_native_password'),
174 | 'pool_size': int(os.getenv('STARROCKS_POOL_SIZE', '10')),
175 | 'pool_name': 'mcp_starrocks_pool',
176 | 'pool_reset_session': True,
177 | 'autocommit': True,
178 | 'connection_timeout': int(os.getenv('STARROCKS_CONNECTION_TIMEOUT', '10')),
179 | 'connect_timeout': int(os.getenv('STARROCKS_CONNECTION_TIMEOUT', '10')),
180 | })
181 | self.default_database = self.connection_params.get('database')
182 |
183 | # MySQL connection pool
184 | self._connection_pool = None
185 |
186 | # ADBC connection (singleton)
187 | self._adbc_connection = None
188 |
189 | def _get_connection_pool(self):
190 | """Get or create a connection pool for MySQL connections."""
191 | if self._connection_pool is None:
192 | try:
193 | self._connection_pool = mysql.connector.pooling.MySQLConnectionPool(**self.connection_params)
194 | except MySQLError as conn_err:
195 | raise conn_err
196 |
197 | return self._connection_pool
198 |
199 | def _validate_connection(self, conn):
200 | """Validate that a MySQL connection is still alive and working."""
201 | try:
202 | conn.ping(reconnect=True, attempts=1, delay=0)
203 | return True
204 | except MySQLError:
205 | return False
206 |
207 | def _get_pooled_connection(self):
208 | """Get a MySQL connection from the pool with timeout and retry logic."""
209 | pool = self._get_connection_pool()
210 | try:
211 | conn = pool.get_connection()
212 | if not self._validate_connection(conn):
213 | conn.close()
214 | conn = pool.get_connection()
215 | return conn
216 | except mysql.connector.errors.PoolError as pool_err:
217 | if "Pool is exhausted" in str(pool_err):
218 | time.sleep(0.1)
219 | try:
220 | return pool.get_connection()
221 | except mysql.connector.errors.PoolError:
222 | self._connection_pool = None
223 | new_pool = self._get_connection_pool()
224 | return new_pool.get_connection()
225 | raise pool_err
226 |
227 | def _create_adbc_connection(self):
228 | """Create a new ADBC connection."""
229 | fe_host = self.connection_params['host']
230 | fe_port = os.getenv('STARROCKS_FE_ARROW_FLIGHT_SQL_PORT', '')
231 | user = self.connection_params['user']
232 | password = self.connection_params['password']
233 |
234 | try:
235 | connection = flight_sql.connect(
236 | uri=f"grpc://{fe_host}:{fe_port}",
237 | db_kwargs={
238 | adbc_driver_manager.DatabaseOptions.USERNAME.value: user,
239 | adbc_driver_manager.DatabaseOptions.PASSWORD.value: password,
240 | }
241 | )
242 |
243 | # Switch to default database if set
244 | if self.default_database:
245 | try:
246 | cursor = connection.cursor()
247 | cursor.execute(f"USE {self.default_database}")
248 | cursor.close()
249 | except adbcError as db_err:
250 | print(f"Warning: Could not switch to default database '{self.default_database}': {db_err}")
251 |
252 | return connection
253 | except adbcError:
254 | print(f"Error creating ADBC connection: {adbcError}")
255 | raise
256 |
257 | def _get_adbc_connection(self):
258 | """Get or create an ADBC connection with health check."""
259 | if self._adbc_connection is None:
260 | self._adbc_connection = self._create_adbc_connection()
261 |
262 | # Health check for ADBC connection
263 | if self._adbc_connection is not None:
264 | try:
265 | self._adbc_connection.adbc_get_info()
266 | except adbcError as check_err:
267 | print(f"Connection check failed: {check_err}, creating new ADBC connection.")
268 | self._reset_adbc_connection()
269 | self._adbc_connection = self._create_adbc_connection()
270 |
271 | return self._adbc_connection
272 |
273 | def _get_connection(self):
274 | """Get appropriate connection based on configuration."""
275 | if self.enable_arrow_flight_sql:
276 | return self._get_adbc_connection()
277 | else:
278 | return self._get_pooled_connection()
279 |
280 | def _reset_adbc_connection(self):
281 | """Reset ADBC connection."""
282 | if self._adbc_connection is not None:
283 | try:
284 | self._adbc_connection.close()
285 | except Exception as e:
286 | print(f"Error closing ADBC connection: {e}")
287 | finally:
288 | self._adbc_connection = None
289 |
290 | def _reset_connection(self):
291 | """Reset connections based on configuration."""
292 | if self.enable_arrow_flight_sql:
293 | self._reset_adbc_connection()
294 | else:
295 | self._connection_pool = None
296 |
297 | def _handle_db_error(self, error):
298 | """Handle database errors and reset connections as needed."""
299 | if not self.enable_arrow_flight_sql and ("MySQL Connection not available" in str(error) or "Lost connection" in str(error)):
300 | self._connection_pool = None
301 | elif self.enable_arrow_flight_sql:
302 | self._reset_adbc_connection()
303 |
304 |
305 | def _execute(self, conn, statement: str, params=None, return_format:str="raw") -> ResultSet:
306 | cursor = None
307 | start_time = time.time()
308 | try:
309 | cursor = conn.cursor()
310 | cursor.execute(statement, params)
311 | # Initialize variables to track the last result set
312 | last_result = None
313 | last_affected_rows = None
314 | # Process first result set
315 | if cursor.description:
316 | column_names = [desc[0] for desc in cursor.description]
317 | if self.enable_arrow_flight_sql:
318 | arrow_result = cursor.fetchallarrow()
319 | pandas_df = arrow_result.to_pandas() if return_format == "pandas" else None
320 | rows = arrow_result.to_pandas().values.tolist()
321 |
322 | # Check if this is a status result for DML operations (INSERT/UPDATE/DELETE)
323 | # Arrow Flight SQL returns status results as a single column 'StatusResult'
324 | # Note: StarRocks Arrow Flight SQL seems to always return '0' in StatusResult,
325 | # so we use cursor.rowcount when available as a fallback
326 | if (len(column_names) == 1 and column_names[0] == 'StatusResult' and
327 | len(rows) == 1 and len(rows[0]) == 1):
328 | try:
329 | status_value = int(rows[0][0])
330 | # If status_value is 0 but we have cursor.rowcount, prefer that
331 | if status_value == 0 and hasattr(cursor, 'rowcount') and cursor.rowcount > 0:
332 | last_affected_rows = cursor.rowcount
333 | else:
334 | last_affected_rows = status_value
335 | last_result = None # Don't treat this as a regular result set
336 | except (ValueError, TypeError):
337 | # If we can't parse the status result as an integer, treat it as a regular result
338 | last_result = ResultSet(
339 | success=True,
340 | column_names=column_names,
341 | rows=rows,
342 | execution_time=0, # Will be set at the end
343 | pandas=pandas_df
344 | )
345 | else:
346 | last_result = ResultSet(
347 | success=True,
348 | column_names=column_names,
349 | rows=rows,
350 | execution_time=0, # Will be set at the end
351 | pandas=pandas_df
352 | )
353 | else:
354 | rows = cursor.fetchall()
355 | pandas_df = pd.DataFrame(rows, columns=column_names) if return_format == "pandas" else None
356 |
357 | last_result = ResultSet(
358 | success=True,
359 | column_names=column_names,
360 | rows=rows,
361 | execution_time=0, # Will be set at the end
362 | pandas=pandas_df
363 | )
364 | else:
365 | last_affected_rows = cursor.rowcount if cursor.rowcount >= 0 else None
366 | # Process additional result sets (for multi-statement queries)
367 | # Note: Arrow Flight SQL may not support nextset(), so we check for it
368 | if not self.enable_arrow_flight_sql and hasattr(cursor, 'nextset'):
369 | while cursor.nextset():
370 | if cursor.description:
371 | column_names = [desc[0] for desc in cursor.description]
372 | rows = cursor.fetchall()
373 | pandas_df = pd.DataFrame(rows, columns=column_names) if return_format == "pandas" else None
374 |
375 | last_result = ResultSet(
376 | success=True,
377 | column_names=column_names,
378 | rows=rows,
379 | execution_time=0, # Will be set at the end
380 | pandas=pandas_df
381 | )
382 | else:
383 | last_affected_rows = cursor.rowcount if cursor.rowcount >= 0 else None
384 | last_result = None
385 | # Return the last result set found
386 | if last_result is not None:
387 | last_result.execution_time = time.time() - start_time
388 | return last_result
389 | else:
390 | return ResultSet(
391 | success=True,
392 | rows_affected=last_affected_rows,
393 | execution_time=time.time() - start_time
394 | )
395 | except (MySQLError, adbcError) as e:
396 | self._handle_db_error(e)
397 | return ResultSet(
398 | success=False,
399 | error_message=f"Error executing statement '{statement}': {str(e)}",
400 | execution_time=time.time() - start_time
401 | )
402 | except Exception as e:
403 | return ResultSet(
404 | success=False,
405 | error_message=f"Unexpected error executing statement '{statement}': {str(e)}",
406 | execution_time=time.time() - start_time
407 | )
408 | finally:
409 | if cursor:
410 | try:
411 | cursor.close()
412 | except:
413 | pass
414 |
415 |
416 | def execute(
417 | self,
418 | statement: str,
419 | db: Optional[str] = None,
420 | return_format: Literal["raw", "pandas"] = "raw"
421 | ) -> ResultSet:
422 | """
423 | Execute a SQL statement and return results.
424 |
425 | Args:
426 | statement: SQL statement to execute
427 | db: Optional database to use (overrides default)
428 | return_format: "raw" returns ResultSet with rows, "pandas" also populates pandas field
429 |
430 | Returns:
431 | ResultSet with column_names and rows, optionally with pandas DataFrame
432 | """
433 | # If dummy test mode is enabled, return dummy data without connecting to database
434 | if self.enable_dummy_test:
435 | column_names = ['name']
436 | rows = [['aaa'], ['bbb'], ['ccc']]
437 | pandas_df = None
438 |
439 | if return_format == "pandas":
440 | pandas_df = pd.DataFrame(rows, columns=column_names)
441 |
442 | return ResultSet(
443 | success=True,
444 | column_names=column_names,
445 | rows=rows,
446 | execution_time=0.1,
447 | pandas=pandas_df
448 | )
449 | conn = None
450 | try:
451 | conn = self._get_connection()
452 | # Switch database if specified
453 | if db and db != self.default_database:
454 | cursor_temp = conn.cursor()
455 | try:
456 | cursor_temp.execute(f"USE `{db}`")
457 | except (MySQLError, adbcError) as db_err:
458 | cursor_temp.close()
459 | return ResultSet(
460 | success=False,
461 | error_message=f"Error switching to database '{db}': {str(db_err)}",
462 | execution_time=0
463 | )
464 | cursor_temp.close()
465 | return self._execute(conn, statement, None, return_format)
466 | except (MySQLError, adbcError) as e:
467 | self._handle_db_error(e)
468 | return ResultSet(
469 | success=False,
470 | error_message=f"Error executing statement '{statement}': {str(e)}",
471 | )
472 | except Exception as e:
473 | return ResultSet(
474 | success=False,
475 | error_message=f"Unexpected error executing statement '{statement}': {str(e)}",
476 | )
477 | finally:
478 | if conn and not self.enable_arrow_flight_sql:
479 | try:
480 | conn.close()
481 | except:
482 | pass
483 |
484 | def collect_perf_analysis_input(self, query: str, db:Optional[str]=None) -> PerfAnalysisInput:
485 | conn = None
486 | try:
487 | conn = self._get_connection()
488 | # Switch database if specified
489 | if db and db != self.default_database:
490 | cursor_temp = conn.cursor()
491 | try:
492 | cursor_temp.execute(f"USE `{db}`")
493 | except (MySQLError, adbcError) as db_err:
494 | return {"error_message":str(db_err)}
495 | finally:
496 | cursor_temp.close()
497 | query_dump_result = self._execute(conn, "select get_query_dump(%s, %s)", (query, False))
498 | if not query_dump_result.success:
499 | return {"error_message":query_dump_result.error_message}
500 | ret = {
501 | "query_dump": json.loads(query_dump_result.rows[0][0]),
502 | }
503 | start_ts = time.time()
504 | profile_query = "/*+ SET_VAR (enable_profile='true') */ " + query
505 | query_result = self._execute(conn, profile_query)
506 | duration = time.time() - start_ts
507 | ret["duration"] = duration
508 | if not query_result.success:
509 | ret["error_message"] = query_result.error_message
510 | return ret
511 | ret["rows_returned"] = len(query_result.rows) if query_result.rows else 0
512 | # Try to get query id
513 | query_id_result = self._execute(conn, "select last_query_id()")
514 | if not query_id_result.success:
515 | ret["error_message"] = query_id_result.error_message
516 | return ret
517 | ret["query_id"] = query_id_result.rows[0][0]
518 | # Try to get query profile with retries
519 | query_profile = ''
520 | retry_count = 0
521 | while not query_profile and retry_count < 3:
522 | time.sleep(1+retry_count)
523 | query_profile_result = self._execute(conn,"select get_query_profile(%s)", (ret["query_id"],))
524 | if query_profile_result.success:
525 | query_profile = query_profile_result.rows[0][0]
526 | retry_count += 1
527 | if not query_profile:
528 | ret['error_message'] = "Failed to get query profile after 3 retries"
529 | return ret
530 | ret['profile'] = query_profile
531 | analyze_profile_result = self._execute(conn,"ANALYZE PROFILE FROM %s", (ret["query_id"],))
532 | if not analyze_profile_result.success:
533 | ret["error_message"] = analyze_profile_result.error_message
534 | return ret
535 | analyze_text = '\n'.join(row[0] for row in analyze_profile_result.rows)
536 | ret['analyze_profile'] = remove_ansi_codes(analyze_text)
537 | return ret
538 | except (MySQLError, adbcError) as e:
539 | self._handle_db_error(e)
540 | return {"error_message":str(e)}
541 | except Exception as e:
542 | return {"error_message":str(e)}
543 | finally:
544 | if conn and not self.enable_arrow_flight_sql:
545 | try:
546 | conn.close()
547 | except:
548 | pass
549 |
550 | def reset_connections(self):
551 | """Public method to reset all connections."""
552 | self._reset_connection()
553 |
554 |
555 | # Global singleton instance
556 | _db_client_instance: Optional[DBClient] = None
557 |
558 |
559 | def get_db_client() -> DBClient:
560 | """Get or create the global DBClient instance."""
561 | global _db_client_instance
562 | if _db_client_instance is None:
563 | _db_client_instance = DBClient()
564 | return _db_client_instance
565 |
566 |
567 | def reset_db_connections():
568 | """Reset all database connections (useful for error recovery)."""
569 | global _db_client_instance
570 | if _db_client_instance is not None:
571 | _db_client_instance.reset_connections()
```
--------------------------------------------------------------------------------
/src/mcp_server_starrocks/server.py:
--------------------------------------------------------------------------------
```python
1 | # Copyright 2021-present StarRocks, Inc. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https:#www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import argparse
15 | import ast
16 | import asyncio
17 | import base64
18 | import json
19 | import math
20 | import sys
21 | import os
22 | import traceback
23 | import threading
24 | import time
25 | from fastmcp import FastMCP
26 | from fastmcp.utilities.types import Image
27 | from fastmcp.tools.tool import ToolResult
28 | from mcp.types import TextContent, ImageContent
29 | from fastmcp.exceptions import ToolError
30 | from typing import Annotated
31 | from pydantic import Field
32 | import plotly.express as px
33 | import plotly.graph_objs
34 | from loguru import logger
35 | from starlette.middleware.cors import CORSMiddleware
36 | from starlette.middleware import Middleware
37 | from .db_client import get_db_client, reset_db_connections, ResultSet, PerfAnalysisInput
38 | from .db_summary_manager import get_db_summary_manager
39 | from .connection_health_checker import (
40 | initialize_health_checker,
41 | start_connection_health_checker,
42 | stop_connection_health_checker,
43 | check_connection_health
44 | )
45 |
46 | # Configure logging
47 | logger.remove() # Remove default handler
48 | logger.add(sys.stderr, level=os.getenv("LOG_LEVEL", "INFO"),
49 | format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}")
50 |
51 | mcp = FastMCP('mcp-server-starrocks')
52 |
53 | # a hint for soft limit, not enforced
54 | overview_length_limit = int(os.getenv('STARROCKS_OVERVIEW_LIMIT', str(20000)))
55 | # Global cache for table overviews: {(db_name, table_name): overview_string}
56 | global_table_overview_cache = {}
57 |
58 | # Get database client instance
59 | db_client = get_db_client()
60 | # Get database summary manager instance
61 | db_summary_manager = get_db_summary_manager(db_client)
62 | # Description suffix for tools, if default db is set
63 | description_suffix = f". db session already in default db `{db_client.default_database}`" if db_client.default_database else ""
64 |
65 | # Initialize connection health checker
66 | _health_checker = initialize_health_checker(db_client)
67 |
68 |
69 | SR_PROC_DESC = '''
70 | Internal information exposed by StarRocks similar to linux /proc, following are some common paths:
71 |
72 | '/frontends' Shows the information of FE nodes.
73 | '/backends' Shows the information of BE nodes if this SR is non cloud native deployment.
74 | '/compute_nodes' Shows the information of CN nodes if this SR is cloud native deployment.
75 | '/dbs' Shows the information of databases.
76 | '/dbs/<DB_ID>' Shows the information of a database by database ID.
77 | '/dbs/<DB_ID>/<TABLE_ID>' Shows the information of tables by database ID.
78 | '/dbs/<DB_ID>/<TABLE_ID>/partitions' Shows the information of partitions by database ID and table ID.
79 | '/transactions' Shows the information of transactions by database.
80 | '/transactions/<DB_ID>' Show the information of transactions by database ID.
81 | '/transactions/<DB_ID>/running' Show the information of running transactions by database ID.
82 | '/transactions/<DB_ID>/finished' Show the information of finished transactions by database ID.
83 | '/jobs' Shows the information of jobs.
84 | '/statistic' Shows the statistics of each database.
85 | '/tasks' Shows the total number of all generic tasks and the failed tasks.
86 | '/cluster_balance' Shows the load balance information.
87 | '/routine_loads' Shows the information of Routine Load.
88 | '/colocation_group' Shows the information of Colocate Join groups.
89 | '/catalog' Shows the information of catalogs.
90 | '''
91 |
92 |
93 | @mcp.resource(uri="starrocks:///databases", name="All Databases", description="List all databases in StarRocks",
94 | mime_type="text/plain")
95 | def get_all_databases() -> str:
96 | logger.debug("Fetching all databases")
97 | result = db_client.execute("SHOW DATABASES")
98 | logger.debug(f"Found {len(result.rows) if result.success and result.rows else 0} databases")
99 | return result.to_string()
100 |
101 |
102 | @mcp.resource(uri="starrocks:///{db}/{table}/schema", name="Table Schema",
103 | description="Get the schema of a table using SHOW CREATE TABLE", mime_type="text/plain")
104 | def get_table_schema(db: str, table: str) -> str:
105 | logger.debug(f"Fetching schema for table {db}.{table}")
106 | return db_client.execute(f"SHOW CREATE TABLE {db}.{table}").to_string()
107 |
108 |
109 | @mcp.resource(uri="starrocks:///{db}/tables", name="Database Tables",
110 | description="List all tables in a specific database", mime_type="text/plain")
111 | def get_database_tables(db: str) -> str:
112 | logger.debug(f"Fetching tables from database {db}")
113 | result = db_client.execute(f"SHOW TABLES FROM {db}")
114 | logger.debug(f"Found {len(result.rows) if result.success and result.rows else 0} tables in {db}")
115 | return result.to_string()
116 |
117 |
118 | @mcp.resource(uri="proc:///{path*}", name="System internal information", description=SR_PROC_DESC,
119 | mime_type="text/plain")
120 | def get_system_internal_information(path: str) -> str:
121 | logger.debug(f"Fetching system information for proc path: {path}")
122 | return db_client.execute(f"show proc '{path}'").to_string(limit=overview_length_limit)
123 |
124 |
125 | def _get_table_details(db_name, table_name, limit=None):
126 | """
127 | Helper function to get description, sample rows, and count for a table.
128 | Returns a formatted string. Handles DB errors internally and returns error messages.
129 | """
130 | global global_table_overview_cache
131 | logger.debug(f"Fetching table details for {db_name}.{table_name}")
132 | output_lines = []
133 |
134 | full_table_name = f"`{table_name}`"
135 | if db_name:
136 | full_table_name = f"`{db_name}`.`{table_name}`"
137 | else:
138 | output_lines.append(
139 | f"Warning: Database name missing for table '{table_name}'. Using potentially incorrect context.")
140 | logger.warning(f"Database name missing for table '{table_name}'")
141 |
142 | count = 0
143 | output_lines.append(f"--- Overview for {full_table_name} ---")
144 |
145 | # 1. Get Row Count
146 | query = f"SELECT COUNT(*) FROM {full_table_name}"
147 | count_result = db_client.execute(query, db=db_name)
148 | if count_result.success and count_result.rows:
149 | count = count_result.rows[0][0]
150 | output_lines.append(f"\nTotal rows: {count}")
151 | logger.debug(f"Table {full_table_name} has {count} rows")
152 | else:
153 | output_lines.append(f"\nCould not determine total row count.")
154 | if not count_result.success:
155 | output_lines.append(f"Error: {count_result.error_message}")
156 | logger.error(f"Failed to get row count for {full_table_name}: {count_result.error_message}")
157 |
158 | # 2. Get Columns (DESCRIBE)
159 | if count > 0:
160 | query = f"DESCRIBE {full_table_name}"
161 | desc_result = db_client.execute(query, db=db_name)
162 | if desc_result.success and desc_result.column_names and desc_result.rows:
163 | output_lines.append(f"\nColumns:")
164 | output_lines.append(desc_result.to_string(limit=limit))
165 | else:
166 | output_lines.append("(Could not retrieve column information or table has no columns).")
167 | if not desc_result.success:
168 | output_lines.append(f"Error getting columns for {full_table_name}: {desc_result.error_message}")
169 | return "\n".join(output_lines)
170 |
171 | # 3. Get Sample Rows (LIMIT 3)
172 | query = f"SELECT * FROM {full_table_name} LIMIT 3"
173 | sample_result = db_client.execute(query, db=db_name)
174 | if sample_result.success and sample_result.column_names and sample_result.rows:
175 | output_lines.append(f"\nSample rows (limit 3):")
176 | output_lines.append(sample_result.to_string(limit=limit))
177 | else:
178 | output_lines.append(f"(No rows found in {full_table_name}).")
179 | if not sample_result.success:
180 | output_lines.append(f"Error getting sample rows for {full_table_name}: {sample_result.error_message}")
181 |
182 | overview_string = "\n".join(output_lines)
183 | # Update cache even if there were partial errors, so we cache the error message too
184 | cache_key = (db_name, table_name)
185 | global_table_overview_cache[cache_key] = overview_string
186 | return overview_string
187 |
188 |
189 | # tools
190 |
191 | @mcp.tool(description="Execute a SELECT query or commands that return a ResultSet" + description_suffix)
192 | def read_query(query: Annotated[str, Field(description="SQL query to execute")],
193 | db: Annotated[str|None, Field(description="database")] = None) -> ToolResult:
194 | # return csv like result set, with column names as first row
195 | logger.info(f"Executing read query: {query[:100]}{'...' if len(query) > 100 else ''}")
196 | result = db_client.execute(query, db=db)
197 | if result.success:
198 | logger.info(f"Query executed successfully, returned {len(result.rows) if result.rows else 0} rows")
199 | else:
200 | logger.error(f"Query failed: {result.error_message}")
201 | return ToolResult(content=[TextContent(type='text', text=result.to_string(limit=10000))],
202 | structured_content=result.to_dict())
203 |
204 |
205 | @mcp.tool(description="Execute a DDL/DML or other StarRocks command that do not have a ResultSet" + description_suffix)
206 | def write_query(query: Annotated[str, Field(description="SQL to execute")],
207 | db: Annotated[str|None, Field(description="database")] = None) -> ToolResult:
208 | logger.info(f"Executing write query: {query[:100]}{'...' if len(query) > 100 else ''}")
209 | result = db_client.execute(query, db=db)
210 | if not result.success:
211 | logger.error(f"Write query failed: {result.error_message}")
212 | elif result.rows_affected is not None and result.rows_affected >= 0:
213 | logger.info(f"Write query executed successfully, {result.rows_affected} rows affected in {result.execution_time:.2f}s")
214 | else:
215 | logger.info(f"Write query executed successfully in {result.execution_time:.2f}s")
216 | return ToolResult(content=[TextContent(type='text', text=result.to_string(limit=2000))],
217 | structured_content=result.to_dict())
218 |
219 | @mcp.tool(description="Analyze a query and get analyze result using query profile" + description_suffix)
220 | def analyze_query(
221 | uuid: Annotated[
222 | str|None, Field(description="Query ID, a string composed of 32 hexadecimal digits formatted as 8-4-4-4-12")]=None,
223 | sql: Annotated[str|None, Field(description="Query SQL")]=None,
224 | db: Annotated[str|None, Field(description="database")] = None
225 | ) -> str:
226 | if uuid:
227 | logger.info(f"Analyzing query profile for UUID: {uuid}")
228 | return db_client.execute(f"ANALYZE PROFILE FROM '{uuid}'", db=db).to_string()
229 | elif sql:
230 | logger.info(f"Analyzing query: {sql[:100]}{'...' if len(sql) > 100 else ''}")
231 | return db_client.execute(f"EXPLAIN ANALYZE {sql}", db=db).to_string()
232 | else:
233 | logger.warning("Analyze query called without valid UUID or SQL")
234 | return f"Failed to analyze query, the reasons maybe: 1.query id is not standard uuid format; 2.the SQL statement have spelling error."
235 |
236 |
237 | @mcp.tool(description="Run a query to get it's query dump and profile, output very large, need special tools to do further processing")
238 | def collect_query_dump_and_profile(
239 | query: Annotated[str, Field(description="query to execute")],
240 | db: Annotated[str|None, Field(description="database")] = None
241 | ) -> ToolResult:
242 | logger.info(f"Collecting query dump and profile for query: {query[:100]}{'...' if len(query) > 100 else ''}")
243 | result : PerfAnalysisInput = db_client.collect_perf_analysis_input(query, db=db)
244 | if result.get('error_message'):
245 | status = f"collecting query dump and profile failed, query_id={result.get('query_id')} error_message={result.get('error_message')}"
246 | logger.warning(status)
247 | else:
248 | status = f"collecting query dump and profile succeeded, but it's only for user/tool, not for AI, query_id={result.get('query_id')}"
249 | logger.info(status)
250 | return ToolResult(
251 | content=[TextContent(type='text', text=status)],
252 | structured_content=result,
253 | )
254 |
255 |
256 | def validate_plotly_expr(expr: str):
257 | """
258 | Validates a string to ensure it represents a single call to a method
259 | of the 'px' object, without containing other statements or imports,
260 | and ensures its arguments do not contain nested function calls.
261 |
262 | Args:
263 | expr: The string expression to validate.
264 |
265 | Raises:
266 | ValueError: If the expression does not meet the security criteria.
267 | SyntaxError: If the expression is not valid Python syntax.
268 | """
269 | # 1. Check for valid Python syntax
270 | try:
271 | tree = ast.parse(expr)
272 | except SyntaxError as e:
273 | raise SyntaxError(f"Invalid Python syntax in expression: {e}") from e
274 |
275 | # 2. Check that the tree contains exactly one top-level node (statement/expression)
276 | if len(tree.body) != 1:
277 | raise ValueError("Expression must be a single statement or expression.")
278 |
279 | node = tree.body[0]
280 |
281 | # 3. Check that the single node is an expression
282 | if not isinstance(node, ast.Expr):
283 | raise ValueError(
284 | "Expression must be a single expression, not a statement (like assignment, function definition, import, etc.).")
285 |
286 | # 4. Get the actual value of the expression and check it's a function call
287 | expr_value = node.value
288 | if not isinstance(expr_value, ast.Call):
289 | raise ValueError("Expression must be a function call.")
290 |
291 | # 5. Check that the function being called is an attribute lookup (like px.scatter)
292 | if not isinstance(expr_value.func, ast.Attribute):
293 | raise ValueError("Function call must be on an object attribute (e.g., px.scatter).")
294 |
295 | # 6. Check that the attribute is being accessed on a simple variable name
296 | if not isinstance(expr_value.func.value, ast.Name):
297 | raise ValueError("Function call must be on a simple variable name (e.g., px.scatter, not obj.px.scatter).")
298 |
299 | # 7. Check that the simple variable name is 'px'
300 | if expr_value.func.value.id != 'px':
301 | raise ValueError("Function call must be on the 'px' object.")
302 |
303 | # Check positional arguments
304 | for i, arg_node in enumerate(expr_value.args):
305 | for sub_node in ast.walk(arg_node):
306 | if isinstance(sub_node, ast.Call):
307 | raise ValueError(f"Positional argument at index {i} contains a disallowed nested function call.")
308 | # Check keyword arguments
309 | for kw in expr_value.keywords:
310 | for sub_node in ast.walk(kw.value):
311 | if isinstance(sub_node, ast.Call):
312 | keyword_name = kw.arg if kw.arg else '<unknown>'
313 | raise ValueError(f"Keyword argument '{keyword_name}' contains a disallowed nested function call.")
314 |
315 |
316 | def one_line_summary(text: str, limit:int=100) -> str:
317 | """Generate a one-line summary of the given text, truncated to the specified limit."""
318 | single_line = ' '.join(text.split())
319 | if len(single_line) > limit:
320 | return single_line[:limit-3] + '...'
321 | return single_line
322 |
323 |
324 | @mcp.tool(description="using sql `query` to extract data from database, then using python `plotly_expr` to generate a chart for UI to display" + description_suffix)
325 | def query_and_plotly_chart(
326 | query: Annotated[str, Field(description="SQL query to execute")],
327 | plotly_expr: Annotated[
328 | str, Field(description="a one function call expression, with 2 vars binded: `px` as `import plotly.express as px`, and `df` as dataframe generated by query `plotly_expr` example: `px.scatter(df, x=\"sepal_width\", y=\"sepal_length\", color=\"species\", marginal_y=\"violin\", marginal_x=\"box\", trendline=\"ols\", template=\"simple_white\")`")],
329 | format: Annotated[str, Field(description="chart output format, json|png|jpeg")] = "jpeg",
330 | db: Annotated[str|None, Field(description="database")] = None
331 | ) -> ToolResult:
332 | """
333 | Executes an SQL query, creates a Pandas DataFrame, generates a Plotly chart
334 | using the provided expression, encodes the chart as a base64 PNG image,
335 | and returns it along with optional text.
336 |
337 | Args:
338 | query: The SQL query string to execute.
339 | plotly_expr: A Python string expression using 'px' (plotly.express)
340 | and 'df' (the DataFrame from the query) to generate a figure.
341 | Example: "px.scatter(df, x='col1', y='col2')"
342 | format: chat output format, json|png|jpeg, default is jpeg
343 | db: Optional database name to execute the query in.
344 |
345 | Returns:
346 | A list containing types.TextContent and types.ImageContent,
347 | or just types.TextContent in case of an error or no data.
348 | """
349 | try:
350 | logger.info(f'query_and_plotly_chart query:{one_line_summary(query)}, plotly:{one_line_summary(plotly_expr)} format:{format}, db:{db}')
351 | result = db_client.execute(query, db=db, return_format="pandas")
352 | errmsg = None
353 | if not result.success:
354 | errmsg = result.error_message
355 | elif result.pandas is None:
356 | errmsg = 'Query did not return data suitable for plotting.'
357 | else:
358 | df = result.pandas
359 | if df.empty:
360 | errmsg = 'Query returned no data to plot.'
361 | if errmsg:
362 | logger.warning(f"Query or data issue: {errmsg}")
363 | return ToolResult(
364 | content=[TextContent(type='text', text=f'Error: {errmsg}')],
365 | structured_content={'success': False, 'error_message': errmsg},
366 | )
367 | # Validate and evaluate the plotly expression using px and df
368 | local_vars = {'df': df}
369 | validate_plotly_expr(plotly_expr)
370 | fig : plotly.graph_objs.Figure = eval(plotly_expr, {"px": px}, local_vars)
371 | if format == 'json':
372 | # return json representation of the figure for front-end rendering
373 | plot_json = json.loads(fig.to_json())
374 | structured_content = result.to_dict()
375 | structured_content['data'] = plot_json['data']
376 | structured_content['layout'] = plot_json['layout']
377 | summary = result.to_string()
378 | return ToolResult(
379 | content=[
380 | TextContent(type='text', text=f'{summary}\nChart Generated for UI rendering'),
381 | ],
382 | structured_content=structured_content,
383 | )
384 | else:
385 | if not hasattr(fig, 'to_image'):
386 | raise ToolError(f"The evaluated expression did not return a Plotly figure object. Result type: {type(fig)}")
387 | if format == 'jpg':
388 | format = 'jpeg'
389 | img_bytes = fig.to_image(format=format, width=960, height=720)
390 | structured_content = result.to_dict()
391 | structured_content['img_bytes_base64'] = base64.b64encode(img_bytes)
392 | return ToolResult(
393 | content=[
394 | TextContent(type='text', text=f'dataframe data:\n{df}\nChart generated but for UI only'),
395 | Image(data=img_bytes, format="jpeg").to_image_content()
396 | ],
397 | structured_content=structured_content
398 | )
399 | except Exception as err:
400 | return ToolResult(
401 | content=[TextContent(type='text', text=f'Error: {err}')],
402 | structured_content={'success': False, 'error_message': str(err)},
403 | )
404 |
405 |
406 | @mcp.tool(description="Get an overview of a specific table: columns, sample rows (up to 5), and total row count. Uses cache unless refresh=true" + description_suffix)
407 | def table_overview(
408 | table: Annotated[str, Field(
409 | description="Table name, optionally prefixed with database name (e.g., 'db_name.table_name'). If database is omitted, uses the default database.")],
410 | refresh: Annotated[
411 | bool, Field(description="Set to true to force refresh, ignoring cache. Defaults to false.")] = False
412 | ) -> str:
413 | try:
414 | logger.info(f"Getting table overview for: {table}, refresh={refresh}")
415 | if not table:
416 | logger.error("Table overview called without table name")
417 | return "Error: Missing 'table' argument."
418 |
419 | # Parse table argument: [db.]<table>
420 | parts = table.split('.', 1)
421 | db_name = None
422 | table_name = None
423 | if len(parts) == 2:
424 | db_name, table_name = parts[0], parts[1]
425 | elif len(parts) == 1:
426 | table_name = parts[0]
427 | db_name = db_client.default_database # Use default if only table name is given
428 |
429 | if not table_name: # Should not happen if table_arg exists, but check
430 | logger.error(f"Invalid table name format: {table}")
431 | return f"Error: Invalid table name format '{table}'."
432 | if not db_name:
433 | logger.error(f"No database specified for table {table_name}")
434 | return f"Error: Database name not specified for table '{table_name}' and no default database is set."
435 |
436 | cache_key = (db_name, table_name)
437 |
438 | # Check cache
439 | if not refresh and cache_key in global_table_overview_cache:
440 | logger.debug(f"Using cached overview for {db_name}.{table_name}")
441 | return global_table_overview_cache[cache_key]
442 |
443 | logger.debug(f"Fetching fresh overview for {db_name}.{table_name}")
444 | # Fetch details (will also update cache)
445 | overview_text = _get_table_details(db_name, table_name, limit=overview_length_limit)
446 | return overview_text
447 | except Exception as e:
448 | # Reset connections on unexpected errors
449 | logger.exception(f"Unexpected error in table_overview for {table}")
450 | reset_db_connections()
451 | stack_trace = traceback.format_exc()
452 | return f"Unexpected Error executing tool 'table_overview': {type(e).__name__}: {e}\nStack Trace:\n{stack_trace}"
453 |
454 | # comment out to prefer db_summary tool
455 | #@mcp.tool(description="Get an overview (columns, sample rows, row count) for ALL tables in a database. Uses cache unless refresh=True" + description_suffix)
456 | def db_overview(
457 | db: Annotated[str, Field(
458 | description="Database name. Optional: uses the default database if not provided.")] = None,
459 | refresh: Annotated[
460 | bool, Field(description="Set to true to force refresh, ignoring cache. Defaults to false.")] = False
461 | ) -> str:
462 | try:
463 | db_name = db if db else db_client.default_database
464 | logger.info(f"Getting database overview for: {db_name}, refresh={refresh}")
465 | if not db_name:
466 | logger.error("Database overview called without database name")
467 | return "Error: Database name not provided and no default database is set."
468 |
469 | # List tables in the database
470 | query = f"SHOW TABLES FROM `{db_name}`"
471 | result = db_client.execute(query, db=db_name)
472 |
473 | if not result.success:
474 | logger.error(f"Failed to list tables in database {db_name}: {result.error_message}")
475 | return f"Database Error listing tables in '{db_name}': {result.error_message}"
476 |
477 | if not result.rows:
478 | logger.info(f"No tables found in database {db_name}")
479 | return f"No tables found in database '{db_name}'."
480 |
481 | tables = [row[0] for row in result.rows]
482 | logger.info(f"Found {len(tables)} tables in database {db_name}")
483 | all_overviews = [f"--- Overview for Database: `{db_name}` ({len(tables)} tables) ---"]
484 |
485 | total_length = 0
486 | limit_per_table = overview_length_limit * (math.log10(len(tables)) + 1) // len(tables) # Limit per table
487 | for table_name in tables:
488 | cache_key = (db_name, table_name)
489 | overview_text = None
490 |
491 | # Check cache first
492 | if not refresh and cache_key in global_table_overview_cache:
493 | logger.debug(f"Using cached overview for {db_name}.{table_name}")
494 | overview_text = global_table_overview_cache[cache_key]
495 | else:
496 | logger.debug(f"Fetching fresh overview for {db_name}.{table_name}")
497 | # Fetch details for this table (will update cache via _get_table_details)
498 | overview_text = _get_table_details(db_name, table_name, limit=limit_per_table)
499 |
500 | all_overviews.append(overview_text)
501 | all_overviews.append("\n") # Add separator
502 | total_length += len(overview_text) + 1
503 |
504 | logger.info(f"Database overview completed for {db_name}, total length: {total_length}")
505 | return "\n".join(all_overviews)
506 |
507 | except Exception as e:
508 | # Catch any other unexpected errors during tool execution
509 | logger.exception(f"Unexpected error in db_overview for database {db}")
510 | reset_db_connections()
511 | stack_trace = traceback.format_exc()
512 | return f"Unexpected Error executing tool 'db_overview': {type(e).__name__}: {e}\nStack Trace:\n{stack_trace}"
513 |
514 |
515 | @mcp.tool(description="Quickly get summary of a database with tables' schema and size information" + description_suffix)
516 | def db_summary(
517 | db: Annotated[str|None, Field(
518 | description="Database name. Optional: uses current database by default.")] = None,
519 | limit: Annotated[int, Field(
520 | description="Output length limit in characters. Defaults to 10000. Higher values show more tables and details.")] = 10000,
521 | refresh: Annotated[bool, Field(
522 | description="Set to true to force refresh, ignoring cache. Defaults to false.")] = False
523 | ) -> str:
524 | try:
525 | db_name = db if db else db_client.default_database
526 | logger.info(f"Getting database summary for: {db_name}, limit={limit}, refresh={refresh}")
527 |
528 | if not db_name:
529 | logger.error("Database summary called without database name")
530 | return "Error: Database name not provided and no default database is set."
531 |
532 | # Use the database summary manager
533 | summary = db_summary_manager.get_database_summary(db_name, limit=limit, refresh=refresh)
534 | logger.info(f"Database summary completed for {db_name}")
535 | return summary
536 |
537 | except Exception as e:
538 | # Reset connections on unexpected errors
539 | logger.exception(f"Unexpected error in db_summary for database {db}")
540 | reset_db_connections()
541 | stack_trace = traceback.format_exc()
542 | return f"Unexpected Error executing tool 'db_summary': {type(e).__name__}: {e}\nStack Trace:\n{stack_trace}"
543 |
544 |
545 | async def main():
546 | parser = argparse.ArgumentParser(description='StarRocks MCP Server')
547 | parser.add_argument('--mode', choices=['stdio', 'sse', 'http', 'streamable-http'],
548 | default=os.getenv('MCP_TRANSPORT_MODE', 'stdio'),
549 | help='Transport mode (default: stdio)')
550 | parser.add_argument('--host', default='localhost',
551 | help='Server host (default: localhost)')
552 | parser.add_argument('--port', type=int, default=3000,
553 | help='Server port (default: 3000)')
554 | parser.add_argument('--test', action='store_true',
555 | help='Run in test mode')
556 |
557 | args = parser.parse_args()
558 |
559 | logger.info(f"Starting StarRocks MCP Server with mode={args.mode}, host={args.host}, port={args.port} default_db={db_client.default_database or 'None'}")
560 |
561 | if args.test:
562 | try:
563 | logger.info("Starting tool test")
564 | # Use the test version without tool wrapper
565 | result = db_client.execute("show databases").to_string()
566 | logger.info("Result:")
567 | logger.info(result)
568 | logger.info("Tool test completed")
569 | finally:
570 | stop_connection_health_checker()
571 | reset_db_connections()
572 | return
573 |
574 | # Start connection health checker
575 | start_connection_health_checker()
576 | try:
577 | # Add CORS middleware for HTTP transports to allow web frontend access
578 | if args.mode in ['http', 'streamable-http', 'sse']:
579 | cors_middleware = [
580 | Middleware(
581 | CORSMiddleware,
582 | allow_origins=["*"], # Allow all origins for development. In production, specify exact origins
583 | allow_credentials=True,
584 | allow_methods=["*"], # Allow all HTTP methods
585 | allow_headers=["*"], # Allow all headers
586 | )
587 | ]
588 | logger.info(f"CORS enabled for {args.mode} transport - allowing all origins")
589 | await mcp.run_async(
590 | transport=args.mode,
591 | host=args.host,
592 | port=args.port,
593 | middleware=cors_middleware
594 | )
595 | else:
596 | await mcp.run_async(transport=args.mode)
597 | except Exception as e:
598 | logger.exception("Failed to start MCP server")
599 | raise
600 | finally:
601 | # Stop connection health checker when server shuts down
602 | stop_connection_health_checker()
603 |
604 |
605 | if __name__ == "__main__":
606 | asyncio.run(main())
607 |
```
--------------------------------------------------------------------------------
/tests/test_db_client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Tests for db_client module.
3 |
4 | These tests assume a StarRocks cluster is running on localhost with default configurations:
5 | - Host: localhost
6 | - Port: 9030 (MySQL protocol)
7 | - User: root
8 | - Password: (empty)
9 | - No default database set
10 |
11 | Run tests with: pytest tests/test_db_client.py -v
12 | """
13 |
14 | import os
15 | import pytest
16 | import pandas as pd
17 | from unittest.mock import patch, MagicMock
18 |
19 | # Set up test environment variables
20 | os.environ.pop('STARROCKS_FE_ARROW_FLIGHT_SQL_PORT', None) # Force MySQL mode for tests
21 | os.environ.pop('STARROCKS_DB', None) # No default database
22 |
23 | from src.mcp_server_starrocks.db_client import (
24 | DBClient,
25 | ResultSet,
26 | get_db_client,
27 | reset_db_connections,
28 | parse_connection_url
29 | )
30 |
31 |
32 | class TestDBClient:
33 | """Test cases for DBClient class."""
34 |
35 | @pytest.fixture
36 | def db_client(self):
37 | """Create a fresh DBClient instance for each test."""
38 | # Reset global state
39 | reset_db_connections()
40 | return DBClient()
41 |
42 | def test_client_initialization(self, db_client):
43 | """Test DBClient initialization with default settings."""
44 | assert db_client.enable_arrow_flight_sql is False
45 | assert db_client.default_database is None
46 | assert db_client._connection_pool is None
47 | assert db_client._adbc_connection is None
48 |
49 | def test_singleton_pattern(self):
50 | """Test that get_db_client returns the same instance."""
51 | client1 = get_db_client()
52 | client2 = get_db_client()
53 | assert client1 is client2
54 |
55 | def test_execute_show_databases(self, db_client):
56 | """Test executing SHOW DATABASES query."""
57 | result = db_client.execute("SHOW DATABASES")
58 |
59 | assert isinstance(result, ResultSet)
60 | assert result.success is True
61 | assert result.column_names is not None
62 | assert len(result.column_names) == 1
63 | assert result.rows is not None
64 | assert len(result.rows) > 0
65 | assert result.execution_time is not None
66 | assert result.execution_time > 0
67 |
68 | # Check that information_schema is present (standard in StarRocks)
69 | database_names = [row[0] for row in result.rows]
70 | assert 'information_schema' in database_names
71 |
72 | def test_execute_show_databases_pandas(self, db_client):
73 | """Test executing SHOW DATABASES with pandas return format."""
74 | result = db_client.execute("SHOW DATABASES", return_format="pandas")
75 |
76 | assert isinstance(result, ResultSet)
77 | assert result.success is True
78 | assert result.pandas is not None
79 | assert isinstance(result.pandas, pd.DataFrame)
80 | assert len(result.pandas.columns) == 1
81 | assert len(result.pandas) > 0
82 |
83 | # Test that to_pandas() returns the same DataFrame
84 | df = result.to_pandas()
85 | assert df is result.pandas
86 |
87 | def test_execute_invalid_query(self, db_client):
88 | """Test executing an invalid SQL query."""
89 | result = db_client.execute("SELECT * FROM nonexistent_table_12345")
90 |
91 | assert isinstance(result, ResultSet)
92 | assert result.success is False
93 | assert result.error_message is not None
94 | assert "nonexistent_table_12345" in result.error_message or "doesn't exist" in result.error_message.lower()
95 | assert result.execution_time is not None
96 |
97 | def test_execute_create_and_drop_database(self, db_client):
98 | """Test creating and dropping a test database."""
99 | test_db_name = "test_mcp_db_client"
100 |
101 | # Clean up first (in case previous test failed)
102 | db_client.execute(f"DROP DATABASE IF EXISTS {test_db_name}")
103 |
104 | # Create database
105 | create_result = db_client.execute(f"CREATE DATABASE {test_db_name}")
106 | assert create_result.success is True
107 | assert create_result.rows_affected is not None # DDL returns row count (usually 0)
108 |
109 | # Verify database exists
110 | show_result = db_client.execute("SHOW DATABASES")
111 | database_names = [row[0] for row in show_result.rows]
112 | assert test_db_name in database_names
113 |
114 | # Drop database
115 | drop_result = db_client.execute(f"DROP DATABASE {test_db_name}")
116 | assert drop_result.success is True
117 |
118 | # Verify database is gone
119 | show_result = db_client.execute("SHOW DATABASES")
120 | database_names = [row[0] for row in show_result.rows]
121 | assert test_db_name not in database_names
122 |
123 | def test_execute_with_specific_database(self, db_client):
124 | """Test executing query with specific database context."""
125 | # Use information_schema which should always be available
126 | result = db_client.execute("SHOW TABLES", db="information_schema")
127 |
128 | assert result.success is True
129 | assert result.column_names is not None
130 | assert result.rows is not None
131 | assert len(result.rows) > 0 # information_schema should have tables
132 |
133 | # Check for expected information_schema tables
134 | table_names = [row[0] for row in result.rows]
135 | expected_tables = ['tables', 'columns', 'schemata']
136 | found_expected = any(table in table_names for table in expected_tables)
137 | assert found_expected, f"Expected at least one of {expected_tables} in {table_names}"
138 |
139 | def test_execute_with_invalid_database(self, db_client):
140 | """Test executing query with non-existent database."""
141 | result = db_client.execute("SHOW TABLES", db="nonexistent_db_12345")
142 |
143 | assert result.success is False
144 | assert result.error_message is not None
145 | assert "nonexistent_db_12345" in result.error_message
146 |
147 | def test_execute_table_operations(self, db_client):
148 | """Test creating, inserting, querying, and dropping a table."""
149 | test_db = "test_mcp_table_ops"
150 | test_table = "test_table"
151 |
152 | try:
153 | # Create database
154 | create_db_result = db_client.execute(f"CREATE DATABASE IF NOT EXISTS {test_db}")
155 | assert create_db_result.success is True
156 |
157 | # Create table (with replication_num=1 for single-node setup)
158 | create_table_sql = f"""
159 | CREATE TABLE {test_db}.{test_table} (
160 | id INT,
161 | name STRING,
162 | value DOUBLE
163 | )
164 | PROPERTIES ("replication_num" = "1")
165 | """
166 | create_result = db_client.execute(create_table_sql)
167 | assert create_result.success is True
168 |
169 | # Insert data
170 | insert_sql = f"""
171 | INSERT INTO {test_db}.{test_table} VALUES
172 | (1, 'test1', 1.5),
173 | (2, 'test2', 2.5),
174 | (3, 'test3', 3.5)
175 | """
176 | insert_result = db_client.execute(insert_sql)
177 | assert insert_result.success is True
178 | assert insert_result.rows_affected == 3
179 |
180 | # Query data
181 | select_result = db_client.execute(f"SELECT * FROM {test_db}.{test_table} ORDER BY id")
182 | assert select_result.success is True
183 | assert len(select_result.column_names) == 3
184 | assert select_result.column_names == ['id', 'name', 'value']
185 | assert len(select_result.rows) == 3
186 | # MySQL connector returns tuples, convert to lists for comparison
187 | assert list(select_result.rows[0]) == [1, 'test1', 1.5]
188 | assert list(select_result.rows[1]) == [2, 'test2', 2.5]
189 | assert list(select_result.rows[2]) == [3, 'test3', 3.5]
190 |
191 | # Test COUNT query
192 | count_result = db_client.execute(f"SELECT COUNT(*) as cnt FROM {test_db}.{test_table}")
193 | assert count_result.success is True
194 | assert count_result.rows[0][0] == 3
195 |
196 | # Test with specific database context
197 | ctx_result = db_client.execute(f"SELECT * FROM {test_table}", db=test_db)
198 | assert ctx_result.success is True
199 | assert len(ctx_result.rows) == 3
200 |
201 | finally:
202 | # Clean up
203 | db_client.execute(f"DROP DATABASE IF EXISTS {test_db}")
204 |
205 | def test_execute_pandas_format_with_data(self, db_client):
206 | """Test pandas format with actual data."""
207 | test_db = "test_mcp_pandas"
208 |
209 | try:
210 | # Setup test data
211 | db_client.execute(f"CREATE DATABASE IF NOT EXISTS {test_db}")
212 | db_client.execute(f"""
213 | CREATE TABLE {test_db}.pandas_test (
214 | id INT,
215 | category STRING,
216 | amount DECIMAL(10,2)
217 | )
218 | PROPERTIES ("replication_num" = "1")
219 | """)
220 | db_client.execute(f"""
221 | INSERT INTO {test_db}.pandas_test VALUES
222 | (1, 'A', 100.50),
223 | (2, 'B', 200.75),
224 | (3, 'A', 150.25)
225 | """)
226 |
227 | # Test executing query with pandas format
228 | result = db_client.execute(f"SELECT * FROM {test_db}.pandas_test ORDER BY id", return_format="pandas")
229 |
230 | assert isinstance(result, ResultSet)
231 | assert result.success is True
232 | assert result.pandas is not None
233 | assert isinstance(result.pandas, pd.DataFrame)
234 | assert len(result.pandas) == 3
235 | assert list(result.pandas.columns) == ['id', 'category', 'amount']
236 | assert result.pandas.iloc[0]['id'] == 1
237 | assert result.pandas.iloc[0]['category'] == 'A'
238 | assert float(result.pandas.iloc[0]['amount']) == 100.50
239 |
240 | # Test that to_pandas() returns the same DataFrame
241 | df = result.to_pandas()
242 | assert df is result.pandas
243 |
244 | finally:
245 | db_client.execute(f"DROP DATABASE IF EXISTS {test_db}")
246 |
247 | def test_connection_error_handling(self, db_client):
248 | """Test error handling when connection fails."""
249 | # Mock a connection failure
250 | with patch.object(db_client, '_get_connection', side_effect=Exception("Connection failed")):
251 | result = db_client.execute("SHOW DATABASES")
252 |
253 | assert result.success is False
254 | assert "Connection failed" in result.error_message
255 | assert result.execution_time is not None
256 |
257 | def test_reset_connections(self, db_client):
258 | """Test connection reset functionality."""
259 | # First execute a query to establish connection
260 | result1 = db_client.execute("SHOW DATABASES")
261 | assert result1.success is True
262 |
263 | # Reset connections
264 | db_client.reset_connections()
265 |
266 | # Should still work after reset
267 | result2 = db_client.execute("SHOW DATABASES")
268 | assert result2.success is True
269 |
270 | def test_describe_table(self, db_client):
271 | """Test DESCRIBE table functionality."""
272 | test_db = "test_mcp_describe"
273 | test_table = "describe_test"
274 |
275 | try:
276 | # Create test table
277 | db_result = db_client.execute(f"CREATE DATABASE IF NOT EXISTS {test_db}")
278 | assert db_result.success, f"Failed to create database: {db_result.error_message}"
279 |
280 | table_result = db_client.execute(f"""
281 | CREATE TABLE {test_db}.{test_table} (
282 | id BIGINT NOT NULL COMMENT 'Primary key',
283 | name VARCHAR(100) COMMENT 'Name field',
284 | created_at DATETIME,
285 | is_active BOOLEAN
286 | )
287 | PROPERTIES ("replication_num" = "1")
288 | """)
289 | assert table_result.success, f"Failed to create table: {table_result.error_message}"
290 |
291 | # Verify table exists first
292 | show_result = db_client.execute(f"SHOW TABLES", db=test_db)
293 | assert show_result.success, f"Failed to show tables: {show_result.error_message}"
294 | table_names = [row[0] for row in show_result.rows]
295 | assert test_table in table_names, f"Table {test_table} not found in {table_names}"
296 |
297 | # Describe table (use full table name for clarity)
298 | result = db_client.execute(f"DESCRIBE {test_db}.{test_table}")
299 |
300 | assert result.success is True
301 | assert result.column_names is not None
302 | assert len(result.rows) == 4 # 4 columns
303 |
304 | # Check column names in result (should include Field, Type, etc.)
305 | expected_columns = ['Field', 'Type', 'Null', 'Key', 'Default', 'Extra']
306 | for expected_col in expected_columns[:len(result.column_names)]:
307 | assert expected_col in result.column_names
308 |
309 | # Check that our table columns are present
310 | field_names = [row[0] for row in result.rows]
311 | assert 'id' in field_names
312 | assert 'name' in field_names
313 | assert 'created_at' in field_names
314 | assert 'is_active' in field_names
315 |
316 | finally:
317 | db_client.execute(f"DROP DATABASE IF EXISTS {test_db}")
318 |
319 |
320 | class TestDBClientWithArrowFlight:
321 | """Test cases for DBClient with Arrow Flight SQL (if configured)."""
322 |
323 | @pytest.fixture
324 | def arrow_client(self):
325 | """Create DBClient with Arrow Flight SQL if available."""
326 | # Check if Arrow Flight SQL port is configured (either from env or default test port)
327 | arrow_port = os.getenv('STARROCKS_FE_ARROW_FLIGHT_SQL_PORT', '9408')
328 |
329 | # Test if Arrow Flight SQL is actually available by trying to connect
330 | try:
331 | with patch.dict(os.environ, {'STARROCKS_FE_ARROW_FLIGHT_SQL_PORT': arrow_port}):
332 | reset_db_connections()
333 | client = DBClient()
334 | assert client.enable_arrow_flight_sql is True
335 |
336 | # Test basic connectivity
337 | result = client.execute("SHOW DATABASES")
338 | if not result.success:
339 | pytest.skip(f"Arrow Flight SQL not available on port {arrow_port}: {result.error_message}")
340 |
341 | return client
342 | except Exception as e:
343 | pytest.skip(f"Arrow Flight SQL not available: {e}")
344 |
345 | def test_arrow_flight_basic_query(self, arrow_client):
346 | """Test basic query with Arrow Flight SQL."""
347 | result = arrow_client.execute("SHOW DATABASES")
348 |
349 | assert isinstance(result, ResultSet)
350 | assert result.success is True
351 | assert result.column_names is not None
352 | assert result.rows is not None
353 | assert len(result.rows) > 0
354 |
355 | # Verify we're actually using Arrow Flight SQL
356 | assert arrow_client.enable_arrow_flight_sql is True
357 |
358 | def test_arrow_flight_pandas_format(self, arrow_client):
359 | """Test pandas format with Arrow Flight SQL."""
360 | result = arrow_client.execute("SHOW DATABASES", return_format="pandas")
361 |
362 | assert isinstance(result, ResultSet)
363 | assert result.success is True
364 | assert result.pandas is not None
365 | assert isinstance(result.pandas, pd.DataFrame)
366 | assert len(result.pandas) > 0
367 | assert len(result.pandas.columns) == 1
368 |
369 | # Test that to_pandas() returns the same DataFrame
370 | df = result.to_pandas()
371 | assert df is result.pandas
372 |
373 | # Verify we're actually using Arrow Flight SQL
374 | assert arrow_client.enable_arrow_flight_sql is True
375 |
376 | def test_arrow_flight_table_operations(self, arrow_client):
377 | """Test table operations with Arrow Flight SQL."""
378 | test_db = "test_arrow_flight"
379 | test_table = "arrow_test"
380 |
381 | try:
382 | # Create database
383 | create_db_result = arrow_client.execute(f"CREATE DATABASE IF NOT EXISTS {test_db}")
384 | assert create_db_result.success is True
385 |
386 | # Create table
387 | create_table_sql = f"""
388 | CREATE TABLE {test_db}.{test_table} (
389 | id INT,
390 | name STRING,
391 | value DOUBLE
392 | )
393 | PROPERTIES ("replication_num" = "1")
394 | """
395 | create_result = arrow_client.execute(create_table_sql)
396 | assert create_result.success is True
397 |
398 | # Insert data
399 | insert_sql = f"""
400 | INSERT INTO {test_db}.{test_table} VALUES
401 | (1, 'arrow1', 1.1),
402 | (2, 'arrow2', 2.2)
403 | """
404 | insert_result = arrow_client.execute(insert_sql)
405 | assert insert_result.success is True
406 | # Note: StarRocks Arrow Flight SQL always returns 0 for rows_affected due to implementation limitations
407 | assert insert_result.rows_affected == 0
408 |
409 | # Query data with pandas format
410 | select_result = arrow_client.execute(f"SELECT * FROM {test_db}.{test_table} ORDER BY id", return_format="pandas")
411 | assert isinstance(select_result, ResultSet)
412 | assert select_result.success is True
413 | assert select_result.pandas is not None
414 | assert isinstance(select_result.pandas, pd.DataFrame)
415 | assert len(select_result.pandas) == 2
416 | # Note: StarRocks Arrow Flight SQL loses column names in SELECT results (known limitation)
417 | # The columns come back as empty strings, but the data is correct
418 | assert len(select_result.pandas.columns) == 3
419 | # Since column names are empty, access by position instead
420 | assert select_result.pandas.iloc[0, 0] == 1 # id column
421 | assert select_result.pandas.iloc[0, 1] == 'arrow1' # name column
422 | assert select_result.pandas.iloc[0, 2] == 1.1 # value column
423 |
424 | # Test that to_pandas() returns the same DataFrame
425 | df = select_result.to_pandas()
426 | assert df is select_result.pandas
427 |
428 | # Query data with raw format
429 | raw_result = arrow_client.execute(f"SELECT * FROM {test_db}.{test_table} ORDER BY id")
430 | assert raw_result.success is True
431 | assert len(raw_result.rows) == 2
432 | # Note: Column names are empty due to StarRocks Arrow Flight SQL limitation
433 | assert raw_result.column_names == ['', '', '']
434 | # But the data is correct
435 | assert raw_result.rows[0] == [1, 'arrow1', 1.1]
436 | assert raw_result.rows[1] == [2, 'arrow2', 2.2]
437 |
438 | finally:
439 | # Clean up
440 | arrow_client.execute(f"DROP DATABASE IF EXISTS {test_db}")
441 |
442 | def test_arrow_flight_error_handling(self, arrow_client):
443 | """Test error handling with Arrow Flight SQL."""
444 | # Test invalid query
445 | result = arrow_client.execute("SELECT * FROM nonexistent_arrow_table")
446 | assert result.success is False
447 | assert result.error_message is not None
448 |
449 | # Test invalid database - Note: Arrow Flight SQL may fail with connection errors
450 | # before database validation, so we just check that it fails
451 | result = arrow_client.execute("SHOW TABLES", db="nonexistent_arrow_db")
452 | assert result.success is False
453 | assert result.error_message is not None
454 |
455 |
456 | class TestResultSet:
457 | """Test cases for ResultSet dataclass."""
458 |
459 | def test_result_set_creation(self):
460 | """Test ResultSet creation with various parameters."""
461 | # Success case
462 | result = ResultSet(
463 | success=True,
464 | column_names=['id', 'name'],
465 | rows=[[1, 'test'], [2, 'test2']],
466 | execution_time=0.5
467 | )
468 |
469 | assert result.success is True
470 | assert result.column_names == ['id', 'name']
471 | assert result.rows == [[1, 'test'], [2, 'test2']]
472 | assert result.execution_time == 0.5
473 | assert result.rows_affected is None
474 | assert result.error_message is None
475 |
476 | def test_result_set_to_pandas_from_rows(self):
477 | """Test ResultSet to_pandas conversion from rows."""
478 | result = ResultSet(
479 | success=True,
480 | column_names=['id', 'name', 'value'],
481 | rows=[[1, 'test1', 10.5], [2, 'test2', 20.5]],
482 | execution_time=0.1
483 | )
484 |
485 | df = result.to_pandas()
486 | assert isinstance(df, pd.DataFrame)
487 | assert len(df) == 2
488 | assert list(df.columns) == ['id', 'name', 'value']
489 | assert df.iloc[0]['id'] == 1
490 | assert df.iloc[0]['name'] == 'test1'
491 | assert df.iloc[0]['value'] == 10.5
492 | assert df.iloc[1]['id'] == 2
493 | assert df.iloc[1]['name'] == 'test2'
494 | assert df.iloc[1]['value'] == 20.5
495 |
496 | def test_result_set_to_pandas_from_pandas_field(self):
497 | """Test ResultSet to_pandas returns existing pandas field if available."""
498 | original_df = pd.DataFrame({
499 | 'id': [1, 2],
500 | 'name': ['test1', 'test2'],
501 | 'value': [10.5, 20.5]
502 | })
503 |
504 | result = ResultSet(
505 | success=True,
506 | column_names=['id', 'name', 'value'],
507 | rows=[[1, 'test1', 10.5], [2, 'test2', 20.5]],
508 | pandas=original_df,
509 | execution_time=0.1
510 | )
511 |
512 | df = result.to_pandas()
513 | assert df is original_df # Should return the same object
514 |
515 | def test_result_set_to_string(self):
516 | """Test ResultSet to_string conversion."""
517 | result = ResultSet(
518 | success=True,
519 | column_names=['id', 'name', 'value'],
520 | rows=[[1, 'test1', 10.5], [2, 'test2', 20.5]],
521 | execution_time=0.1
522 | )
523 |
524 | string_output = result.to_string()
525 | expected_lines = [
526 | 'id,name,value',
527 | '1,test1,10.5',
528 | '2,test2,20.5',
529 | ''
530 | ]
531 | assert string_output == '\n'.join(expected_lines)
532 |
533 | def test_result_set_to_string_with_limit(self):
534 | """Test ResultSet to_string with limit."""
535 | result = ResultSet(
536 | success=True,
537 | column_names=['id', 'name'],
538 | rows=[[1, 'very_long_test_string'], [2, 'another_long_string']],
539 | execution_time=0.1
540 | )
541 |
542 | # Test with very small limit
543 | string_output = result.to_string(limit=20)
544 | lines = string_output.split('\n')
545 | assert lines[0] == 'id,name' # Header should always be included
546 | # Should stop before all rows due to limit
547 | assert len(lines) < 4 # Should be less than header + 2 rows + empty line
548 |
549 | def test_result_set_to_string_error_cases(self):
550 | """Test ResultSet to_string error handling."""
551 | # Test with failed result
552 | failed_result = ResultSet(
553 | success=False,
554 | error_message="Test error"
555 | )
556 |
557 | string_output = failed_result.to_string()
558 | assert string_output == "Error: Test error"
559 |
560 | # Test with no data
561 | no_data_result = ResultSet(
562 | success=True,
563 | column_names=None,
564 | rows=None
565 | )
566 |
567 | string_output = no_data_result.to_string()
568 | assert string_output == "No data"
569 |
570 | def test_result_set_to_pandas_error_cases(self):
571 | """Test ResultSet to_pandas error handling."""
572 | # Test with failed result
573 | failed_result = ResultSet(
574 | success=False,
575 | error_message="Test error"
576 | )
577 |
578 | with pytest.raises(ValueError, match="Cannot convert failed result to DataFrame"):
579 | failed_result.to_pandas()
580 |
581 | # Test with no data
582 | no_data_result = ResultSet(
583 | success=True,
584 | column_names=None,
585 | rows=None
586 | )
587 |
588 | with pytest.raises(ValueError, match="No data available to convert to DataFrame"):
589 | no_data_result.to_pandas()
590 |
591 | def test_result_set_error_case(self):
592 | """Test ResultSet for error cases."""
593 | result = ResultSet(
594 | success=False,
595 | error_message="Test error",
596 | execution_time=0.1
597 | )
598 |
599 | assert result.success is False
600 | assert result.error_message == "Test error"
601 | assert result.execution_time == 0.1
602 | assert result.column_names is None
603 | assert result.rows is None
604 | assert result.rows_affected is None
605 |
606 | def test_result_set_write_operation(self):
607 | """Test ResultSet for write operations."""
608 | result = ResultSet(
609 | success=True,
610 | rows_affected=5,
611 | execution_time=0.2
612 | )
613 |
614 | assert result.success is True
615 | assert result.rows_affected == 5
616 | assert result.execution_time == 0.2
617 | assert result.column_names is None
618 | assert result.rows is None
619 | assert result.error_message is None
620 |
621 |
622 | class TestParseConnectionUrl:
623 | """Test cases for parse_connection_url function."""
624 |
625 | def test_parse_basic_url(self):
626 | """Test parsing basic connection URL without schema."""
627 | url = "root:password123@localhost:9030/test_db"
628 | result = parse_connection_url(url)
629 |
630 | expected = {
631 | 'user': 'root',
632 | 'password': 'password123',
633 | 'host': 'localhost',
634 | 'port': '9030',
635 | 'database': 'test_db'
636 | }
637 | assert result == expected
638 |
639 | def test_parse_url_with_schema(self):
640 | """Test parsing connection URL with schema."""
641 | url = "mysql://admin:[email protected]:3306/production"
642 | result = parse_connection_url(url)
643 |
644 | expected = {
645 | 'user': 'admin',
646 | 'password': 'secret',
647 | 'host': 'db.example.com',
648 | 'port': '3306',
649 | 'database': 'production'
650 | }
651 | assert result == expected
652 |
653 | def test_parse_url_with_different_schemas(self):
654 | """Test parsing URLs with various schema types."""
655 | test_cases = [
656 | ("starrocks://user:pass@host:9030/db", "starrocks"),
657 | ("jdbc+mysql://user:pass@host:3306/db", "jdbc+mysql"),
658 | ("postgresql://user:pass@host:5432/db", "postgresql"),
659 | ]
660 |
661 | for url, expected_schema in test_cases:
662 | result = parse_connection_url(url)
663 | # Schema is no longer returned in the result
664 | assert result['user'] == 'user'
665 | assert result['password'] == 'pass'
666 | assert result['host'] == 'host'
667 | assert result['database'] == 'db'
668 |
669 | def test_parse_url_empty_password_succeeds(self):
670 | """Test that URL with empty password now works."""
671 | url = "root:@localhost:9030/test_db"
672 | result = parse_connection_url(url)
673 |
674 | expected = {
675 | 'user': 'root',
676 | 'password': '', # Empty password
677 | 'host': 'localhost',
678 | 'port': '9030',
679 | 'database': 'test_db'
680 | }
681 | assert result == expected
682 |
683 | def test_parse_url_no_password_colon(self):
684 | """Test URL without password colon (e.g., root@localhost:9030)."""
685 | url = "root@localhost:9030"
686 | result = parse_connection_url(url)
687 |
688 | expected = {
689 | 'user': 'root',
690 | 'password': '', # Default empty password
691 | 'host': 'localhost',
692 | 'port': '9030',
693 | 'database': None
694 | }
695 | assert result == expected
696 |
697 | def test_parse_url_missing_port_uses_default(self):
698 | """Test URL without port uses default 9030."""
699 | url = "root:password@localhost/mydb"
700 | result = parse_connection_url(url)
701 |
702 | expected = {
703 | 'user': 'root',
704 | 'password': 'password',
705 | 'host': 'localhost',
706 | 'port': '9030', # Default port
707 | 'database': 'mydb'
708 | }
709 | assert result == expected
710 |
711 | def test_parse_url_minimal_format(self):
712 | """Test minimal URL format (just user@host)."""
713 | url = "user@host"
714 | result = parse_connection_url(url)
715 |
716 | expected = {
717 | 'user': 'user',
718 | 'password': '', # Default empty password
719 | 'host': 'host',
720 | 'port': '9030', # Default port
721 | 'database': None
722 | }
723 | assert result == expected
724 |
725 | def test_parse_url_empty_string_password(self):
726 | """Test URL with explicit empty password using double colon."""
727 | url = "user::@host:9030/db"
728 | result = parse_connection_url(url)
729 |
730 | expected = {
731 | 'user': 'user',
732 | 'password': ':', # Literal colon as password
733 | 'host': 'host',
734 | 'port': '9030',
735 | 'database': 'db'
736 | }
737 | assert result == expected
738 |
739 | def test_parse_url_complex_password_limitation(self):
740 | """Test that password with @ symbol has regex limitation (parses incorrectly)."""
741 | url = "user:p@ssw0rd!@server:9030/mydb"
742 | result = parse_connection_url(url)
743 |
744 | # Due to regex limitation, @ in password causes incorrect parsing
745 | assert result['user'] == 'user'
746 | assert result['password'] == 'p' # Only gets characters before first @
747 | assert result['host'] == 'ssw0rd!@server' # Rest becomes host
748 | assert result['port'] == '9030'
749 | assert result['database'] == 'mydb'
750 |
751 | def test_parse_url_password_without_at_symbol(self):
752 | """Test parsing URL with complex password without @ symbol."""
753 | url = "user:p#ssw0rd!$%^&*()@server:9030/mydb"
754 | result = parse_connection_url(url)
755 |
756 | assert result['user'] == 'user'
757 | assert result['password'] == 'p#ssw0rd!$%^&*()'
758 | assert result['host'] == 'server'
759 | assert result['port'] == '9030'
760 | assert result['database'] == 'mydb'
761 |
762 | def test_parse_url_complex_username_with_at_symbol_limitation(self):
763 | """Test that username with @ symbol fails (regex limitation)."""
764 | url = "user.name+tag@domain:password123@host:9030/db"
765 | # This should fail because our regex cannot distinguish between
766 | # the @ in username vs the @ separator for host
767 | with pytest.raises(ValueError, match="Invalid connection URL"):
768 | parse_connection_url(url)
769 |
770 | def test_parse_url_complex_username_without_at(self):
771 | """Test parsing URL with complex username without @ symbol."""
772 | url = "user.name+tag_domain:password123@host:9030/db"
773 | result = parse_connection_url(url)
774 |
775 | assert result['user'] == 'user.name+tag_domain'
776 | assert result['password'] == 'password123'
777 | assert result['host'] == 'host'
778 | assert result['port'] == '9030'
779 | assert result['database'] == 'db'
780 |
781 | def test_parse_url_numeric_database(self):
782 | """Test parsing URL with numeric database name."""
783 | url = "root:pass@localhost:9030/db123"
784 | result = parse_connection_url(url)
785 |
786 | assert result['database'] == 'db123'
787 |
788 | def test_parse_url_database_with_hyphens(self):
789 | """Test parsing URL with database name containing hyphens."""
790 | url = "root:pass@localhost:9030/test-db-name"
791 | result = parse_connection_url(url)
792 |
793 | assert result['database'] == 'test-db-name'
794 |
795 | def test_parse_url_ip_address_host(self):
796 | """Test parsing URL with IP address as host."""
797 | url = "root:[email protected]:9030/testdb"
798 | result = parse_connection_url(url)
799 |
800 | assert result['host'] == '192.168.1.100'
801 | assert result['port'] == '9030'
802 | assert result['database'] == 'testdb'
803 |
804 | def test_parse_url_different_ports(self):
805 | """Test parsing URLs with different port numbers."""
806 | test_cases = [
807 | ("user:pass@host:3306/db", "3306"),
808 | ("user:pass@host:5432/db", "5432"),
809 | ("user:pass@host:27017/db", "27017"),
810 | ("user:pass@host:1/db", "1"),
811 | ("user:pass@host:65535/db", "65535"),
812 | ]
813 |
814 | for url, expected_port in test_cases:
815 | result = parse_connection_url(url)
816 | assert result['port'] == expected_port
817 |
818 | def test_parse_invalid_urls(self):
819 | """Test that invalid URLs raise ValueError."""
820 | invalid_urls = [
821 | # Missing required parts
822 | "@host:9030/db", # Missing user
823 | "user:pass@:9030/db", # Missing host
824 |
825 | # Malformed URLs
826 | "user:pass@host:port/db", # Non-numeric port
827 | "user:pass@host:9030/", # Empty database
828 | "user:pass@host:9030/db/extra", # Extra path component
829 | "", # Empty string
830 | "random-string-not-url", # Not a URL format
831 |
832 | # Special cases
833 | "://user:pass@host:9030/db", # Empty schema
834 | "user:pass@host:-1/db", # Negative port
835 | ]
836 |
837 | for invalid_url in invalid_urls:
838 | with pytest.raises(ValueError, match="Invalid connection URL"):
839 | parse_connection_url(invalid_url)
840 |
841 | def test_parse_url_colon_in_password_works(self):
842 | """Test that colon in password actually works (unlike @ symbol)."""
843 | url = "user:pass:extra@host:9030/db"
844 | result = parse_connection_url(url)
845 |
846 | assert result['user'] == 'user'
847 | assert result['password'] == 'pass:extra' # Colons in password are fine
848 | assert result['host'] == 'host'
849 | assert result['port'] == '9030'
850 | assert result['database'] == 'db'
851 |
852 | def test_parse_url_without_database(self):
853 | """Test parsing URL without database (database is optional)."""
854 | url = "user:password@host:9030"
855 | result = parse_connection_url(url)
856 |
857 | assert result['user'] == 'user'
858 | assert result['password'] == 'password'
859 | assert result['host'] == 'host'
860 | assert result['port'] == '9030'
861 | assert result['database'] == None # Database should be None when omitted
862 |
863 | def test_parse_url_with_schema_without_database(self):
864 | """Test parsing URL with schema but without database."""
865 | url = "mysql://admin:[email protected]:3306"
866 | result = parse_connection_url(url)
867 |
868 | assert result['user'] == 'admin'
869 | assert result['password'] == 'secret'
870 | assert result['host'] == 'db.example.com'
871 | assert result['port'] == '3306'
872 | assert result['database'] == None
873 |
874 | def test_parse_url_various_schemas_without_database(self):
875 | """Test parsing URLs with various schemas but no database."""
876 | test_cases = [
877 | ("starrocks://user:pass@host:9030", "starrocks"),
878 | ("jdbc+mysql://user:pass@host:3306", "jdbc+mysql"),
879 | ("postgresql://user:pass@host:5432", "postgresql"),
880 | ]
881 |
882 | for url, expected_schema in test_cases:
883 | result = parse_connection_url(url)
884 | # Schema is no longer returned in the result
885 | assert result['user'] == 'user'
886 | assert result['password'] == 'pass'
887 | assert result['host'] == 'host'
888 | assert result['database'] == None
889 |
890 | def test_parse_url_edge_cases(self):
891 | """Test edge cases that should work."""
892 | # Single character components
893 | url = "a:b@c:1/d"
894 | result = parse_connection_url(url)
895 | assert result['user'] == 'a'
896 | assert result['password'] == 'b'
897 | assert result['host'] == 'c'
898 | assert result['port'] == '1'
899 | assert result['database'] == 'd'
900 |
901 | # Long components
902 | long_user = "a" * 100
903 | long_pass = "b" * 100
904 | long_host = "c" * 50
905 | long_db = "d" * 50
906 | url = f"{long_user}:{long_pass}@{long_host}:9030/{long_db}"
907 | result = parse_connection_url(url)
908 | assert result['user'] == long_user
909 | assert result['password'] == long_pass
910 | assert result['host'] == long_host
911 | assert result['database'] == long_db
912 |
913 | def test_parse_url_returns_dict_with_all_keys(self):
914 | """Test that parse_connection_url always returns dict with all expected keys."""
915 | test_cases = [
916 | "root:pass@localhost:9030/db",
917 | "mysql://root:pass@localhost:3306/db",
918 | ]
919 |
920 | expected_keys = {'user', 'password', 'host', 'port', 'database'}
921 |
922 | for url in test_cases:
923 | result = parse_connection_url(url)
924 | assert isinstance(result, dict)
925 | assert set(result.keys()) == expected_keys
926 |
927 | def test_parse_url_regex_pattern_comprehensive(self):
928 | """Test comprehensive regex pattern matching."""
929 | # Test that the regex correctly captures each group
930 | url = "custom+schema://test_user:[email protected]:12345/my_db-name"
931 | result = parse_connection_url(url)
932 |
933 | # Schema is no longer returned in the result
934 | assert result['user'] == 'test_user'
935 | assert result['password'] == 'complex!pass'
936 | assert result['host'] == 'sub.domain.com'
937 | assert result['port'] == '12345'
938 | assert result['database'] == 'my_db-name'
939 |
940 |
941 | class TestDummyMode:
942 | """Test cases for STARROCKS_DUMMY_TEST environment variable."""
943 |
944 | def test_dummy_mode_enabled(self):
945 | """Test that dummy mode returns expected dummy data."""
946 | # Set dummy test environment variable
947 | with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
948 | client = DBClient()
949 | assert client.enable_dummy_test is True
950 |
951 | # Test basic query
952 | result = client.execute("SELECT * FROM any_table")
953 |
954 | assert result.success is True
955 | assert result.column_names == ['name']
956 | assert result.rows == [['aaa'], ['bbb'], ['ccc']]
957 | assert result.execution_time is not None
958 | assert result.execution_time > 0
959 | assert result.pandas is None # pandas should be None for raw format
960 |
961 | def test_dummy_mode_with_pandas_format(self):
962 | """Test dummy mode with pandas return format."""
963 | with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
964 | client = DBClient()
965 |
966 | result = client.execute("SELECT * FROM any_table", return_format="pandas")
967 |
968 | assert result.success is True
969 | assert result.column_names == ['name']
970 | assert result.rows == [['aaa'], ['bbb'], ['ccc']]
971 | assert result.pandas is not None
972 | assert isinstance(result.pandas, pd.DataFrame)
973 | assert len(result.pandas) == 3
974 | assert list(result.pandas.columns) == ['name']
975 | assert result.pandas.iloc[0]['name'] == 'aaa'
976 | assert result.pandas.iloc[1]['name'] == 'bbb'
977 | assert result.pandas.iloc[2]['name'] == 'ccc'
978 |
979 | def test_dummy_mode_ignores_statement_and_db(self):
980 | """Test that dummy mode returns same data regardless of SQL statement or database."""
981 | with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
982 | client = DBClient()
983 |
984 | # Test different statements
985 | result1 = client.execute("SHOW DATABASES")
986 | result2 = client.execute("CREATE TABLE test (id INT)")
987 | result3 = client.execute("SELECT COUNT(*) FROM users", db="production")
988 |
989 | # All should return the same dummy data
990 | for result in [result1, result2, result3]:
991 | assert result.success is True
992 | assert result.column_names == ['name']
993 | assert result.rows == [['aaa'], ['bbb'], ['ccc']]
994 |
995 | def test_dummy_mode_disabled_by_default(self):
996 | """Test that dummy mode is disabled when environment variable is not set."""
997 | # Ensure STARROCKS_DUMMY_TEST is not set
998 | with patch.dict(os.environ, {}, clear=True):
999 | os.environ.pop('STARROCKS_DUMMY_TEST', None) # Remove if exists
1000 | client = DBClient()
1001 | assert client.enable_dummy_test is False
1002 |
1003 | def test_dummy_mode_with_empty_string(self):
1004 | """Test that empty string for STARROCKS_DUMMY_TEST disables dummy mode."""
1005 | with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': ''}):
1006 | client = DBClient()
1007 | assert client.enable_dummy_test is False
1008 |
1009 | def test_dummy_mode_with_various_truthy_values(self):
1010 | """Test that various truthy values enable dummy mode."""
1011 | truthy_values = ['1', 'true', 'True', 'yes', 'on', 'any_non_empty_string']
1012 |
1013 | for value in truthy_values:
1014 | with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': value}):
1015 | client = DBClient()
1016 | assert client.enable_dummy_test is True, f"Failed for value: {value}"
1017 |
1018 | def test_dummy_mode_to_pandas_conversion(self):
1019 | """Test to_pandas() method works with dummy data."""
1020 | with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
1021 | client = DBClient()
1022 |
1023 | # Test raw format conversion
1024 | result = client.execute("SELECT * FROM test")
1025 | df = result.to_pandas()
1026 | assert isinstance(df, pd.DataFrame)
1027 | assert len(df) == 3
1028 | assert list(df.columns) == ['name']
1029 | assert df.iloc[0]['name'] == 'aaa'
1030 |
1031 | # Test pandas format (should return same DataFrame)
1032 | result_pandas = client.execute("SELECT * FROM test", return_format="pandas")
1033 | df_pandas = result_pandas.to_pandas()
1034 | assert df_pandas is result_pandas.pandas
1035 |
1036 | def test_dummy_mode_to_string_conversion(self):
1037 | """Test to_string() method works with dummy data."""
1038 | with patch.dict(os.environ, {'STARROCKS_DUMMY_TEST': '1'}):
1039 | client = DBClient()
1040 |
1041 | result = client.execute("SELECT * FROM test")
1042 | string_output = result.to_string()
1043 |
1044 | expected_lines = [
1045 | 'name',
1046 | 'aaa',
1047 | 'bbb',
1048 | 'ccc',
1049 | ''
1050 | ]
1051 | assert string_output == '\n'.join(expected_lines)
1052 |
1053 |
1054 | if __name__ == "__main__":
1055 | pytest.main([__file__, "-v"])
```