This is page 3 of 6. Use http://codebase.md/alexander-zuev/supabase-mcp-server?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .claude │ └── settings.local.json ├── .dockerignore ├── .env.example ├── .env.test.example ├── .github │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.md │ │ ├── feature_request.md │ │ └── roadmap_item.md │ ├── PULL_REQUEST_TEMPLATE.md │ └── workflows │ ├── ci.yaml │ ├── docs │ │ └── release-checklist.md │ └── publish.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── CHANGELOG.MD ├── codecov.yml ├── CONTRIBUTING.MD ├── Dockerfile ├── LICENSE ├── llms-full.txt ├── pyproject.toml ├── README.md ├── smithery.yaml ├── supabase_mcp │ ├── __init__.py │ ├── clients │ │ ├── api_client.py │ │ ├── base_http_client.py │ │ ├── management_client.py │ │ └── sdk_client.py │ ├── core │ │ ├── __init__.py │ │ ├── container.py │ │ └── feature_manager.py │ ├── exceptions.py │ ├── logger.py │ ├── main.py │ ├── services │ │ ├── __init__.py │ │ ├── api │ │ │ ├── __init__.py │ │ │ ├── api_manager.py │ │ │ ├── spec_manager.py │ │ │ └── specs │ │ │ └── api_spec.json │ │ ├── database │ │ │ ├── __init__.py │ │ │ ├── migration_manager.py │ │ │ ├── postgres_client.py │ │ │ ├── query_manager.py │ │ │ └── sql │ │ │ ├── loader.py │ │ │ ├── models.py │ │ │ ├── queries │ │ │ │ ├── create_migration.sql │ │ │ │ ├── get_migrations.sql │ │ │ │ ├── get_schemas.sql │ │ │ │ ├── get_table_schema.sql │ │ │ │ ├── get_tables.sql │ │ │ │ ├── init_migrations.sql │ │ │ │ └── logs │ │ │ │ ├── auth_logs.sql │ │ │ │ ├── cron_logs.sql │ │ │ │ ├── edge_logs.sql │ │ │ │ ├── function_edge_logs.sql │ │ │ │ ├── pgbouncer_logs.sql │ │ │ │ ├── postgres_logs.sql │ │ │ │ ├── postgrest_logs.sql │ │ │ │ ├── realtime_logs.sql │ │ │ │ ├── storage_logs.sql │ │ │ │ └── supavisor_logs.sql │ │ │ └── validator.py │ │ ├── logs │ │ │ ├── __init__.py │ │ │ └── log_manager.py │ │ ├── safety │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ ├── safety_configs.py │ │ │ └── safety_manager.py │ │ └── sdk │ │ ├── __init__.py │ │ ├── auth_admin_models.py │ │ └── auth_admin_sdk_spec.py │ ├── settings.py │ └── tools │ ├── __init__.py │ ├── descriptions │ │ ├── api_tools.yaml │ │ ├── database_tools.yaml │ │ ├── logs_and_analytics_tools.yaml │ │ ├── safety_tools.yaml │ │ └── sdk_tools.yaml │ ├── manager.py │ └── registry.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── services │ │ ├── __init__.py │ │ ├── api │ │ │ ├── __init__.py │ │ │ ├── test_api_client.py │ │ │ ├── test_api_manager.py │ │ │ └── test_spec_manager.py │ │ ├── database │ │ │ ├── sql │ │ │ │ ├── __init__.py │ │ │ │ ├── conftest.py │ │ │ │ ├── test_loader.py │ │ │ │ ├── test_sql_validator_integration.py │ │ │ │ └── test_sql_validator.py │ │ │ ├── test_migration_manager.py │ │ │ ├── test_postgres_client.py │ │ │ └── test_query_manager.py │ │ ├── logs │ │ │ └── test_log_manager.py │ │ ├── safety │ │ │ ├── test_api_safety_config.py │ │ │ ├── test_safety_manager.py │ │ │ └── test_sql_safety_config.py │ │ └── sdk │ │ ├── test_auth_admin_models.py │ │ └── test_sdk_client.py │ ├── test_container.py │ ├── test_main.py │ ├── test_settings.py │ ├── test_tool_manager.py │ ├── test_tools_integration.py.bak │ └── test_tools.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /supabase_mcp/services/api/api_manager.py: -------------------------------------------------------------------------------- ```python 1 | from __future__ import annotations 2 | 3 | from enum import Enum 4 | from typing import Any 5 | 6 | from supabase_mcp.clients.management_client import ManagementAPIClient 7 | from supabase_mcp.logger import logger 8 | from supabase_mcp.services.api.spec_manager import ApiSpecManager 9 | from supabase_mcp.services.logs.log_manager import LogManager 10 | from supabase_mcp.services.safety.models import ClientType 11 | from supabase_mcp.services.safety.safety_manager import SafetyManager 12 | from supabase_mcp.settings import settings 13 | 14 | 15 | class PathPlaceholder(str, Enum): 16 | """Enum of all possible path placeholders in the Supabase Management API.""" 17 | 18 | REF = "ref" 19 | FUNCTION_SLUG = "function_slug" 20 | ID = "id" 21 | SLUG = "slug" 22 | BRANCH_ID = "branch_id" 23 | PROVIDER_ID = "provider_id" 24 | TPA_ID = "tpa_id" 25 | 26 | 27 | class SupabaseApiManager: 28 | """ 29 | Manages the Supabase Management API. 30 | """ 31 | 32 | _instance: SupabaseApiManager | None = None 33 | 34 | def __init__( 35 | self, 36 | api_client: ManagementAPIClient, 37 | safety_manager: SafetyManager, 38 | spec_manager: ApiSpecManager | None = None, 39 | log_manager: LogManager | None = None, 40 | ) -> None: 41 | """Initialize the API manager.""" 42 | self.spec_manager = spec_manager or ApiSpecManager() # this is so that I don't have to pass it 43 | self.client = api_client 44 | self.safety_manager = safety_manager 45 | self.log_manager = log_manager or LogManager() 46 | 47 | @classmethod 48 | def get_instance( 49 | cls, 50 | api_client: ManagementAPIClient, 51 | safety_manager: SafetyManager, 52 | spec_manager: ApiSpecManager | None = None, 53 | ) -> SupabaseApiManager: 54 | """Get the singleton instance""" 55 | if cls._instance is None: 56 | cls._instance = SupabaseApiManager(api_client, safety_manager, spec_manager) 57 | return cls._instance 58 | 59 | @classmethod 60 | def reset(cls) -> None: 61 | """Reset the singleton instance""" 62 | if cls._instance is not None: 63 | cls._instance = None 64 | logger.info("SupabaseApiManager instance reset complete") 65 | 66 | def get_safety_rules(self) -> str: 67 | """ 68 | Get safety rules with human-readable descriptions. 69 | 70 | Returns: 71 | str: Human readable safety rules explanation 72 | """ 73 | # Get safety configuration from the safety manager 74 | safety_manager = self.safety_manager 75 | 76 | # Get risk levels and operations by risk level 77 | extreme_risk_ops = safety_manager.get_operations_by_risk_level("extreme", ClientType.API) 78 | high_risk_ops = safety_manager.get_operations_by_risk_level("high", ClientType.API) 79 | medium_risk_ops = safety_manager.get_operations_by_risk_level("medium", ClientType.API) 80 | 81 | # Create human-readable explanations 82 | extreme_risk_summary = ( 83 | "\n".join([f"- {method} {path}" for method, paths in extreme_risk_ops.items() for path in paths]) 84 | if extreme_risk_ops 85 | else "None" 86 | ) 87 | 88 | high_risk_summary = ( 89 | "\n".join([f"- {method} {path}" for method, paths in high_risk_ops.items() for path in paths]) 90 | if high_risk_ops 91 | else "None" 92 | ) 93 | 94 | medium_risk_summary = ( 95 | "\n".join([f"- {method} {path}" for method, paths in medium_risk_ops.items() for path in paths]) 96 | if medium_risk_ops 97 | else "None" 98 | ) 99 | 100 | current_mode = safety_manager.get_current_mode(ClientType.API) 101 | 102 | return f"""MCP Server Safety Rules: 103 | 104 | EXTREME RISK Operations (never allowed by the server): 105 | {extreme_risk_summary} 106 | 107 | HIGH RISK Operations (require unsafe mode): 108 | {high_risk_summary} 109 | 110 | MEDIUM RISK Operations (require unsafe mode): 111 | {medium_risk_summary} 112 | 113 | All other operations are LOW RISK (always allowed). 114 | 115 | Current mode: {current_mode} 116 | In safe mode, only low risk operations are allowed. 117 | Use live_dangerously() to enable unsafe mode for medium and high risk operations. 118 | """ 119 | 120 | def replace_path_params(self, path: str, path_params: dict[str, Any] | None = None) -> str: 121 | """ 122 | Replace path parameters in the path string with actual values. 123 | 124 | This method: 125 | 1. Automatically injects the project ref from settings 126 | 2. Replaces all placeholders in the path with values from path_params 127 | 3. Validates that all placeholders are replaced 128 | 129 | Args: 130 | path: The API path with placeholders (e.g., "/v1/projects/{ref}/functions/{function_slug}") 131 | path_params: Dictionary of path parameters to replace (e.g., {"function_slug": "my-function"}) 132 | 133 | Returns: 134 | The path with all placeholders replaced 135 | 136 | Raises: 137 | ValueError: If any placeholders remain after replacement or if invalid placeholders are provided 138 | """ 139 | # Create a working copy of path_params to avoid modifying the original 140 | working_params = {} if path_params is None else path_params.copy() 141 | 142 | # Check if user provided ref and raise an error 143 | if working_params and PathPlaceholder.REF.value in working_params: 144 | raise ValueError( 145 | "Do not provide 'ref' in path_params. The project reference is automatically injected from settings. " 146 | "If you need to change the project reference, modify the environment variables instead." 147 | ) 148 | 149 | # Validate that all provided path parameters are known placeholders 150 | if working_params: 151 | for key in working_params: 152 | try: 153 | PathPlaceholder(key) 154 | except ValueError as e: 155 | raise ValueError( 156 | f"Unknown path parameter: '{key}'. Valid placeholders are: " 157 | f"{', '.join([p.value for p in PathPlaceholder])}" 158 | ) from e 159 | 160 | # Get project ref from settings and add it to working_params 161 | working_params[PathPlaceholder.REF.value] = settings.supabase_project_ref 162 | 163 | logger.info(f"Replacing path parameters in path: {working_params}") 164 | 165 | # Replace all placeholders in the path 166 | for key, value in working_params.items(): 167 | placeholder = "{" + key + "}" 168 | if placeholder in path: 169 | path = path.replace(placeholder, str(value)) 170 | logger.debug(f"Replaced {placeholder} with {value}") 171 | 172 | # Check if any placeholders remain 173 | import re 174 | 175 | remaining_placeholders = re.findall(r"\{([^}]+)\}", path) 176 | if remaining_placeholders: 177 | raise ValueError( 178 | f"Missing path parameters: {', '.join(remaining_placeholders)}. " 179 | f"Please provide values for these placeholders in the path_params dictionary." 180 | ) 181 | 182 | return path 183 | 184 | async def execute_request( 185 | self, 186 | method: str, 187 | path: str, 188 | path_params: dict[str, Any] | None = None, 189 | request_params: dict[str, Any] | None = None, 190 | request_body: dict[str, Any] | None = None, 191 | has_confirmation: bool = False, 192 | ) -> dict[str, Any]: 193 | """ 194 | Execute Management API request with safety validation. 195 | 196 | Args: 197 | method: HTTP method to use 198 | path: API path to call 199 | request_params: Query parameters to include 200 | request_body: Request body to send 201 | has_confirmation: Whether the operation has been confirmed by the user 202 | Returns: 203 | API response as a dictionary 204 | 205 | Raises: 206 | SafetyError: If the operation is not allowed by safety rules 207 | """ 208 | # Log the request with proper formatting 209 | logger.info( 210 | f"API Request: {method} {path} | Path params: {path_params or {}} | Query params: {request_params or {}} | Body: {request_body or {}}" 211 | ) 212 | 213 | # Create an operation object for validation 214 | operation = (method, path, path_params, request_params, request_body) 215 | 216 | # Use the safety manager to validate the operation 217 | logger.debug(f"Validating operation safety: {method} {path}") 218 | self.safety_manager.validate_operation(ClientType.API, operation, has_confirmation=has_confirmation) 219 | 220 | # Replace path parameters in the path string with actual values 221 | path = self.replace_path_params(path, path_params) 222 | 223 | # Execute the request using the API client 224 | return await self.client.execute_request(method, path, request_params, request_body) 225 | 226 | async def handle_confirmation(self, confirmation_id: str) -> dict[str, Any]: 227 | """Handle a confirmation request.""" 228 | # retrieve the operation from the confirmation id 229 | operation = self.safety_manager.get_stored_operation(confirmation_id) 230 | if not operation: 231 | raise ValueError("No operation found for confirmation id") 232 | 233 | # execute the operation 234 | return await self.execute_request( 235 | method=operation[0], 236 | path=operation[1], 237 | path_params=operation[2], 238 | request_params=operation[3], 239 | request_body=operation[4], 240 | has_confirmation=True, 241 | ) 242 | 243 | async def handle_spec_request( 244 | self, 245 | path: str | None = None, 246 | method: str | None = None, 247 | domain: str | None = None, 248 | all_paths: bool | None = False, 249 | ) -> dict[str, Any]: 250 | """Handle a spec request. 251 | 252 | Args: 253 | path: Optional API path 254 | method: Optional HTTP method 255 | api_domain: Optional domain/tag name 256 | full_spec: If True, returns all paths and methods 257 | 258 | Returns: 259 | API specification based on the provided parameters 260 | """ 261 | spec_manager = self.spec_manager 262 | 263 | if spec_manager is None: 264 | raise RuntimeError("API spec manager is not initialized") 265 | 266 | # Ensure spec is loaded 267 | await spec_manager.get_spec() 268 | 269 | # Option 1: Get spec for specific path and method 270 | if path and method: 271 | method = method.lower() # Normalize method to lowercase 272 | result = spec_manager.get_spec_for_path_and_method(path, method) 273 | if result is None: 274 | return {"error": f"No specification found for {method.upper()} {path}"} 275 | return result 276 | 277 | # Option 2: Get all paths and methods for a specific domain 278 | elif domain: 279 | result = spec_manager.get_paths_and_methods_by_domain(domain) 280 | if not result: 281 | # Check if the domain exists 282 | all_domains = spec_manager.get_all_domains() 283 | if domain not in all_domains: 284 | return {"error": f"Domain '{domain}' not found", "available_domains": all_domains} 285 | return {"domain": domain, "paths": result} 286 | 287 | # Option 4: Get all paths and methods 288 | elif all_paths: 289 | return {"paths": spec_manager.get_all_paths_and_methods()} 290 | 291 | # Option 3: Get all domains (default) 292 | else: 293 | return {"domains": spec_manager.get_all_domains()} 294 | 295 | async def retrieve_logs( 296 | self, 297 | collection: str, 298 | limit: int = 20, 299 | hours_ago: int | None = 1, 300 | filters: list[dict[str, Any]] | None = None, 301 | search: str | None = None, 302 | custom_query: str | None = None, 303 | ) -> dict[str, Any]: 304 | """Retrieve logs from a Supabase service. 305 | 306 | Args: 307 | collection: The log collection to query 308 | limit: Maximum number of log entries to return 309 | hours_ago: Retrieve logs from the last N hours 310 | filters: List of filter objects with field, operator, and value 311 | search: Text to search for in event messages 312 | custom_query: Complete custom SQL query to execute 313 | 314 | Returns: 315 | The query result 316 | 317 | Raises: 318 | ValueError: If the collection is unknown 319 | """ 320 | log_manager = self.log_manager 321 | 322 | # Build the SQL query using LogManager 323 | sql = log_manager.build_logs_query( 324 | collection=collection, 325 | limit=limit, 326 | hours_ago=hours_ago, 327 | filters=filters, 328 | search=search, 329 | custom_query=custom_query, 330 | ) 331 | 332 | logger.debug(f"Executing log query: {sql}") 333 | 334 | # Make the API request 335 | try: 336 | response = await self.execute_request( 337 | method="GET", 338 | path="/v1/projects/{ref}/analytics/endpoints/logs.all", 339 | path_params={}, 340 | request_params={"sql": sql}, 341 | request_body={}, 342 | ) 343 | 344 | return response 345 | except Exception as e: 346 | logger.error(f"Error retrieving logs: {e}") 347 | raise 348 | ``` -------------------------------------------------------------------------------- /tests/services/safety/test_safety_manager.py: -------------------------------------------------------------------------------- ```python 1 | import time 2 | 3 | import pytest 4 | 5 | from supabase_mcp.exceptions import ConfirmationRequiredError, OperationNotAllowedError 6 | from supabase_mcp.services.safety.models import ClientType, OperationRiskLevel, SafetyMode 7 | from supabase_mcp.services.safety.safety_configs import SafetyConfigBase 8 | from supabase_mcp.services.safety.safety_manager import SafetyManager 9 | 10 | 11 | class MockSafetyConfig(SafetyConfigBase[str]): 12 | """Mock safety configuration for testing.""" 13 | 14 | def get_risk_level(self, operation: str) -> OperationRiskLevel: 15 | """Get the risk level for an operation.""" 16 | if operation == "low_risk": 17 | return OperationRiskLevel.LOW 18 | elif operation == "medium_risk": 19 | return OperationRiskLevel.MEDIUM 20 | elif operation == "high_risk": 21 | return OperationRiskLevel.HIGH 22 | elif operation == "extreme_risk": 23 | return OperationRiskLevel.EXTREME 24 | else: 25 | return OperationRiskLevel.LOW 26 | 27 | 28 | @pytest.mark.unit 29 | class TestSafetyManager: 30 | """Unit test cases for the SafetyManager class.""" 31 | 32 | @pytest.fixture(autouse=True) 33 | def setup_and_teardown(self): 34 | """Setup and teardown for each test.""" 35 | # Reset the singleton before each test 36 | # pylint: disable=protected-access 37 | SafetyManager._instance = None # type: ignore 38 | yield 39 | # Reset the singleton after each test 40 | SafetyManager._instance = None # type: ignore 41 | 42 | def test_singleton_pattern(self): 43 | """Test that SafetyManager follows the singleton pattern.""" 44 | # Get two instances of the SafetyManager 45 | manager1 = SafetyManager.get_instance() 46 | manager2 = SafetyManager.get_instance() 47 | 48 | # Verify they are the same instance 49 | assert manager1 is manager2 50 | 51 | # Verify that creating a new instance directly doesn't affect the singleton 52 | direct_instance = SafetyManager() 53 | assert direct_instance is not manager1 54 | 55 | def test_register_config(self): 56 | """Test registering a safety configuration.""" 57 | manager = SafetyManager.get_instance() 58 | mock_config = MockSafetyConfig() 59 | 60 | # Register the config for DATABASE client type 61 | manager.register_config(ClientType.DATABASE, mock_config) 62 | 63 | # Verify the config was registered 64 | assert manager._safety_configs[ClientType.DATABASE] is mock_config 65 | 66 | # Test that registering a config for the same client type overwrites the previous config 67 | new_mock_config = MockSafetyConfig() 68 | manager.register_config(ClientType.DATABASE, new_mock_config) 69 | assert manager._safety_configs[ClientType.DATABASE] is new_mock_config 70 | 71 | def test_get_safety_mode_default(self): 72 | """Test getting the default safety mode for an unregistered client type.""" 73 | manager = SafetyManager.get_instance() 74 | 75 | # Create a custom client type that hasn't been registered 76 | class CustomClientType(str): 77 | pass 78 | 79 | custom_type = CustomClientType("custom") 80 | 81 | # Verify that getting a safety mode for an unregistered client type returns SafetyMode.SAFE 82 | assert manager.get_safety_mode(custom_type) == SafetyMode.SAFE # type: ignore 83 | 84 | def test_get_safety_mode_registered(self): 85 | """Test getting the safety mode for a registered client type.""" 86 | manager = SafetyManager.get_instance() 87 | 88 | # Set a safety mode for a client type 89 | manager._safety_modes[ClientType.API] = SafetyMode.UNSAFE 90 | 91 | # Verify it's returned correctly 92 | assert manager.get_safety_mode(ClientType.API) == SafetyMode.UNSAFE 93 | 94 | def test_set_safety_mode(self): 95 | """Test setting the safety mode for a client type.""" 96 | manager = SafetyManager.get_instance() 97 | 98 | # Set a safety mode for a client type 99 | manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE) 100 | 101 | # Verify it was updated 102 | assert manager._safety_modes[ClientType.DATABASE] == SafetyMode.UNSAFE 103 | 104 | # Change it back to SAFE 105 | manager.set_safety_mode(ClientType.DATABASE, SafetyMode.SAFE) 106 | 107 | # Verify it was updated again 108 | assert manager._safety_modes[ClientType.DATABASE] == SafetyMode.SAFE 109 | 110 | def test_validate_operation_allowed(self): 111 | """Test validating an operation that is allowed.""" 112 | manager = SafetyManager.get_instance() 113 | mock_config = MockSafetyConfig() 114 | 115 | # Register the config 116 | manager.register_config(ClientType.DATABASE, mock_config) 117 | 118 | # Set safety mode to SAFE 119 | manager.set_safety_mode(ClientType.DATABASE, SafetyMode.SAFE) 120 | 121 | # Validate a low risk operation (should be allowed in SAFE mode) 122 | # This should not raise an exception 123 | manager.validate_operation(ClientType.DATABASE, "low_risk") 124 | 125 | # Set safety mode to UNSAFE 126 | manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE) 127 | 128 | # Validate medium risk operation (should be allowed in UNSAFE mode) 129 | # This should not raise an exception 130 | manager.validate_operation(ClientType.DATABASE, "medium_risk") 131 | 132 | # High risk operations require confirmation, so we test with confirmation=True 133 | manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=True) 134 | 135 | def test_validate_operation_not_allowed(self): 136 | """Test validating an operation that is not allowed.""" 137 | manager = SafetyManager.get_instance() 138 | mock_config = MockSafetyConfig() 139 | 140 | # Register the config 141 | manager.register_config(ClientType.DATABASE, mock_config) 142 | 143 | # Set safety mode to SAFE 144 | manager.set_safety_mode(ClientType.DATABASE, SafetyMode.SAFE) 145 | 146 | # Validate medium risk operation (should not be allowed in SAFE mode) 147 | with pytest.raises(OperationNotAllowedError): 148 | manager.validate_operation(ClientType.DATABASE, "medium_risk") 149 | 150 | # Validate high risk operation (should not be allowed in SAFE mode) 151 | with pytest.raises(OperationNotAllowedError): 152 | manager.validate_operation(ClientType.DATABASE, "high_risk") 153 | 154 | # Validate extreme risk operation (should not be allowed in SAFE mode) 155 | with pytest.raises(OperationNotAllowedError): 156 | manager.validate_operation(ClientType.DATABASE, "extreme_risk") 157 | 158 | def test_validate_operation_requires_confirmation(self): 159 | """Test validating an operation that requires confirmation.""" 160 | manager = SafetyManager.get_instance() 161 | mock_config = MockSafetyConfig() 162 | 163 | # Register the config 164 | manager.register_config(ClientType.DATABASE, mock_config) 165 | 166 | # Set safety mode to UNSAFE 167 | manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE) 168 | 169 | # Validate high risk operation without confirmation 170 | # Should raise ConfirmationRequiredError 171 | with pytest.raises(ConfirmationRequiredError): 172 | manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=False) 173 | 174 | # Extreme risk operations are not allowed even in UNSAFE mode 175 | with pytest.raises(OperationNotAllowedError): 176 | manager.validate_operation(ClientType.DATABASE, "extreme_risk", has_confirmation=False) 177 | 178 | # Even with confirmation, extreme risk operations are not allowed 179 | with pytest.raises(OperationNotAllowedError): 180 | manager.validate_operation(ClientType.DATABASE, "extreme_risk", has_confirmation=True) 181 | 182 | def test_store_confirmation(self): 183 | """Test storing a confirmation for an operation.""" 184 | manager = SafetyManager.get_instance() 185 | 186 | # Store a confirmation 187 | confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME) 188 | 189 | # Verify that a confirmation ID is returned 190 | assert confirmation_id is not None 191 | assert confirmation_id.startswith("conf_") 192 | 193 | # Verify that the confirmation can be retrieved 194 | confirmation = manager._get_confirmation(confirmation_id) 195 | assert confirmation is not None 196 | assert confirmation["operation"] == "test_operation" 197 | assert confirmation["client_type"] == ClientType.DATABASE 198 | assert confirmation["risk_level"] == OperationRiskLevel.EXTREME 199 | assert "timestamp" in confirmation 200 | 201 | def test_get_confirmation_valid(self): 202 | """Test getting a valid confirmation.""" 203 | manager = SafetyManager.get_instance() 204 | 205 | # Store a confirmation 206 | confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME) 207 | 208 | # Retrieve the confirmation 209 | confirmation = manager._get_confirmation(confirmation_id) 210 | 211 | # Verify it matches what was stored 212 | assert confirmation is not None 213 | assert confirmation["operation"] == "test_operation" 214 | assert confirmation["client_type"] == ClientType.DATABASE 215 | assert confirmation["risk_level"] == OperationRiskLevel.EXTREME 216 | 217 | def test_get_confirmation_invalid(self): 218 | """Test getting an invalid confirmation.""" 219 | manager = SafetyManager.get_instance() 220 | 221 | # Try to retrieve a confirmation with an invalid ID 222 | confirmation = manager._get_confirmation("invalid_id") 223 | 224 | # Verify that None is returned 225 | assert confirmation is None 226 | 227 | def test_get_confirmation_expired(self): 228 | """Test getting an expired confirmation.""" 229 | manager = SafetyManager.get_instance() 230 | 231 | # Store a confirmation with a past expiration time 232 | confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME) 233 | 234 | # Manually set the timestamp to be older than the expiry time 235 | manager._pending_confirmations[confirmation_id]["timestamp"] = time.time() - manager._confirmation_expiry - 10 236 | 237 | # Try to retrieve the confirmation 238 | confirmation = manager._get_confirmation(confirmation_id) 239 | 240 | # Verify that None is returned 241 | assert confirmation is None 242 | 243 | def test_cleanup_expired_confirmations(self): 244 | """Test cleaning up expired confirmations.""" 245 | manager = SafetyManager.get_instance() 246 | 247 | # Store multiple confirmations with different expiration times 248 | valid_id = manager._store_confirmation(ClientType.DATABASE, "valid_operation", OperationRiskLevel.EXTREME) 249 | 250 | expired_id = manager._store_confirmation(ClientType.DATABASE, "expired_operation", OperationRiskLevel.EXTREME) 251 | 252 | # Manually set the timestamp of the expired confirmation to be older than the expiry time 253 | manager._pending_confirmations[expired_id]["timestamp"] = time.time() - manager._confirmation_expiry - 10 254 | 255 | # Call cleanup 256 | manager._cleanup_expired_confirmations() 257 | 258 | # Verify that expired confirmations are removed and valid ones remain 259 | assert valid_id in manager._pending_confirmations 260 | assert expired_id not in manager._pending_confirmations 261 | 262 | def test_get_stored_operation(self): 263 | """Test getting a stored operation.""" 264 | manager = SafetyManager.get_instance() 265 | 266 | # Store a confirmation for an operation 267 | confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME) 268 | 269 | # Retrieve the operation 270 | operation = manager.get_stored_operation(confirmation_id) 271 | 272 | # Verify that the retrieved operation matches the original 273 | assert operation == "test_operation" 274 | 275 | # Test with an invalid ID 276 | assert manager.get_stored_operation("invalid_id") is None 277 | 278 | def test_integration_validate_and_confirm(self): 279 | """Test the full flow of validating an operation that requires confirmation and then confirming it.""" 280 | manager = SafetyManager.get_instance() 281 | mock_config = MockSafetyConfig() 282 | 283 | # Register the config 284 | manager.register_config(ClientType.DATABASE, mock_config) 285 | 286 | # Set safety mode to UNSAFE 287 | manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE) 288 | 289 | # Try to validate a high risk operation and catch the ConfirmationRequiredError 290 | confirmation_id = None 291 | try: 292 | manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=False) 293 | except ConfirmationRequiredError as e: 294 | # Extract the confirmation ID from the error message 295 | error_message = str(e) 296 | # Find the confirmation ID in the message 297 | import re 298 | 299 | match = re.search(r"ID: (conf_[a-f0-9]+)", error_message) 300 | if match: 301 | confirmation_id = match.group(1) 302 | 303 | # Verify that we got a confirmation ID 304 | assert confirmation_id is not None 305 | 306 | # Now validate the operation again with the confirmation ID 307 | # This should not raise an exception 308 | manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=True) 309 | ``` -------------------------------------------------------------------------------- /supabase_mcp/services/database/sql/validator.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any 2 | 3 | from pglast.parser import ParseError, parse_sql 4 | 5 | from supabase_mcp.exceptions import ValidationError 6 | from supabase_mcp.logger import logger 7 | from supabase_mcp.services.database.sql.models import ( 8 | QueryValidationResults, 9 | SQLQueryCategory, 10 | SQLQueryCommand, 11 | ValidatedStatement, 12 | ) 13 | from supabase_mcp.services.safety.safety_configs import SQLSafetyConfig 14 | 15 | 16 | class SQLValidator: 17 | """SQL validator class that is based on pglast library. 18 | 19 | Responsible for: 20 | - SQL query syntax validation 21 | - SQL query categorization""" 22 | 23 | # Mapping from statement types to object types 24 | STATEMENT_TYPE_TO_OBJECT_TYPE = { 25 | "CreateFunctionStmt": "function", 26 | "ViewStmt": "view", 27 | "CreateTableAsStmt": "materialized_view", # When relkind is 'm', otherwise 'table' 28 | "CreateEnumStmt": "type", 29 | "CreateTypeStmt": "type", 30 | "CreateExtensionStmt": "extension", 31 | "CreateForeignTableStmt": "foreign_table", 32 | "CreatePolicyStmt": "policy", 33 | "CreateTrigStmt": "trigger", 34 | "IndexStmt": "index", 35 | "CreateStmt": "table", 36 | "AlterTableStmt": "table", 37 | "GrantStmt": "privilege", 38 | "RevokeStmt": "privilege", 39 | "CreateProcStmt": "procedure", # For CREATE PROCEDURE 40 | } 41 | 42 | def __init__(self, safety_config: SQLSafetyConfig | None = None) -> None: 43 | self.safety_config = safety_config or SQLSafetyConfig() 44 | 45 | def validate_schema_name(self, schema_name: str) -> str: 46 | """Validate schema name. 47 | 48 | Rules: 49 | - Must be a string 50 | - Cannot be empty 51 | - Cannot contain spaces or special characters 52 | """ 53 | if not schema_name.strip(): 54 | raise ValidationError("Schema name cannot be empty") 55 | if " " in schema_name: 56 | raise ValidationError("Schema name cannot contain spaces") 57 | return schema_name 58 | 59 | def validate_table_name(self, table: str) -> str: 60 | """Validate table name. 61 | 62 | Rules: 63 | - Must be a string 64 | - Cannot be empty 65 | - Cannot contain spaces or special characters 66 | """ 67 | if not table.strip(): 68 | raise ValidationError("Table name cannot be empty") 69 | if " " in table: 70 | raise ValidationError("Table name cannot contain spaces") 71 | return table 72 | 73 | def basic_query_validation(self, query: str) -> str: 74 | """Validate SQL query. 75 | 76 | Rules: 77 | - Must be a string 78 | - Cannot be empty 79 | """ 80 | if not query.strip(): 81 | raise ValidationError("Query cannot be empty") 82 | return query 83 | 84 | @classmethod 85 | def validate_transaction_control(cls, query: str) -> bool: 86 | """Check if the query contains transaction control statements. 87 | 88 | Args: 89 | query: SQL query string 90 | 91 | Returns: 92 | bool: True if the query contains any transaction control statements 93 | """ 94 | return any(x in query.upper() for x in ["BEGIN", "COMMIT", "ROLLBACK"]) 95 | 96 | def validate_query(self, sql_query: str) -> QueryValidationResults: 97 | """ 98 | Identify the type of SQL query using PostgreSQL's parser. 99 | 100 | Args: 101 | sql_query: A SQL query string to parse 102 | 103 | Returns: 104 | QueryValidationResults: A validation result object containing information about the SQL statements 105 | Raises: 106 | ValidationError: If the query is not valid or contains TCL statements 107 | """ 108 | try: 109 | # Validate raw input 110 | sql_query = self.basic_query_validation(sql_query) 111 | 112 | # Parse the SQL using PostgreSQL's parser 113 | parse_tree = parse_sql(sql_query) 114 | if parse_tree is None: 115 | logger.debug("No statements found in the query") 116 | # logger.debug(f"Parse tree generated with {parse_tree} statements") 117 | 118 | # Validate statements 119 | result = self.validate_statements(original_query=sql_query, parse_tree=parse_tree) 120 | 121 | # Check if the query contains transaction control statements and reject them 122 | for statement in result.statements: 123 | if statement.category == SQLQueryCategory.TCL: 124 | logger.warning(f"Transaction control statement detected: {statement.command}") 125 | raise ValidationError( 126 | "Transaction control statements (BEGIN, COMMIT, ROLLBACK) are not allowed. " 127 | "Queries will be automatically wrapped in transactions by the system." 128 | ) 129 | 130 | return result 131 | except ParseError as e: 132 | logger.exception(f"SQL syntax error: {str(e)}") 133 | raise ValidationError(f"SQL syntax error: {str(e)}") from e 134 | except ValidationError: 135 | # let it propagate 136 | raise 137 | except Exception as e: 138 | logger.exception(f"Unexpected error during SQL validation: {str(e)}") 139 | raise ValidationError(f"Unexpected error during SQL validation: {str(e)}") from e 140 | 141 | def _map_to_command(self, stmt_type: str) -> SQLQueryCommand: 142 | """Map a pglast statement type to our SQLQueryCommand enum.""" 143 | 144 | mapping = { 145 | # DQL Commands 146 | "SelectStmt": SQLQueryCommand.SELECT, 147 | # DML Commands 148 | "InsertStmt": SQLQueryCommand.INSERT, 149 | "UpdateStmt": SQLQueryCommand.UPDATE, 150 | "DeleteStmt": SQLQueryCommand.DELETE, 151 | "MergeStmt": SQLQueryCommand.MERGE, 152 | # DDL Commands 153 | "CreateStmt": SQLQueryCommand.CREATE, 154 | "CreateTableAsStmt": SQLQueryCommand.CREATE, 155 | "CreateSchemaStmt": SQLQueryCommand.CREATE, 156 | "CreateExtensionStmt": SQLQueryCommand.CREATE, 157 | "CreateFunctionStmt": SQLQueryCommand.CREATE, 158 | "CreateTrigStmt": SQLQueryCommand.CREATE, 159 | "ViewStmt": SQLQueryCommand.CREATE, 160 | "IndexStmt": SQLQueryCommand.CREATE, 161 | # Additional DDL Commands 162 | "CreateEnumStmt": SQLQueryCommand.CREATE, 163 | "CreateTypeStmt": SQLQueryCommand.CREATE, 164 | "CreateDomainStmt": SQLQueryCommand.CREATE, 165 | "CreateSeqStmt": SQLQueryCommand.CREATE, 166 | "CreateForeignTableStmt": SQLQueryCommand.CREATE, 167 | "CreatePolicyStmt": SQLQueryCommand.CREATE, 168 | "CreateCastStmt": SQLQueryCommand.CREATE, 169 | "CreateOpClassStmt": SQLQueryCommand.CREATE, 170 | "CreateOpFamilyStmt": SQLQueryCommand.CREATE, 171 | "AlterTableStmt": SQLQueryCommand.ALTER, 172 | "AlterDomainStmt": SQLQueryCommand.ALTER, 173 | "AlterEnumStmt": SQLQueryCommand.ALTER, 174 | "AlterSeqStmt": SQLQueryCommand.ALTER, 175 | "AlterOwnerStmt": SQLQueryCommand.ALTER, 176 | "AlterObjectSchemaStmt": SQLQueryCommand.ALTER, 177 | "DropStmt": SQLQueryCommand.DROP, 178 | "TruncateStmt": SQLQueryCommand.TRUNCATE, 179 | "CommentStmt": SQLQueryCommand.COMMENT, 180 | "RenameStmt": SQLQueryCommand.RENAME, 181 | # DCL Commands 182 | "GrantStmt": SQLQueryCommand.GRANT, 183 | "GrantRoleStmt": SQLQueryCommand.GRANT, 184 | "RevokeStmt": SQLQueryCommand.REVOKE, 185 | "RevokeRoleStmt": SQLQueryCommand.REVOKE, 186 | "CreateRoleStmt": SQLQueryCommand.CREATE, 187 | "AlterRoleStmt": SQLQueryCommand.ALTER, 188 | "DropRoleStmt": SQLQueryCommand.DROP, 189 | # TCL Commands 190 | "TransactionStmt": SQLQueryCommand.BEGIN, # Will need refinement for different transaction types 191 | # PostgreSQL-specific Commands 192 | "VacuumStmt": SQLQueryCommand.VACUUM, 193 | "ExplainStmt": SQLQueryCommand.EXPLAIN, 194 | "CopyStmt": SQLQueryCommand.COPY, 195 | "ListenStmt": SQLQueryCommand.LISTEN, 196 | "NotifyStmt": SQLQueryCommand.NOTIFY, 197 | "PrepareStmt": SQLQueryCommand.PREPARE, 198 | "ExecuteStmt": SQLQueryCommand.EXECUTE, 199 | "DeallocateStmt": SQLQueryCommand.DEALLOCATE, 200 | } 201 | 202 | # Try to map the statement type, default to UNKNOWN 203 | return mapping.get(stmt_type, SQLQueryCommand.UNKNOWN) 204 | 205 | def validate_statements(self, original_query: str, parse_tree: Any) -> QueryValidationResults: 206 | """Validate the statements in the parse tree. 207 | 208 | Args: 209 | parse_tree: The parse tree to validate 210 | 211 | Returns: 212 | SQLBatchValidationResult: A validation result object containing information about the SQL statements 213 | Raises: 214 | ValidationError: If the query is not valid 215 | """ 216 | result = QueryValidationResults(original_query=original_query) 217 | 218 | if parse_tree is None: 219 | return result 220 | 221 | try: 222 | for stmt in parse_tree: 223 | if not hasattr(stmt, "stmt"): 224 | continue 225 | 226 | stmt_node = stmt.stmt 227 | stmt_type = stmt_node.__class__.__name__ 228 | logger.debug(f"Processing statement node type: {stmt_type}") 229 | # logger.debug(f"DEBUGGING stmt_node: {stmt_node}") 230 | logger.debug(f"DEBUGGING stmt_node.stmt_location: {stmt.stmt_location}") 231 | 232 | # Extract the object type if available 233 | object_type = None 234 | schema_name = None 235 | if hasattr(stmt_node, "relation") and stmt_node.relation is not None: 236 | if hasattr(stmt_node.relation, "relname"): 237 | object_type = stmt_node.relation.relname 238 | if hasattr(stmt_node.relation, "schemaname"): 239 | schema_name = stmt_node.relation.schemaname 240 | # For statements with 'relations' list (like TRUNCATE) 241 | elif hasattr(stmt_node, "relations") and stmt_node.relations: 242 | for relation in stmt_node.relations: 243 | if hasattr(relation, "relname"): 244 | object_type = relation.relname 245 | if hasattr(relation, "schemaname"): 246 | schema_name = relation.schemaname 247 | break 248 | 249 | # Simple approach: Set object_type based on statement type if not already set 250 | if object_type is None and stmt_type in self.STATEMENT_TYPE_TO_OBJECT_TYPE: 251 | object_type = self.STATEMENT_TYPE_TO_OBJECT_TYPE[stmt_type] 252 | 253 | # Default schema to public if not set 254 | if schema_name is None: 255 | schema_name = "public" 256 | 257 | # Get classification for this statement type 258 | classification = self.safety_config.classify_statement(stmt_type, stmt_node) 259 | logger.debug( 260 | f"Statement category classified as: {classification.get('category', 'UNKNOWN')} - risk level: {classification.get('risk_level', 'UNKNOWN')}" 261 | ) 262 | logger.debug(f"DEBUGGING QUERY EXTRACTION LOCATION: {stmt.stmt_location} - {stmt.stmt_len}") 263 | 264 | # Create validation result 265 | query_result = ValidatedStatement( 266 | category=classification["category"], 267 | command=self._map_to_command(stmt_type), 268 | risk_level=classification["risk_level"], 269 | needs_migration=classification["needs_migration"], 270 | object_type=object_type, 271 | schema_name=schema_name, 272 | query=original_query[stmt.stmt_location : stmt.stmt_location + stmt.stmt_len] 273 | if hasattr(stmt, "stmt_location") and hasattr(stmt, "stmt_len") 274 | else None, 275 | ) 276 | # logger.debug(f"Isolated query: {query_result.query}") 277 | logger.debug( 278 | "Query validation result:", 279 | { 280 | "statement_category": query_result.category, 281 | "risk_level": query_result.risk_level, 282 | "needs_migration": query_result.needs_migration, 283 | "object_type": query_result.object_type, 284 | "schema_name": query_result.schema_name, 285 | "query": query_result.query, 286 | }, 287 | ) 288 | 289 | # Add result to the batch 290 | result.statements.append(query_result) 291 | 292 | # Update highest risk level 293 | if query_result.risk_level > result.highest_risk_level: 294 | result.highest_risk_level = query_result.risk_level 295 | logger.debug(f"Updated batch validation result to: {query_result.risk_level}") 296 | if len(result.statements) == 0: 297 | logger.debug("No valid statements found in the query") 298 | raise ValidationError("No queries were parsed - please check correctness of your query") 299 | logger.debug( 300 | f"Validated a total of {len(result.statements)} with the highest risk level of: {result.highest_risk_level}" 301 | ) 302 | return result 303 | 304 | except AttributeError as e: 305 | # Handle attempting to access missing attributes in the parse tree 306 | raise ValidationError(f"Error accessing parse tree structure: {str(e)}") from e 307 | except KeyError as e: 308 | # Handle missing keys in classification dictionary 309 | raise ValidationError(f"Missing classification key: {str(e)}") from e 310 | ``` -------------------------------------------------------------------------------- /tests/services/database/test_postgres_client.py: -------------------------------------------------------------------------------- ```python 1 | import asyncpg 2 | import pytest 3 | from unittest.mock import AsyncMock, MagicMock, patch 4 | 5 | from supabase_mcp.exceptions import ConnectionError, QueryError, PermissionError as SupabasePermissionError 6 | from supabase_mcp.services.database.postgres_client import PostgresClient, QueryResult, StatementResult 7 | from supabase_mcp.services.database.sql.validator import ( 8 | QueryValidationResults, 9 | SQLQueryCategory, 10 | SQLQueryCommand, 11 | ValidatedStatement, 12 | ) 13 | from supabase_mcp.services.safety.models import OperationRiskLevel 14 | from supabase_mcp.settings import Settings 15 | 16 | 17 | @pytest.mark.asyncio(loop_scope="class") 18 | class TestPostgresClient: 19 | """Unit tests for the Postgres client.""" 20 | 21 | @pytest.fixture 22 | def mock_settings(self): 23 | """Create mock settings for testing.""" 24 | settings = MagicMock(spec=Settings) 25 | settings.supabase_project_ref = "test-project-ref" 26 | settings.supabase_db_password = "test-password" 27 | settings.supabase_region = "us-east-1" 28 | settings.database_url = "postgresql://test:test@localhost:5432/test" 29 | return settings 30 | 31 | @pytest.fixture 32 | async def mock_postgres_client(self, mock_settings): 33 | """Create a mock Postgres client for testing.""" 34 | # Reset the singleton first 35 | await PostgresClient.reset() 36 | 37 | # Create client and mock execute_query directly 38 | client = PostgresClient(settings=mock_settings) 39 | return client 40 | 41 | async def test_execute_simple_select(self, mock_postgres_client: PostgresClient): 42 | """Test executing a simple SELECT query.""" 43 | # Create a simple validation result with a SELECT query 44 | query = "SELECT 1 as number;" 45 | statement = ValidatedStatement( 46 | query=query, 47 | command=SQLQueryCommand.SELECT, 48 | category=SQLQueryCategory.DQL, 49 | risk_level=OperationRiskLevel.LOW, 50 | needs_migration=False, 51 | object_type=None, 52 | schema_name=None, 53 | ) 54 | validation_result = QueryValidationResults( 55 | statements=[statement], 56 | original_query=query, 57 | highest_risk_level=OperationRiskLevel.LOW, 58 | ) 59 | 60 | # Mock the query result 61 | expected_result = QueryResult(results=[ 62 | StatementResult(rows=[{"number": 1}]) 63 | ]) 64 | 65 | with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): 66 | # Execute the query 67 | result = await mock_postgres_client.execute_query(validation_result) 68 | 69 | # Verify the result 70 | assert isinstance(result, QueryResult) 71 | assert len(result.results) == 1 72 | assert isinstance(result.results[0], StatementResult) 73 | assert len(result.results[0].rows) == 1 74 | assert result.results[0].rows[0]["number"] == 1 75 | 76 | async def test_execute_multiple_statements(self, mock_postgres_client: PostgresClient): 77 | """Test executing multiple SQL statements in a single query.""" 78 | # Create validation result with multiple statements 79 | query = "SELECT 1 as first; SELECT 2 as second;" 80 | statements = [ 81 | ValidatedStatement( 82 | query="SELECT 1 as first;", 83 | command=SQLQueryCommand.SELECT, 84 | category=SQLQueryCategory.DQL, 85 | risk_level=OperationRiskLevel.LOW, 86 | needs_migration=False, 87 | object_type=None, 88 | schema_name=None, 89 | ), 90 | ValidatedStatement( 91 | query="SELECT 2 as second;", 92 | command=SQLQueryCommand.SELECT, 93 | category=SQLQueryCategory.DQL, 94 | risk_level=OperationRiskLevel.LOW, 95 | needs_migration=False, 96 | object_type=None, 97 | schema_name=None, 98 | ), 99 | ] 100 | validation_result = QueryValidationResults( 101 | statements=statements, 102 | original_query=query, 103 | highest_risk_level=OperationRiskLevel.LOW, 104 | ) 105 | 106 | # Mock the query result 107 | expected_result = QueryResult(results=[ 108 | StatementResult(rows=[{"first": 1}]), 109 | StatementResult(rows=[{"second": 2}]) 110 | ]) 111 | 112 | with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): 113 | # Execute the query 114 | result = await mock_postgres_client.execute_query(validation_result) 115 | 116 | # Verify the result 117 | assert isinstance(result, QueryResult) 118 | assert len(result.results) == 2 119 | assert result.results[0].rows[0]["first"] == 1 120 | assert result.results[1].rows[0]["second"] == 2 121 | 122 | async def test_execute_query_with_parameters(self, mock_postgres_client: PostgresClient): 123 | """Test executing a query with parameters.""" 124 | query = "SELECT 'test' as name, 42 as value;" 125 | statement = ValidatedStatement( 126 | query=query, 127 | command=SQLQueryCommand.SELECT, 128 | category=SQLQueryCategory.DQL, 129 | risk_level=OperationRiskLevel.LOW, 130 | needs_migration=False, 131 | object_type=None, 132 | schema_name=None, 133 | ) 134 | validation_result = QueryValidationResults( 135 | statements=[statement], 136 | original_query=query, 137 | highest_risk_level=OperationRiskLevel.LOW, 138 | ) 139 | 140 | # Mock the query result 141 | expected_result = QueryResult(results=[ 142 | StatementResult(rows=[{"name": "test", "value": 42}]) 143 | ]) 144 | 145 | with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): 146 | # Execute the query 147 | result = await mock_postgres_client.execute_query(validation_result) 148 | 149 | # Verify the result 150 | assert isinstance(result, QueryResult) 151 | assert len(result.results) == 1 152 | assert result.results[0].rows[0]["name"] == "test" 153 | assert result.results[0].rows[0]["value"] == 42 154 | 155 | async def test_permission_error(self, mock_postgres_client: PostgresClient): 156 | """Test handling a permission error.""" 157 | # Create a mock error 158 | error = asyncpg.exceptions.InsufficientPrivilegeError("Permission denied") 159 | 160 | # Verify that the method raises PermissionError with the expected message 161 | with pytest.raises(SupabasePermissionError) as exc_info: 162 | await mock_postgres_client._handle_postgres_error(error) 163 | 164 | # Verify the error message 165 | assert "Access denied" in str(exc_info.value) 166 | assert "Permission denied" in str(exc_info.value) 167 | assert "live_dangerously" in str(exc_info.value) 168 | 169 | async def test_query_error(self, mock_postgres_client: PostgresClient): 170 | """Test handling a query error.""" 171 | # Create a validation result with a syntactically valid but semantically incorrect query 172 | query = "SELECT * FROM nonexistent_table;" 173 | statement = ValidatedStatement( 174 | query=query, 175 | command=SQLQueryCommand.SELECT, 176 | category=SQLQueryCategory.DQL, 177 | risk_level=OperationRiskLevel.LOW, 178 | needs_migration=False, 179 | object_type="TABLE", 180 | schema_name="public", 181 | ) 182 | validation_result = QueryValidationResults( 183 | statements=[statement], 184 | original_query=query, 185 | highest_risk_level=OperationRiskLevel.LOW, 186 | ) 187 | 188 | # Mock execute_query to raise a QueryError 189 | with patch.object(mock_postgres_client, 'execute_query', 190 | side_effect=QueryError("relation \"nonexistent_table\" does not exist")): 191 | # Execute the query - should raise a QueryError 192 | with pytest.raises(QueryError) as excinfo: 193 | await mock_postgres_client.execute_query(validation_result) 194 | 195 | # Verify the error message contains the specific error 196 | assert "nonexistent_table" in str(excinfo.value) 197 | 198 | async def test_schema_error(self, mock_postgres_client: PostgresClient): 199 | """Test handling a schema error.""" 200 | # Create a validation result with a query referencing a non-existent column 201 | query = "SELECT nonexistent_column FROM information_schema.tables;" 202 | statement = ValidatedStatement( 203 | query=query, 204 | command=SQLQueryCommand.SELECT, 205 | category=SQLQueryCategory.DQL, 206 | risk_level=OperationRiskLevel.LOW, 207 | needs_migration=False, 208 | object_type="TABLE", 209 | schema_name="information_schema", 210 | ) 211 | validation_result = QueryValidationResults( 212 | statements=[statement], 213 | original_query=query, 214 | highest_risk_level=OperationRiskLevel.LOW, 215 | ) 216 | 217 | # Mock execute_query to raise a QueryError 218 | with patch.object(mock_postgres_client, 'execute_query', 219 | side_effect=QueryError("column \"nonexistent_column\" does not exist")): 220 | # Execute the query - should raise a QueryError 221 | with pytest.raises(QueryError) as excinfo: 222 | await mock_postgres_client.execute_query(validation_result) 223 | 224 | # Verify the error message contains the specific error 225 | assert "nonexistent_column" in str(excinfo.value) 226 | 227 | async def test_write_operation(self, mock_postgres_client: PostgresClient): 228 | """Test a basic write operation (INSERT).""" 229 | # Create insert query 230 | insert_query = "INSERT INTO test_write (name) VALUES ('test_value') RETURNING id, name;" 231 | insert_statement = ValidatedStatement( 232 | query=insert_query, 233 | command=SQLQueryCommand.INSERT, 234 | category=SQLQueryCategory.DML, 235 | risk_level=OperationRiskLevel.MEDIUM, 236 | needs_migration=False, 237 | object_type="TABLE", 238 | schema_name="public", 239 | ) 240 | insert_validation = QueryValidationResults( 241 | statements=[insert_statement], 242 | original_query=insert_query, 243 | highest_risk_level=OperationRiskLevel.MEDIUM, 244 | ) 245 | 246 | # Mock the query result 247 | expected_result = QueryResult(results=[ 248 | StatementResult(rows=[{"id": 1, "name": "test_value"}]) 249 | ]) 250 | 251 | with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): 252 | # Execute the insert query 253 | result = await mock_postgres_client.execute_query(insert_validation, readonly=False) 254 | 255 | # Verify the result 256 | assert isinstance(result, QueryResult) 257 | assert len(result.results) == 1 258 | assert result.results[0].rows[0]["name"] == "test_value" 259 | assert result.results[0].rows[0]["id"] == 1 260 | 261 | async def test_ddl_operation(self, mock_postgres_client: PostgresClient): 262 | """Test a basic DDL operation (CREATE TABLE).""" 263 | # Create a test table 264 | create_query = "CREATE TEMPORARY TABLE test_ddl (id SERIAL PRIMARY KEY, value TEXT);" 265 | create_statement = ValidatedStatement( 266 | query=create_query, 267 | command=SQLQueryCommand.CREATE, 268 | category=SQLQueryCategory.DDL, 269 | risk_level=OperationRiskLevel.MEDIUM, 270 | needs_migration=False, 271 | object_type="TABLE", 272 | schema_name="public", 273 | ) 274 | create_validation = QueryValidationResults( 275 | statements=[create_statement], 276 | original_query=create_query, 277 | highest_risk_level=OperationRiskLevel.MEDIUM, 278 | ) 279 | 280 | # Mock the query result - DDL typically returns empty results 281 | expected_result = QueryResult(results=[ 282 | StatementResult(rows=[]) 283 | ]) 284 | 285 | with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): 286 | # Execute the create table query 287 | result = await mock_postgres_client.execute_query(create_validation, readonly=False) 288 | 289 | # Verify the result 290 | assert isinstance(result, QueryResult) 291 | assert len(result.results) == 1 292 | # DDL operations typically don't return rows 293 | assert result.results[0].rows == [] 294 | 295 | async def test_execute_metadata_query(self, mock_postgres_client: PostgresClient): 296 | """Test executing a metadata query.""" 297 | # Create a simple validation result with a SELECT query 298 | query = "SELECT schema_name FROM information_schema.schemata LIMIT 5;" 299 | statement = ValidatedStatement( 300 | query=query, 301 | command=SQLQueryCommand.SELECT, 302 | category=SQLQueryCategory.DQL, 303 | risk_level=OperationRiskLevel.LOW, 304 | needs_migration=False, 305 | object_type="schemata", 306 | schema_name="information_schema", 307 | ) 308 | validation_result = QueryValidationResults( 309 | statements=[statement], 310 | original_query=query, 311 | highest_risk_level=OperationRiskLevel.LOW, 312 | ) 313 | 314 | # Mock the query result 315 | expected_result = QueryResult(results=[ 316 | StatementResult(rows=[ 317 | {"schema_name": "public"}, 318 | {"schema_name": "information_schema"}, 319 | {"schema_name": "pg_catalog"}, 320 | {"schema_name": "auth"}, 321 | {"schema_name": "storage"} 322 | ]) 323 | ]) 324 | 325 | with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): 326 | # Execute the query 327 | result = await mock_postgres_client.execute_query(validation_result) 328 | 329 | # Verify the result 330 | assert isinstance(result, QueryResult) 331 | assert len(result.results) == 1 332 | assert len(result.results[0].rows) == 5 333 | assert "schema_name" in result.results[0].rows[0] 334 | 335 | async def test_connection_retry_mechanism(self, mock_postgres_client: PostgresClient): 336 | """Test that the tenacity retry mechanism works correctly for database connections.""" 337 | # Reset the pool 338 | mock_postgres_client._pool = None 339 | 340 | # Mock create_pool to always raise a connection error 341 | with patch.object(mock_postgres_client, 'create_pool', 342 | side_effect=ConnectionError("Could not connect to database")): 343 | # This should trigger the retry mechanism and eventually fail 344 | with pytest.raises(ConnectionError) as exc_info: 345 | await mock_postgres_client.ensure_pool() 346 | 347 | # Verify the error message indicates a connection failure after retries 348 | assert "Could not connect to database" in str(exc_info.value) ``` -------------------------------------------------------------------------------- /supabase_mcp/services/sdk/auth_admin_sdk_spec.py: -------------------------------------------------------------------------------- ```python 1 | def get_auth_admin_methods_spec() -> dict: 2 | """Returns a detailed specification of all Auth Admin methods.""" 3 | return { 4 | "get_user_by_id": { 5 | "description": "Retrieve a user by their ID", 6 | "parameters": {"uid": {"type": "string", "description": "The user's UUID", "required": True}}, 7 | "returns": {"type": "object", "description": "User object containing all user data"}, 8 | "example": { 9 | "request": {"uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"}, 10 | "response": { 11 | "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", 12 | "email": "[email protected]", 13 | "phone": "", 14 | "created_at": "2023-01-01T00:00:00Z", 15 | "confirmed_at": "2023-01-01T00:00:00Z", 16 | "last_sign_in_at": "2023-01-01T00:00:00Z", 17 | "user_metadata": {"name": "John Doe"}, 18 | "app_metadata": {}, 19 | }, 20 | }, 21 | }, 22 | "list_users": { 23 | "description": "List all users with pagination", 24 | "parameters": { 25 | "page": { 26 | "type": "integer", 27 | "description": "Page number (starts at 1)", 28 | "required": False, 29 | "default": 1, 30 | }, 31 | "per_page": { 32 | "type": "integer", 33 | "description": "Number of users per page", 34 | "required": False, 35 | "default": 50, 36 | }, 37 | }, 38 | "returns": {"type": "object", "description": "Paginated list of users with metadata"}, 39 | "example": { 40 | "request": {"page": 1, "per_page": 10}, 41 | "response": { 42 | "users": [ 43 | { 44 | "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", 45 | "email": "[email protected]", 46 | "user_metadata": {"name": "John Doe"}, 47 | } 48 | ], 49 | "aud": "authenticated", 50 | "total_count": 100, 51 | "next_page": 2, 52 | }, 53 | }, 54 | }, 55 | "create_user": { 56 | "description": "Create a new user. Does not send a confirmation email by default.", 57 | "parameters": { 58 | "email": {"type": "string", "description": "The user's email address"}, 59 | "password": {"type": "string", "description": "The user's password"}, 60 | "email_confirm": { 61 | "type": "boolean", 62 | "description": "Confirms the user's email address if set to true", 63 | "default": False, 64 | }, 65 | "phone": {"type": "string", "description": "The user's phone number with country code"}, 66 | "phone_confirm": { 67 | "type": "boolean", 68 | "description": "Confirms the user's phone number if set to true", 69 | "default": False, 70 | }, 71 | "user_metadata": { 72 | "type": "object", 73 | "description": "A custom data object to store the user's metadata", 74 | }, 75 | "app_metadata": { 76 | "type": "object", 77 | "description": "A custom data object to store the user's application specific metadata", 78 | }, 79 | "role": {"type": "string", "description": "The role claim set in the user's access token JWT"}, 80 | "ban_duration": {"type": "string", "description": "Determines how long a user is banned for"}, 81 | "nonce": { 82 | "type": "string", 83 | "description": "The nonce (required for reauthentication if updating password)", 84 | }, 85 | }, 86 | "returns": {"type": "object", "description": "Created user object"}, 87 | "example": { 88 | "request": { 89 | "email": "[email protected]", 90 | "password": "secure-password", 91 | "email_confirm": True, 92 | "user_metadata": {"name": "New User"}, 93 | }, 94 | "response": { 95 | "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", 96 | "email": "[email protected]", 97 | "email_confirmed_at": "2023-01-01T00:00:00Z", 98 | "user_metadata": {"name": "New User"}, 99 | }, 100 | }, 101 | "notes": "Either email or phone must be provided. Use invite_user_by_email() if you want to send an email invite.", 102 | }, 103 | "delete_user": { 104 | "description": "Delete a user by their ID. Requires a service_role key.", 105 | "parameters": { 106 | "id": {"type": "string", "description": "The user's UUID", "required": True}, 107 | "should_soft_delete": { 108 | "type": "boolean", 109 | "description": "If true, the user will be soft-deleted (preserving their data but disabling the account). Defaults to false.", 110 | "required": False, 111 | "default": False, 112 | }, 113 | }, 114 | "returns": {"type": "object", "description": "Success message"}, 115 | "example": { 116 | "request": {"id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"}, 117 | "response": {"message": "User deleted successfully"}, 118 | }, 119 | "notes": "This function should only be called on a server. Never expose your service_role key in the browser.", 120 | }, 121 | "invite_user_by_email": { 122 | "description": "Sends an invite link to a user's email address. Typically used by administrators to invite users to join the application.", 123 | "parameters": { 124 | "email": {"type": "string", "description": "The email address of the user", "required": True}, 125 | "options": { 126 | "type": "object", 127 | "description": "Optional settings for the invite", 128 | "required": False, 129 | "properties": { 130 | "data": { 131 | "type": "object", 132 | "description": "A custom data object to store additional metadata about the user. Maps to auth.users.user_metadata", 133 | "required": False, 134 | }, 135 | "redirect_to": { 136 | "type": "string", 137 | "description": "The URL which will be appended to the email link. Once clicked the user will end up on this URL", 138 | "required": False, 139 | }, 140 | }, 141 | }, 142 | }, 143 | "returns": {"type": "object", "description": "User object for the invited user"}, 144 | "example": { 145 | "request": { 146 | "email": "[email protected]", 147 | "options": {"data": {"name": "John Doe"}, "redirect_to": "https://example.com/welcome"}, 148 | }, 149 | "response": { 150 | "id": "a1a1a1a1-a1a1-a1a1-a1a1-a1a1a1a1a1a1", 151 | "email": "[email protected]", 152 | "role": "authenticated", 153 | "email_confirmed_at": None, 154 | "invited_at": "2023-01-01T00:00:00Z", 155 | }, 156 | }, 157 | "notes": "Note that PKCE is not supported when using invite_user_by_email. This is because the browser initiating the invite is often different from the browser accepting the invite.", 158 | }, 159 | "generate_link": { 160 | "description": "Generate an email link for various authentication purposes. Handles user creation for signup, invite and magiclink types.", 161 | "parameters": { 162 | "type": { 163 | "type": "string", 164 | "description": "Link type: 'signup', 'invite', 'magiclink', 'recovery', 'email_change_current', 'email_change_new', 'phone_change'", 165 | "required": True, 166 | "enum": [ 167 | "signup", 168 | "invite", 169 | "magiclink", 170 | "recovery", 171 | "email_change_current", 172 | "email_change_new", 173 | "phone_change", 174 | ], 175 | }, 176 | "email": {"type": "string", "description": "User's email address", "required": True}, 177 | "password": { 178 | "type": "string", 179 | "description": "User's password. Only required if type is signup", 180 | "required": False, 181 | }, 182 | "new_email": { 183 | "type": "string", 184 | "description": "New email address. Only required if type is email_change_current or email_change_new", 185 | "required": False, 186 | }, 187 | "options": { 188 | "type": "object", 189 | "description": "Additional options for the link", 190 | "required": False, 191 | "properties": { 192 | "data": { 193 | "type": "object", 194 | "description": "Custom JSON object containing user metadata. Only accepted if type is signup, invite, or magiclink", 195 | "required": False, 196 | }, 197 | "redirect_to": { 198 | "type": "string", 199 | "description": "A redirect URL which will be appended to the generated email link", 200 | "required": False, 201 | }, 202 | }, 203 | }, 204 | }, 205 | "returns": {"type": "object", "description": "Generated link details"}, 206 | "example": { 207 | "request": { 208 | "type": "signup", 209 | "email": "[email protected]", 210 | "password": "secure-password", 211 | "options": {"data": {"name": "John Doe"}, "redirect_to": "https://example.com/welcome"}, 212 | }, 213 | "response": { 214 | "action_link": "https://your-project.supabase.co/auth/v1/verify?token=...", 215 | "email_otp": "123456", 216 | "hashed_token": "...", 217 | "redirect_to": "https://example.com/welcome", 218 | "verification_type": "signup", 219 | }, 220 | }, 221 | "notes": "generate_link() only generates the email link for email_change_email if the Secure email change is enabled in your project's email auth provider settings.", 222 | }, 223 | "update_user_by_id": { 224 | "description": "Update user attributes by ID. Requires a service_role key.", 225 | "parameters": { 226 | "uid": {"type": "string", "description": "The user's UUID", "required": True}, 227 | "attributes": { 228 | "type": "object", 229 | "description": "The user attributes to update.", 230 | "required": True, 231 | "properties": { 232 | "email": {"type": "string", "description": "The user's email"}, 233 | "phone": {"type": "string", "description": "The user's phone"}, 234 | "password": {"type": "string", "description": "The user's password"}, 235 | "email_confirm": { 236 | "type": "boolean", 237 | "description": "Confirms the user's email address if set to true", 238 | }, 239 | "phone_confirm": { 240 | "type": "boolean", 241 | "description": "Confirms the user's phone number if set to true", 242 | }, 243 | "user_metadata": { 244 | "type": "object", 245 | "description": "A custom data object to store the user's metadata.", 246 | }, 247 | "app_metadata": { 248 | "type": "object", 249 | "description": "A custom data object to store the user's application specific metadata.", 250 | }, 251 | "role": { 252 | "type": "string", 253 | "description": "The role claim set in the user's access token JWT", 254 | }, 255 | "ban_duration": { 256 | "type": "string", 257 | "description": "Determines how long a user is banned for", 258 | }, 259 | "nonce": { 260 | "type": "string", 261 | "description": "The nonce sent for reauthentication if the user's password is to be updated", 262 | }, 263 | }, 264 | }, 265 | }, 266 | "returns": {"type": "object", "description": "Updated user object"}, 267 | "example": { 268 | "request": { 269 | "uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", 270 | "attributes": {"email": "[email protected]", "user_metadata": {"name": "Updated Name"}}, 271 | }, 272 | "response": { 273 | "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", 274 | "email": "[email protected]", 275 | "user_metadata": {"name": "Updated Name"}, 276 | }, 277 | }, 278 | "notes": "This function should only be called on a server. Never expose your service_role key in the browser.", 279 | }, 280 | "delete_factor": { 281 | "description": "Deletes a factor on a user. This will log the user out of all active sessions if the deleted factor was verified.", 282 | "parameters": { 283 | "user_id": { 284 | "type": "string", 285 | "description": "ID of the user whose factor is being deleted", 286 | "required": True, 287 | }, 288 | "id": {"type": "string", "description": "ID of the MFA factor to delete", "required": True}, 289 | }, 290 | "returns": {"type": "object", "description": "Success message"}, 291 | "example": { 292 | "request": {"user_id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "id": "totp-factor-id-123"}, 293 | "response": {"message": "Factor deleted successfully"}, 294 | }, 295 | "notes": "This will log the user out of all active sessions if the deleted factor was verified.", 296 | }, 297 | } 298 | ``` -------------------------------------------------------------------------------- /tests/services/sdk/test_sdk_client.py: -------------------------------------------------------------------------------- ```python 1 | import time 2 | import uuid 3 | from datetime import datetime 4 | from unittest.mock import AsyncMock, MagicMock, patch 5 | 6 | import pytest 7 | 8 | from supabase_mcp.clients.sdk_client import SupabaseSDKClient 9 | from supabase_mcp.exceptions import PythonSDKError 10 | from supabase_mcp.settings import Settings 11 | 12 | # Unique identifier for test users to avoid conflicts 13 | TEST_ID = f"test-{int(time.time())}-{uuid.uuid4().hex[:6]}" 14 | 15 | 16 | # Create unique test emails 17 | def get_test_email(prefix: str = "user"): 18 | """Generate a unique test email""" 19 | return f"a.zuev+{prefix}-{TEST_ID}@outlook.com" 20 | 21 | 22 | @pytest.mark.asyncio(loop_scope="module") 23 | class TestSDKClientIntegration: 24 | """ 25 | Unit tests for the SupabaseSDKClient. 26 | """ 27 | 28 | @pytest.fixture 29 | def mock_settings(self): 30 | """Create mock settings for testing.""" 31 | settings = MagicMock(spec=Settings) 32 | settings.supabase_project_ref = "test-project-ref" 33 | settings.supabase_service_role_key = "test-service-role-key" 34 | settings.supabase_region = "us-east-1" 35 | settings.supabase_url = "https://test-project-ref.supabase.co" 36 | return settings 37 | 38 | @pytest.fixture 39 | async def mock_sdk_client(self, mock_settings): 40 | """Create a mock SDK client for testing.""" 41 | # Reset singleton 42 | SupabaseSDKClient.reset() 43 | 44 | # Mock the Supabase client 45 | mock_supabase = MagicMock() 46 | mock_auth_admin = MagicMock() 47 | mock_supabase.auth.admin = mock_auth_admin 48 | 49 | # Mock the create_async_client function to return our mock client 50 | with patch('supabase_mcp.clients.sdk_client.create_async_client', return_value=mock_supabase): 51 | # Create client - this will now use our mocked create_async_client 52 | client = SupabaseSDKClient.get_instance(settings=mock_settings) 53 | # Manually set the client to ensure it's available 54 | client.client = mock_supabase 55 | 56 | return client 57 | 58 | async def test_list_users(self, mock_sdk_client: SupabaseSDKClient): 59 | """Test listing users with pagination""" 60 | # Mock user data 61 | mock_users = [ 62 | MagicMock(id="user1", email="[email protected]", user_metadata={}), 63 | MagicMock(id="user2", email="[email protected]", user_metadata={}) 64 | ] 65 | 66 | # Mock the list_users method as an async function 67 | mock_sdk_client.client.auth.admin.list_users = AsyncMock(return_value=mock_users) 68 | 69 | # Create test parameters 70 | list_params = {"page": 1, "per_page": 10} 71 | 72 | # List users 73 | result = await mock_sdk_client.call_auth_admin_method("list_users", list_params) 74 | 75 | # Verify response format 76 | assert result is not None 77 | assert hasattr(result, "__iter__") # Should be iterable (list of users) 78 | assert len(result) == 2 79 | 80 | # Check that the first user has expected attributes 81 | first_user = result[0] 82 | assert hasattr(first_user, "id") 83 | assert hasattr(first_user, "email") 84 | assert hasattr(first_user, "user_metadata") 85 | 86 | # Test with invalid parameters - mock the validation error 87 | mock_sdk_client.client.auth.admin.list_users = AsyncMock(side_effect=Exception("Bad Pagination Parameters")) 88 | 89 | invalid_params = {"page": -1, "per_page": 10} 90 | with pytest.raises(PythonSDKError) as excinfo: 91 | await mock_sdk_client.call_auth_admin_method("list_users", invalid_params) 92 | 93 | assert "Bad Pagination Parameters" in str(excinfo.value) 94 | 95 | async def test_get_user_by_id(self, mock_sdk_client: SupabaseSDKClient): 96 | """Test retrieving a user by ID""" 97 | # Mock user data 98 | test_email = get_test_email("get") 99 | user_id = str(uuid.uuid4()) 100 | 101 | mock_user = MagicMock( 102 | id=user_id, 103 | email=test_email, 104 | user_metadata={"name": "Test User", "test_id": TEST_ID} 105 | ) 106 | mock_response = MagicMock(user=mock_user) 107 | 108 | # Mock the get_user_by_id method as an async function 109 | mock_sdk_client.client.auth.admin.get_user_by_id = AsyncMock(return_value=mock_response) 110 | 111 | # Get the user by ID 112 | get_params = {"uid": user_id} 113 | get_result = await mock_sdk_client.call_auth_admin_method("get_user_by_id", get_params) 114 | 115 | # Verify user data 116 | assert get_result is not None 117 | assert hasattr(get_result, "user") 118 | assert get_result.user.id == user_id 119 | assert get_result.user.email == test_email 120 | 121 | # Test with invalid parameters (non-existent user ID) 122 | mock_sdk_client.client.auth.admin.get_user_by_id = AsyncMock(side_effect=Exception("user_id must be an UUID")) 123 | 124 | invalid_params = {"uid": "non-existent-user-id"} 125 | with pytest.raises(PythonSDKError) as excinfo: 126 | await mock_sdk_client.call_auth_admin_method("get_user_by_id", invalid_params) 127 | 128 | assert "user_id must be an UUID" in str(excinfo.value) 129 | 130 | async def test_create_user(self, mock_sdk_client: SupabaseSDKClient): 131 | """Test creating a new user""" 132 | # Create a new test user 133 | test_email = get_test_email("create") 134 | user_id = str(uuid.uuid4()) 135 | 136 | mock_user = MagicMock( 137 | id=user_id, 138 | email=test_email, 139 | user_metadata={"name": "Test User", "test_id": TEST_ID} 140 | ) 141 | mock_response = MagicMock(user=mock_user) 142 | 143 | # Mock the create_user method as an async function 144 | mock_sdk_client.client.auth.admin.create_user = AsyncMock(return_value=mock_response) 145 | 146 | create_params = { 147 | "email": test_email, 148 | "password": f"Password123!{TEST_ID}", 149 | "email_confirm": True, 150 | "user_metadata": {"name": "Test User", "test_id": TEST_ID}, 151 | } 152 | 153 | # Create the user 154 | create_result = await mock_sdk_client.call_auth_admin_method("create_user", create_params) 155 | assert create_result is not None 156 | assert hasattr(create_result, "user") 157 | assert hasattr(create_result.user, "id") 158 | assert create_result.user.id == user_id 159 | 160 | # Test with invalid parameters (missing required fields) 161 | mock_sdk_client.client.auth.admin.create_user = AsyncMock(side_effect=Exception("Invalid parameters")) 162 | 163 | invalid_params = {"user_metadata": {"name": "Invalid User"}} 164 | with pytest.raises(PythonSDKError) as excinfo: 165 | await mock_sdk_client.call_auth_admin_method("create_user", invalid_params) 166 | 167 | assert "Invalid parameters" in str(excinfo.value) 168 | 169 | async def test_update_user_by_id(self, mock_sdk_client: SupabaseSDKClient): 170 | """Test updating a user's attributes""" 171 | # Mock user data 172 | test_email = get_test_email("update") 173 | user_id = str(uuid.uuid4()) 174 | 175 | mock_user = MagicMock( 176 | id=user_id, 177 | email=test_email, 178 | user_metadata={"email": "[email protected]"} 179 | ) 180 | mock_response = MagicMock(user=mock_user) 181 | 182 | # Mock the update_user_by_id method as an async function 183 | mock_sdk_client.client.auth.admin.update_user_by_id = AsyncMock(return_value=mock_response) 184 | 185 | # Update the user 186 | update_params = { 187 | "uid": user_id, 188 | "attributes": { 189 | "user_metadata": { 190 | "email": "[email protected]", 191 | } 192 | }, 193 | } 194 | 195 | update_result = await mock_sdk_client.call_auth_admin_method("update_user_by_id", update_params) 196 | 197 | # Verify user was updated 198 | assert update_result is not None 199 | assert hasattr(update_result, "user") 200 | assert update_result.user.id == user_id 201 | assert update_result.user.user_metadata["email"] == "[email protected]" 202 | 203 | # Test with invalid parameters (non-existent user ID) 204 | mock_sdk_client.client.auth.admin.update_user_by_id = AsyncMock(side_effect=Exception("user_id must be an uuid")) 205 | 206 | invalid_params = { 207 | "uid": "non-existent-user-id", 208 | "attributes": {"user_metadata": {"name": "Invalid Update"}}, 209 | } 210 | with pytest.raises(PythonSDKError) as excinfo: 211 | await mock_sdk_client.call_auth_admin_method("update_user_by_id", invalid_params) 212 | 213 | assert "user_id must be an uuid" in str(excinfo.value).lower() 214 | 215 | async def test_delete_user(self, mock_sdk_client: SupabaseSDKClient): 216 | """Test deleting a user""" 217 | # Mock user data 218 | user_id = str(uuid.uuid4()) 219 | 220 | # Mock the delete_user method as an async function to return None (success) 221 | mock_sdk_client.client.auth.admin.delete_user = AsyncMock(return_value=None) 222 | 223 | # Delete the user 224 | delete_params = {"id": user_id} 225 | # The delete_user method returns None on success 226 | result = await mock_sdk_client.call_auth_admin_method("delete_user", delete_params) 227 | assert result is None 228 | 229 | # Test with invalid parameters (non-UUID format user ID) 230 | mock_sdk_client.client.auth.admin.delete_user = AsyncMock(side_effect=Exception("user_id must be an uuid")) 231 | 232 | invalid_params = {"id": "non-existent-user-id"} 233 | with pytest.raises(PythonSDKError) as excinfo: 234 | await mock_sdk_client.call_auth_admin_method("delete_user", invalid_params) 235 | 236 | assert "user_id must be an uuid" in str(excinfo.value).lower() 237 | 238 | async def test_invite_user_by_email(self, mock_sdk_client: SupabaseSDKClient): 239 | """Test inviting a user by email""" 240 | # Mock user data 241 | test_email = get_test_email("invite") 242 | user_id = str(uuid.uuid4()) 243 | 244 | mock_user = MagicMock( 245 | id=user_id, 246 | email=test_email, 247 | invited_at=datetime.now().isoformat() 248 | ) 249 | mock_response = MagicMock(user=mock_user) 250 | 251 | # Mock the invite_user_by_email method as an async function 252 | mock_sdk_client.client.auth.admin.invite_user_by_email = AsyncMock(return_value=mock_response) 253 | 254 | # Create invite parameters 255 | invite_params = { 256 | "email": test_email, 257 | "options": {"data": {"name": "Invited User", "test_id": TEST_ID, "invited_at": datetime.now().isoformat()}}, 258 | } 259 | 260 | # Invite the user 261 | result = await mock_sdk_client.call_auth_admin_method("invite_user_by_email", invite_params) 262 | 263 | # Verify response 264 | assert result is not None 265 | assert hasattr(result, "user") 266 | assert result.user.email == test_email 267 | assert hasattr(result.user, "invited_at") 268 | 269 | # Test with invalid parameters (missing email) 270 | mock_sdk_client.client.auth.admin.invite_user_by_email = AsyncMock(side_effect=Exception("Invalid parameters")) 271 | 272 | invalid_params = {"options": {"data": {"name": "Invalid Invite"}}} 273 | with pytest.raises(PythonSDKError) as excinfo: 274 | await mock_sdk_client.call_auth_admin_method("invite_user_by_email", invalid_params) 275 | 276 | assert "Invalid parameters" in str(excinfo.value) 277 | 278 | async def test_generate_link(self, mock_sdk_client: SupabaseSDKClient): 279 | """Test generating authentication links""" 280 | # Mock response for generate_link 281 | mock_properties = MagicMock(action_link="https://example.com/auth/link") 282 | mock_response = MagicMock(properties=mock_properties) 283 | 284 | # Mock the generate_link method as an async function 285 | mock_sdk_client.client.auth.admin.generate_link = AsyncMock(return_value=mock_response) 286 | 287 | # Test signup link 288 | link_params = { 289 | "type": "signup", 290 | "email": get_test_email("signup"), 291 | "password": f"Password123!{TEST_ID}", 292 | "options": { 293 | "data": {"name": "Signup User", "test_id": TEST_ID}, 294 | "redirect_to": "https://example.com/welcome", 295 | }, 296 | } 297 | 298 | # Generate link 299 | result = await mock_sdk_client.call_auth_admin_method("generate_link", link_params) 300 | 301 | # Verify response 302 | assert result is not None 303 | assert hasattr(result, "properties") 304 | assert hasattr(result.properties, "action_link") 305 | 306 | # Test with invalid parameters (invalid link type) 307 | mock_sdk_client.client.auth.admin.generate_link = AsyncMock(side_effect=Exception("Invalid parameters")) 308 | 309 | invalid_params = {"type": "invalid_type", "email": get_test_email("invalid")} 310 | with pytest.raises(PythonSDKError) as excinfo: 311 | await mock_sdk_client.call_auth_admin_method("generate_link", invalid_params) 312 | 313 | assert "Invalid parameters" in str(excinfo.value) or "invalid type" in str(excinfo.value).lower() 314 | 315 | async def test_delete_factor(self, mock_sdk_client: SupabaseSDKClient): 316 | """Test deleting an MFA factor""" 317 | # Mock the delete_factor method as an async function to raise not implemented 318 | mock_sdk_client.client.auth.admin.delete_factor = AsyncMock(side_effect=AttributeError("method not found")) 319 | 320 | # Attempt to delete a factor 321 | delete_factor_params = {"user_id": str(uuid.uuid4()), "id": "non-existent-factor-id"} 322 | 323 | with pytest.raises(PythonSDKError) as excinfo: 324 | await mock_sdk_client.call_auth_admin_method("delete_factor", delete_factor_params) 325 | 326 | # We expect this to fail with a specific error message 327 | assert "not implemented" in str(excinfo.value).lower() or "method not found" in str(excinfo.value).lower() 328 | 329 | async def test_empty_parameters(self, mock_sdk_client: SupabaseSDKClient): 330 | """Test validation errors with empty parameters for various methods""" 331 | # Test methods with empty parameters 332 | methods = ["get_user_by_id", "create_user", "update_user_by_id", "delete_user", "generate_link"] 333 | 334 | for method in methods: 335 | empty_params = {} 336 | 337 | # Mock the method to raise validation error 338 | setattr(mock_sdk_client.client.auth.admin, method, AsyncMock(side_effect=Exception("Invalid parameters"))) 339 | 340 | # Should raise PythonSDKError containing validation error details 341 | with pytest.raises(PythonSDKError) as excinfo: 342 | await mock_sdk_client.call_auth_admin_method(method, empty_params) 343 | 344 | # Verify error message contains validation details 345 | assert "Invalid parameters" in str(excinfo.value) or "validation error" in str(excinfo.value).lower() 346 | 347 | async def test_client_without_service_role_key(self, mock_settings): 348 | """Test that an exception is raised when attempting to use the SDK client without a service role key.""" 349 | # Create settings without service role key 350 | mock_settings.supabase_service_role_key = None 351 | 352 | # Reset singleton 353 | SupabaseSDKClient.reset() 354 | 355 | # Create client 356 | client = SupabaseSDKClient.get_instance(settings=mock_settings) 357 | 358 | # Attempt to call a method - should raise an exception 359 | with pytest.raises(PythonSDKError) as excinfo: 360 | await client.call_auth_admin_method("list_users", {}) 361 | 362 | assert "service role key is not configured" in str(excinfo.value) ``` -------------------------------------------------------------------------------- /supabase_mcp/services/database/postgres_client.py: -------------------------------------------------------------------------------- ```python 1 | from __future__ import annotations 2 | 3 | import urllib.parse 4 | from collections.abc import Awaitable, Callable 5 | from typing import Any, TypeVar 6 | 7 | import asyncpg 8 | from pydantic import BaseModel, Field 9 | from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential 10 | 11 | from supabase_mcp.exceptions import ConnectionError, PermissionError, QueryError 12 | from supabase_mcp.logger import logger 13 | from supabase_mcp.services.database.sql.models import QueryValidationResults 14 | from supabase_mcp.services.database.sql.validator import SQLValidator 15 | from supabase_mcp.settings import Settings 16 | 17 | # Define a type variable for generic return types 18 | T = TypeVar("T") 19 | 20 | # TODO: Use a context manager to properly handle the connection pool 21 | 22 | 23 | class StatementResult(BaseModel): 24 | """Represents the result of a single SQL statement.""" 25 | 26 | rows: list[dict[str, Any]] = Field( 27 | default_factory=list, 28 | description="List of rows returned by the statement. Is empty if the statement is a DDL statement.", 29 | ) 30 | 31 | 32 | class QueryResult(BaseModel): 33 | """Represents results of query execution, consisting of one or more statements.""" 34 | 35 | results: list[StatementResult] = Field( 36 | description="List of results from the statements in the query.", 37 | ) 38 | 39 | 40 | # Helper function for retry decorator to safely log exceptions 41 | def log_db_retry_attempt(retry_state: RetryCallState) -> None: 42 | """Log database retry attempts. 43 | 44 | Args: 45 | retry_state: Current retry state from tenacity 46 | """ 47 | if retry_state.outcome is not None and retry_state.outcome.failed: 48 | exception = retry_state.outcome.exception() 49 | exception_str = str(exception) 50 | logger.warning(f"Database error, retrying ({retry_state.attempt_number}/3): {exception_str}") 51 | 52 | 53 | # Add the new AsyncSupabaseClient class 54 | class PostgresClient: 55 | """Asynchronous client for interacting with Supabase PostgreSQL database.""" 56 | 57 | _instance: PostgresClient | None = None # Singleton instance 58 | 59 | def __init__( 60 | self, 61 | settings: Settings, 62 | project_ref: str | None = None, 63 | db_password: str | None = None, 64 | db_region: str | None = None, 65 | ): 66 | """Initialize client configuration (but don't connect yet). 67 | 68 | Args: 69 | settings_instance: Settings instance to use for configuration. 70 | project_ref: Optional Supabase project reference. If not provided, will be taken from settings. 71 | db_password: Optional database password. If not provided, will be taken from settings. 72 | db_region: Optional database region. If not provided, will be taken from settings. 73 | """ 74 | self._pool: asyncpg.Pool[asyncpg.Record] | None = None 75 | self._settings = settings 76 | self.project_ref = project_ref or self._settings.supabase_project_ref 77 | self.db_password = db_password or self._settings.supabase_db_password 78 | self.db_region = db_region or self._settings.supabase_region 79 | self.db_url = self._build_connection_string() 80 | self.sql_validator: SQLValidator = SQLValidator() 81 | 82 | # Only log once during initialization with clear project info 83 | is_local = self.project_ref.startswith("127.0.0.1") 84 | logger.info( 85 | f"✔️ PostgreSQL client initialized successfully for {'local' if is_local else 'remote'} " 86 | f"project: {self.project_ref} (region: {self.db_region})" 87 | ) 88 | 89 | @classmethod 90 | def get_instance( 91 | cls, 92 | settings: Settings, 93 | project_ref: str | None = None, 94 | db_password: str | None = None, 95 | ) -> PostgresClient: 96 | """Create and return a configured AsyncSupabaseClient instance. 97 | 98 | This is the recommended way to create a client instance. 99 | 100 | Args: 101 | settings_instance: Settings instance to use for configuration 102 | project_ref: Optional Supabase project reference 103 | db_password: Optional database password 104 | 105 | Returns: 106 | Configured AsyncSupabaseClient instance 107 | """ 108 | if cls._instance is None: 109 | cls._instance = cls( 110 | settings=settings, 111 | project_ref=project_ref, 112 | db_password=db_password, 113 | ) 114 | # Doesn't connect yet - will connect lazily when needed 115 | return cls._instance 116 | 117 | def _build_connection_string(self) -> str: 118 | """Build the database connection string for asyncpg. 119 | 120 | Returns: 121 | PostgreSQL connection string compatible with asyncpg 122 | """ 123 | encoded_password = urllib.parse.quote_plus(self.db_password) 124 | 125 | if self.project_ref.startswith("127.0.0.1"): 126 | # Local development 127 | connection_string = f"postgresql://postgres:{encoded_password}@{self.project_ref}/postgres" 128 | return connection_string 129 | 130 | # Production Supabase - via transaction pooler 131 | connection_string = ( 132 | f"postgresql://postgres.{self.project_ref}:{encoded_password}" 133 | f"@aws-0-{self._settings.supabase_region}.pooler.supabase.com:6543/postgres" 134 | ) 135 | return connection_string 136 | 137 | @retry( 138 | retry=retry_if_exception_type( 139 | ( 140 | asyncpg.exceptions.ConnectionDoesNotExistError, # Connection lost 141 | asyncpg.exceptions.InterfaceError, # Connection disruption 142 | asyncpg.exceptions.TooManyConnectionsError, # Temporary connection limit 143 | OSError, # Network issues 144 | ) 145 | ), 146 | stop=stop_after_attempt(3), 147 | wait=wait_exponential(multiplier=1, min=2, max=10), 148 | before_sleep=log_db_retry_attempt, 149 | ) 150 | async def create_pool(self) -> asyncpg.Pool[asyncpg.Record]: 151 | """Create and configure a database connection pool. 152 | 153 | Returns: 154 | Configured asyncpg connection pool 155 | 156 | Raises: 157 | ConnectionError: If unable to establish a connection to the database 158 | """ 159 | try: 160 | logger.debug(f"Creating connection pool for project: {self.project_ref}") 161 | 162 | # Create the pool with optimal settings 163 | pool = await asyncpg.create_pool( 164 | self.db_url, 165 | min_size=2, # Minimum connections to keep ready 166 | max_size=10, # Maximum connections allowed (same as current) 167 | statement_cache_size=0, 168 | command_timeout=30.0, # Command timeout in seconds 169 | max_inactive_connection_lifetime=300.0, # 5 minutes 170 | ) 171 | 172 | # Test the connection with a simple query 173 | async with pool.acquire() as conn: 174 | await conn.execute("SELECT 1") 175 | 176 | logger.info("✓ Database connection established successfully") 177 | return pool 178 | 179 | except asyncpg.PostgresError as e: 180 | # Extract connection details for better error reporting 181 | host_part = self.db_url.split("@")[1].split("/")[0] if "@" in self.db_url else "unknown" 182 | 183 | # Check specifically for the "Tenant or user not found" error which is often caused by region mismatch 184 | if "Tenant or user not found" in str(e): 185 | error_message = ( 186 | "CONNECTION ERROR: Region mismatch detected!\n\n" 187 | f"Could not connect to Supabase project '{self.project_ref}'.\n\n" 188 | "This error typically occurs when your SUPABASE_REGION setting doesn't match your project's actual region.\n" 189 | f"Your configuration is using region: '{self.db_region}' (default: us-east-1)\n\n" 190 | "ACTION REQUIRED: Please set the correct SUPABASE_REGION in your MCP server configuration.\n" 191 | "You can find your project's region in the Supabase dashboard under Project Settings." 192 | ) 193 | else: 194 | error_message = ( 195 | f"Could not connect to database: {e}\n" 196 | f"Connection attempted to: {host_part}\n via Transaction Pooler\n" 197 | f"Project ref: {self.project_ref}\n" 198 | f"Region: {self.db_region}\n\n" 199 | f"Please check:\n" 200 | f"1. Your Supabase project reference is correct\n" 201 | f"2. Your database password is correct\n" 202 | f"3. Your region setting matches your Supabase project region\n" 203 | f"4. Your Supabase project is active and the database is online\n" 204 | ) 205 | 206 | logger.error(f"Failed to connect to database: {e}") 207 | logger.error(f"Connection details: {host_part}, Project: {self.project_ref}, Region: {self.db_region}") 208 | 209 | raise ConnectionError(error_message) from e 210 | 211 | except OSError as e: 212 | # For network-related errors, provide a different message that clearly indicates 213 | # this is a network/system issue rather than a database configuration problem 214 | host_part = self.db_url.split("@")[1].split("/")[0] if "@" in self.db_url else "unknown" 215 | 216 | error_message = ( 217 | f"Network error while connecting to database: {e}\n" 218 | f"Connection attempted to: {host_part}\n\n" 219 | f"This appears to be a network or system issue rather than a database configuration problem.\n" 220 | f"Please check:\n" 221 | f"1. Your internet connection is working\n" 222 | f"2. Any firewalls or network security settings allow connections to {host_part}\n" 223 | f"3. DNS resolution is working correctly\n" 224 | f"4. The Supabase service is not experiencing an outage\n" 225 | ) 226 | 227 | logger.error(f"Network error connecting to database: {e}") 228 | logger.error(f"Connection details: {host_part}") 229 | raise ConnectionError(error_message) from e 230 | 231 | async def ensure_pool(self) -> None: 232 | """Ensure a valid connection pool exists. 233 | 234 | This method is called before executing queries to make sure 235 | we have an active connection pool. 236 | """ 237 | if self._pool is None: 238 | logger.debug("No active connection pool, creating one") 239 | self._pool = await self.create_pool() 240 | else: 241 | logger.debug("Using existing connection pool") 242 | 243 | async def close(self) -> None: 244 | """Close the connection pool and release all resources. 245 | 246 | This should be called when shutting down the application. 247 | """ 248 | import asyncio 249 | 250 | if self._pool: 251 | await asyncio.wait_for(self._pool.close(), timeout=5.0) 252 | self._pool = None 253 | else: 254 | logger.debug("No PostgreSQL connection pool to close") 255 | 256 | @classmethod 257 | async def reset(cls) -> None: 258 | """Reset the singleton instance cleanly. 259 | 260 | This closes any open connections and resets the singleton instance. 261 | """ 262 | if cls._instance is not None: 263 | await cls._instance.close() 264 | cls._instance = None 265 | logger.info("AsyncSupabaseClient instance reset complete") 266 | 267 | async def with_connection(self, operation_func: Callable[[asyncpg.Connection[Any]], Awaitable[T]]) -> T: 268 | """Execute an operation with a database connection. 269 | 270 | Args: 271 | operation_func: Async function that takes a connection and returns a result 272 | 273 | Returns: 274 | The result of the operation function 275 | 276 | Raises: 277 | ConnectionError: If a database connection issue occurs 278 | """ 279 | # Ensure we have an active connection pool 280 | await self.ensure_pool() 281 | 282 | # Acquire a connection from the pool and execute the operation 283 | async with self._pool.acquire() as conn: 284 | return await operation_func(conn) 285 | 286 | async def with_transaction( 287 | self, conn: asyncpg.Connection[Any], operation_func: Callable[[], Awaitable[T]], readonly: bool = False 288 | ) -> T: 289 | """Execute an operation within a transaction. 290 | 291 | Args: 292 | conn: Database connection 293 | operation_func: Async function that executes within the transaction 294 | readonly: Whether the transaction is read-only 295 | 296 | Returns: 297 | The result of the operation function 298 | 299 | Raises: 300 | QueryError: If the query execution fails 301 | """ 302 | # Execute the operation within a transaction 303 | async with conn.transaction(readonly=readonly): 304 | return await operation_func() 305 | 306 | async def execute_statement(self, conn: asyncpg.Connection[Any], query: str) -> StatementResult: 307 | """Execute a single SQL statement. 308 | 309 | Args: 310 | conn: Database connection 311 | query: SQL query to execute 312 | 313 | Returns: 314 | StatementResult containing the rows returned by the statement 315 | 316 | Raises: 317 | QueryError: If the statement execution fails 318 | """ 319 | try: 320 | # Execute the query 321 | result = await conn.fetch(query) 322 | 323 | # Convert records to dictionaries 324 | rows = [dict(record) for record in result] 325 | 326 | # Log success 327 | logger.debug(f"Statement executed successfully, rows: {len(rows)}") 328 | 329 | # Return the result 330 | return StatementResult(rows=rows) 331 | 332 | except asyncpg.PostgresError as e: 333 | await self._handle_postgres_error(e) 334 | 335 | @retry( 336 | retry=retry_if_exception_type( 337 | ( 338 | asyncpg.exceptions.ConnectionDoesNotExistError, # Connection lost 339 | asyncpg.exceptions.InterfaceError, # Connection disruption 340 | asyncpg.exceptions.TooManyConnectionsError, # Temporary connection limit 341 | OSError, # Network issues 342 | ) 343 | ), 344 | stop=stop_after_attempt(3), 345 | wait=wait_exponential(multiplier=1, min=2, max=10), 346 | before_sleep=log_db_retry_attempt, 347 | ) 348 | async def execute_query( 349 | self, 350 | validated_query: QueryValidationResults, 351 | readonly: bool = True, # Default to read-only for safety 352 | ) -> QueryResult: 353 | """Execute a SQL query asynchronously with proper transaction management. 354 | 355 | Args: 356 | validated_query: Validated query containing statements to execute 357 | readonly: Whether to execute in read-only mode 358 | 359 | Returns: 360 | QueryResult containing the results of all statements 361 | 362 | Raises: 363 | ConnectionError: If a database connection issue occurs 364 | QueryError: If the query execution fails 365 | PermissionError: When user lacks required privileges 366 | """ 367 | # Log query execution (truncate long queries for readability) 368 | truncated_query = ( 369 | validated_query.original_query[:100] + "..." 370 | if len(validated_query.original_query) > 100 371 | else validated_query.original_query 372 | ) 373 | logger.debug(f"Executing query (readonly={readonly}): {truncated_query}") 374 | 375 | # Define the operation to execute all statements within a transaction 376 | async def execute_all_statements(conn): 377 | async def transaction_operation(): 378 | results = [] 379 | for statement in validated_query.statements: 380 | if statement.query: # Skip statements with no query 381 | result = await self.execute_statement(conn, statement.query) 382 | results.append(result) 383 | else: 384 | logger.warning(f"Statement has no query, statement: {statement}") 385 | return results 386 | 387 | # Execute the operation within a transaction 388 | results = await self.with_transaction(conn, transaction_operation, readonly) 389 | return QueryResult(results=results) 390 | 391 | # Execute the operation with a connection 392 | return await self.with_connection(execute_all_statements) 393 | 394 | async def _handle_postgres_error(self, error: asyncpg.PostgresError) -> None: 395 | """Handle PostgreSQL errors and convert to appropriate exceptions. 396 | 397 | Args: 398 | error: PostgreSQL error 399 | 400 | Raises: 401 | PermissionError: When user lacks required privileges 402 | QueryError: For schema errors or general query errors 403 | """ 404 | if isinstance(error, asyncpg.exceptions.InsufficientPrivilegeError): 405 | logger.error(f"Permission denied: {error}") 406 | raise PermissionError( 407 | f"Access denied: {str(error)}. Use live_dangerously('database', True) for write operations." 408 | ) from error 409 | elif isinstance( 410 | error, 411 | ( 412 | asyncpg.exceptions.UndefinedTableError, 413 | asyncpg.exceptions.UndefinedColumnError, 414 | ), 415 | ): 416 | logger.error(f"Schema error: {error}") 417 | raise QueryError(str(error)) from error 418 | else: 419 | logger.error(f"Database error: {error}") 420 | raise QueryError(f"Query execution failed: {str(error)}") from error 421 | ``` -------------------------------------------------------------------------------- /supabase_mcp/services/database/migration_manager.py: -------------------------------------------------------------------------------- ```python 1 | import datetime 2 | import hashlib 3 | import re 4 | 5 | from supabase_mcp.logger import logger 6 | from supabase_mcp.services.database.sql.loader import SQLLoader 7 | from supabase_mcp.services.database.sql.models import ( 8 | QueryValidationResults, 9 | SQLQueryCategory, 10 | ValidatedStatement, 11 | ) 12 | 13 | 14 | class MigrationManager: 15 | """Responsible for preparing migration scripts without executing them.""" 16 | 17 | def __init__(self, loader: SQLLoader | None = None): 18 | """Initialize the migration manager with a SQL loader. 19 | 20 | Args: 21 | loader: The SQL loader to use for loading SQL queries 22 | """ 23 | self.loader = loader or SQLLoader() 24 | 25 | def prepare_migration_query( 26 | self, 27 | validation_result: QueryValidationResults, 28 | original_query: str, 29 | migration_name: str = "", 30 | ) -> tuple[str, str]: 31 | """ 32 | Prepare a migration script without executing it. 33 | 34 | Args: 35 | validation_result: The validation result 36 | original_query: The original query 37 | migration_name: The name of the migration, if provided by the client 38 | 39 | Returns: 40 | Complete SQL query to create the migration 41 | Migration name 42 | """ 43 | # If client provided a name, use it directly without generating a new one 44 | if migration_name.strip(): 45 | name = self.sanitize_name(migration_name) 46 | else: 47 | # Otherwise generate a descriptive name 48 | name = self.generate_descriptive_name(validation_result) 49 | 50 | # Generate migration version (timestamp) 51 | version = self.generate_query_timestamp() 52 | 53 | # Escape single quotes in the query for SQL safety 54 | statements = original_query.replace("'", "''") 55 | 56 | # Get the migration query using the loader 57 | migration_query = self.loader.get_create_migration_query(version, name, statements) 58 | 59 | logger.info(f"Prepared migration: {version}_{name}") 60 | 61 | # Return the complete query 62 | return migration_query, name 63 | 64 | def sanitize_name(self, name: str) -> str: 65 | """ 66 | Generate a standardized name for a migration script. 67 | 68 | Args: 69 | name: Raw migration name 70 | 71 | Returns: 72 | str: Sanitized migration name 73 | """ 74 | # Remove special characters and replace spaces with underscores 75 | sanitized_name = re.sub(r"[^\w\s]", "", name).lower() 76 | sanitized_name = re.sub(r"\s+", "_", sanitized_name) 77 | 78 | # Ensure the name is not too long (max 100 chars) 79 | if len(sanitized_name) > 100: 80 | sanitized_name = sanitized_name[:100] 81 | 82 | return sanitized_name 83 | 84 | def generate_descriptive_name( 85 | self, 86 | query_validation_result: QueryValidationResults, 87 | ) -> str: 88 | """ 89 | Generate a descriptive name for a migration based on the validation result. 90 | 91 | This method should only be called when no client-provided name is available. 92 | 93 | Priority order: 94 | 1. Auto-generated name based on SQL analysis 95 | 2. Fallback to hash if no meaningful information can be extracted 96 | 97 | Args: 98 | query_validation_result: Validation result for a batch of SQL statements 99 | 100 | Returns: 101 | str: Descriptive migration name 102 | """ 103 | # Case 1: No client-provided name, generate descriptive name 104 | # Find the first statement that needs migration 105 | statement = None 106 | for stmt in query_validation_result.statements: 107 | if stmt.needs_migration: 108 | statement = stmt 109 | break 110 | 111 | # If no statement found (unlikely), use a hash-based name 112 | if not statement: 113 | logger.warning( 114 | "No statement found in validation result, using hash-based name, statements: %s", 115 | query_validation_result.statements, 116 | ) 117 | # Generate a short hash from the query text 118 | query_hash = self._generate_short_hash(query_validation_result.original_query) 119 | return f"migration_{query_hash}" 120 | 121 | # Generate name based on statement category and command 122 | logger.debug(f"Generating name for statement: {statement}") 123 | if statement.category == SQLQueryCategory.DDL: 124 | return self._generate_ddl_name(statement) 125 | elif statement.category == SQLQueryCategory.DML: 126 | return self._generate_dml_name(statement) 127 | elif statement.category == SQLQueryCategory.DCL: 128 | return self._generate_dcl_name(statement) 129 | else: 130 | # Fallback for other categories 131 | return self._generate_generic_name(statement) 132 | 133 | def _generate_short_hash(self, text: str) -> str: 134 | """Generate a short hash from text for use in migration names.""" 135 | hash_object = hashlib.md5(text.encode()) 136 | return hash_object.hexdigest()[:8] # First 8 chars of MD5 hash 137 | 138 | def _generate_ddl_name(self, statement: ValidatedStatement) -> str: 139 | """ 140 | Generate a name for DDL statements (CREATE, ALTER, DROP). 141 | Format: {command}_{object_type}_{schema}_{object_name} 142 | Examples: 143 | - create_table_public_users 144 | - alter_function_auth_authenticate 145 | - drop_index_public_users_email_idx 146 | """ 147 | command = statement.command.value.lower() 148 | schema = statement.schema_name.lower() if statement.schema_name else "public" 149 | 150 | # Extract object type and name with enhanced detection 151 | object_type = "object" # Default fallback 152 | object_name = "unknown" # Default fallback 153 | 154 | # Enhanced object type detection based on command 155 | if statement.object_type: 156 | object_type = statement.object_type.lower() 157 | 158 | # Handle specific object types 159 | if object_type == "table" and statement.query: 160 | object_name = self._extract_table_name(statement.query) 161 | elif (object_type == "function" or object_type == "procedure") and statement.query: 162 | object_name = self._extract_function_name(statement.query) 163 | elif object_type == "trigger" and statement.query: 164 | object_name = self._extract_trigger_name(statement.query) 165 | elif object_type == "index" and statement.query: 166 | object_name = self._extract_index_name(statement.query) 167 | elif object_type == "view" and statement.query: 168 | object_name = self._extract_view_name(statement.query) 169 | elif object_type == "materialized_view" and statement.query: 170 | object_name = self._extract_materialized_view_name(statement.query) 171 | elif object_type == "sequence" and statement.query: 172 | object_name = self._extract_sequence_name(statement.query) 173 | elif object_type == "constraint" and statement.query: 174 | object_name = self._extract_constraint_name(statement.query) 175 | elif object_type == "foreign_table" and statement.query: 176 | object_name = self._extract_foreign_table_name(statement.query) 177 | elif object_type == "extension" and statement.query: 178 | object_name = self._extract_extension_name(statement.query) 179 | elif object_type == "type" and statement.query: 180 | object_name = self._extract_type_name(statement.query) 181 | elif statement.query: 182 | # For other object types, use a generic extraction 183 | object_name = self._extract_generic_object_name(statement.query) 184 | 185 | # Combine parts into a descriptive name 186 | name = f"{command}_{object_type}_{schema}_{object_name}" 187 | return self.sanitize_name(name) 188 | 189 | def _generate_dml_name(self, statement: ValidatedStatement) -> str: 190 | """ 191 | Generate a name for DML statements (INSERT, UPDATE, DELETE). 192 | Format: {command}_{schema}_{table_name} 193 | Examples: 194 | - insert_public_users 195 | - update_auth_users 196 | - delete_public_logs 197 | """ 198 | command = statement.command.value.lower() 199 | schema = statement.schema_name.lower() if statement.schema_name else "public" 200 | 201 | # Extract table name 202 | table_name = "unknown" 203 | if statement.query: 204 | table_name = self._extract_table_name(statement.query) or "unknown" 205 | 206 | # For UPDATE and DELETE, add what's being modified if possible 207 | if command == "update" and statement.query: 208 | # Try to extract column names being updated 209 | columns = self._extract_update_columns(statement.query) 210 | if columns: 211 | return self.sanitize_name(f"{command}_{columns}_in_{schema}_{table_name}") 212 | 213 | # Default format 214 | name = f"{command}_{schema}_{table_name}" 215 | return self.sanitize_name(name) 216 | 217 | def _generate_dcl_name(self, statement: ValidatedStatement) -> str: 218 | """ 219 | Generate a name for DCL statements (GRANT, REVOKE). 220 | Format: {command}_{privilege}_{schema}_{object_name} 221 | Examples: 222 | - grant_select_public_users 223 | - revoke_all_public_items 224 | """ 225 | command = statement.command.value.lower() 226 | schema = statement.schema_name.lower() if statement.schema_name else "public" 227 | 228 | # Extract privilege and object name 229 | privilege = "privilege" 230 | object_name = "unknown" 231 | 232 | if statement.query: 233 | privilege = self._extract_privilege(statement.query) or "privilege" 234 | object_name = self._extract_dcl_object_name(statement.query) or "unknown" 235 | 236 | name = f"{command}_{privilege}_{schema}_{object_name}" 237 | return self.sanitize_name(name) 238 | 239 | def _generate_generic_name(self, statement: ValidatedStatement) -> str: 240 | """ 241 | Generate a name for other statement types. 242 | Format: {command}_{schema}_{object_type} 243 | """ 244 | command = statement.command.value.lower() 245 | schema = statement.schema_name.lower() if statement.schema_name else "public" 246 | object_type = statement.object_type.lower() if statement.object_type else "object" 247 | 248 | name = f"{command}_{schema}_{object_type}" 249 | return self.sanitize_name(name) 250 | 251 | # Helper methods for extracting specific parts from SQL queries 252 | 253 | def _extract_table_name(self, query: str) -> str: 254 | """Extract table name from a query.""" 255 | if not query: 256 | return "unknown" 257 | 258 | # Simple regex-based extraction for demonstration 259 | # In a real implementation, this would use more sophisticated parsing 260 | import re 261 | 262 | # For CREATE TABLE 263 | match = re.search(r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE) 264 | if match: 265 | return match.group(2) 266 | 267 | # For ALTER TABLE 268 | match = re.search(r"ALTER\s+TABLE\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE) 269 | if match: 270 | return match.group(2) 271 | 272 | # For DROP TABLE 273 | match = re.search(r"DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE) 274 | if match: 275 | return match.group(2) 276 | 277 | # For INSERT, UPDATE, DELETE 278 | match = re.search(r"(?:INSERT\s+INTO|UPDATE|DELETE\s+FROM)\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE) 279 | if match: 280 | return match.group(2) 281 | 282 | return "unknown" 283 | 284 | def _extract_function_name(self, query: str) -> str: 285 | """Extract function name from a query.""" 286 | if not query: 287 | return "unknown" 288 | 289 | import re 290 | 291 | match = re.search( 292 | r"(?:CREATE|ALTER|DROP)\s+(?:OR\s+REPLACE\s+)?FUNCTION\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE 293 | ) 294 | if match: 295 | return match.group(2) 296 | 297 | return "unknown" 298 | 299 | def _extract_trigger_name(self, query: str) -> str: 300 | """Extract trigger name from a query.""" 301 | if not query: 302 | return "unknown" 303 | 304 | import re 305 | 306 | match = re.search(r"(?:CREATE|ALTER|DROP)\s+TRIGGER\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)", query, re.IGNORECASE) 307 | if match: 308 | return match.group(1) 309 | 310 | return "unknown" 311 | 312 | def _extract_view_name(self, query: str) -> str: 313 | """Extract view name from a query.""" 314 | if not query: 315 | return "unknown" 316 | 317 | import re 318 | 319 | match = re.search(r"(?:CREATE|ALTER|DROP)\s+(?:OR\s+REPLACE\s+)?VIEW\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE) 320 | if match: 321 | return match.group(2) 322 | 323 | return "unknown" 324 | 325 | def _extract_index_name(self, query: str) -> str: 326 | """Extract index name from a query.""" 327 | if not query: 328 | return "unknown" 329 | 330 | import re 331 | 332 | match = re.search(r"(?:CREATE|DROP)\s+INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE) 333 | if match: 334 | return match.group(2) 335 | 336 | return "unknown" 337 | 338 | def _extract_sequence_name(self, query: str) -> str: 339 | """Extract sequence name from a query.""" 340 | if not query: 341 | return "unknown" 342 | 343 | import re 344 | 345 | match = re.search( 346 | r"(?:CREATE|ALTER|DROP)\s+SEQUENCE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE 347 | ) 348 | if match: 349 | return match.group(2) 350 | 351 | return "unknown" 352 | 353 | def _extract_constraint_name(self, query: str) -> str: 354 | """Extract constraint name from a query.""" 355 | if not query: 356 | return "unknown" 357 | 358 | import re 359 | 360 | match = re.search(r"CONSTRAINT\s+(\w+)", query, re.IGNORECASE) 361 | if match: 362 | return match.group(1) 363 | 364 | return "unknown" 365 | 366 | def _extract_update_columns(self, query: str) -> str: 367 | """Extract columns being updated in an UPDATE statement.""" 368 | if not query: 369 | return "" 370 | 371 | import re 372 | 373 | # This is a simplified approach - a real implementation would use proper SQL parsing 374 | match = re.search(r"UPDATE\s+(?:\w+\.)?(?:\w+)\s+SET\s+([\w\s,=]+)\s+WHERE", query, re.IGNORECASE) 375 | if match: 376 | # Extract column names from the SET clause 377 | set_clause = match.group(1) 378 | columns = re.findall(r"(\w+)\s*=", set_clause) 379 | if columns and len(columns) <= 3: # Limit to 3 columns to keep name reasonable 380 | return "_".join(columns) 381 | elif columns: 382 | return f"{columns[0]}_and_others" 383 | 384 | return "" 385 | 386 | def _extract_privilege(self, query: str) -> str: 387 | """Extract privilege from a GRANT or REVOKE statement.""" 388 | if not query: 389 | return "privilege" 390 | 391 | import re 392 | 393 | match = re.search(r"(?:GRANT|REVOKE)\s+([\w\s,]+)\s+ON", query, re.IGNORECASE) 394 | if match: 395 | privileges = match.group(1).strip().lower() 396 | if "all" in privileges: 397 | return "all" 398 | elif "select" in privileges: 399 | return "select" 400 | elif "insert" in privileges: 401 | return "insert" 402 | elif "update" in privileges: 403 | return "update" 404 | elif "delete" in privileges: 405 | return "delete" 406 | 407 | return "privilege" 408 | 409 | def _extract_dcl_object_name(self, query: str) -> str: 410 | """Extract object name from a GRANT or REVOKE statement.""" 411 | if not query: 412 | return "unknown" 413 | 414 | import re 415 | 416 | match = re.search(r"ON\s+(?:TABLE\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE) 417 | if match: 418 | return match.group(2) 419 | 420 | return "unknown" 421 | 422 | def _extract_generic_object_name(self, query: str) -> str: 423 | """Extract a generic object name when specific extractors don't apply.""" 424 | if not query: 425 | return "unknown" 426 | 427 | import re 428 | 429 | # Look for common patterns of object names in SQL 430 | patterns = [ 431 | r"(?:CREATE|ALTER|DROP)\s+(?:\w+\s+)+(?:(\w+)\.)?(\w+)", # General DDL pattern 432 | r"ON\s+(?:(\w+)\.)?(\w+)", # ON clause 433 | r"FROM\s+(?:(\w+)\.)?(\w+)", # FROM clause 434 | r"INTO\s+(?:(\w+)\.)?(\w+)", # INTO clause 435 | ] 436 | 437 | for pattern in patterns: 438 | match = re.search(pattern, query, re.IGNORECASE) 439 | if match and match.group(2): 440 | return match.group(2) 441 | 442 | return "unknown" 443 | 444 | def _extract_materialized_view_name(self, query: str) -> str: 445 | """Extract materialized view name from a query.""" 446 | if not query: 447 | return "unknown" 448 | 449 | import re 450 | 451 | match = re.search( 452 | r"(?:CREATE|ALTER|DROP|REFRESH)\s+(?:MATERIALIZED\s+VIEW)\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", 453 | query, 454 | re.IGNORECASE, 455 | ) 456 | if match: 457 | return match.group(2) 458 | 459 | return "unknown" 460 | 461 | def _extract_foreign_table_name(self, query: str) -> str: 462 | """Extract foreign table name from a query.""" 463 | if not query: 464 | return "unknown" 465 | 466 | import re 467 | 468 | match = re.search( 469 | r"(?:CREATE|ALTER|DROP)\s+(?:FOREIGN\s+TABLE)\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", 470 | query, 471 | re.IGNORECASE, 472 | ) 473 | if match: 474 | return match.group(2) 475 | 476 | return "unknown" 477 | 478 | def _extract_extension_name(self, query: str) -> str: 479 | """Extract extension name from a query.""" 480 | if not query: 481 | return "unknown" 482 | 483 | import re 484 | 485 | match = re.search(r"(?:CREATE|ALTER|DROP)\s+EXTENSION\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)", query, re.IGNORECASE) 486 | if match: 487 | return match.group(1) 488 | 489 | return "unknown" 490 | 491 | def _extract_type_name(self, query: str) -> str: 492 | """Extract custom type name from a query.""" 493 | if not query: 494 | return "unknown" 495 | 496 | import re 497 | 498 | # For ENUM types 499 | match = re.search(r"(?:CREATE|ALTER|DROP)\s+TYPE\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE) 500 | if match: 501 | return match.group(2) 502 | 503 | # For DOMAIN types 504 | match = re.search(r"(?:CREATE|ALTER|DROP)\s+DOMAIN\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE) 505 | if match: 506 | return match.group(2) 507 | 508 | return "unknown" 509 | 510 | def generate_query_timestamp(self) -> str: 511 | """ 512 | Generate a timestamp for a migration script in the format YYYYMMDDHHMMSS. 513 | 514 | Returns: 515 | str: Timestamp string 516 | """ 517 | now = datetime.datetime.now() 518 | return now.strftime("%Y%m%d%H%M%S") 519 | ``` -------------------------------------------------------------------------------- /supabase_mcp/services/safety/safety_configs.py: -------------------------------------------------------------------------------- ```python 1 | import re 2 | from abc import ABC, abstractmethod 3 | from enum import Enum 4 | from typing import Any, Generic, TypeVar 5 | 6 | from supabase_mcp.services.database.sql.models import ( 7 | QueryValidationResults, 8 | SQLQueryCategory, 9 | ) 10 | from supabase_mcp.services.safety.models import OperationRiskLevel, SafetyMode 11 | 12 | T = TypeVar("T") 13 | 14 | 15 | class SafetyConfigBase(Generic[T], ABC): 16 | """Abstract base class for all SafetyConfig classes of specific clients. 17 | 18 | Provides methods: 19 | - register safety configuration 20 | - to get / set safety level 21 | - check safety level of operation 22 | """ 23 | 24 | @abstractmethod 25 | def get_risk_level(self, operation: T) -> OperationRiskLevel: 26 | """Get the risk level for an operation. 27 | 28 | Args: 29 | operation: The operation to check 30 | 31 | Returns: 32 | The risk level for the operation 33 | """ 34 | pass 35 | 36 | def is_operation_allowed(self, risk_level: OperationRiskLevel, mode: SafetyMode) -> bool: 37 | """Check if an operation is allowed based on its risk level and the current safety mode. 38 | 39 | Args: 40 | risk_level: The risk level of the operation 41 | mode: The current safety mode 42 | 43 | Returns: 44 | True if the operation is allowed, False otherwise 45 | """ 46 | # LOW risk operations are always allowed 47 | if risk_level == OperationRiskLevel.LOW: 48 | return True 49 | 50 | # MEDIUM risk operations are allowed only in UNSAFE mode 51 | if risk_level == OperationRiskLevel.MEDIUM: 52 | return mode == SafetyMode.UNSAFE 53 | 54 | # HIGH risk operations are allowed only in UNSAFE mode with confirmation 55 | if risk_level == OperationRiskLevel.HIGH: 56 | return mode == SafetyMode.UNSAFE 57 | 58 | # EXTREME risk operations are never allowed 59 | return False 60 | 61 | def needs_confirmation(self, risk_level: OperationRiskLevel) -> bool: 62 | """Check if an operation needs confirmation based on its risk level. 63 | 64 | Args: 65 | risk_level: The risk level of the operation 66 | 67 | Returns: 68 | True if the operation needs confirmation, False otherwise 69 | """ 70 | # Only HIGH and EXTREME risk operations require confirmation 71 | return risk_level >= OperationRiskLevel.HIGH 72 | 73 | 74 | # ======== 75 | # API Safety Config 76 | # ======== 77 | 78 | 79 | class HTTPMethod(str, Enum): 80 | """HTTP methods used in API operations.""" 81 | 82 | GET = "GET" 83 | POST = "POST" 84 | PUT = "PUT" 85 | PATCH = "PATCH" 86 | DELETE = "DELETE" 87 | HEAD = "HEAD" 88 | OPTIONS = "OPTIONS" 89 | 90 | 91 | class APISafetyConfig(SafetyConfigBase[tuple[str, str, dict[str, Any], dict[str, Any], dict[str, Any]]]): 92 | """Safety configuration for API operations. 93 | 94 | The operation type is a tuple of (method, path). 95 | """ 96 | 97 | # Maps risk levels to operations (method + path patterns) 98 | PATH_SAFETY_CONFIG = { 99 | OperationRiskLevel.EXTREME: { 100 | HTTPMethod.DELETE: [ 101 | "/v1/projects/{ref}", # Delete project. Irreversible, complete data loss. 102 | ] 103 | }, 104 | OperationRiskLevel.HIGH: { 105 | HTTPMethod.DELETE: [ 106 | "/v1/projects/{ref}/branches/{branch_id}", # Delete a database branch. Data loss on branch. 107 | "/v1/projects/{ref}/branches", # Disables preview branching. Disruptive to development workflows. 108 | "/v1/projects/{ref}/custom-hostname", # Deletes custom hostname config. Can break production access. 109 | "/v1/projects/{ref}/vanity-subdomain", # Deletes vanity subdomain config. Breaks vanity URL access. 110 | "/v1/projects/{ref}/network-bans", # Remove network bans (can expose database to wider network). 111 | "/v1/projects/{ref}/secrets", # Bulk delete secrets. Can break application functionality if critical secrets are removed. 112 | "/v1/projects/{ref}/functions/{function_slug}", # Delete function. Breaks functionality relying on the function. 113 | "/v1/projects/{ref}/api-keys/{id}", # Delete api key. Can break API access. 114 | "/v1/projects/{ref}/config/auth/sso/providers/{provider_id}", # Delete SSO Provider. Breaks SSO login. 115 | "/v1/projects/{ref}/config/auth/signing-keys/{id}", # Delete signing key. Can break JWT verification. 116 | ], 117 | HTTPMethod.POST: [ 118 | "/v1/projects/{ref}/pause", # Pause project - Impacts production, database becomes unavailable. 119 | "/v1/projects/{ref}/restore", # Restore project - Can overwrite existing data with backup. 120 | "/v1/projects/{ref}/upgrade", # Upgrades the project's Postgres version - potential downtime/compatibility issues. 121 | "/v1/projects/{ref}/read-replicas/remove", # Remove a read replica. Can impact read scalability. 122 | "/v1/projects/{ref}/restore/cancel", # Cancels the given project restoration. Can leave project in inconsistent state. 123 | "/v1/projects/{ref}/readonly/temporary-disable", # Disables readonly mode. Allows potentially destructive operations. 124 | ], 125 | }, 126 | OperationRiskLevel.MEDIUM: { 127 | HTTPMethod.POST: [ 128 | "/v1/projects", # Create project. Significant infrastructure change. 129 | "/v1/organizations", # Create org. Significant infrastructure change. 130 | "/v1/projects/{ref}/branches", # Create a database branch. Could potentially impact production if misused. 131 | "/v1/projects/{ref}/branches/{branch_id}/push", # Push a database branch. Could overwrite production data if pushed to the wrong branch. 132 | "/v1/projects/{ref}/branches/{branch_id}/reset", # Reset a database branch. Data loss on the branch. 133 | "/v1/projects/{ref}/custom-hostname/initialize", # Updates custom hostname configuration, potentially breaking existing config. 134 | "/v1/projects/{ref}/custom-hostname/reverify", # Attempts to verify DNS configuration. Could disrupt custom hostname if misconfigured. 135 | "/v1/projects/{ref}/custom-hostname/activate", # Activates custom hostname. Could lead to downtime during switchover. 136 | "/v1/projects/{ref}/network-bans/retrieve", # Gets project's network bans. Information disclosure, though less risky than removing bans. 137 | "/v1/projects/{ref}/network-restrictions/apply", # Updates project's network restrictions. Could block legitimate access if misconfigured. 138 | "/v1/projects/{ref}/secrets", # Bulk create secrets. Could overwrite existing secrets if names collide. 139 | "/v1/projects/{ref}/upgrade/status", # get status for upgrade 140 | "/v1/projects/{ref}/database/webhooks/enable", # Enables Database Webhooks. Could expose data if webhooks are misconfigured. 141 | "/v1/projects/{ref}/functions", # Create a function (deprecated). 142 | "/v1/projects/{ref}/functions/deploy", # Deploy a function. Could break functionality if deployed code has errors. 143 | "/v1/projects/{ref}/config/auth/sso/providers", # Create SSO provider. Could impact authentication if misconfigured. 144 | "/v1/projects/{ref}/database/backups/restore-pitr", # Restore a PITR backup. Can overwrite data. 145 | "/v1/projects/{ref}/read-replicas/setup", # Setup a read replica 146 | "/v1/projects/{ref}/database/query", # Run SQL query. *Crucially*, this allows arbitrary SQL, including `DROP TABLE`, `DELETE`, etc. 147 | "/v1/projects/{ref}/config/auth/signing-keys", # Create a new signing key, requires key rotation. 148 | "/v1/oauth/token", # Exchange auth code for user's access token. Security-sensitive. 149 | "/v1/oauth/revoke", # Revoke oauth app authorization. Can break application access. 150 | "/v1/projects/{ref}/api-keys", # Create an API key 151 | ], 152 | HTTPMethod.PATCH: [ 153 | "/v1/projects/{ref}/config/auth", # Auth config. Could lock users out or introduce vulnerabilities if misconfigured. 154 | "/v1/projects/{ref}/config/database/pooler", # Connection pooling changes. Can impact database performance. 155 | "/v1/projects/{ref}/postgrest", # Update Postgrest config. Can impact API behavior. 156 | "/v1/projects/{ref}/functions/{function_slug}", # Updates a function. Can break functionality. 157 | "/v1/projects/{ref}/config/storage", # Update Storage config. Can change file size limits, etc. 158 | "/v1/branches/{branch_id}", # Update database branch config. 159 | "/v1/projects/{ref}/api-keys/{id}", # Updates a API key 160 | "/v1/projects/{ref}/config/auth/signing-keys/{id}", # updates signing key. 161 | ], 162 | HTTPMethod.PUT: [ 163 | "/v1/projects/{ref}/config/database/postgres", # Postgres config changes. Can significantly impact database performance/behavior. 164 | "/v1/projects/{ref}/pgsodium", # Update pgsodium config. *Critical*: Updating the `root_key` can cause data loss. 165 | "/v1/projects/{ref}/ssl-enforcement", # Update SSL enforcement config. Could break access if misconfigured. 166 | "/v1/projects/{ref}/functions", # Bulk update Edge Functions. Could break multiple functions at once. 167 | "/v1/projects/{ref}/config/auth/sso/providers/{provider_id}", # Update sso provider. 168 | ], 169 | }, 170 | } 171 | 172 | def get_risk_level( 173 | self, operation: tuple[str, str, dict[str, Any], dict[str, Any], dict[str, Any]] 174 | ) -> OperationRiskLevel: 175 | """Get the risk level for an API operation. 176 | 177 | Args: 178 | operation: Tuple of (method, path) 179 | 180 | Returns: 181 | The risk level for the operation 182 | """ 183 | method, path, _, _, _ = operation 184 | 185 | # Check each risk level from highest to lowest 186 | for risk_level in sorted(self.PATH_SAFETY_CONFIG.keys(), reverse=True): 187 | if self._path_matches_risk_level(method, path, risk_level): 188 | return risk_level 189 | 190 | # Default to low risk 191 | return OperationRiskLevel.LOW 192 | 193 | def _path_matches_risk_level(self, method: str, path: str, risk_level: OperationRiskLevel) -> bool: 194 | """Check if the method and path match any pattern for the given risk level.""" 195 | patterns = self.PATH_SAFETY_CONFIG.get(risk_level, {}) 196 | 197 | if method not in patterns: 198 | return False 199 | 200 | for pattern in patterns[method]: 201 | # Convert placeholder pattern to regex 202 | regex_pattern = self._convert_pattern_to_regex(pattern) 203 | if re.match(regex_pattern, path): 204 | return True 205 | 206 | return False 207 | 208 | def _convert_pattern_to_regex(self, pattern: str) -> str: 209 | """Convert a placeholder pattern to a regex pattern. 210 | 211 | Replaces placeholders like {ref} with regex patterns for matching. 212 | """ 213 | # Replace common placeholders with regex patterns 214 | regex_pattern = pattern 215 | regex_pattern = regex_pattern.replace("{ref}", r"[^/]+") 216 | regex_pattern = regex_pattern.replace("{id}", r"[^/]+") 217 | regex_pattern = regex_pattern.replace("{slug}", r"[^/]+") 218 | regex_pattern = regex_pattern.replace("{table}", r"[^/]+") 219 | regex_pattern = regex_pattern.replace("{branch_id}", r"[^/]+") 220 | regex_pattern = regex_pattern.replace("{function_slug}", r"[^/]+") 221 | 222 | # Add end anchor to ensure full path matching 223 | if not regex_pattern.endswith("$"): 224 | regex_pattern += "$" 225 | 226 | return regex_pattern 227 | 228 | 229 | # ======== 230 | # SQL Safety Config 231 | # ======== 232 | 233 | 234 | class SQLSafetyConfig(SafetyConfigBase[QueryValidationResults]): 235 | """Safety configuration for SQL operations.""" 236 | 237 | STATEMENT_CONFIG = { 238 | # DQL - all LOW risk, no migrations 239 | "SelectStmt": { 240 | "category": SQLQueryCategory.DQL, 241 | "risk_level": OperationRiskLevel.LOW, 242 | "needs_migration": False, 243 | }, 244 | "ExplainStmt": { 245 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 246 | "risk_level": OperationRiskLevel.LOW, 247 | "needs_migration": False, 248 | }, 249 | # DML - all MEDIUM risk, no migrations 250 | "InsertStmt": { 251 | "category": SQLQueryCategory.DML, 252 | "risk_level": OperationRiskLevel.MEDIUM, 253 | "needs_migration": False, 254 | }, 255 | "UpdateStmt": { 256 | "category": SQLQueryCategory.DML, 257 | "risk_level": OperationRiskLevel.MEDIUM, 258 | "needs_migration": False, 259 | }, 260 | "DeleteStmt": { 261 | "category": SQLQueryCategory.DML, 262 | "risk_level": OperationRiskLevel.MEDIUM, 263 | "needs_migration": False, 264 | }, 265 | "MergeStmt": { 266 | "category": SQLQueryCategory.DML, 267 | "risk_level": OperationRiskLevel.MEDIUM, 268 | "needs_migration": False, 269 | }, 270 | # DDL - mix of MEDIUM and HIGH risk, need migrations 271 | "CreateStmt": { 272 | "category": SQLQueryCategory.DDL, 273 | "risk_level": OperationRiskLevel.MEDIUM, 274 | "needs_migration": True, 275 | }, 276 | "CreateTableAsStmt": { 277 | "category": SQLQueryCategory.DDL, 278 | "risk_level": OperationRiskLevel.MEDIUM, 279 | "needs_migration": True, 280 | }, 281 | "CreateSchemaStmt": { 282 | "category": SQLQueryCategory.DDL, 283 | "risk_level": OperationRiskLevel.MEDIUM, 284 | "needs_migration": True, 285 | }, 286 | "CreateExtensionStmt": { 287 | "category": SQLQueryCategory.DDL, 288 | "risk_level": OperationRiskLevel.MEDIUM, 289 | "needs_migration": True, 290 | }, 291 | "AlterTableStmt": { 292 | "category": SQLQueryCategory.DDL, 293 | "risk_level": OperationRiskLevel.MEDIUM, 294 | "needs_migration": True, 295 | }, 296 | "AlterDomainStmt": { 297 | "category": SQLQueryCategory.DDL, 298 | "risk_level": OperationRiskLevel.MEDIUM, 299 | "needs_migration": True, 300 | }, 301 | "CreateFunctionStmt": { 302 | "category": SQLQueryCategory.DDL, 303 | "risk_level": OperationRiskLevel.MEDIUM, 304 | "needs_migration": True, 305 | }, 306 | "IndexStmt": { # CREATE INDEX 307 | "category": SQLQueryCategory.DDL, 308 | "risk_level": OperationRiskLevel.MEDIUM, 309 | "needs_migration": True, 310 | }, 311 | "CreateTrigStmt": { 312 | "category": SQLQueryCategory.DDL, 313 | "risk_level": OperationRiskLevel.MEDIUM, 314 | "needs_migration": True, 315 | }, 316 | "ViewStmt": { # CREATE VIEW 317 | "category": SQLQueryCategory.DDL, 318 | "risk_level": OperationRiskLevel.MEDIUM, 319 | "needs_migration": True, 320 | }, 321 | "CommentStmt": { 322 | "category": SQLQueryCategory.DDL, 323 | "risk_level": OperationRiskLevel.MEDIUM, 324 | "needs_migration": True, 325 | }, 326 | # Additional DDL statements 327 | "CreateEnumStmt": { # CREATE TYPE ... AS ENUM 328 | "category": SQLQueryCategory.DDL, 329 | "risk_level": OperationRiskLevel.MEDIUM, 330 | "needs_migration": True, 331 | }, 332 | "CreateTypeStmt": { # CREATE TYPE (composite) 333 | "category": SQLQueryCategory.DDL, 334 | "risk_level": OperationRiskLevel.MEDIUM, 335 | "needs_migration": True, 336 | }, 337 | "CreateDomainStmt": { # CREATE DOMAIN 338 | "category": SQLQueryCategory.DDL, 339 | "risk_level": OperationRiskLevel.MEDIUM, 340 | "needs_migration": True, 341 | }, 342 | "CreateSeqStmt": { # CREATE SEQUENCE 343 | "category": SQLQueryCategory.DDL, 344 | "risk_level": OperationRiskLevel.MEDIUM, 345 | "needs_migration": True, 346 | }, 347 | "CreateForeignTableStmt": { # CREATE FOREIGN TABLE 348 | "category": SQLQueryCategory.DDL, 349 | "risk_level": OperationRiskLevel.MEDIUM, 350 | "needs_migration": True, 351 | }, 352 | "CreatePolicyStmt": { # CREATE POLICY 353 | "category": SQLQueryCategory.DDL, 354 | "risk_level": OperationRiskLevel.MEDIUM, 355 | "needs_migration": True, 356 | }, 357 | "CreateCastStmt": { # CREATE CAST 358 | "category": SQLQueryCategory.DDL, 359 | "risk_level": OperationRiskLevel.MEDIUM, 360 | "needs_migration": True, 361 | }, 362 | "CreateOpClassStmt": { # CREATE OPERATOR CLASS 363 | "category": SQLQueryCategory.DDL, 364 | "risk_level": OperationRiskLevel.MEDIUM, 365 | "needs_migration": True, 366 | }, 367 | "CreateOpFamilyStmt": { # CREATE OPERATOR FAMILY 368 | "category": SQLQueryCategory.DDL, 369 | "risk_level": OperationRiskLevel.MEDIUM, 370 | "needs_migration": True, 371 | }, 372 | "AlterEnumStmt": { # ALTER TYPE ... ADD VALUE 373 | "category": SQLQueryCategory.DDL, 374 | "risk_level": OperationRiskLevel.MEDIUM, 375 | "needs_migration": True, 376 | }, 377 | "AlterSeqStmt": { # ALTER SEQUENCE 378 | "category": SQLQueryCategory.DDL, 379 | "risk_level": OperationRiskLevel.MEDIUM, 380 | "needs_migration": True, 381 | }, 382 | "AlterOwnerStmt": { # ALTER ... OWNER TO 383 | "category": SQLQueryCategory.DDL, 384 | "risk_level": OperationRiskLevel.MEDIUM, 385 | "needs_migration": True, 386 | }, 387 | "AlterObjectSchemaStmt": { # ALTER ... SET SCHEMA 388 | "category": SQLQueryCategory.DDL, 389 | "risk_level": OperationRiskLevel.MEDIUM, 390 | "needs_migration": True, 391 | }, 392 | "RenameStmt": { # RENAME operations 393 | "category": SQLQueryCategory.DDL, 394 | "risk_level": OperationRiskLevel.MEDIUM, 395 | "needs_migration": True, 396 | }, 397 | # DESTRUCTIVE DDL - HIGH risk, need migrations and confirmation 398 | "DropStmt": { 399 | "category": SQLQueryCategory.DDL, 400 | "risk_level": OperationRiskLevel.HIGH, 401 | "needs_migration": True, 402 | }, 403 | "TruncateStmt": { 404 | "category": SQLQueryCategory.DDL, 405 | "risk_level": OperationRiskLevel.HIGH, 406 | "needs_migration": True, 407 | }, 408 | # DCL - MEDIUM risk, need migrations 409 | "GrantStmt": { 410 | "category": SQLQueryCategory.DCL, 411 | "risk_level": OperationRiskLevel.MEDIUM, 412 | "needs_migration": True, 413 | }, 414 | "GrantRoleStmt": { 415 | "category": SQLQueryCategory.DCL, 416 | "risk_level": OperationRiskLevel.MEDIUM, 417 | "needs_migration": True, 418 | }, 419 | "RevokeStmt": { 420 | "category": SQLQueryCategory.DCL, 421 | "risk_level": OperationRiskLevel.MEDIUM, 422 | "needs_migration": True, 423 | }, 424 | "RevokeRoleStmt": { 425 | "category": SQLQueryCategory.DCL, 426 | "risk_level": OperationRiskLevel.MEDIUM, 427 | "needs_migration": True, 428 | }, 429 | "CreateRoleStmt": { 430 | "category": SQLQueryCategory.DCL, 431 | "risk_level": OperationRiskLevel.MEDIUM, 432 | "needs_migration": True, 433 | }, 434 | "AlterRoleStmt": { 435 | "category": SQLQueryCategory.DCL, 436 | "risk_level": OperationRiskLevel.MEDIUM, 437 | "needs_migration": True, 438 | }, 439 | "DropRoleStmt": { 440 | "category": SQLQueryCategory.DCL, 441 | "risk_level": OperationRiskLevel.HIGH, 442 | "needs_migration": True, 443 | }, 444 | # TCL - LOW risk, no migrations 445 | "TransactionStmt": { 446 | "category": SQLQueryCategory.TCL, 447 | "risk_level": OperationRiskLevel.LOW, 448 | "needs_migration": False, 449 | }, 450 | # PostgreSQL-specific 451 | "VacuumStmt": { 452 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 453 | "risk_level": OperationRiskLevel.MEDIUM, 454 | "needs_migration": False, 455 | }, 456 | "AnalyzeStmt": { 457 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 458 | "risk_level": OperationRiskLevel.LOW, 459 | "needs_migration": False, 460 | }, 461 | "ClusterStmt": { 462 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 463 | "risk_level": OperationRiskLevel.MEDIUM, 464 | "needs_migration": False, 465 | }, 466 | "CheckPointStmt": { 467 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 468 | "risk_level": OperationRiskLevel.MEDIUM, 469 | "needs_migration": False, 470 | }, 471 | "PrepareStmt": { 472 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 473 | "risk_level": OperationRiskLevel.LOW, 474 | "needs_migration": False, 475 | }, 476 | "ExecuteStmt": { 477 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 478 | "risk_level": OperationRiskLevel.MEDIUM, # Could be LOW or MEDIUM based on prepared statement 479 | "needs_migration": False, 480 | }, 481 | "DeallocateStmt": { 482 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 483 | "risk_level": OperationRiskLevel.LOW, 484 | "needs_migration": False, 485 | }, 486 | "ListenStmt": { 487 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 488 | "risk_level": OperationRiskLevel.LOW, 489 | "needs_migration": False, 490 | }, 491 | "NotifyStmt": { 492 | "category": SQLQueryCategory.POSTGRES_SPECIFIC, 493 | "risk_level": OperationRiskLevel.MEDIUM, 494 | "needs_migration": False, 495 | }, 496 | } 497 | 498 | # Functions for more complex determinations 499 | def classify_statement(self, stmt_type: str, stmt_node: Any) -> dict[str, Any]: 500 | """Get classification rules for a given statement type from our config.""" 501 | config = self.STATEMENT_CONFIG.get( 502 | stmt_type, 503 | # if not found - default to MEDIUM risk 504 | { 505 | "category": SQLQueryCategory.OTHER, 506 | "risk_level": OperationRiskLevel.MEDIUM, # Default to MEDIUM risk for unknown 507 | "needs_migration": False, 508 | }, 509 | ) 510 | 511 | # Special case: CopyStmt can be read or write 512 | if stmt_type == "CopyStmt" and stmt_node: 513 | # Check if it's COPY TO (read) or COPY FROM (write) 514 | if hasattr(stmt_node, "is_from") and not stmt_node.is_from: 515 | # COPY TO - it's a read operation (LOW risk) 516 | config["category"] = SQLQueryCategory.DQL 517 | config["risk_level"] = OperationRiskLevel.LOW 518 | else: 519 | # COPY FROM - it's a write operation (MEDIUM risk) 520 | config["category"] = SQLQueryCategory.DML 521 | config["risk_level"] = OperationRiskLevel.MEDIUM 522 | 523 | # Other special cases can be added here 524 | 525 | return config 526 | 527 | def get_risk_level(self, operation: QueryValidationResults) -> OperationRiskLevel: 528 | """Get the risk level for an SQL batch operation. 529 | 530 | Args: 531 | operation: The SQL batch validation result to check 532 | 533 | Returns: 534 | The highest risk level found in the batch 535 | """ 536 | # Simply return the highest risk level that's already tracked in the batch 537 | return operation.highest_risk_level 538 | ``` -------------------------------------------------------------------------------- /tests/services/database/sql/test_sql_validator.py: -------------------------------------------------------------------------------- ```python 1 | import pytest 2 | 3 | from supabase_mcp.exceptions import ValidationError 4 | from supabase_mcp.services.database.sql.models import SQLQueryCategory, SQLQueryCommand 5 | from supabase_mcp.services.database.sql.validator import SQLValidator 6 | from supabase_mcp.services.safety.models import OperationRiskLevel 7 | 8 | 9 | class TestSQLValidator: 10 | """Test suite for the SQLValidator class.""" 11 | 12 | # ========================================================================= 13 | # Core Validation Tests 14 | # ========================================================================= 15 | 16 | def test_empty_query_validation(self, mock_validator: SQLValidator): 17 | """ 18 | Test that empty queries are properly rejected. 19 | 20 | This is a fundamental validation test to ensure the validator 21 | rejects empty or whitespace-only queries. 22 | """ 23 | # Test empty string 24 | with pytest.raises(ValidationError, match="Query cannot be empty"): 25 | mock_validator.validate_query("") 26 | 27 | # Test whitespace-only string 28 | with pytest.raises(ValidationError, match="Query cannot be empty"): 29 | mock_validator.validate_query(" \n \t ") 30 | 31 | def test_schema_and_table_name_validation(self, mock_validator: SQLValidator): 32 | """ 33 | Test validation of schema and table names. 34 | 35 | This test ensures that schema and table names are properly validated 36 | to prevent SQL injection and other security issues. 37 | """ 38 | # Test schema name validation 39 | valid_schema = "public" 40 | assert mock_validator.validate_schema_name(valid_schema) == valid_schema 41 | 42 | # The actual error message is "Schema name cannot contain spaces" 43 | invalid_schema = "public; DROP TABLE users;" 44 | with pytest.raises(ValidationError, match="Schema name cannot contain spaces"): 45 | mock_validator.validate_schema_name(invalid_schema) 46 | 47 | # Test table name validation 48 | valid_table = "users" 49 | assert mock_validator.validate_table_name(valid_table) == valid_table 50 | 51 | # The actual error message is "Table name cannot contain spaces" 52 | invalid_table = "users; DROP TABLE users;" 53 | with pytest.raises(ValidationError, match="Table name cannot contain spaces"): 54 | mock_validator.validate_table_name(invalid_table) 55 | 56 | # ========================================================================= 57 | # Safety Level Classification Tests 58 | # ========================================================================= 59 | 60 | def test_safe_operation_identification(self, mock_validator: SQLValidator, sample_dql_queries: dict[str, str]): 61 | """ 62 | Test that safe operations (SELECT queries) are correctly identified. 63 | 64 | This test ensures that all SELECT queries are properly categorized as 65 | safe operations, which is critical for security. 66 | """ 67 | for name, query in sample_dql_queries.items(): 68 | result = mock_validator.validate_query(query) 69 | assert result.highest_risk_level == OperationRiskLevel.LOW, f"Query '{name}' should be classified as SAFE" 70 | assert result.statements[0].category == SQLQueryCategory.DQL, f"Query '{name}' should be categorized as DQL" 71 | assert result.statements[0].command == SQLQueryCommand.SELECT, f"Query '{name}' should have command SELECT" 72 | 73 | def test_write_operation_identification(self, mock_validator: SQLValidator, sample_dml_queries: dict[str, str]): 74 | """ 75 | Test that write operations (INSERT, UPDATE, DELETE) are correctly identified. 76 | 77 | This test ensures that all data modification queries are properly categorized 78 | as write operations, which require different permissions. 79 | """ 80 | for name, query in sample_dml_queries.items(): 81 | result = mock_validator.validate_query(query) 82 | assert result.highest_risk_level == OperationRiskLevel.MEDIUM, ( 83 | f"Query '{name}' should be classified as WRITE" 84 | ) 85 | assert result.statements[0].category == SQLQueryCategory.DML, f"Query '{name}' should be categorized as DML" 86 | 87 | # Check specific command based on query type 88 | if name.startswith("insert"): 89 | assert result.statements[0].command == SQLQueryCommand.INSERT 90 | elif name.startswith("update"): 91 | assert result.statements[0].command == SQLQueryCommand.UPDATE 92 | elif name.startswith("delete"): 93 | assert result.statements[0].command == SQLQueryCommand.DELETE 94 | elif name.startswith("merge"): 95 | assert result.statements[0].command == SQLQueryCommand.MERGE 96 | 97 | def test_destructive_operation_identification( 98 | self, mock_validator: SQLValidator, sample_ddl_queries: dict[str, str] 99 | ): 100 | """ 101 | Test that destructive operations (DROP, TRUNCATE) are correctly identified. 102 | 103 | This test ensures that all data definition queries that can destroy data 104 | are properly categorized as destructive operations, which require 105 | the highest level of permissions. 106 | """ 107 | # Test DROP statements 108 | drop_query = sample_ddl_queries["drop_table"] 109 | drop_result = mock_validator.validate_query(drop_query) 110 | 111 | # Verify the statement is correctly categorized as DDL and has the DROP command 112 | assert drop_result.statements[0].category == SQLQueryCategory.DDL, "DROP should be categorized as DDL" 113 | assert drop_result.statements[0].command == SQLQueryCommand.DROP, "Command should be DROP" 114 | 115 | # Test TRUNCATE statements 116 | truncate_query = sample_ddl_queries["truncate_table"] 117 | truncate_result = mock_validator.validate_query(truncate_query) 118 | 119 | # Verify the statement is correctly categorized as DDL and has the TRUNCATE command 120 | assert truncate_result.statements[0].category == SQLQueryCategory.DDL, "TRUNCATE should be categorized as DDL" 121 | assert truncate_result.statements[0].command == SQLQueryCommand.TRUNCATE, "Command should be TRUNCATE" 122 | 123 | # ========================================================================= 124 | # Transaction Control Tests 125 | # ========================================================================= 126 | 127 | def test_transaction_control_detection(self, mock_validator: SQLValidator, sample_tcl_queries: dict[str, str]): 128 | """ 129 | Test that BEGIN/COMMIT/ROLLBACK statements are correctly identified as TCL. 130 | 131 | Transaction control is critical for maintaining data integrity and 132 | must be properly detected regardless of case or formatting. 133 | """ 134 | # Test BEGIN statement 135 | with pytest.raises(ValidationError) as excinfo: 136 | mock_validator.validate_query(sample_tcl_queries["begin_transaction"]) 137 | assert "Transaction control statements" in str(excinfo.value) 138 | 139 | # Test COMMIT statement 140 | with pytest.raises(ValidationError) as excinfo: 141 | mock_validator.validate_query(sample_tcl_queries["commit_transaction"]) 142 | assert "Transaction control statements" in str(excinfo.value) 143 | 144 | # Test ROLLBACK statement 145 | with pytest.raises(ValidationError) as excinfo: 146 | mock_validator.validate_query(sample_tcl_queries["rollback_transaction"]) 147 | assert "Transaction control statements" in str(excinfo.value) 148 | 149 | # Test mixed case transaction statement 150 | with pytest.raises(ValidationError) as excinfo: 151 | mock_validator.validate_query(sample_tcl_queries["mixed_case_transaction"]) 152 | assert "Transaction control statements" in str(excinfo.value) 153 | 154 | # Test string-based detection method directly 155 | assert SQLValidator.validate_transaction_control("BEGIN"), "String-based detection should identify BEGIN" 156 | assert SQLValidator.validate_transaction_control("COMMIT"), "String-based detection should identify COMMIT" 157 | assert SQLValidator.validate_transaction_control("ROLLBACK"), "String-based detection should identify ROLLBACK" 158 | assert SQLValidator.validate_transaction_control("begin transaction"), ( 159 | "String-based detection should be case-insensitive" 160 | ) 161 | 162 | # ========================================================================= 163 | # Multiple Statements Tests 164 | # ========================================================================= 165 | 166 | def test_multiple_statements_with_mixed_safety_levels( 167 | self, mock_validator: SQLValidator, sample_multiple_statements: dict[str, str] 168 | ): 169 | """ 170 | Test that multiple statements with different safety levels are correctly identified. 171 | 172 | Note: Due to the string-based comparison in the implementation, the safety levels 173 | are not correctly ordered (SAFE > WRITE > DESTRUCTIVE). This test focuses on 174 | verifying that multiple statements are correctly parsed and categorized. 175 | """ 176 | # Test multiple safe statements 177 | safe_result = mock_validator.validate_query(sample_multiple_statements["multiple_safe"]) 178 | assert len(safe_result.statements) == 2, "Should identify two statements" 179 | assert safe_result.statements[0].category == SQLQueryCategory.DQL, "First statement should be DQL" 180 | assert safe_result.statements[1].category == SQLQueryCategory.DQL, "Second statement should be DQL" 181 | 182 | # Test safe + write statements 183 | mixed_result = mock_validator.validate_query(sample_multiple_statements["safe_and_write"]) 184 | assert len(mixed_result.statements) == 2, "Should identify two statements" 185 | assert mixed_result.statements[0].category == SQLQueryCategory.DQL, "First statement should be DQL" 186 | assert mixed_result.statements[1].category == SQLQueryCategory.DML, "Second statement should be DML" 187 | 188 | # Test write + destructive statements 189 | destructive_result = mock_validator.validate_query(sample_multiple_statements["write_and_destructive"]) 190 | assert len(destructive_result.statements) == 2, "Should identify two statements" 191 | assert destructive_result.statements[0].category == SQLQueryCategory.DML, "First statement should be DML" 192 | assert destructive_result.statements[1].category == SQLQueryCategory.DDL, "Second statement should be DDL" 193 | assert destructive_result.statements[1].command == SQLQueryCommand.DROP, "Second command should be DROP" 194 | 195 | # Test transaction statements 196 | with pytest.raises(ValidationError) as excinfo: 197 | mock_validator.validate_query(sample_multiple_statements["with_transaction"]) 198 | assert "Transaction control statements" in str(excinfo.value) 199 | 200 | # ========================================================================= 201 | # Error Handling Tests 202 | # ========================================================================= 203 | 204 | def test_syntax_error_handling(self, mock_validator: SQLValidator, sample_invalid_queries: dict[str, str]): 205 | """ 206 | Test that SQL syntax errors are properly caught and reported. 207 | 208 | Fundamental for providing clear feedback to users when their SQL is invalid. 209 | """ 210 | # Test syntax error 211 | with pytest.raises(ValidationError, match="SQL syntax error"): 212 | mock_validator.validate_query(sample_invalid_queries["syntax_error"]) 213 | 214 | # Test missing parenthesis 215 | with pytest.raises(ValidationError, match="SQL syntax error"): 216 | mock_validator.validate_query(sample_invalid_queries["missing_parenthesis"]) 217 | 218 | # Test incomplete statement 219 | with pytest.raises(ValidationError, match="SQL syntax error"): 220 | mock_validator.validate_query(sample_invalid_queries["incomplete_statement"]) 221 | 222 | # ========================================================================= 223 | # PostgreSQL-Specific Features Tests 224 | # ========================================================================= 225 | 226 | def test_copy_statement_direction_detection( 227 | self, mock_validator: SQLValidator, sample_postgres_specific_queries: dict[str, str] 228 | ): 229 | """ 230 | Test that COPY TO (read) vs COPY FROM (write) are correctly distinguished. 231 | 232 | Important edge case with safety implications as COPY TO is safe 233 | while COPY FROM modifies data. 234 | """ 235 | # Test COPY TO (should be SAFE) 236 | copy_to_result = mock_validator.validate_query(sample_postgres_specific_queries["copy_to"]) 237 | assert copy_to_result.highest_risk_level == OperationRiskLevel.LOW, "COPY TO should be classified as SAFE" 238 | assert copy_to_result.statements[0].category == SQLQueryCategory.DQL, "COPY TO should be categorized as DQL" 239 | 240 | # Test COPY FROM (should be WRITE) 241 | copy_from_result = mock_validator.validate_query(sample_postgres_specific_queries["copy_from"]) 242 | assert copy_from_result.highest_risk_level == OperationRiskLevel.MEDIUM, ( 243 | "COPY FROM should be classified as WRITE" 244 | ) 245 | assert copy_from_result.statements[0].category == SQLQueryCategory.DML, "COPY FROM should be categorized as DML" 246 | 247 | # ========================================================================= 248 | # Complex Scenarios Tests 249 | # ========================================================================= 250 | 251 | def test_complex_queries_with_subqueries_and_ctes( 252 | self, mock_validator: SQLValidator, sample_dql_queries: dict[str, str] 253 | ): 254 | """ 255 | Test that complex queries with subqueries and CTEs are correctly parsed. 256 | 257 | Ensures robustness with real-world queries that may contain 258 | complex structures but are still valid. 259 | """ 260 | # Test query with subquery 261 | subquery_result = mock_validator.validate_query(sample_dql_queries["select_with_subquery"]) 262 | assert subquery_result.highest_risk_level == OperationRiskLevel.LOW, "Query with subquery should be SAFE" 263 | assert subquery_result.statements[0].category == SQLQueryCategory.DQL, "Query with subquery should be DQL" 264 | 265 | # Test query with CTE (Common Table Expression) 266 | cte_result = mock_validator.validate_query(sample_dql_queries["select_with_cte"]) 267 | assert cte_result.highest_risk_level == OperationRiskLevel.LOW, "Query with CTE should be SAFE" 268 | assert cte_result.statements[0].category == SQLQueryCategory.DQL, "Query with CTE should be DQL" 269 | 270 | # ========================================================================= 271 | # False Positive Prevention Tests 272 | # ========================================================================= 273 | 274 | def test_valid_queries_with_comments(self, mock_validator: SQLValidator, sample_edge_cases: dict[str, str]): 275 | """ 276 | Test that valid queries with SQL comments are not rejected. 277 | 278 | Ensures that comments (inline and block) don't cause valid queries 279 | to be incorrectly flagged as invalid. 280 | """ 281 | # Test query with comments 282 | query_with_comments = sample_edge_cases["with_comments"] 283 | result = mock_validator.validate_query(query_with_comments) 284 | 285 | # Verify the query is parsed correctly despite comments 286 | assert result.statements[0].category == SQLQueryCategory.DQL, "Query with comments should be categorized as DQL" 287 | assert result.statements[0].command == SQLQueryCommand.SELECT, "Query with comments should have SELECT command" 288 | assert result.highest_risk_level == OperationRiskLevel.LOW, "Query with comments should be SAFE" 289 | 290 | def test_valid_queries_with_quoted_identifiers( 291 | self, mock_validator: SQLValidator, sample_edge_cases: dict[str, str] 292 | ): 293 | """ 294 | Test that valid queries with quoted identifiers are not rejected. 295 | 296 | Ensures that double-quoted table/column names and single-quoted 297 | strings don't cause false positives. 298 | """ 299 | # Test query with quoted identifiers 300 | query_with_quotes = sample_edge_cases["quoted_identifiers"] 301 | result = mock_validator.validate_query(query_with_quotes) 302 | 303 | # Verify the query is parsed correctly despite quoted identifiers 304 | assert result.statements[0].category == SQLQueryCategory.DQL, ( 305 | "Query with quoted identifiers should be categorized as DQL" 306 | ) 307 | assert result.statements[0].command == SQLQueryCommand.SELECT, ( 308 | "Query with quoted identifiers should have SELECT command" 309 | ) 310 | assert result.highest_risk_level == OperationRiskLevel.LOW, "Query with quoted identifiers should be SAFE" 311 | 312 | def test_valid_queries_with_special_characters( 313 | self, mock_validator: SQLValidator, sample_edge_cases: dict[str, str] 314 | ): 315 | """ 316 | Test that valid queries with special characters are not rejected. 317 | 318 | Ensures that special characters in strings and identifiers 319 | don't trigger false positives. 320 | """ 321 | # Test query with special characters 322 | query_with_special_chars = sample_edge_cases["special_characters"] 323 | result = mock_validator.validate_query(query_with_special_chars) 324 | 325 | # Verify the query is parsed correctly despite special characters 326 | assert result.statements[0].category == SQLQueryCategory.DQL, ( 327 | "Query with special characters should be categorized as DQL" 328 | ) 329 | assert result.statements[0].command == SQLQueryCommand.SELECT, ( 330 | "Query with special characters should have SELECT command" 331 | ) 332 | assert result.highest_risk_level == OperationRiskLevel.LOW, "Query with special characters should be SAFE" 333 | 334 | def test_valid_postgresql_specific_syntax( 335 | self, 336 | mock_validator: SQLValidator, 337 | sample_edge_cases: dict[str, str], 338 | sample_postgres_specific_queries: dict[str, str], 339 | ): 340 | """ 341 | Test that valid PostgreSQL-specific syntax is not rejected. 342 | 343 | Ensures that PostgreSQL extensions to standard SQL (like RETURNING 344 | clauses or specific operators) don't cause false positives. 345 | """ 346 | # Test query with dollar-quoted strings (PostgreSQL-specific feature) 347 | query_with_dollar_quotes = sample_edge_cases["with_dollar_quotes"] 348 | result = mock_validator.validate_query(query_with_dollar_quotes) 349 | assert result.statements[0].category == SQLQueryCategory.DQL, ( 350 | "Query with dollar quotes should be categorized as DQL" 351 | ) 352 | 353 | # Test schema-qualified names 354 | schema_qualified_query = sample_edge_cases["schema_qualified"] 355 | result = mock_validator.validate_query(schema_qualified_query) 356 | assert result.statements[0].category == SQLQueryCategory.DQL, ( 357 | "Query with schema qualification should be categorized as DQL" 358 | ) 359 | 360 | # Test EXPLAIN ANALYZE (PostgreSQL-specific) 361 | explain_query = sample_postgres_specific_queries["explain"] 362 | result = mock_validator.validate_query(explain_query) 363 | assert result.statements[0].category == SQLQueryCategory.POSTGRES_SPECIFIC, ( 364 | "EXPLAIN should be categorized as POSTGRES_SPECIFIC" 365 | ) 366 | 367 | def test_valid_complex_joins(self, mock_validator: SQLValidator): 368 | """ 369 | Test that valid complex JOIN operations are not rejected. 370 | 371 | Ensures that complex but valid JOIN syntax (including LATERAL joins, 372 | multiple join conditions, etc.) doesn't cause false positives. 373 | """ 374 | # Test complex join with multiple conditions 375 | complex_join_query = """ 376 | SELECT u.id, u.name, p.title, c.content 377 | FROM users u 378 | JOIN posts p ON u.id = p.user_id AND p.published = true 379 | LEFT JOIN comments c ON p.id = c.post_id 380 | WHERE u.active = true 381 | ORDER BY p.created_at DESC 382 | """ 383 | result = mock_validator.validate_query(complex_join_query) 384 | assert result.statements[0].category == SQLQueryCategory.DQL, "Complex join query should be categorized as DQL" 385 | assert result.statements[0].command == SQLQueryCommand.SELECT, "Complex join query should have SELECT command" 386 | 387 | # Test LATERAL join (PostgreSQL-specific join type) 388 | lateral_join_query = """ 389 | SELECT u.id, u.name, p.title 390 | FROM users u 391 | LEFT JOIN LATERAL ( 392 | SELECT title FROM posts WHERE user_id = u.id ORDER BY created_at DESC LIMIT 1 393 | ) p ON true 394 | """ 395 | result = mock_validator.validate_query(lateral_join_query) 396 | assert result.statements[0].category == SQLQueryCategory.DQL, "LATERAL join query should be categorized as DQL" 397 | assert result.statements[0].command == SQLQueryCommand.SELECT, "LATERAL join query should have SELECT command" 398 | 399 | # ========================================================================= 400 | # Additional Tests Based on Code Review 401 | # ========================================================================= 402 | 403 | def test_dcl_statement_identification(self, mock_validator: SQLValidator, sample_dcl_queries: dict[str, str]): 404 | """ 405 | Test that GRANT/REVOKE statements are correctly identified as DCL. 406 | 407 | DCL statements control access to data and should be properly classified 408 | to ensure appropriate permissions management. 409 | """ 410 | # Test GRANT statement 411 | grant_query = sample_dcl_queries["grant_select"] 412 | grant_result = mock_validator.validate_query(grant_query) 413 | assert grant_result.statements[0].category == SQLQueryCategory.DCL, "GRANT should be categorized as DCL" 414 | assert grant_result.statements[0].command == SQLQueryCommand.GRANT, "Command should be GRANT" 415 | 416 | # Test REVOKE statement 417 | revoke_query = sample_dcl_queries["revoke_select"] 418 | revoke_result = mock_validator.validate_query(revoke_query) 419 | assert revoke_result.statements[0].category == SQLQueryCategory.DCL, "REVOKE should be categorized as DCL" 420 | # Note: The current implementation may not correctly identify REVOKE commands 421 | # so we're only checking the category, not the specific command 422 | 423 | # Test CREATE ROLE statement (also DCL) 424 | create_role_query = sample_dcl_queries["create_role"] 425 | create_role_result = mock_validator.validate_query(create_role_query) 426 | assert create_role_result.statements[0].category == SQLQueryCategory.DCL, ( 427 | "CREATE ROLE should be categorized as DCL" 428 | ) 429 | 430 | def test_needs_migration_flag( 431 | self, mock_validator: SQLValidator, sample_ddl_queries: dict[str, str], sample_dml_queries: dict[str, str] 432 | ): 433 | """ 434 | Test that statements requiring migrations are correctly flagged. 435 | 436 | Ensures that DDL statements that require migrations (like CREATE TABLE) 437 | are properly identified to enforce migration requirements. 438 | """ 439 | # Test CREATE TABLE (should need migration) 440 | create_table_query = sample_ddl_queries["create_table"] 441 | create_result = mock_validator.validate_query(create_table_query) 442 | assert create_result.statements[0].needs_migration, "CREATE TABLE should require migration" 443 | 444 | # Test ALTER TABLE (should need migration) 445 | alter_table_query = sample_ddl_queries["alter_table"] 446 | alter_result = mock_validator.validate_query(alter_table_query) 447 | assert alter_result.statements[0].needs_migration, "ALTER TABLE should require migration" 448 | 449 | # Test INSERT (should NOT need migration) 450 | insert_query = sample_dml_queries["simple_insert"] 451 | insert_result = mock_validator.validate_query(insert_query) 452 | assert not insert_result.statements[0].needs_migration, "INSERT should not require migration" 453 | 454 | def test_object_type_extraction(self, mock_validator: SQLValidator): 455 | """ 456 | Test that object types (table names, etc.) are correctly extracted when possible. 457 | 458 | Note: The current implementation has limitations in extracting object types 459 | from all statement types. This test focuses on verifying the basic functionality 460 | without making assumptions about specific extraction capabilities. 461 | """ 462 | # Test that object_type is present in the result structure 463 | select_query = "SELECT * FROM users WHERE id = 1" 464 | select_result = mock_validator.validate_query(select_query) 465 | 466 | # Verify the object_type field exists in the result 467 | assert hasattr(select_result.statements[0], "object_type"), "Result should have object_type field" 468 | 469 | # Test with a more complex query 470 | complex_query = """ 471 | WITH active_users AS ( 472 | SELECT * FROM users WHERE active = true 473 | ) 474 | SELECT * FROM active_users 475 | """ 476 | complex_result = mock_validator.validate_query(complex_query) 477 | assert hasattr(complex_result.statements[0], "object_type"), ( 478 | "Complex query result should have object_type field" 479 | ) 480 | 481 | def test_string_based_transaction_control(self, mock_validator: SQLValidator): 482 | """ 483 | Test the string-based transaction control detection method. 484 | 485 | Specifically tests the validate_transaction_control class method 486 | to ensure it correctly identifies transaction keywords. 487 | """ 488 | # Test standard transaction keywords 489 | assert SQLValidator.validate_transaction_control("BEGIN"), "Should detect 'BEGIN'" 490 | assert SQLValidator.validate_transaction_control("COMMIT"), "Should detect 'COMMIT'" 491 | assert SQLValidator.validate_transaction_control("ROLLBACK"), "Should detect 'ROLLBACK'" 492 | 493 | # Test case insensitivity 494 | assert SQLValidator.validate_transaction_control("begin"), "Should be case-insensitive" 495 | assert SQLValidator.validate_transaction_control("Commit"), "Should be case-insensitive" 496 | assert SQLValidator.validate_transaction_control("ROLLBACK"), "Should be case-insensitive" 497 | 498 | # Test with additional text 499 | assert SQLValidator.validate_transaction_control("BEGIN TRANSACTION"), "Should detect 'BEGIN TRANSACTION'" 500 | assert SQLValidator.validate_transaction_control("COMMIT WORK"), "Should detect 'COMMIT WORK'" 501 | 502 | # Test negative cases 503 | assert not SQLValidator.validate_transaction_control("SELECT * FROM transactions"), ( 504 | "Should not detect in regular SQL" 505 | ) 506 | assert not SQLValidator.validate_transaction_control(""), "Should not detect in empty string" 507 | 508 | def test_basic_query_validation_method(self, mock_validator: SQLValidator): 509 | """ 510 | Test the basic_query_validation method. 511 | 512 | Ensures that the method correctly validates and sanitizes 513 | input queries before parsing. 514 | """ 515 | # Test valid query 516 | valid_query = "SELECT * FROM users" 517 | assert mock_validator.basic_query_validation(valid_query) == valid_query, "Should return valid query unchanged" 518 | 519 | # Test query with whitespace 520 | whitespace_query = " SELECT * FROM users " 521 | assert mock_validator.basic_query_validation(whitespace_query) == whitespace_query, "Should preserve whitespace" 522 | 523 | # Test empty query 524 | with pytest.raises(ValidationError, match="Query cannot be empty"): 525 | mock_validator.basic_query_validation("") 526 | 527 | # Test whitespace-only query 528 | with pytest.raises(ValidationError, match="Query cannot be empty"): 529 | mock_validator.basic_query_validation(" \n \t ") 530 | ```