#
tokens: 45299/50000 10/106 files (page 3/6)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 3 of 6. Use http://codebase.md/alexander-zuev/supabase-mcp-server?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .claude
│   └── settings.local.json
├── .dockerignore
├── .env.example
├── .env.test.example
├── .github
│   ├── FUNDING.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.md
│   │   ├── feature_request.md
│   │   └── roadmap_item.md
│   ├── PULL_REQUEST_TEMPLATE.md
│   └── workflows
│       ├── ci.yaml
│       ├── docs
│       │   └── release-checklist.md
│       └── publish.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── CHANGELOG.MD
├── codecov.yml
├── CONTRIBUTING.MD
├── Dockerfile
├── LICENSE
├── llms-full.txt
├── pyproject.toml
├── README.md
├── smithery.yaml
├── supabase_mcp
│   ├── __init__.py
│   ├── clients
│   │   ├── api_client.py
│   │   ├── base_http_client.py
│   │   ├── management_client.py
│   │   └── sdk_client.py
│   ├── core
│   │   ├── __init__.py
│   │   ├── container.py
│   │   └── feature_manager.py
│   ├── exceptions.py
│   ├── logger.py
│   ├── main.py
│   ├── services
│   │   ├── __init__.py
│   │   ├── api
│   │   │   ├── __init__.py
│   │   │   ├── api_manager.py
│   │   │   ├── spec_manager.py
│   │   │   └── specs
│   │   │       └── api_spec.json
│   │   ├── database
│   │   │   ├── __init__.py
│   │   │   ├── migration_manager.py
│   │   │   ├── postgres_client.py
│   │   │   ├── query_manager.py
│   │   │   └── sql
│   │   │       ├── loader.py
│   │   │       ├── models.py
│   │   │       ├── queries
│   │   │       │   ├── create_migration.sql
│   │   │       │   ├── get_migrations.sql
│   │   │       │   ├── get_schemas.sql
│   │   │       │   ├── get_table_schema.sql
│   │   │       │   ├── get_tables.sql
│   │   │       │   ├── init_migrations.sql
│   │   │       │   └── logs
│   │   │       │       ├── auth_logs.sql
│   │   │       │       ├── cron_logs.sql
│   │   │       │       ├── edge_logs.sql
│   │   │       │       ├── function_edge_logs.sql
│   │   │       │       ├── pgbouncer_logs.sql
│   │   │       │       ├── postgres_logs.sql
│   │   │       │       ├── postgrest_logs.sql
│   │   │       │       ├── realtime_logs.sql
│   │   │       │       ├── storage_logs.sql
│   │   │       │       └── supavisor_logs.sql
│   │   │       └── validator.py
│   │   ├── logs
│   │   │   ├── __init__.py
│   │   │   └── log_manager.py
│   │   ├── safety
│   │   │   ├── __init__.py
│   │   │   ├── models.py
│   │   │   ├── safety_configs.py
│   │   │   └── safety_manager.py
│   │   └── sdk
│   │       ├── __init__.py
│   │       ├── auth_admin_models.py
│   │       └── auth_admin_sdk_spec.py
│   ├── settings.py
│   └── tools
│       ├── __init__.py
│       ├── descriptions
│       │   ├── api_tools.yaml
│       │   ├── database_tools.yaml
│       │   ├── logs_and_analytics_tools.yaml
│       │   ├── safety_tools.yaml
│       │   └── sdk_tools.yaml
│       ├── manager.py
│       └── registry.py
├── tests
│   ├── __init__.py
│   ├── conftest.py
│   ├── services
│   │   ├── __init__.py
│   │   ├── api
│   │   │   ├── __init__.py
│   │   │   ├── test_api_client.py
│   │   │   ├── test_api_manager.py
│   │   │   └── test_spec_manager.py
│   │   ├── database
│   │   │   ├── sql
│   │   │   │   ├── __init__.py
│   │   │   │   ├── conftest.py
│   │   │   │   ├── test_loader.py
│   │   │   │   ├── test_sql_validator_integration.py
│   │   │   │   └── test_sql_validator.py
│   │   │   ├── test_migration_manager.py
│   │   │   ├── test_postgres_client.py
│   │   │   └── test_query_manager.py
│   │   ├── logs
│   │   │   └── test_log_manager.py
│   │   ├── safety
│   │   │   ├── test_api_safety_config.py
│   │   │   ├── test_safety_manager.py
│   │   │   └── test_sql_safety_config.py
│   │   └── sdk
│   │       ├── test_auth_admin_models.py
│   │       └── test_sdk_client.py
│   ├── test_container.py
│   ├── test_main.py
│   ├── test_settings.py
│   ├── test_tool_manager.py
│   ├── test_tools_integration.py.bak
│   └── test_tools.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/supabase_mcp/services/api/api_manager.py:
--------------------------------------------------------------------------------

```python
  1 | from __future__ import annotations
  2 | 
  3 | from enum import Enum
  4 | from typing import Any
  5 | 
  6 | from supabase_mcp.clients.management_client import ManagementAPIClient
  7 | from supabase_mcp.logger import logger
  8 | from supabase_mcp.services.api.spec_manager import ApiSpecManager
  9 | from supabase_mcp.services.logs.log_manager import LogManager
 10 | from supabase_mcp.services.safety.models import ClientType
 11 | from supabase_mcp.services.safety.safety_manager import SafetyManager
 12 | from supabase_mcp.settings import settings
 13 | 
 14 | 
 15 | class PathPlaceholder(str, Enum):
 16 |     """Enum of all possible path placeholders in the Supabase Management API."""
 17 | 
 18 |     REF = "ref"
 19 |     FUNCTION_SLUG = "function_slug"
 20 |     ID = "id"
 21 |     SLUG = "slug"
 22 |     BRANCH_ID = "branch_id"
 23 |     PROVIDER_ID = "provider_id"
 24 |     TPA_ID = "tpa_id"
 25 | 
 26 | 
 27 | class SupabaseApiManager:
 28 |     """
 29 |     Manages the Supabase Management API.
 30 |     """
 31 | 
 32 |     _instance: SupabaseApiManager | None = None
 33 | 
 34 |     def __init__(
 35 |         self,
 36 |         api_client: ManagementAPIClient,
 37 |         safety_manager: SafetyManager,
 38 |         spec_manager: ApiSpecManager | None = None,
 39 |         log_manager: LogManager | None = None,
 40 |     ) -> None:
 41 |         """Initialize the API manager."""
 42 |         self.spec_manager = spec_manager or ApiSpecManager()  # this is so that I don't have to pass it
 43 |         self.client = api_client
 44 |         self.safety_manager = safety_manager
 45 |         self.log_manager = log_manager or LogManager()
 46 | 
 47 |     @classmethod
 48 |     def get_instance(
 49 |         cls,
 50 |         api_client: ManagementAPIClient,
 51 |         safety_manager: SafetyManager,
 52 |         spec_manager: ApiSpecManager | None = None,
 53 |     ) -> SupabaseApiManager:
 54 |         """Get the singleton instance"""
 55 |         if cls._instance is None:
 56 |             cls._instance = SupabaseApiManager(api_client, safety_manager, spec_manager)
 57 |         return cls._instance
 58 | 
 59 |     @classmethod
 60 |     def reset(cls) -> None:
 61 |         """Reset the singleton instance"""
 62 |         if cls._instance is not None:
 63 |             cls._instance = None
 64 |             logger.info("SupabaseApiManager instance reset complete")
 65 | 
 66 |     def get_safety_rules(self) -> str:
 67 |         """
 68 |         Get safety rules with human-readable descriptions.
 69 | 
 70 |         Returns:
 71 |             str: Human readable safety rules explanation
 72 |         """
 73 |         # Get safety configuration from the safety manager
 74 |         safety_manager = self.safety_manager
 75 | 
 76 |         # Get risk levels and operations by risk level
 77 |         extreme_risk_ops = safety_manager.get_operations_by_risk_level("extreme", ClientType.API)
 78 |         high_risk_ops = safety_manager.get_operations_by_risk_level("high", ClientType.API)
 79 |         medium_risk_ops = safety_manager.get_operations_by_risk_level("medium", ClientType.API)
 80 | 
 81 |         # Create human-readable explanations
 82 |         extreme_risk_summary = (
 83 |             "\n".join([f"- {method} {path}" for method, paths in extreme_risk_ops.items() for path in paths])
 84 |             if extreme_risk_ops
 85 |             else "None"
 86 |         )
 87 | 
 88 |         high_risk_summary = (
 89 |             "\n".join([f"- {method} {path}" for method, paths in high_risk_ops.items() for path in paths])
 90 |             if high_risk_ops
 91 |             else "None"
 92 |         )
 93 | 
 94 |         medium_risk_summary = (
 95 |             "\n".join([f"- {method} {path}" for method, paths in medium_risk_ops.items() for path in paths])
 96 |             if medium_risk_ops
 97 |             else "None"
 98 |         )
 99 | 
100 |         current_mode = safety_manager.get_current_mode(ClientType.API)
101 | 
102 |         return f"""MCP Server Safety Rules:
103 | 
104 |             EXTREME RISK Operations (never allowed by the server):
105 |             {extreme_risk_summary}
106 | 
107 |             HIGH RISK Operations (require unsafe mode):
108 |             {high_risk_summary}
109 | 
110 |             MEDIUM RISK Operations (require unsafe mode):
111 |             {medium_risk_summary}
112 | 
113 |             All other operations are LOW RISK (always allowed).
114 | 
115 |             Current mode: {current_mode}
116 |             In safe mode, only low risk operations are allowed.
117 |             Use live_dangerously() to enable unsafe mode for medium and high risk operations.
118 |             """
119 | 
120 |     def replace_path_params(self, path: str, path_params: dict[str, Any] | None = None) -> str:
121 |         """
122 |         Replace path parameters in the path string with actual values.
123 | 
124 |         This method:
125 |         1. Automatically injects the project ref from settings
126 |         2. Replaces all placeholders in the path with values from path_params
127 |         3. Validates that all placeholders are replaced
128 | 
129 |         Args:
130 |             path: The API path with placeholders (e.g., "/v1/projects/{ref}/functions/{function_slug}")
131 |             path_params: Dictionary of path parameters to replace (e.g., {"function_slug": "my-function"})
132 | 
133 |         Returns:
134 |             The path with all placeholders replaced
135 | 
136 |         Raises:
137 |             ValueError: If any placeholders remain after replacement or if invalid placeholders are provided
138 |         """
139 |         # Create a working copy of path_params to avoid modifying the original
140 |         working_params = {} if path_params is None else path_params.copy()
141 | 
142 |         # Check if user provided ref and raise an error
143 |         if working_params and PathPlaceholder.REF.value in working_params:
144 |             raise ValueError(
145 |                 "Do not provide 'ref' in path_params. The project reference is automatically injected from settings. "
146 |                 "If you need to change the project reference, modify the environment variables instead."
147 |             )
148 | 
149 |         # Validate that all provided path parameters are known placeholders
150 |         if working_params:
151 |             for key in working_params:
152 |                 try:
153 |                     PathPlaceholder(key)
154 |                 except ValueError as e:
155 |                     raise ValueError(
156 |                         f"Unknown path parameter: '{key}'. Valid placeholders are: "
157 |                         f"{', '.join([p.value for p in PathPlaceholder])}"
158 |                     ) from e
159 | 
160 |         # Get project ref from settings and add it to working_params
161 |         working_params[PathPlaceholder.REF.value] = settings.supabase_project_ref
162 | 
163 |         logger.info(f"Replacing path parameters in path: {working_params}")
164 | 
165 |         # Replace all placeholders in the path
166 |         for key, value in working_params.items():
167 |             placeholder = "{" + key + "}"
168 |             if placeholder in path:
169 |                 path = path.replace(placeholder, str(value))
170 |                 logger.debug(f"Replaced {placeholder} with {value}")
171 | 
172 |         # Check if any placeholders remain
173 |         import re
174 | 
175 |         remaining_placeholders = re.findall(r"\{([^}]+)\}", path)
176 |         if remaining_placeholders:
177 |             raise ValueError(
178 |                 f"Missing path parameters: {', '.join(remaining_placeholders)}. "
179 |                 f"Please provide values for these placeholders in the path_params dictionary."
180 |             )
181 | 
182 |         return path
183 | 
184 |     async def execute_request(
185 |         self,
186 |         method: str,
187 |         path: str,
188 |         path_params: dict[str, Any] | None = None,
189 |         request_params: dict[str, Any] | None = None,
190 |         request_body: dict[str, Any] | None = None,
191 |         has_confirmation: bool = False,
192 |     ) -> dict[str, Any]:
193 |         """
194 |         Execute Management API request with safety validation.
195 | 
196 |         Args:
197 |             method: HTTP method to use
198 |             path: API path to call
199 |             request_params: Query parameters to include
200 |             request_body: Request body to send
201 |             has_confirmation: Whether the operation has been confirmed by the user
202 |         Returns:
203 |             API response as a dictionary
204 | 
205 |         Raises:
206 |             SafetyError: If the operation is not allowed by safety rules
207 |         """
208 |         # Log the request with proper formatting
209 |         logger.info(
210 |             f"API Request: {method} {path} | Path params: {path_params or {}} | Query params: {request_params or {}} | Body: {request_body or {}}"
211 |         )
212 | 
213 |         # Create an operation object for validation
214 |         operation = (method, path, path_params, request_params, request_body)
215 | 
216 |         # Use the safety manager to validate the operation
217 |         logger.debug(f"Validating operation safety: {method} {path}")
218 |         self.safety_manager.validate_operation(ClientType.API, operation, has_confirmation=has_confirmation)
219 | 
220 |         # Replace path parameters in the path string with actual values
221 |         path = self.replace_path_params(path, path_params)
222 | 
223 |         # Execute the request using the API client
224 |         return await self.client.execute_request(method, path, request_params, request_body)
225 | 
226 |     async def handle_confirmation(self, confirmation_id: str) -> dict[str, Any]:
227 |         """Handle a confirmation request."""
228 |         # retrieve the operation from the confirmation id
229 |         operation = self.safety_manager.get_stored_operation(confirmation_id)
230 |         if not operation:
231 |             raise ValueError("No operation found for confirmation id")
232 | 
233 |         # execute the operation
234 |         return await self.execute_request(
235 |             method=operation[0],
236 |             path=operation[1],
237 |             path_params=operation[2],
238 |             request_params=operation[3],
239 |             request_body=operation[4],
240 |             has_confirmation=True,
241 |         )
242 | 
243 |     async def handle_spec_request(
244 |         self,
245 |         path: str | None = None,
246 |         method: str | None = None,
247 |         domain: str | None = None,
248 |         all_paths: bool | None = False,
249 |     ) -> dict[str, Any]:
250 |         """Handle a spec request.
251 | 
252 |         Args:
253 |             path: Optional API path
254 |             method: Optional HTTP method
255 |             api_domain: Optional domain/tag name
256 |             full_spec: If True, returns all paths and methods
257 | 
258 |         Returns:
259 |             API specification based on the provided parameters
260 |         """
261 |         spec_manager = self.spec_manager
262 | 
263 |         if spec_manager is None:
264 |             raise RuntimeError("API spec manager is not initialized")
265 | 
266 |         # Ensure spec is loaded
267 |         await spec_manager.get_spec()
268 | 
269 |         # Option 1: Get spec for specific path and method
270 |         if path and method:
271 |             method = method.lower()  # Normalize method to lowercase
272 |             result = spec_manager.get_spec_for_path_and_method(path, method)
273 |             if result is None:
274 |                 return {"error": f"No specification found for {method.upper()} {path}"}
275 |             return result
276 | 
277 |         # Option 2: Get all paths and methods for a specific domain
278 |         elif domain:
279 |             result = spec_manager.get_paths_and_methods_by_domain(domain)
280 |             if not result:
281 |                 # Check if the domain exists
282 |                 all_domains = spec_manager.get_all_domains()
283 |                 if domain not in all_domains:
284 |                     return {"error": f"Domain '{domain}' not found", "available_domains": all_domains}
285 |             return {"domain": domain, "paths": result}
286 | 
287 |         # Option 4: Get all paths and methods
288 |         elif all_paths:
289 |             return {"paths": spec_manager.get_all_paths_and_methods()}
290 | 
291 |         # Option 3: Get all domains (default)
292 |         else:
293 |             return {"domains": spec_manager.get_all_domains()}
294 | 
295 |     async def retrieve_logs(
296 |         self,
297 |         collection: str,
298 |         limit: int = 20,
299 |         hours_ago: int | None = 1,
300 |         filters: list[dict[str, Any]] | None = None,
301 |         search: str | None = None,
302 |         custom_query: str | None = None,
303 |     ) -> dict[str, Any]:
304 |         """Retrieve logs from a Supabase service.
305 | 
306 |         Args:
307 |             collection: The log collection to query
308 |             limit: Maximum number of log entries to return
309 |             hours_ago: Retrieve logs from the last N hours
310 |             filters: List of filter objects with field, operator, and value
311 |             search: Text to search for in event messages
312 |             custom_query: Complete custom SQL query to execute
313 | 
314 |         Returns:
315 |             The query result
316 | 
317 |         Raises:
318 |             ValueError: If the collection is unknown
319 |         """
320 |         log_manager = self.log_manager
321 | 
322 |         # Build the SQL query using LogManager
323 |         sql = log_manager.build_logs_query(
324 |             collection=collection,
325 |             limit=limit,
326 |             hours_ago=hours_ago,
327 |             filters=filters,
328 |             search=search,
329 |             custom_query=custom_query,
330 |         )
331 | 
332 |         logger.debug(f"Executing log query: {sql}")
333 | 
334 |         # Make the API request
335 |         try:
336 |             response = await self.execute_request(
337 |                 method="GET",
338 |                 path="/v1/projects/{ref}/analytics/endpoints/logs.all",
339 |                 path_params={},
340 |                 request_params={"sql": sql},
341 |                 request_body={},
342 |             )
343 | 
344 |             return response
345 |         except Exception as e:
346 |             logger.error(f"Error retrieving logs: {e}")
347 |             raise
348 | 
```

--------------------------------------------------------------------------------
/tests/services/safety/test_safety_manager.py:
--------------------------------------------------------------------------------

```python
  1 | import time
  2 | 
  3 | import pytest
  4 | 
  5 | from supabase_mcp.exceptions import ConfirmationRequiredError, OperationNotAllowedError
  6 | from supabase_mcp.services.safety.models import ClientType, OperationRiskLevel, SafetyMode
  7 | from supabase_mcp.services.safety.safety_configs import SafetyConfigBase
  8 | from supabase_mcp.services.safety.safety_manager import SafetyManager
  9 | 
 10 | 
 11 | class MockSafetyConfig(SafetyConfigBase[str]):
 12 |     """Mock safety configuration for testing."""
 13 | 
 14 |     def get_risk_level(self, operation: str) -> OperationRiskLevel:
 15 |         """Get the risk level for an operation."""
 16 |         if operation == "low_risk":
 17 |             return OperationRiskLevel.LOW
 18 |         elif operation == "medium_risk":
 19 |             return OperationRiskLevel.MEDIUM
 20 |         elif operation == "high_risk":
 21 |             return OperationRiskLevel.HIGH
 22 |         elif operation == "extreme_risk":
 23 |             return OperationRiskLevel.EXTREME
 24 |         else:
 25 |             return OperationRiskLevel.LOW
 26 | 
 27 | 
 28 | @pytest.mark.unit
 29 | class TestSafetyManager:
 30 |     """Unit test cases for the SafetyManager class."""
 31 | 
 32 |     @pytest.fixture(autouse=True)
 33 |     def setup_and_teardown(self):
 34 |         """Setup and teardown for each test."""
 35 |         # Reset the singleton before each test
 36 |         # pylint: disable=protected-access
 37 |         SafetyManager._instance = None  # type: ignore
 38 |         yield
 39 |         # Reset the singleton after each test
 40 |         SafetyManager._instance = None  # type: ignore
 41 | 
 42 |     def test_singleton_pattern(self):
 43 |         """Test that SafetyManager follows the singleton pattern."""
 44 |         # Get two instances of the SafetyManager
 45 |         manager1 = SafetyManager.get_instance()
 46 |         manager2 = SafetyManager.get_instance()
 47 | 
 48 |         # Verify they are the same instance
 49 |         assert manager1 is manager2
 50 | 
 51 |         # Verify that creating a new instance directly doesn't affect the singleton
 52 |         direct_instance = SafetyManager()
 53 |         assert direct_instance is not manager1
 54 | 
 55 |     def test_register_config(self):
 56 |         """Test registering a safety configuration."""
 57 |         manager = SafetyManager.get_instance()
 58 |         mock_config = MockSafetyConfig()
 59 | 
 60 |         # Register the config for DATABASE client type
 61 |         manager.register_config(ClientType.DATABASE, mock_config)
 62 | 
 63 |         # Verify the config was registered
 64 |         assert manager._safety_configs[ClientType.DATABASE] is mock_config
 65 | 
 66 |         # Test that registering a config for the same client type overwrites the previous config
 67 |         new_mock_config = MockSafetyConfig()
 68 |         manager.register_config(ClientType.DATABASE, new_mock_config)
 69 |         assert manager._safety_configs[ClientType.DATABASE] is new_mock_config
 70 | 
 71 |     def test_get_safety_mode_default(self):
 72 |         """Test getting the default safety mode for an unregistered client type."""
 73 |         manager = SafetyManager.get_instance()
 74 | 
 75 |         # Create a custom client type that hasn't been registered
 76 |         class CustomClientType(str):
 77 |             pass
 78 | 
 79 |         custom_type = CustomClientType("custom")
 80 | 
 81 |         # Verify that getting a safety mode for an unregistered client type returns SafetyMode.SAFE
 82 |         assert manager.get_safety_mode(custom_type) == SafetyMode.SAFE  # type: ignore
 83 | 
 84 |     def test_get_safety_mode_registered(self):
 85 |         """Test getting the safety mode for a registered client type."""
 86 |         manager = SafetyManager.get_instance()
 87 | 
 88 |         # Set a safety mode for a client type
 89 |         manager._safety_modes[ClientType.API] = SafetyMode.UNSAFE
 90 | 
 91 |         # Verify it's returned correctly
 92 |         assert manager.get_safety_mode(ClientType.API) == SafetyMode.UNSAFE
 93 | 
 94 |     def test_set_safety_mode(self):
 95 |         """Test setting the safety mode for a client type."""
 96 |         manager = SafetyManager.get_instance()
 97 | 
 98 |         # Set a safety mode for a client type
 99 |         manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE)
100 | 
101 |         # Verify it was updated
102 |         assert manager._safety_modes[ClientType.DATABASE] == SafetyMode.UNSAFE
103 | 
104 |         # Change it back to SAFE
105 |         manager.set_safety_mode(ClientType.DATABASE, SafetyMode.SAFE)
106 | 
107 |         # Verify it was updated again
108 |         assert manager._safety_modes[ClientType.DATABASE] == SafetyMode.SAFE
109 | 
110 |     def test_validate_operation_allowed(self):
111 |         """Test validating an operation that is allowed."""
112 |         manager = SafetyManager.get_instance()
113 |         mock_config = MockSafetyConfig()
114 | 
115 |         # Register the config
116 |         manager.register_config(ClientType.DATABASE, mock_config)
117 | 
118 |         # Set safety mode to SAFE
119 |         manager.set_safety_mode(ClientType.DATABASE, SafetyMode.SAFE)
120 | 
121 |         # Validate a low risk operation (should be allowed in SAFE mode)
122 |         # This should not raise an exception
123 |         manager.validate_operation(ClientType.DATABASE, "low_risk")
124 | 
125 |         # Set safety mode to UNSAFE
126 |         manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE)
127 | 
128 |         # Validate medium risk operation (should be allowed in UNSAFE mode)
129 |         # This should not raise an exception
130 |         manager.validate_operation(ClientType.DATABASE, "medium_risk")
131 | 
132 |         # High risk operations require confirmation, so we test with confirmation=True
133 |         manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=True)
134 | 
135 |     def test_validate_operation_not_allowed(self):
136 |         """Test validating an operation that is not allowed."""
137 |         manager = SafetyManager.get_instance()
138 |         mock_config = MockSafetyConfig()
139 | 
140 |         # Register the config
141 |         manager.register_config(ClientType.DATABASE, mock_config)
142 | 
143 |         # Set safety mode to SAFE
144 |         manager.set_safety_mode(ClientType.DATABASE, SafetyMode.SAFE)
145 | 
146 |         # Validate medium risk operation (should not be allowed in SAFE mode)
147 |         with pytest.raises(OperationNotAllowedError):
148 |             manager.validate_operation(ClientType.DATABASE, "medium_risk")
149 | 
150 |         # Validate high risk operation (should not be allowed in SAFE mode)
151 |         with pytest.raises(OperationNotAllowedError):
152 |             manager.validate_operation(ClientType.DATABASE, "high_risk")
153 | 
154 |         # Validate extreme risk operation (should not be allowed in SAFE mode)
155 |         with pytest.raises(OperationNotAllowedError):
156 |             manager.validate_operation(ClientType.DATABASE, "extreme_risk")
157 | 
158 |     def test_validate_operation_requires_confirmation(self):
159 |         """Test validating an operation that requires confirmation."""
160 |         manager = SafetyManager.get_instance()
161 |         mock_config = MockSafetyConfig()
162 | 
163 |         # Register the config
164 |         manager.register_config(ClientType.DATABASE, mock_config)
165 | 
166 |         # Set safety mode to UNSAFE
167 |         manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE)
168 | 
169 |         # Validate high risk operation without confirmation
170 |         # Should raise ConfirmationRequiredError
171 |         with pytest.raises(ConfirmationRequiredError):
172 |             manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=False)
173 | 
174 |         # Extreme risk operations are not allowed even in UNSAFE mode
175 |         with pytest.raises(OperationNotAllowedError):
176 |             manager.validate_operation(ClientType.DATABASE, "extreme_risk", has_confirmation=False)
177 | 
178 |         # Even with confirmation, extreme risk operations are not allowed
179 |         with pytest.raises(OperationNotAllowedError):
180 |             manager.validate_operation(ClientType.DATABASE, "extreme_risk", has_confirmation=True)
181 | 
182 |     def test_store_confirmation(self):
183 |         """Test storing a confirmation for an operation."""
184 |         manager = SafetyManager.get_instance()
185 | 
186 |         # Store a confirmation
187 |         confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME)
188 | 
189 |         # Verify that a confirmation ID is returned
190 |         assert confirmation_id is not None
191 |         assert confirmation_id.startswith("conf_")
192 | 
193 |         # Verify that the confirmation can be retrieved
194 |         confirmation = manager._get_confirmation(confirmation_id)
195 |         assert confirmation is not None
196 |         assert confirmation["operation"] == "test_operation"
197 |         assert confirmation["client_type"] == ClientType.DATABASE
198 |         assert confirmation["risk_level"] == OperationRiskLevel.EXTREME
199 |         assert "timestamp" in confirmation
200 | 
201 |     def test_get_confirmation_valid(self):
202 |         """Test getting a valid confirmation."""
203 |         manager = SafetyManager.get_instance()
204 | 
205 |         # Store a confirmation
206 |         confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME)
207 | 
208 |         # Retrieve the confirmation
209 |         confirmation = manager._get_confirmation(confirmation_id)
210 | 
211 |         # Verify it matches what was stored
212 |         assert confirmation is not None
213 |         assert confirmation["operation"] == "test_operation"
214 |         assert confirmation["client_type"] == ClientType.DATABASE
215 |         assert confirmation["risk_level"] == OperationRiskLevel.EXTREME
216 | 
217 |     def test_get_confirmation_invalid(self):
218 |         """Test getting an invalid confirmation."""
219 |         manager = SafetyManager.get_instance()
220 | 
221 |         # Try to retrieve a confirmation with an invalid ID
222 |         confirmation = manager._get_confirmation("invalid_id")
223 | 
224 |         # Verify that None is returned
225 |         assert confirmation is None
226 | 
227 |     def test_get_confirmation_expired(self):
228 |         """Test getting an expired confirmation."""
229 |         manager = SafetyManager.get_instance()
230 | 
231 |         # Store a confirmation with a past expiration time
232 |         confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME)
233 | 
234 |         # Manually set the timestamp to be older than the expiry time
235 |         manager._pending_confirmations[confirmation_id]["timestamp"] = time.time() - manager._confirmation_expiry - 10
236 | 
237 |         # Try to retrieve the confirmation
238 |         confirmation = manager._get_confirmation(confirmation_id)
239 | 
240 |         # Verify that None is returned
241 |         assert confirmation is None
242 | 
243 |     def test_cleanup_expired_confirmations(self):
244 |         """Test cleaning up expired confirmations."""
245 |         manager = SafetyManager.get_instance()
246 | 
247 |         # Store multiple confirmations with different expiration times
248 |         valid_id = manager._store_confirmation(ClientType.DATABASE, "valid_operation", OperationRiskLevel.EXTREME)
249 | 
250 |         expired_id = manager._store_confirmation(ClientType.DATABASE, "expired_operation", OperationRiskLevel.EXTREME)
251 | 
252 |         # Manually set the timestamp of the expired confirmation to be older than the expiry time
253 |         manager._pending_confirmations[expired_id]["timestamp"] = time.time() - manager._confirmation_expiry - 10
254 | 
255 |         # Call cleanup
256 |         manager._cleanup_expired_confirmations()
257 | 
258 |         # Verify that expired confirmations are removed and valid ones remain
259 |         assert valid_id in manager._pending_confirmations
260 |         assert expired_id not in manager._pending_confirmations
261 | 
262 |     def test_get_stored_operation(self):
263 |         """Test getting a stored operation."""
264 |         manager = SafetyManager.get_instance()
265 | 
266 |         # Store a confirmation for an operation
267 |         confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME)
268 | 
269 |         # Retrieve the operation
270 |         operation = manager.get_stored_operation(confirmation_id)
271 | 
272 |         # Verify that the retrieved operation matches the original
273 |         assert operation == "test_operation"
274 | 
275 |         # Test with an invalid ID
276 |         assert manager.get_stored_operation("invalid_id") is None
277 | 
278 |     def test_integration_validate_and_confirm(self):
279 |         """Test the full flow of validating an operation that requires confirmation and then confirming it."""
280 |         manager = SafetyManager.get_instance()
281 |         mock_config = MockSafetyConfig()
282 | 
283 |         # Register the config
284 |         manager.register_config(ClientType.DATABASE, mock_config)
285 | 
286 |         # Set safety mode to UNSAFE
287 |         manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE)
288 | 
289 |         # Try to validate a high risk operation and catch the ConfirmationRequiredError
290 |         confirmation_id = None
291 |         try:
292 |             manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=False)
293 |         except ConfirmationRequiredError as e:
294 |             # Extract the confirmation ID from the error message
295 |             error_message = str(e)
296 |             # Find the confirmation ID in the message
297 |             import re
298 | 
299 |             match = re.search(r"ID: (conf_[a-f0-9]+)", error_message)
300 |             if match:
301 |                 confirmation_id = match.group(1)
302 | 
303 |         # Verify that we got a confirmation ID
304 |         assert confirmation_id is not None
305 | 
306 |         # Now validate the operation again with the confirmation ID
307 |         # This should not raise an exception
308 |         manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=True)
309 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/services/database/sql/validator.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Any
  2 | 
  3 | from pglast.parser import ParseError, parse_sql
  4 | 
  5 | from supabase_mcp.exceptions import ValidationError
  6 | from supabase_mcp.logger import logger
  7 | from supabase_mcp.services.database.sql.models import (
  8 |     QueryValidationResults,
  9 |     SQLQueryCategory,
 10 |     SQLQueryCommand,
 11 |     ValidatedStatement,
 12 | )
 13 | from supabase_mcp.services.safety.safety_configs import SQLSafetyConfig
 14 | 
 15 | 
 16 | class SQLValidator:
 17 |     """SQL validator class that is based on pglast library.
 18 | 
 19 |     Responsible for:
 20 |     - SQL query syntax validation
 21 |     - SQL query categorization"""
 22 | 
 23 |     # Mapping from statement types to object types
 24 |     STATEMENT_TYPE_TO_OBJECT_TYPE = {
 25 |         "CreateFunctionStmt": "function",
 26 |         "ViewStmt": "view",
 27 |         "CreateTableAsStmt": "materialized_view",  # When relkind is 'm', otherwise 'table'
 28 |         "CreateEnumStmt": "type",
 29 |         "CreateTypeStmt": "type",
 30 |         "CreateExtensionStmt": "extension",
 31 |         "CreateForeignTableStmt": "foreign_table",
 32 |         "CreatePolicyStmt": "policy",
 33 |         "CreateTrigStmt": "trigger",
 34 |         "IndexStmt": "index",
 35 |         "CreateStmt": "table",
 36 |         "AlterTableStmt": "table",
 37 |         "GrantStmt": "privilege",
 38 |         "RevokeStmt": "privilege",
 39 |         "CreateProcStmt": "procedure",  # For CREATE PROCEDURE
 40 |     }
 41 | 
 42 |     def __init__(self, safety_config: SQLSafetyConfig | None = None) -> None:
 43 |         self.safety_config = safety_config or SQLSafetyConfig()
 44 | 
 45 |     def validate_schema_name(self, schema_name: str) -> str:
 46 |         """Validate schema name.
 47 | 
 48 |         Rules:
 49 |         - Must be a string
 50 |         - Cannot be empty
 51 |         - Cannot contain spaces or special characters
 52 |         """
 53 |         if not schema_name.strip():
 54 |             raise ValidationError("Schema name cannot be empty")
 55 |         if " " in schema_name:
 56 |             raise ValidationError("Schema name cannot contain spaces")
 57 |         return schema_name
 58 | 
 59 |     def validate_table_name(self, table: str) -> str:
 60 |         """Validate table name.
 61 | 
 62 |         Rules:
 63 |         - Must be a string
 64 |         - Cannot be empty
 65 |         - Cannot contain spaces or special characters
 66 |         """
 67 |         if not table.strip():
 68 |             raise ValidationError("Table name cannot be empty")
 69 |         if " " in table:
 70 |             raise ValidationError("Table name cannot contain spaces")
 71 |         return table
 72 | 
 73 |     def basic_query_validation(self, query: str) -> str:
 74 |         """Validate SQL query.
 75 | 
 76 |         Rules:
 77 |         - Must be a string
 78 |         - Cannot be empty
 79 |         """
 80 |         if not query.strip():
 81 |             raise ValidationError("Query cannot be empty")
 82 |         return query
 83 | 
 84 |     @classmethod
 85 |     def validate_transaction_control(cls, query: str) -> bool:
 86 |         """Check if the query contains transaction control statements.
 87 | 
 88 |         Args:
 89 |             query: SQL query string
 90 | 
 91 |         Returns:
 92 |             bool: True if the query contains any transaction control statements
 93 |         """
 94 |         return any(x in query.upper() for x in ["BEGIN", "COMMIT", "ROLLBACK"])
 95 | 
 96 |     def validate_query(self, sql_query: str) -> QueryValidationResults:
 97 |         """
 98 |         Identify the type of SQL query using PostgreSQL's parser.
 99 | 
100 |         Args:
101 |             sql_query: A SQL query string to parse
102 | 
103 |         Returns:
104 |             QueryValidationResults: A validation result object containing information about the SQL statements
105 |         Raises:
106 |             ValidationError: If the query is not valid or contains TCL statements
107 |         """
108 |         try:
109 |             # Validate raw input
110 |             sql_query = self.basic_query_validation(sql_query)
111 | 
112 |             # Parse the SQL using PostgreSQL's parser
113 |             parse_tree = parse_sql(sql_query)
114 |             if parse_tree is None:
115 |                 logger.debug("No statements found in the query")
116 |             # logger.debug(f"Parse tree generated with {parse_tree} statements")
117 | 
118 |             # Validate statements
119 |             result = self.validate_statements(original_query=sql_query, parse_tree=parse_tree)
120 | 
121 |             # Check if the query contains transaction control statements and reject them
122 |             for statement in result.statements:
123 |                 if statement.category == SQLQueryCategory.TCL:
124 |                     logger.warning(f"Transaction control statement detected: {statement.command}")
125 |                     raise ValidationError(
126 |                         "Transaction control statements (BEGIN, COMMIT, ROLLBACK) are not allowed. "
127 |                         "Queries will be automatically wrapped in transactions by the system."
128 |                     )
129 | 
130 |             return result
131 |         except ParseError as e:
132 |             logger.exception(f"SQL syntax error: {str(e)}")
133 |             raise ValidationError(f"SQL syntax error: {str(e)}") from e
134 |         except ValidationError:
135 |             # let it propagate
136 |             raise
137 |         except Exception as e:
138 |             logger.exception(f"Unexpected error during SQL validation: {str(e)}")
139 |             raise ValidationError(f"Unexpected error during SQL validation: {str(e)}") from e
140 | 
141 |     def _map_to_command(self, stmt_type: str) -> SQLQueryCommand:
142 |         """Map a pglast statement type to our SQLQueryCommand enum."""
143 | 
144 |         mapping = {
145 |             # DQL Commands
146 |             "SelectStmt": SQLQueryCommand.SELECT,
147 |             # DML Commands
148 |             "InsertStmt": SQLQueryCommand.INSERT,
149 |             "UpdateStmt": SQLQueryCommand.UPDATE,
150 |             "DeleteStmt": SQLQueryCommand.DELETE,
151 |             "MergeStmt": SQLQueryCommand.MERGE,
152 |             # DDL Commands
153 |             "CreateStmt": SQLQueryCommand.CREATE,
154 |             "CreateTableAsStmt": SQLQueryCommand.CREATE,
155 |             "CreateSchemaStmt": SQLQueryCommand.CREATE,
156 |             "CreateExtensionStmt": SQLQueryCommand.CREATE,
157 |             "CreateFunctionStmt": SQLQueryCommand.CREATE,
158 |             "CreateTrigStmt": SQLQueryCommand.CREATE,
159 |             "ViewStmt": SQLQueryCommand.CREATE,
160 |             "IndexStmt": SQLQueryCommand.CREATE,
161 |             # Additional DDL Commands
162 |             "CreateEnumStmt": SQLQueryCommand.CREATE,
163 |             "CreateTypeStmt": SQLQueryCommand.CREATE,
164 |             "CreateDomainStmt": SQLQueryCommand.CREATE,
165 |             "CreateSeqStmt": SQLQueryCommand.CREATE,
166 |             "CreateForeignTableStmt": SQLQueryCommand.CREATE,
167 |             "CreatePolicyStmt": SQLQueryCommand.CREATE,
168 |             "CreateCastStmt": SQLQueryCommand.CREATE,
169 |             "CreateOpClassStmt": SQLQueryCommand.CREATE,
170 |             "CreateOpFamilyStmt": SQLQueryCommand.CREATE,
171 |             "AlterTableStmt": SQLQueryCommand.ALTER,
172 |             "AlterDomainStmt": SQLQueryCommand.ALTER,
173 |             "AlterEnumStmt": SQLQueryCommand.ALTER,
174 |             "AlterSeqStmt": SQLQueryCommand.ALTER,
175 |             "AlterOwnerStmt": SQLQueryCommand.ALTER,
176 |             "AlterObjectSchemaStmt": SQLQueryCommand.ALTER,
177 |             "DropStmt": SQLQueryCommand.DROP,
178 |             "TruncateStmt": SQLQueryCommand.TRUNCATE,
179 |             "CommentStmt": SQLQueryCommand.COMMENT,
180 |             "RenameStmt": SQLQueryCommand.RENAME,
181 |             # DCL Commands
182 |             "GrantStmt": SQLQueryCommand.GRANT,
183 |             "GrantRoleStmt": SQLQueryCommand.GRANT,
184 |             "RevokeStmt": SQLQueryCommand.REVOKE,
185 |             "RevokeRoleStmt": SQLQueryCommand.REVOKE,
186 |             "CreateRoleStmt": SQLQueryCommand.CREATE,
187 |             "AlterRoleStmt": SQLQueryCommand.ALTER,
188 |             "DropRoleStmt": SQLQueryCommand.DROP,
189 |             # TCL Commands
190 |             "TransactionStmt": SQLQueryCommand.BEGIN,  # Will need refinement for different transaction types
191 |             # PostgreSQL-specific Commands
192 |             "VacuumStmt": SQLQueryCommand.VACUUM,
193 |             "ExplainStmt": SQLQueryCommand.EXPLAIN,
194 |             "CopyStmt": SQLQueryCommand.COPY,
195 |             "ListenStmt": SQLQueryCommand.LISTEN,
196 |             "NotifyStmt": SQLQueryCommand.NOTIFY,
197 |             "PrepareStmt": SQLQueryCommand.PREPARE,
198 |             "ExecuteStmt": SQLQueryCommand.EXECUTE,
199 |             "DeallocateStmt": SQLQueryCommand.DEALLOCATE,
200 |         }
201 | 
202 |         # Try to map the statement type, default to UNKNOWN
203 |         return mapping.get(stmt_type, SQLQueryCommand.UNKNOWN)
204 | 
205 |     def validate_statements(self, original_query: str, parse_tree: Any) -> QueryValidationResults:
206 |         """Validate the statements in the parse tree.
207 | 
208 |         Args:
209 |             parse_tree: The parse tree to validate
210 | 
211 |         Returns:
212 |             SQLBatchValidationResult: A validation result object containing information about the SQL statements
213 |         Raises:
214 |             ValidationError: If the query is not valid
215 |         """
216 |         result = QueryValidationResults(original_query=original_query)
217 | 
218 |         if parse_tree is None:
219 |             return result
220 | 
221 |         try:
222 |             for stmt in parse_tree:
223 |                 if not hasattr(stmt, "stmt"):
224 |                     continue
225 | 
226 |                 stmt_node = stmt.stmt
227 |                 stmt_type = stmt_node.__class__.__name__
228 |                 logger.debug(f"Processing statement node type: {stmt_type}")
229 |                 # logger.debug(f"DEBUGGING stmt_node: {stmt_node}")
230 |                 logger.debug(f"DEBUGGING stmt_node.stmt_location: {stmt.stmt_location}")
231 | 
232 |                 # Extract the object type if available
233 |                 object_type = None
234 |                 schema_name = None
235 |                 if hasattr(stmt_node, "relation") and stmt_node.relation is not None:
236 |                     if hasattr(stmt_node.relation, "relname"):
237 |                         object_type = stmt_node.relation.relname
238 |                     if hasattr(stmt_node.relation, "schemaname"):
239 |                         schema_name = stmt_node.relation.schemaname
240 |                 # For statements with 'relations' list (like TRUNCATE)
241 |                 elif hasattr(stmt_node, "relations") and stmt_node.relations:
242 |                     for relation in stmt_node.relations:
243 |                         if hasattr(relation, "relname"):
244 |                             object_type = relation.relname
245 |                         if hasattr(relation, "schemaname"):
246 |                             schema_name = relation.schemaname
247 |                         break
248 | 
249 |                 # Simple approach: Set object_type based on statement type if not already set
250 |                 if object_type is None and stmt_type in self.STATEMENT_TYPE_TO_OBJECT_TYPE:
251 |                     object_type = self.STATEMENT_TYPE_TO_OBJECT_TYPE[stmt_type]
252 | 
253 |                 # Default schema to public if not set
254 |                 if schema_name is None:
255 |                     schema_name = "public"
256 | 
257 |                 # Get classification for this statement type
258 |                 classification = self.safety_config.classify_statement(stmt_type, stmt_node)
259 |                 logger.debug(
260 |                     f"Statement category classified as: {classification.get('category', 'UNKNOWN')} - risk level: {classification.get('risk_level', 'UNKNOWN')}"
261 |                 )
262 |                 logger.debug(f"DEBUGGING QUERY EXTRACTION LOCATION: {stmt.stmt_location} - {stmt.stmt_len}")
263 | 
264 |                 # Create validation result
265 |                 query_result = ValidatedStatement(
266 |                     category=classification["category"],
267 |                     command=self._map_to_command(stmt_type),
268 |                     risk_level=classification["risk_level"],
269 |                     needs_migration=classification["needs_migration"],
270 |                     object_type=object_type,
271 |                     schema_name=schema_name,
272 |                     query=original_query[stmt.stmt_location : stmt.stmt_location + stmt.stmt_len]
273 |                     if hasattr(stmt, "stmt_location") and hasattr(stmt, "stmt_len")
274 |                     else None,
275 |                 )
276 |                 # logger.debug(f"Isolated query: {query_result.query}")
277 |                 logger.debug(
278 |                     "Query validation result:",
279 |                     {
280 |                         "statement_category": query_result.category,
281 |                         "risk_level": query_result.risk_level,
282 |                         "needs_migration": query_result.needs_migration,
283 |                         "object_type": query_result.object_type,
284 |                         "schema_name": query_result.schema_name,
285 |                         "query": query_result.query,
286 |                     },
287 |                 )
288 | 
289 |                 # Add result to the batch
290 |                 result.statements.append(query_result)
291 | 
292 |                 # Update highest risk level
293 |                 if query_result.risk_level > result.highest_risk_level:
294 |                     result.highest_risk_level = query_result.risk_level
295 |                     logger.debug(f"Updated batch validation result to: {query_result.risk_level}")
296 |             if len(result.statements) == 0:
297 |                 logger.debug("No valid statements found in the query")
298 |                 raise ValidationError("No queries were parsed - please check correctness of your query")
299 |             logger.debug(
300 |                 f"Validated a total of {len(result.statements)} with the highest risk level of: {result.highest_risk_level}"
301 |             )
302 |             return result
303 | 
304 |         except AttributeError as e:
305 |             # Handle attempting to access missing attributes in the parse tree
306 |             raise ValidationError(f"Error accessing parse tree structure: {str(e)}") from e
307 |         except KeyError as e:
308 |             # Handle missing keys in classification dictionary
309 |             raise ValidationError(f"Missing classification key: {str(e)}") from e
310 | 
```

--------------------------------------------------------------------------------
/tests/services/database/test_postgres_client.py:
--------------------------------------------------------------------------------

```python
  1 | import asyncpg
  2 | import pytest
  3 | from unittest.mock import AsyncMock, MagicMock, patch
  4 | 
  5 | from supabase_mcp.exceptions import ConnectionError, QueryError, PermissionError as SupabasePermissionError
  6 | from supabase_mcp.services.database.postgres_client import PostgresClient, QueryResult, StatementResult
  7 | from supabase_mcp.services.database.sql.validator import (
  8 |     QueryValidationResults,
  9 |     SQLQueryCategory,
 10 |     SQLQueryCommand,
 11 |     ValidatedStatement,
 12 | )
 13 | from supabase_mcp.services.safety.models import OperationRiskLevel
 14 | from supabase_mcp.settings import Settings
 15 | 
 16 | 
 17 | @pytest.mark.asyncio(loop_scope="class")
 18 | class TestPostgresClient:
 19 |     """Unit tests for the Postgres client."""
 20 | 
 21 |     @pytest.fixture
 22 |     def mock_settings(self):
 23 |         """Create mock settings for testing."""
 24 |         settings = MagicMock(spec=Settings)
 25 |         settings.supabase_project_ref = "test-project-ref"
 26 |         settings.supabase_db_password = "test-password"
 27 |         settings.supabase_region = "us-east-1"
 28 |         settings.database_url = "postgresql://test:test@localhost:5432/test"
 29 |         return settings
 30 | 
 31 |     @pytest.fixture
 32 |     async def mock_postgres_client(self, mock_settings):
 33 |         """Create a mock Postgres client for testing."""
 34 |         # Reset the singleton first
 35 |         await PostgresClient.reset()
 36 |         
 37 |         # Create client and mock execute_query directly
 38 |         client = PostgresClient(settings=mock_settings)
 39 |         return client
 40 | 
 41 |     async def test_execute_simple_select(self, mock_postgres_client: PostgresClient):
 42 |         """Test executing a simple SELECT query."""
 43 |         # Create a simple validation result with a SELECT query
 44 |         query = "SELECT 1 as number;"
 45 |         statement = ValidatedStatement(
 46 |             query=query,
 47 |             command=SQLQueryCommand.SELECT,
 48 |             category=SQLQueryCategory.DQL,
 49 |             risk_level=OperationRiskLevel.LOW,
 50 |             needs_migration=False,
 51 |             object_type=None,
 52 |             schema_name=None,
 53 |         )
 54 |         validation_result = QueryValidationResults(
 55 |             statements=[statement],
 56 |             original_query=query,
 57 |             highest_risk_level=OperationRiskLevel.LOW,
 58 |         )
 59 | 
 60 |         # Mock the query result
 61 |         expected_result = QueryResult(results=[
 62 |             StatementResult(rows=[{"number": 1}])
 63 |         ])
 64 |         
 65 |         with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result):
 66 |             # Execute the query
 67 |             result = await mock_postgres_client.execute_query(validation_result)
 68 | 
 69 |             # Verify the result
 70 |             assert isinstance(result, QueryResult)
 71 |             assert len(result.results) == 1
 72 |             assert isinstance(result.results[0], StatementResult)
 73 |             assert len(result.results[0].rows) == 1
 74 |             assert result.results[0].rows[0]["number"] == 1
 75 | 
 76 |     async def test_execute_multiple_statements(self, mock_postgres_client: PostgresClient):
 77 |         """Test executing multiple SQL statements in a single query."""
 78 |         # Create validation result with multiple statements
 79 |         query = "SELECT 1 as first; SELECT 2 as second;"
 80 |         statements = [
 81 |             ValidatedStatement(
 82 |                 query="SELECT 1 as first;",
 83 |                 command=SQLQueryCommand.SELECT,
 84 |                 category=SQLQueryCategory.DQL,
 85 |                 risk_level=OperationRiskLevel.LOW,
 86 |                 needs_migration=False,
 87 |                 object_type=None,
 88 |                 schema_name=None,
 89 |             ),
 90 |             ValidatedStatement(
 91 |                 query="SELECT 2 as second;",
 92 |                 command=SQLQueryCommand.SELECT,
 93 |                 category=SQLQueryCategory.DQL,
 94 |                 risk_level=OperationRiskLevel.LOW,
 95 |                 needs_migration=False,
 96 |                 object_type=None,
 97 |                 schema_name=None,
 98 |             ),
 99 |         ]
100 |         validation_result = QueryValidationResults(
101 |             statements=statements,
102 |             original_query=query,
103 |             highest_risk_level=OperationRiskLevel.LOW,
104 |         )
105 | 
106 |         # Mock the query result
107 |         expected_result = QueryResult(results=[
108 |             StatementResult(rows=[{"first": 1}]),
109 |             StatementResult(rows=[{"second": 2}])
110 |         ])
111 |         
112 |         with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result):
113 |             # Execute the query
114 |             result = await mock_postgres_client.execute_query(validation_result)
115 | 
116 |             # Verify the result
117 |             assert isinstance(result, QueryResult)
118 |             assert len(result.results) == 2
119 |             assert result.results[0].rows[0]["first"] == 1
120 |             assert result.results[1].rows[0]["second"] == 2
121 | 
122 |     async def test_execute_query_with_parameters(self, mock_postgres_client: PostgresClient):
123 |         """Test executing a query with parameters."""
124 |         query = "SELECT 'test' as name, 42 as value;"
125 |         statement = ValidatedStatement(
126 |             query=query,
127 |             command=SQLQueryCommand.SELECT,
128 |             category=SQLQueryCategory.DQL,
129 |             risk_level=OperationRiskLevel.LOW,
130 |             needs_migration=False,
131 |             object_type=None,
132 |             schema_name=None,
133 |         )
134 |         validation_result = QueryValidationResults(
135 |             statements=[statement],
136 |             original_query=query,
137 |             highest_risk_level=OperationRiskLevel.LOW,
138 |         )
139 | 
140 |         # Mock the query result
141 |         expected_result = QueryResult(results=[
142 |             StatementResult(rows=[{"name": "test", "value": 42}])
143 |         ])
144 |         
145 |         with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result):
146 |             # Execute the query
147 |             result = await mock_postgres_client.execute_query(validation_result)
148 | 
149 |             # Verify the result
150 |             assert isinstance(result, QueryResult)
151 |             assert len(result.results) == 1
152 |             assert result.results[0].rows[0]["name"] == "test"
153 |             assert result.results[0].rows[0]["value"] == 42
154 | 
155 |     async def test_permission_error(self, mock_postgres_client: PostgresClient):
156 |         """Test handling a permission error."""
157 |         # Create a mock error
158 |         error = asyncpg.exceptions.InsufficientPrivilegeError("Permission denied")
159 | 
160 |         # Verify that the method raises PermissionError with the expected message
161 |         with pytest.raises(SupabasePermissionError) as exc_info:
162 |             await mock_postgres_client._handle_postgres_error(error)
163 |         
164 |         # Verify the error message
165 |         assert "Access denied" in str(exc_info.value)
166 |         assert "Permission denied" in str(exc_info.value)
167 |         assert "live_dangerously" in str(exc_info.value)
168 | 
169 |     async def test_query_error(self, mock_postgres_client: PostgresClient):
170 |         """Test handling a query error."""
171 |         # Create a validation result with a syntactically valid but semantically incorrect query
172 |         query = "SELECT * FROM nonexistent_table;"
173 |         statement = ValidatedStatement(
174 |             query=query,
175 |             command=SQLQueryCommand.SELECT,
176 |             category=SQLQueryCategory.DQL,
177 |             risk_level=OperationRiskLevel.LOW,
178 |             needs_migration=False,
179 |             object_type="TABLE",
180 |             schema_name="public",
181 |         )
182 |         validation_result = QueryValidationResults(
183 |             statements=[statement],
184 |             original_query=query,
185 |             highest_risk_level=OperationRiskLevel.LOW,
186 |         )
187 | 
188 |         # Mock execute_query to raise a QueryError
189 |         with patch.object(mock_postgres_client, 'execute_query', 
190 |                          side_effect=QueryError("relation \"nonexistent_table\" does not exist")):
191 |             # Execute the query - should raise a QueryError
192 |             with pytest.raises(QueryError) as excinfo:
193 |                 await mock_postgres_client.execute_query(validation_result)
194 | 
195 |             # Verify the error message contains the specific error
196 |             assert "nonexistent_table" in str(excinfo.value)
197 | 
198 |     async def test_schema_error(self, mock_postgres_client: PostgresClient):
199 |         """Test handling a schema error."""
200 |         # Create a validation result with a query referencing a non-existent column
201 |         query = "SELECT nonexistent_column FROM information_schema.tables;"
202 |         statement = ValidatedStatement(
203 |             query=query,
204 |             command=SQLQueryCommand.SELECT,
205 |             category=SQLQueryCategory.DQL,
206 |             risk_level=OperationRiskLevel.LOW,
207 |             needs_migration=False,
208 |             object_type="TABLE",
209 |             schema_name="information_schema",
210 |         )
211 |         validation_result = QueryValidationResults(
212 |             statements=[statement],
213 |             original_query=query,
214 |             highest_risk_level=OperationRiskLevel.LOW,
215 |         )
216 | 
217 |         # Mock execute_query to raise a QueryError
218 |         with patch.object(mock_postgres_client, 'execute_query',
219 |                          side_effect=QueryError("column \"nonexistent_column\" does not exist")):
220 |             # Execute the query - should raise a QueryError
221 |             with pytest.raises(QueryError) as excinfo:
222 |                 await mock_postgres_client.execute_query(validation_result)
223 | 
224 |             # Verify the error message contains the specific error
225 |             assert "nonexistent_column" in str(excinfo.value)
226 | 
227 |     async def test_write_operation(self, mock_postgres_client: PostgresClient):
228 |         """Test a basic write operation (INSERT)."""
229 |         # Create insert query
230 |         insert_query = "INSERT INTO test_write (name) VALUES ('test_value') RETURNING id, name;"
231 |         insert_statement = ValidatedStatement(
232 |             query=insert_query,
233 |             command=SQLQueryCommand.INSERT,
234 |             category=SQLQueryCategory.DML,
235 |             risk_level=OperationRiskLevel.MEDIUM,
236 |             needs_migration=False,
237 |             object_type="TABLE",
238 |             schema_name="public",
239 |         )
240 |         insert_validation = QueryValidationResults(
241 |             statements=[insert_statement],
242 |             original_query=insert_query,
243 |             highest_risk_level=OperationRiskLevel.MEDIUM,
244 |         )
245 | 
246 |         # Mock the query result
247 |         expected_result = QueryResult(results=[
248 |             StatementResult(rows=[{"id": 1, "name": "test_value"}])
249 |         ])
250 |         
251 |         with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result):
252 |             # Execute the insert query
253 |             result = await mock_postgres_client.execute_query(insert_validation, readonly=False)
254 | 
255 |             # Verify the result
256 |             assert isinstance(result, QueryResult)
257 |             assert len(result.results) == 1
258 |             assert result.results[0].rows[0]["name"] == "test_value"
259 |             assert result.results[0].rows[0]["id"] == 1
260 | 
261 |     async def test_ddl_operation(self, mock_postgres_client: PostgresClient):
262 |         """Test a basic DDL operation (CREATE TABLE)."""
263 |         # Create a test table
264 |         create_query = "CREATE TEMPORARY TABLE test_ddl (id SERIAL PRIMARY KEY, value TEXT);"
265 |         create_statement = ValidatedStatement(
266 |             query=create_query,
267 |             command=SQLQueryCommand.CREATE,
268 |             category=SQLQueryCategory.DDL,
269 |             risk_level=OperationRiskLevel.MEDIUM,
270 |             needs_migration=False,
271 |             object_type="TABLE",
272 |             schema_name="public",
273 |         )
274 |         create_validation = QueryValidationResults(
275 |             statements=[create_statement],
276 |             original_query=create_query,
277 |             highest_risk_level=OperationRiskLevel.MEDIUM,
278 |         )
279 | 
280 |         # Mock the query result - DDL typically returns empty results
281 |         expected_result = QueryResult(results=[
282 |             StatementResult(rows=[])
283 |         ])
284 |         
285 |         with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result):
286 |             # Execute the create table query
287 |             result = await mock_postgres_client.execute_query(create_validation, readonly=False)
288 | 
289 |             # Verify the result
290 |             assert isinstance(result, QueryResult)
291 |             assert len(result.results) == 1
292 |             # DDL operations typically don't return rows
293 |             assert result.results[0].rows == []
294 | 
295 |     async def test_execute_metadata_query(self, mock_postgres_client: PostgresClient):
296 |         """Test executing a metadata query."""
297 |         # Create a simple validation result with a SELECT query
298 |         query = "SELECT schema_name FROM information_schema.schemata LIMIT 5;"
299 |         statement = ValidatedStatement(
300 |             query=query,
301 |             command=SQLQueryCommand.SELECT,
302 |             category=SQLQueryCategory.DQL,
303 |             risk_level=OperationRiskLevel.LOW,
304 |             needs_migration=False,
305 |             object_type="schemata",
306 |             schema_name="information_schema",
307 |         )
308 |         validation_result = QueryValidationResults(
309 |             statements=[statement],
310 |             original_query=query,
311 |             highest_risk_level=OperationRiskLevel.LOW,
312 |         )
313 | 
314 |         # Mock the query result
315 |         expected_result = QueryResult(results=[
316 |             StatementResult(rows=[
317 |                 {"schema_name": "public"},
318 |                 {"schema_name": "information_schema"},
319 |                 {"schema_name": "pg_catalog"},
320 |                 {"schema_name": "auth"},
321 |                 {"schema_name": "storage"}
322 |             ])
323 |         ])
324 |         
325 |         with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result):
326 |             # Execute the query
327 |             result = await mock_postgres_client.execute_query(validation_result)
328 | 
329 |             # Verify the result
330 |             assert isinstance(result, QueryResult)
331 |             assert len(result.results) == 1
332 |             assert len(result.results[0].rows) == 5
333 |             assert "schema_name" in result.results[0].rows[0]
334 | 
335 |     async def test_connection_retry_mechanism(self, mock_postgres_client: PostgresClient):
336 |         """Test that the tenacity retry mechanism works correctly for database connections."""
337 |         # Reset the pool
338 |         mock_postgres_client._pool = None
339 |         
340 |         # Mock create_pool to always raise a connection error
341 |         with patch.object(mock_postgres_client, 'create_pool', 
342 |                          side_effect=ConnectionError("Could not connect to database")):
343 |             # This should trigger the retry mechanism and eventually fail
344 |             with pytest.raises(ConnectionError) as exc_info:
345 |                 await mock_postgres_client.ensure_pool()
346 | 
347 |             # Verify the error message indicates a connection failure after retries
348 |             assert "Could not connect to database" in str(exc_info.value)
```

--------------------------------------------------------------------------------
/supabase_mcp/services/sdk/auth_admin_sdk_spec.py:
--------------------------------------------------------------------------------

```python
  1 | def get_auth_admin_methods_spec() -> dict:
  2 |     """Returns a detailed specification of all Auth Admin methods."""
  3 |     return {
  4 |         "get_user_by_id": {
  5 |             "description": "Retrieve a user by their ID",
  6 |             "parameters": {"uid": {"type": "string", "description": "The user's UUID", "required": True}},
  7 |             "returns": {"type": "object", "description": "User object containing all user data"},
  8 |             "example": {
  9 |                 "request": {"uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"},
 10 |                 "response": {
 11 |                     "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b",
 12 |                     "email": "[email protected]",
 13 |                     "phone": "",
 14 |                     "created_at": "2023-01-01T00:00:00Z",
 15 |                     "confirmed_at": "2023-01-01T00:00:00Z",
 16 |                     "last_sign_in_at": "2023-01-01T00:00:00Z",
 17 |                     "user_metadata": {"name": "John Doe"},
 18 |                     "app_metadata": {},
 19 |                 },
 20 |             },
 21 |         },
 22 |         "list_users": {
 23 |             "description": "List all users with pagination",
 24 |             "parameters": {
 25 |                 "page": {
 26 |                     "type": "integer",
 27 |                     "description": "Page number (starts at 1)",
 28 |                     "required": False,
 29 |                     "default": 1,
 30 |                 },
 31 |                 "per_page": {
 32 |                     "type": "integer",
 33 |                     "description": "Number of users per page",
 34 |                     "required": False,
 35 |                     "default": 50,
 36 |                 },
 37 |             },
 38 |             "returns": {"type": "object", "description": "Paginated list of users with metadata"},
 39 |             "example": {
 40 |                 "request": {"page": 1, "per_page": 10},
 41 |                 "response": {
 42 |                     "users": [
 43 |                         {
 44 |                             "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b",
 45 |                             "email": "[email protected]",
 46 |                             "user_metadata": {"name": "John Doe"},
 47 |                         }
 48 |                     ],
 49 |                     "aud": "authenticated",
 50 |                     "total_count": 100,
 51 |                     "next_page": 2,
 52 |                 },
 53 |             },
 54 |         },
 55 |         "create_user": {
 56 |             "description": "Create a new user. Does not send a confirmation email by default.",
 57 |             "parameters": {
 58 |                 "email": {"type": "string", "description": "The user's email address"},
 59 |                 "password": {"type": "string", "description": "The user's password"},
 60 |                 "email_confirm": {
 61 |                     "type": "boolean",
 62 |                     "description": "Confirms the user's email address if set to true",
 63 |                     "default": False,
 64 |                 },
 65 |                 "phone": {"type": "string", "description": "The user's phone number with country code"},
 66 |                 "phone_confirm": {
 67 |                     "type": "boolean",
 68 |                     "description": "Confirms the user's phone number if set to true",
 69 |                     "default": False,
 70 |                 },
 71 |                 "user_metadata": {
 72 |                     "type": "object",
 73 |                     "description": "A custom data object to store the user's metadata",
 74 |                 },
 75 |                 "app_metadata": {
 76 |                     "type": "object",
 77 |                     "description": "A custom data object to store the user's application specific metadata",
 78 |                 },
 79 |                 "role": {"type": "string", "description": "The role claim set in the user's access token JWT"},
 80 |                 "ban_duration": {"type": "string", "description": "Determines how long a user is banned for"},
 81 |                 "nonce": {
 82 |                     "type": "string",
 83 |                     "description": "The nonce (required for reauthentication if updating password)",
 84 |                 },
 85 |             },
 86 |             "returns": {"type": "object", "description": "Created user object"},
 87 |             "example": {
 88 |                 "request": {
 89 |                     "email": "[email protected]",
 90 |                     "password": "secure-password",
 91 |                     "email_confirm": True,
 92 |                     "user_metadata": {"name": "New User"},
 93 |                 },
 94 |                 "response": {
 95 |                     "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b",
 96 |                     "email": "[email protected]",
 97 |                     "email_confirmed_at": "2023-01-01T00:00:00Z",
 98 |                     "user_metadata": {"name": "New User"},
 99 |                 },
100 |             },
101 |             "notes": "Either email or phone must be provided. Use invite_user_by_email() if you want to send an email invite.",
102 |         },
103 |         "delete_user": {
104 |             "description": "Delete a user by their ID. Requires a service_role key.",
105 |             "parameters": {
106 |                 "id": {"type": "string", "description": "The user's UUID", "required": True},
107 |                 "should_soft_delete": {
108 |                     "type": "boolean",
109 |                     "description": "If true, the user will be soft-deleted (preserving their data but disabling the account). Defaults to false.",
110 |                     "required": False,
111 |                     "default": False,
112 |                 },
113 |             },
114 |             "returns": {"type": "object", "description": "Success message"},
115 |             "example": {
116 |                 "request": {"id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"},
117 |                 "response": {"message": "User deleted successfully"},
118 |             },
119 |             "notes": "This function should only be called on a server. Never expose your service_role key in the browser.",
120 |         },
121 |         "invite_user_by_email": {
122 |             "description": "Sends an invite link to a user's email address. Typically used by administrators to invite users to join the application.",
123 |             "parameters": {
124 |                 "email": {"type": "string", "description": "The email address of the user", "required": True},
125 |                 "options": {
126 |                     "type": "object",
127 |                     "description": "Optional settings for the invite",
128 |                     "required": False,
129 |                     "properties": {
130 |                         "data": {
131 |                             "type": "object",
132 |                             "description": "A custom data object to store additional metadata about the user. Maps to auth.users.user_metadata",
133 |                             "required": False,
134 |                         },
135 |                         "redirect_to": {
136 |                             "type": "string",
137 |                             "description": "The URL which will be appended to the email link. Once clicked the user will end up on this URL",
138 |                             "required": False,
139 |                         },
140 |                     },
141 |                 },
142 |             },
143 |             "returns": {"type": "object", "description": "User object for the invited user"},
144 |             "example": {
145 |                 "request": {
146 |                     "email": "[email protected]",
147 |                     "options": {"data": {"name": "John Doe"}, "redirect_to": "https://example.com/welcome"},
148 |                 },
149 |                 "response": {
150 |                     "id": "a1a1a1a1-a1a1-a1a1-a1a1-a1a1a1a1a1a1",
151 |                     "email": "[email protected]",
152 |                     "role": "authenticated",
153 |                     "email_confirmed_at": None,
154 |                     "invited_at": "2023-01-01T00:00:00Z",
155 |                 },
156 |             },
157 |             "notes": "Note that PKCE is not supported when using invite_user_by_email. This is because the browser initiating the invite is often different from the browser accepting the invite.",
158 |         },
159 |         "generate_link": {
160 |             "description": "Generate an email link for various authentication purposes. Handles user creation for signup, invite and magiclink types.",
161 |             "parameters": {
162 |                 "type": {
163 |                     "type": "string",
164 |                     "description": "Link type: 'signup', 'invite', 'magiclink', 'recovery', 'email_change_current', 'email_change_new', 'phone_change'",
165 |                     "required": True,
166 |                     "enum": [
167 |                         "signup",
168 |                         "invite",
169 |                         "magiclink",
170 |                         "recovery",
171 |                         "email_change_current",
172 |                         "email_change_new",
173 |                         "phone_change",
174 |                     ],
175 |                 },
176 |                 "email": {"type": "string", "description": "User's email address", "required": True},
177 |                 "password": {
178 |                     "type": "string",
179 |                     "description": "User's password. Only required if type is signup",
180 |                     "required": False,
181 |                 },
182 |                 "new_email": {
183 |                     "type": "string",
184 |                     "description": "New email address. Only required if type is email_change_current or email_change_new",
185 |                     "required": False,
186 |                 },
187 |                 "options": {
188 |                     "type": "object",
189 |                     "description": "Additional options for the link",
190 |                     "required": False,
191 |                     "properties": {
192 |                         "data": {
193 |                             "type": "object",
194 |                             "description": "Custom JSON object containing user metadata. Only accepted if type is signup, invite, or magiclink",
195 |                             "required": False,
196 |                         },
197 |                         "redirect_to": {
198 |                             "type": "string",
199 |                             "description": "A redirect URL which will be appended to the generated email link",
200 |                             "required": False,
201 |                         },
202 |                     },
203 |                 },
204 |             },
205 |             "returns": {"type": "object", "description": "Generated link details"},
206 |             "example": {
207 |                 "request": {
208 |                     "type": "signup",
209 |                     "email": "[email protected]",
210 |                     "password": "secure-password",
211 |                     "options": {"data": {"name": "John Doe"}, "redirect_to": "https://example.com/welcome"},
212 |                 },
213 |                 "response": {
214 |                     "action_link": "https://your-project.supabase.co/auth/v1/verify?token=...",
215 |                     "email_otp": "123456",
216 |                     "hashed_token": "...",
217 |                     "redirect_to": "https://example.com/welcome",
218 |                     "verification_type": "signup",
219 |                 },
220 |             },
221 |             "notes": "generate_link() only generates the email link for email_change_email if the Secure email change is enabled in your project's email auth provider settings.",
222 |         },
223 |         "update_user_by_id": {
224 |             "description": "Update user attributes by ID. Requires a service_role key.",
225 |             "parameters": {
226 |                 "uid": {"type": "string", "description": "The user's UUID", "required": True},
227 |                 "attributes": {
228 |                     "type": "object",
229 |                     "description": "The user attributes to update.",
230 |                     "required": True,
231 |                     "properties": {
232 |                         "email": {"type": "string", "description": "The user's email"},
233 |                         "phone": {"type": "string", "description": "The user's phone"},
234 |                         "password": {"type": "string", "description": "The user's password"},
235 |                         "email_confirm": {
236 |                             "type": "boolean",
237 |                             "description": "Confirms the user's email address if set to true",
238 |                         },
239 |                         "phone_confirm": {
240 |                             "type": "boolean",
241 |                             "description": "Confirms the user's phone number if set to true",
242 |                         },
243 |                         "user_metadata": {
244 |                             "type": "object",
245 |                             "description": "A custom data object to store the user's metadata.",
246 |                         },
247 |                         "app_metadata": {
248 |                             "type": "object",
249 |                             "description": "A custom data object to store the user's application specific metadata.",
250 |                         },
251 |                         "role": {
252 |                             "type": "string",
253 |                             "description": "The role claim set in the user's access token JWT",
254 |                         },
255 |                         "ban_duration": {
256 |                             "type": "string",
257 |                             "description": "Determines how long a user is banned for",
258 |                         },
259 |                         "nonce": {
260 |                             "type": "string",
261 |                             "description": "The nonce sent for reauthentication if the user's password is to be updated",
262 |                         },
263 |                     },
264 |                 },
265 |             },
266 |             "returns": {"type": "object", "description": "Updated user object"},
267 |             "example": {
268 |                 "request": {
269 |                     "uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b",
270 |                     "attributes": {"email": "[email protected]", "user_metadata": {"name": "Updated Name"}},
271 |                 },
272 |                 "response": {
273 |                     "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b",
274 |                     "email": "[email protected]",
275 |                     "user_metadata": {"name": "Updated Name"},
276 |                 },
277 |             },
278 |             "notes": "This function should only be called on a server. Never expose your service_role key in the browser.",
279 |         },
280 |         "delete_factor": {
281 |             "description": "Deletes a factor on a user. This will log the user out of all active sessions if the deleted factor was verified.",
282 |             "parameters": {
283 |                 "user_id": {
284 |                     "type": "string",
285 |                     "description": "ID of the user whose factor is being deleted",
286 |                     "required": True,
287 |                 },
288 |                 "id": {"type": "string", "description": "ID of the MFA factor to delete", "required": True},
289 |             },
290 |             "returns": {"type": "object", "description": "Success message"},
291 |             "example": {
292 |                 "request": {"user_id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "id": "totp-factor-id-123"},
293 |                 "response": {"message": "Factor deleted successfully"},
294 |             },
295 |             "notes": "This will log the user out of all active sessions if the deleted factor was verified.",
296 |         },
297 |     }
298 | 
```

--------------------------------------------------------------------------------
/tests/services/sdk/test_sdk_client.py:
--------------------------------------------------------------------------------

```python
  1 | import time
  2 | import uuid
  3 | from datetime import datetime
  4 | from unittest.mock import AsyncMock, MagicMock, patch
  5 | 
  6 | import pytest
  7 | 
  8 | from supabase_mcp.clients.sdk_client import SupabaseSDKClient
  9 | from supabase_mcp.exceptions import PythonSDKError
 10 | from supabase_mcp.settings import Settings
 11 | 
 12 | # Unique identifier for test users to avoid conflicts
 13 | TEST_ID = f"test-{int(time.time())}-{uuid.uuid4().hex[:6]}"
 14 | 
 15 | 
 16 | # Create unique test emails
 17 | def get_test_email(prefix: str = "user"):
 18 |     """Generate a unique test email"""
 19 |     return f"a.zuev+{prefix}-{TEST_ID}@outlook.com"
 20 | 
 21 | 
 22 | @pytest.mark.asyncio(loop_scope="module")
 23 | class TestSDKClientIntegration:
 24 |     """
 25 |     Unit tests for the SupabaseSDKClient.
 26 |     """
 27 | 
 28 |     @pytest.fixture
 29 |     def mock_settings(self):
 30 |         """Create mock settings for testing."""
 31 |         settings = MagicMock(spec=Settings)
 32 |         settings.supabase_project_ref = "test-project-ref"
 33 |         settings.supabase_service_role_key = "test-service-role-key"
 34 |         settings.supabase_region = "us-east-1"
 35 |         settings.supabase_url = "https://test-project-ref.supabase.co"
 36 |         return settings
 37 | 
 38 |     @pytest.fixture
 39 |     async def mock_sdk_client(self, mock_settings):
 40 |         """Create a mock SDK client for testing."""
 41 |         # Reset singleton
 42 |         SupabaseSDKClient.reset()
 43 |         
 44 |         # Mock the Supabase client
 45 |         mock_supabase = MagicMock()
 46 |         mock_auth_admin = MagicMock()
 47 |         mock_supabase.auth.admin = mock_auth_admin
 48 |         
 49 |         # Mock the create_async_client function to return our mock client
 50 |         with patch('supabase_mcp.clients.sdk_client.create_async_client', return_value=mock_supabase):
 51 |             # Create client - this will now use our mocked create_async_client
 52 |             client = SupabaseSDKClient.get_instance(settings=mock_settings)
 53 |             # Manually set the client to ensure it's available
 54 |             client.client = mock_supabase
 55 |             
 56 |         return client
 57 | 
 58 |     async def test_list_users(self, mock_sdk_client: SupabaseSDKClient):
 59 |         """Test listing users with pagination"""
 60 |         # Mock user data
 61 |         mock_users = [
 62 |             MagicMock(id="user1", email="[email protected]", user_metadata={}),
 63 |             MagicMock(id="user2", email="[email protected]", user_metadata={})
 64 |         ]
 65 |         
 66 |         # Mock the list_users method as an async function
 67 |         mock_sdk_client.client.auth.admin.list_users = AsyncMock(return_value=mock_users)
 68 |         
 69 |         # Create test parameters
 70 |         list_params = {"page": 1, "per_page": 10}
 71 | 
 72 |         # List users
 73 |         result = await mock_sdk_client.call_auth_admin_method("list_users", list_params)
 74 | 
 75 |         # Verify response format
 76 |         assert result is not None
 77 |         assert hasattr(result, "__iter__")  # Should be iterable (list of users)
 78 |         assert len(result) == 2
 79 | 
 80 |         # Check that the first user has expected attributes
 81 |         first_user = result[0]
 82 |         assert hasattr(first_user, "id")
 83 |         assert hasattr(first_user, "email")
 84 |         assert hasattr(first_user, "user_metadata")
 85 | 
 86 |         # Test with invalid parameters - mock the validation error
 87 |         mock_sdk_client.client.auth.admin.list_users = AsyncMock(side_effect=Exception("Bad Pagination Parameters"))
 88 |         
 89 |         invalid_params = {"page": -1, "per_page": 10}
 90 |         with pytest.raises(PythonSDKError) as excinfo:
 91 |             await mock_sdk_client.call_auth_admin_method("list_users", invalid_params)
 92 | 
 93 |         assert "Bad Pagination Parameters" in str(excinfo.value)
 94 | 
 95 |     async def test_get_user_by_id(self, mock_sdk_client: SupabaseSDKClient):
 96 |         """Test retrieving a user by ID"""
 97 |         # Mock user data
 98 |         test_email = get_test_email("get")
 99 |         user_id = str(uuid.uuid4())
100 |         
101 |         mock_user = MagicMock(
102 |             id=user_id,
103 |             email=test_email,
104 |             user_metadata={"name": "Test User", "test_id": TEST_ID}
105 |         )
106 |         mock_response = MagicMock(user=mock_user)
107 |         
108 |         # Mock the get_user_by_id method as an async function
109 |         mock_sdk_client.client.auth.admin.get_user_by_id = AsyncMock(return_value=mock_response)
110 |         
111 |         # Get the user by ID
112 |         get_params = {"uid": user_id}
113 |         get_result = await mock_sdk_client.call_auth_admin_method("get_user_by_id", get_params)
114 | 
115 |         # Verify user data
116 |         assert get_result is not None
117 |         assert hasattr(get_result, "user")
118 |         assert get_result.user.id == user_id
119 |         assert get_result.user.email == test_email
120 | 
121 |         # Test with invalid parameters (non-existent user ID)
122 |         mock_sdk_client.client.auth.admin.get_user_by_id = AsyncMock(side_effect=Exception("user_id must be an UUID"))
123 |         
124 |         invalid_params = {"uid": "non-existent-user-id"}
125 |         with pytest.raises(PythonSDKError) as excinfo:
126 |             await mock_sdk_client.call_auth_admin_method("get_user_by_id", invalid_params)
127 | 
128 |         assert "user_id must be an UUID" in str(excinfo.value)
129 | 
130 |     async def test_create_user(self, mock_sdk_client: SupabaseSDKClient):
131 |         """Test creating a new user"""
132 |         # Create a new test user
133 |         test_email = get_test_email("create")
134 |         user_id = str(uuid.uuid4())
135 |         
136 |         mock_user = MagicMock(
137 |             id=user_id,
138 |             email=test_email,
139 |             user_metadata={"name": "Test User", "test_id": TEST_ID}
140 |         )
141 |         mock_response = MagicMock(user=mock_user)
142 |         
143 |         # Mock the create_user method as an async function
144 |         mock_sdk_client.client.auth.admin.create_user = AsyncMock(return_value=mock_response)
145 |         
146 |         create_params = {
147 |             "email": test_email,
148 |             "password": f"Password123!{TEST_ID}",
149 |             "email_confirm": True,
150 |             "user_metadata": {"name": "Test User", "test_id": TEST_ID},
151 |         }
152 | 
153 |         # Create the user
154 |         create_result = await mock_sdk_client.call_auth_admin_method("create_user", create_params)
155 |         assert create_result is not None
156 |         assert hasattr(create_result, "user")
157 |         assert hasattr(create_result.user, "id")
158 |         assert create_result.user.id == user_id
159 | 
160 |         # Test with invalid parameters (missing required fields)
161 |         mock_sdk_client.client.auth.admin.create_user = AsyncMock(side_effect=Exception("Invalid parameters"))
162 |         
163 |         invalid_params = {"user_metadata": {"name": "Invalid User"}}
164 |         with pytest.raises(PythonSDKError) as excinfo:
165 |             await mock_sdk_client.call_auth_admin_method("create_user", invalid_params)
166 | 
167 |         assert "Invalid parameters" in str(excinfo.value)
168 | 
169 |     async def test_update_user_by_id(self, mock_sdk_client: SupabaseSDKClient):
170 |         """Test updating a user's attributes"""
171 |         # Mock user data
172 |         test_email = get_test_email("update")
173 |         user_id = str(uuid.uuid4())
174 |         
175 |         mock_user = MagicMock(
176 |             id=user_id,
177 |             email=test_email,
178 |             user_metadata={"email": "[email protected]"}
179 |         )
180 |         mock_response = MagicMock(user=mock_user)
181 |         
182 |         # Mock the update_user_by_id method as an async function
183 |         mock_sdk_client.client.auth.admin.update_user_by_id = AsyncMock(return_value=mock_response)
184 |         
185 |         # Update the user
186 |         update_params = {
187 |             "uid": user_id,
188 |             "attributes": {
189 |                 "user_metadata": {
190 |                     "email": "[email protected]",
191 |                 }
192 |             },
193 |         }
194 | 
195 |         update_result = await mock_sdk_client.call_auth_admin_method("update_user_by_id", update_params)
196 | 
197 |         # Verify user was updated
198 |         assert update_result is not None
199 |         assert hasattr(update_result, "user")
200 |         assert update_result.user.id == user_id
201 |         assert update_result.user.user_metadata["email"] == "[email protected]"
202 | 
203 |         # Test with invalid parameters (non-existent user ID)
204 |         mock_sdk_client.client.auth.admin.update_user_by_id = AsyncMock(side_effect=Exception("user_id must be an uuid"))
205 |         
206 |         invalid_params = {
207 |             "uid": "non-existent-user-id",
208 |             "attributes": {"user_metadata": {"name": "Invalid Update"}},
209 |         }
210 |         with pytest.raises(PythonSDKError) as excinfo:
211 |             await mock_sdk_client.call_auth_admin_method("update_user_by_id", invalid_params)
212 | 
213 |         assert "user_id must be an uuid" in str(excinfo.value).lower()
214 | 
215 |     async def test_delete_user(self, mock_sdk_client: SupabaseSDKClient):
216 |         """Test deleting a user"""
217 |         # Mock user data
218 |         user_id = str(uuid.uuid4())
219 |         
220 |         # Mock the delete_user method as an async function to return None (success)
221 |         mock_sdk_client.client.auth.admin.delete_user = AsyncMock(return_value=None)
222 |         
223 |         # Delete the user
224 |         delete_params = {"id": user_id}
225 |         # The delete_user method returns None on success
226 |         result = await mock_sdk_client.call_auth_admin_method("delete_user", delete_params)
227 |         assert result is None
228 | 
229 |         # Test with invalid parameters (non-UUID format user ID)
230 |         mock_sdk_client.client.auth.admin.delete_user = AsyncMock(side_effect=Exception("user_id must be an uuid"))
231 |         
232 |         invalid_params = {"id": "non-existent-user-id"}
233 |         with pytest.raises(PythonSDKError) as excinfo:
234 |             await mock_sdk_client.call_auth_admin_method("delete_user", invalid_params)
235 | 
236 |         assert "user_id must be an uuid" in str(excinfo.value).lower()
237 | 
238 |     async def test_invite_user_by_email(self, mock_sdk_client: SupabaseSDKClient):
239 |         """Test inviting a user by email"""
240 |         # Mock user data
241 |         test_email = get_test_email("invite")
242 |         user_id = str(uuid.uuid4())
243 |         
244 |         mock_user = MagicMock(
245 |             id=user_id,
246 |             email=test_email,
247 |             invited_at=datetime.now().isoformat()
248 |         )
249 |         mock_response = MagicMock(user=mock_user)
250 |         
251 |         # Mock the invite_user_by_email method as an async function
252 |         mock_sdk_client.client.auth.admin.invite_user_by_email = AsyncMock(return_value=mock_response)
253 |         
254 |         # Create invite parameters
255 |         invite_params = {
256 |             "email": test_email,
257 |             "options": {"data": {"name": "Invited User", "test_id": TEST_ID, "invited_at": datetime.now().isoformat()}},
258 |         }
259 | 
260 |         # Invite the user
261 |         result = await mock_sdk_client.call_auth_admin_method("invite_user_by_email", invite_params)
262 | 
263 |         # Verify response
264 |         assert result is not None
265 |         assert hasattr(result, "user")
266 |         assert result.user.email == test_email
267 |         assert hasattr(result.user, "invited_at")
268 | 
269 |         # Test with invalid parameters (missing email)
270 |         mock_sdk_client.client.auth.admin.invite_user_by_email = AsyncMock(side_effect=Exception("Invalid parameters"))
271 |         
272 |         invalid_params = {"options": {"data": {"name": "Invalid Invite"}}}
273 |         with pytest.raises(PythonSDKError) as excinfo:
274 |             await mock_sdk_client.call_auth_admin_method("invite_user_by_email", invalid_params)
275 | 
276 |         assert "Invalid parameters" in str(excinfo.value)
277 | 
278 |     async def test_generate_link(self, mock_sdk_client: SupabaseSDKClient):
279 |         """Test generating authentication links"""
280 |         # Mock response for generate_link
281 |         mock_properties = MagicMock(action_link="https://example.com/auth/link")
282 |         mock_response = MagicMock(properties=mock_properties)
283 |         
284 |         # Mock the generate_link method as an async function
285 |         mock_sdk_client.client.auth.admin.generate_link = AsyncMock(return_value=mock_response)
286 |         
287 |         # Test signup link
288 |         link_params = {
289 |             "type": "signup",
290 |             "email": get_test_email("signup"),
291 |             "password": f"Password123!{TEST_ID}",
292 |             "options": {
293 |                 "data": {"name": "Signup User", "test_id": TEST_ID},
294 |                 "redirect_to": "https://example.com/welcome",
295 |             },
296 |         }
297 | 
298 |         # Generate link
299 |         result = await mock_sdk_client.call_auth_admin_method("generate_link", link_params)
300 | 
301 |         # Verify response
302 |         assert result is not None
303 |         assert hasattr(result, "properties")
304 |         assert hasattr(result.properties, "action_link")
305 | 
306 |         # Test with invalid parameters (invalid link type)
307 |         mock_sdk_client.client.auth.admin.generate_link = AsyncMock(side_effect=Exception("Invalid parameters"))
308 |         
309 |         invalid_params = {"type": "invalid_type", "email": get_test_email("invalid")}
310 |         with pytest.raises(PythonSDKError) as excinfo:
311 |             await mock_sdk_client.call_auth_admin_method("generate_link", invalid_params)
312 | 
313 |         assert "Invalid parameters" in str(excinfo.value) or "invalid type" in str(excinfo.value).lower()
314 | 
315 |     async def test_delete_factor(self, mock_sdk_client: SupabaseSDKClient):
316 |         """Test deleting an MFA factor"""
317 |         # Mock the delete_factor method as an async function to raise not implemented
318 |         mock_sdk_client.client.auth.admin.delete_factor = AsyncMock(side_effect=AttributeError("method not found"))
319 |         
320 |         # Attempt to delete a factor
321 |         delete_factor_params = {"user_id": str(uuid.uuid4()), "id": "non-existent-factor-id"}
322 | 
323 |         with pytest.raises(PythonSDKError) as excinfo:
324 |             await mock_sdk_client.call_auth_admin_method("delete_factor", delete_factor_params)
325 |         
326 |         # We expect this to fail with a specific error message
327 |         assert "not implemented" in str(excinfo.value).lower() or "method not found" in str(excinfo.value).lower()
328 | 
329 |     async def test_empty_parameters(self, mock_sdk_client: SupabaseSDKClient):
330 |         """Test validation errors with empty parameters for various methods"""
331 |         # Test methods with empty parameters
332 |         methods = ["get_user_by_id", "create_user", "update_user_by_id", "delete_user", "generate_link"]
333 | 
334 |         for method in methods:
335 |             empty_params = {}
336 |             
337 |             # Mock the method to raise validation error
338 |             setattr(mock_sdk_client.client.auth.admin, method, AsyncMock(side_effect=Exception("Invalid parameters")))
339 | 
340 |             # Should raise PythonSDKError containing validation error details
341 |             with pytest.raises(PythonSDKError) as excinfo:
342 |                 await mock_sdk_client.call_auth_admin_method(method, empty_params)
343 | 
344 |             # Verify error message contains validation details
345 |             assert "Invalid parameters" in str(excinfo.value) or "validation error" in str(excinfo.value).lower()
346 | 
347 |     async def test_client_without_service_role_key(self, mock_settings):
348 |         """Test that an exception is raised when attempting to use the SDK client without a service role key."""
349 |         # Create settings without service role key
350 |         mock_settings.supabase_service_role_key = None
351 |         
352 |         # Reset singleton
353 |         SupabaseSDKClient.reset()
354 |         
355 |         # Create client
356 |         client = SupabaseSDKClient.get_instance(settings=mock_settings)
357 | 
358 |         # Attempt to call a method - should raise an exception
359 |         with pytest.raises(PythonSDKError) as excinfo:
360 |             await client.call_auth_admin_method("list_users", {})
361 |             
362 |         assert "service role key is not configured" in str(excinfo.value)
```

--------------------------------------------------------------------------------
/supabase_mcp/services/database/postgres_client.py:
--------------------------------------------------------------------------------

```python
  1 | from __future__ import annotations
  2 | 
  3 | import urllib.parse
  4 | from collections.abc import Awaitable, Callable
  5 | from typing import Any, TypeVar
  6 | 
  7 | import asyncpg
  8 | from pydantic import BaseModel, Field
  9 | from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential
 10 | 
 11 | from supabase_mcp.exceptions import ConnectionError, PermissionError, QueryError
 12 | from supabase_mcp.logger import logger
 13 | from supabase_mcp.services.database.sql.models import QueryValidationResults
 14 | from supabase_mcp.services.database.sql.validator import SQLValidator
 15 | from supabase_mcp.settings import Settings
 16 | 
 17 | # Define a type variable for generic return types
 18 | T = TypeVar("T")
 19 | 
 20 | # TODO: Use a context manager to properly handle the connection pool
 21 | 
 22 | 
 23 | class StatementResult(BaseModel):
 24 |     """Represents the result of a single SQL statement."""
 25 | 
 26 |     rows: list[dict[str, Any]] = Field(
 27 |         default_factory=list,
 28 |         description="List of rows returned by the statement. Is empty if the statement is a DDL statement.",
 29 |     )
 30 | 
 31 | 
 32 | class QueryResult(BaseModel):
 33 |     """Represents results of query execution, consisting of one or more statements."""
 34 | 
 35 |     results: list[StatementResult] = Field(
 36 |         description="List of results from the statements in the query.",
 37 |     )
 38 | 
 39 | 
 40 | # Helper function for retry decorator to safely log exceptions
 41 | def log_db_retry_attempt(retry_state: RetryCallState) -> None:
 42 |     """Log database retry attempts.
 43 | 
 44 |     Args:
 45 |         retry_state: Current retry state from tenacity
 46 |     """
 47 |     if retry_state.outcome is not None and retry_state.outcome.failed:
 48 |         exception = retry_state.outcome.exception()
 49 |         exception_str = str(exception)
 50 |         logger.warning(f"Database error, retrying ({retry_state.attempt_number}/3): {exception_str}")
 51 | 
 52 | 
 53 | # Add the new AsyncSupabaseClient class
 54 | class PostgresClient:
 55 |     """Asynchronous client for interacting with Supabase PostgreSQL database."""
 56 | 
 57 |     _instance: PostgresClient | None = None  # Singleton instance
 58 | 
 59 |     def __init__(
 60 |         self,
 61 |         settings: Settings,
 62 |         project_ref: str | None = None,
 63 |         db_password: str | None = None,
 64 |         db_region: str | None = None,
 65 |     ):
 66 |         """Initialize client configuration (but don't connect yet).
 67 | 
 68 |         Args:
 69 |             settings_instance: Settings instance to use for configuration.
 70 |             project_ref: Optional Supabase project reference. If not provided, will be taken from settings.
 71 |             db_password: Optional database password. If not provided, will be taken from settings.
 72 |             db_region: Optional database region. If not provided, will be taken from settings.
 73 |         """
 74 |         self._pool: asyncpg.Pool[asyncpg.Record] | None = None
 75 |         self._settings = settings
 76 |         self.project_ref = project_ref or self._settings.supabase_project_ref
 77 |         self.db_password = db_password or self._settings.supabase_db_password
 78 |         self.db_region = db_region or self._settings.supabase_region
 79 |         self.db_url = self._build_connection_string()
 80 |         self.sql_validator: SQLValidator = SQLValidator()
 81 | 
 82 |         # Only log once during initialization with clear project info
 83 |         is_local = self.project_ref.startswith("127.0.0.1")
 84 |         logger.info(
 85 |             f"✔️ PostgreSQL client initialized successfully for {'local' if is_local else 'remote'} "
 86 |             f"project: {self.project_ref} (region: {self.db_region})"
 87 |         )
 88 | 
 89 |     @classmethod
 90 |     def get_instance(
 91 |         cls,
 92 |         settings: Settings,
 93 |         project_ref: str | None = None,
 94 |         db_password: str | None = None,
 95 |     ) -> PostgresClient:
 96 |         """Create and return a configured AsyncSupabaseClient instance.
 97 | 
 98 |         This is the recommended way to create a client instance.
 99 | 
100 |         Args:
101 |             settings_instance: Settings instance to use for configuration
102 |             project_ref: Optional Supabase project reference
103 |             db_password: Optional database password
104 | 
105 |         Returns:
106 |             Configured AsyncSupabaseClient instance
107 |         """
108 |         if cls._instance is None:
109 |             cls._instance = cls(
110 |                 settings=settings,
111 |                 project_ref=project_ref,
112 |                 db_password=db_password,
113 |             )
114 |             # Doesn't connect yet - will connect lazily when needed
115 |         return cls._instance
116 | 
117 |     def _build_connection_string(self) -> str:
118 |         """Build the database connection string for asyncpg.
119 | 
120 |         Returns:
121 |             PostgreSQL connection string compatible with asyncpg
122 |         """
123 |         encoded_password = urllib.parse.quote_plus(self.db_password)
124 | 
125 |         if self.project_ref.startswith("127.0.0.1"):
126 |             # Local development
127 |             connection_string = f"postgresql://postgres:{encoded_password}@{self.project_ref}/postgres"
128 |             return connection_string
129 | 
130 |         # Production Supabase - via transaction pooler
131 |         connection_string = (
132 |             f"postgresql://postgres.{self.project_ref}:{encoded_password}"
133 |             f"@aws-0-{self._settings.supabase_region}.pooler.supabase.com:6543/postgres"
134 |         )
135 |         return connection_string
136 | 
137 |     @retry(
138 |         retry=retry_if_exception_type(
139 |             (
140 |                 asyncpg.exceptions.ConnectionDoesNotExistError,  # Connection lost
141 |                 asyncpg.exceptions.InterfaceError,  # Connection disruption
142 |                 asyncpg.exceptions.TooManyConnectionsError,  # Temporary connection limit
143 |                 OSError,  # Network issues
144 |             )
145 |         ),
146 |         stop=stop_after_attempt(3),
147 |         wait=wait_exponential(multiplier=1, min=2, max=10),
148 |         before_sleep=log_db_retry_attempt,
149 |     )
150 |     async def create_pool(self) -> asyncpg.Pool[asyncpg.Record]:
151 |         """Create and configure a database connection pool.
152 | 
153 |         Returns:
154 |             Configured asyncpg connection pool
155 | 
156 |         Raises:
157 |             ConnectionError: If unable to establish a connection to the database
158 |         """
159 |         try:
160 |             logger.debug(f"Creating connection pool for project: {self.project_ref}")
161 | 
162 |             # Create the pool with optimal settings
163 |             pool = await asyncpg.create_pool(
164 |                 self.db_url,
165 |                 min_size=2,  # Minimum connections to keep ready
166 |                 max_size=10,  # Maximum connections allowed (same as current)
167 |                 statement_cache_size=0,
168 |                 command_timeout=30.0,  # Command timeout in seconds
169 |                 max_inactive_connection_lifetime=300.0,  # 5 minutes
170 |             )
171 | 
172 |             # Test the connection with a simple query
173 |             async with pool.acquire() as conn:
174 |                 await conn.execute("SELECT 1")
175 | 
176 |             logger.info("✓ Database connection established successfully")
177 |             return pool
178 | 
179 |         except asyncpg.PostgresError as e:
180 |             # Extract connection details for better error reporting
181 |             host_part = self.db_url.split("@")[1].split("/")[0] if "@" in self.db_url else "unknown"
182 | 
183 |             # Check specifically for the "Tenant or user not found" error which is often caused by region mismatch
184 |             if "Tenant or user not found" in str(e):
185 |                 error_message = (
186 |                     "CONNECTION ERROR: Region mismatch detected!\n\n"
187 |                     f"Could not connect to Supabase project '{self.project_ref}'.\n\n"
188 |                     "This error typically occurs when your SUPABASE_REGION setting doesn't match your project's actual region.\n"
189 |                     f"Your configuration is using region: '{self.db_region}' (default: us-east-1)\n\n"
190 |                     "ACTION REQUIRED: Please set the correct SUPABASE_REGION in your MCP server configuration.\n"
191 |                     "You can find your project's region in the Supabase dashboard under Project Settings."
192 |                 )
193 |             else:
194 |                 error_message = (
195 |                     f"Could not connect to database: {e}\n"
196 |                     f"Connection attempted to: {host_part}\n via Transaction Pooler\n"
197 |                     f"Project ref: {self.project_ref}\n"
198 |                     f"Region: {self.db_region}\n\n"
199 |                     f"Please check:\n"
200 |                     f"1. Your Supabase project reference is correct\n"
201 |                     f"2. Your database password is correct\n"
202 |                     f"3. Your region setting matches your Supabase project region\n"
203 |                     f"4. Your Supabase project is active and the database is online\n"
204 |                 )
205 | 
206 |             logger.error(f"Failed to connect to database: {e}")
207 |             logger.error(f"Connection details: {host_part}, Project: {self.project_ref}, Region: {self.db_region}")
208 | 
209 |             raise ConnectionError(error_message) from e
210 | 
211 |         except OSError as e:
212 |             # For network-related errors, provide a different message that clearly indicates
213 |             # this is a network/system issue rather than a database configuration problem
214 |             host_part = self.db_url.split("@")[1].split("/")[0] if "@" in self.db_url else "unknown"
215 | 
216 |             error_message = (
217 |                 f"Network error while connecting to database: {e}\n"
218 |                 f"Connection attempted to: {host_part}\n\n"
219 |                 f"This appears to be a network or system issue rather than a database configuration problem.\n"
220 |                 f"Please check:\n"
221 |                 f"1. Your internet connection is working\n"
222 |                 f"2. Any firewalls or network security settings allow connections to {host_part}\n"
223 |                 f"3. DNS resolution is working correctly\n"
224 |                 f"4. The Supabase service is not experiencing an outage\n"
225 |             )
226 | 
227 |             logger.error(f"Network error connecting to database: {e}")
228 |             logger.error(f"Connection details: {host_part}")
229 |             raise ConnectionError(error_message) from e
230 | 
231 |     async def ensure_pool(self) -> None:
232 |         """Ensure a valid connection pool exists.
233 | 
234 |         This method is called before executing queries to make sure
235 |         we have an active connection pool.
236 |         """
237 |         if self._pool is None:
238 |             logger.debug("No active connection pool, creating one")
239 |             self._pool = await self.create_pool()
240 |         else:
241 |             logger.debug("Using existing connection pool")
242 | 
243 |     async def close(self) -> None:
244 |         """Close the connection pool and release all resources.
245 | 
246 |         This should be called when shutting down the application.
247 |         """
248 |         import asyncio
249 | 
250 |         if self._pool:
251 |             await asyncio.wait_for(self._pool.close(), timeout=5.0)
252 |             self._pool = None
253 |         else:
254 |             logger.debug("No PostgreSQL connection pool to close")
255 | 
256 |     @classmethod
257 |     async def reset(cls) -> None:
258 |         """Reset the singleton instance cleanly.
259 | 
260 |         This closes any open connections and resets the singleton instance.
261 |         """
262 |         if cls._instance is not None:
263 |             await cls._instance.close()
264 |             cls._instance = None
265 |             logger.info("AsyncSupabaseClient instance reset complete")
266 | 
267 |     async def with_connection(self, operation_func: Callable[[asyncpg.Connection[Any]], Awaitable[T]]) -> T:
268 |         """Execute an operation with a database connection.
269 | 
270 |         Args:
271 |             operation_func: Async function that takes a connection and returns a result
272 | 
273 |         Returns:
274 |             The result of the operation function
275 | 
276 |         Raises:
277 |             ConnectionError: If a database connection issue occurs
278 |         """
279 |         # Ensure we have an active connection pool
280 |         await self.ensure_pool()
281 | 
282 |         # Acquire a connection from the pool and execute the operation
283 |         async with self._pool.acquire() as conn:
284 |             return await operation_func(conn)
285 | 
286 |     async def with_transaction(
287 |         self, conn: asyncpg.Connection[Any], operation_func: Callable[[], Awaitable[T]], readonly: bool = False
288 |     ) -> T:
289 |         """Execute an operation within a transaction.
290 | 
291 |         Args:
292 |             conn: Database connection
293 |             operation_func: Async function that executes within the transaction
294 |             readonly: Whether the transaction is read-only
295 | 
296 |         Returns:
297 |             The result of the operation function
298 | 
299 |         Raises:
300 |             QueryError: If the query execution fails
301 |         """
302 |         # Execute the operation within a transaction
303 |         async with conn.transaction(readonly=readonly):
304 |             return await operation_func()
305 | 
306 |     async def execute_statement(self, conn: asyncpg.Connection[Any], query: str) -> StatementResult:
307 |         """Execute a single SQL statement.
308 | 
309 |         Args:
310 |             conn: Database connection
311 |             query: SQL query to execute
312 | 
313 |         Returns:
314 |             StatementResult containing the rows returned by the statement
315 | 
316 |         Raises:
317 |             QueryError: If the statement execution fails
318 |         """
319 |         try:
320 |             # Execute the query
321 |             result = await conn.fetch(query)
322 | 
323 |             # Convert records to dictionaries
324 |             rows = [dict(record) for record in result]
325 | 
326 |             # Log success
327 |             logger.debug(f"Statement executed successfully, rows: {len(rows)}")
328 | 
329 |             # Return the result
330 |             return StatementResult(rows=rows)
331 | 
332 |         except asyncpg.PostgresError as e:
333 |             await self._handle_postgres_error(e)
334 | 
335 |     @retry(
336 |         retry=retry_if_exception_type(
337 |             (
338 |                 asyncpg.exceptions.ConnectionDoesNotExistError,  # Connection lost
339 |                 asyncpg.exceptions.InterfaceError,  # Connection disruption
340 |                 asyncpg.exceptions.TooManyConnectionsError,  # Temporary connection limit
341 |                 OSError,  # Network issues
342 |             )
343 |         ),
344 |         stop=stop_after_attempt(3),
345 |         wait=wait_exponential(multiplier=1, min=2, max=10),
346 |         before_sleep=log_db_retry_attempt,
347 |     )
348 |     async def execute_query(
349 |         self,
350 |         validated_query: QueryValidationResults,
351 |         readonly: bool = True,  # Default to read-only for safety
352 |     ) -> QueryResult:
353 |         """Execute a SQL query asynchronously with proper transaction management.
354 | 
355 |         Args:
356 |             validated_query: Validated query containing statements to execute
357 |             readonly: Whether to execute in read-only mode
358 | 
359 |         Returns:
360 |             QueryResult containing the results of all statements
361 | 
362 |         Raises:
363 |             ConnectionError: If a database connection issue occurs
364 |             QueryError: If the query execution fails
365 |             PermissionError: When user lacks required privileges
366 |         """
367 |         # Log query execution (truncate long queries for readability)
368 |         truncated_query = (
369 |             validated_query.original_query[:100] + "..."
370 |             if len(validated_query.original_query) > 100
371 |             else validated_query.original_query
372 |         )
373 |         logger.debug(f"Executing query (readonly={readonly}): {truncated_query}")
374 | 
375 |         # Define the operation to execute all statements within a transaction
376 |         async def execute_all_statements(conn):
377 |             async def transaction_operation():
378 |                 results = []
379 |                 for statement in validated_query.statements:
380 |                     if statement.query:  # Skip statements with no query
381 |                         result = await self.execute_statement(conn, statement.query)
382 |                         results.append(result)
383 |                     else:
384 |                         logger.warning(f"Statement has no query, statement: {statement}")
385 |                 return results
386 | 
387 |             # Execute the operation within a transaction
388 |             results = await self.with_transaction(conn, transaction_operation, readonly)
389 |             return QueryResult(results=results)
390 | 
391 |         # Execute the operation with a connection
392 |         return await self.with_connection(execute_all_statements)
393 | 
394 |     async def _handle_postgres_error(self, error: asyncpg.PostgresError) -> None:
395 |         """Handle PostgreSQL errors and convert to appropriate exceptions.
396 | 
397 |         Args:
398 |             error: PostgreSQL error
399 | 
400 |         Raises:
401 |             PermissionError: When user lacks required privileges
402 |             QueryError: For schema errors or general query errors
403 |         """
404 |         if isinstance(error, asyncpg.exceptions.InsufficientPrivilegeError):
405 |             logger.error(f"Permission denied: {error}")
406 |             raise PermissionError(
407 |                 f"Access denied: {str(error)}. Use live_dangerously('database', True) for write operations."
408 |             ) from error
409 |         elif isinstance(
410 |             error,
411 |             (
412 |                 asyncpg.exceptions.UndefinedTableError,
413 |                 asyncpg.exceptions.UndefinedColumnError,
414 |             ),
415 |         ):
416 |             logger.error(f"Schema error: {error}")
417 |             raise QueryError(str(error)) from error
418 |         else:
419 |             logger.error(f"Database error: {error}")
420 |             raise QueryError(f"Query execution failed: {str(error)}") from error
421 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/services/database/migration_manager.py:
--------------------------------------------------------------------------------

```python
  1 | import datetime
  2 | import hashlib
  3 | import re
  4 | 
  5 | from supabase_mcp.logger import logger
  6 | from supabase_mcp.services.database.sql.loader import SQLLoader
  7 | from supabase_mcp.services.database.sql.models import (
  8 |     QueryValidationResults,
  9 |     SQLQueryCategory,
 10 |     ValidatedStatement,
 11 | )
 12 | 
 13 | 
 14 | class MigrationManager:
 15 |     """Responsible for preparing migration scripts without executing them."""
 16 | 
 17 |     def __init__(self, loader: SQLLoader | None = None):
 18 |         """Initialize the migration manager with a SQL loader.
 19 | 
 20 |         Args:
 21 |             loader: The SQL loader to use for loading SQL queries
 22 |         """
 23 |         self.loader = loader or SQLLoader()
 24 | 
 25 |     def prepare_migration_query(
 26 |         self,
 27 |         validation_result: QueryValidationResults,
 28 |         original_query: str,
 29 |         migration_name: str = "",
 30 |     ) -> tuple[str, str]:
 31 |         """
 32 |         Prepare a migration script without executing it.
 33 | 
 34 |         Args:
 35 |             validation_result: The validation result
 36 |             original_query: The original query
 37 |             migration_name: The name of the migration, if provided by the client
 38 | 
 39 |         Returns:
 40 |             Complete SQL query to create the migration
 41 |             Migration name
 42 |         """
 43 |         # If client provided a name, use it directly without generating a new one
 44 |         if migration_name.strip():
 45 |             name = self.sanitize_name(migration_name)
 46 |         else:
 47 |             # Otherwise generate a descriptive name
 48 |             name = self.generate_descriptive_name(validation_result)
 49 | 
 50 |         # Generate migration version (timestamp)
 51 |         version = self.generate_query_timestamp()
 52 | 
 53 |         # Escape single quotes in the query for SQL safety
 54 |         statements = original_query.replace("'", "''")
 55 | 
 56 |         # Get the migration query using the loader
 57 |         migration_query = self.loader.get_create_migration_query(version, name, statements)
 58 | 
 59 |         logger.info(f"Prepared migration: {version}_{name}")
 60 | 
 61 |         # Return the complete query
 62 |         return migration_query, name
 63 | 
 64 |     def sanitize_name(self, name: str) -> str:
 65 |         """
 66 |         Generate a standardized name for a migration script.
 67 | 
 68 |         Args:
 69 |             name: Raw migration name
 70 | 
 71 |         Returns:
 72 |             str: Sanitized migration name
 73 |         """
 74 |         # Remove special characters and replace spaces with underscores
 75 |         sanitized_name = re.sub(r"[^\w\s]", "", name).lower()
 76 |         sanitized_name = re.sub(r"\s+", "_", sanitized_name)
 77 | 
 78 |         # Ensure the name is not too long (max 100 chars)
 79 |         if len(sanitized_name) > 100:
 80 |             sanitized_name = sanitized_name[:100]
 81 | 
 82 |         return sanitized_name
 83 | 
 84 |     def generate_descriptive_name(
 85 |         self,
 86 |         query_validation_result: QueryValidationResults,
 87 |     ) -> str:
 88 |         """
 89 |         Generate a descriptive name for a migration based on the validation result.
 90 | 
 91 |         This method should only be called when no client-provided name is available.
 92 | 
 93 |         Priority order:
 94 |         1. Auto-generated name based on SQL analysis
 95 |         2. Fallback to hash if no meaningful information can be extracted
 96 | 
 97 |         Args:
 98 |             query_validation_result: Validation result for a batch of SQL statements
 99 | 
100 |         Returns:
101 |             str: Descriptive migration name
102 |         """
103 |         # Case 1: No client-provided name, generate descriptive name
104 |         # Find the first statement that needs migration
105 |         statement = None
106 |         for stmt in query_validation_result.statements:
107 |             if stmt.needs_migration:
108 |                 statement = stmt
109 |                 break
110 | 
111 |         # If no statement found (unlikely), use a hash-based name
112 |         if not statement:
113 |             logger.warning(
114 |                 "No statement found in validation result, using hash-based name, statements: %s",
115 |                 query_validation_result.statements,
116 |             )
117 |             # Generate a short hash from the query text
118 |             query_hash = self._generate_short_hash(query_validation_result.original_query)
119 |             return f"migration_{query_hash}"
120 | 
121 |         # Generate name based on statement category and command
122 |         logger.debug(f"Generating name for statement: {statement}")
123 |         if statement.category == SQLQueryCategory.DDL:
124 |             return self._generate_ddl_name(statement)
125 |         elif statement.category == SQLQueryCategory.DML:
126 |             return self._generate_dml_name(statement)
127 |         elif statement.category == SQLQueryCategory.DCL:
128 |             return self._generate_dcl_name(statement)
129 |         else:
130 |             # Fallback for other categories
131 |             return self._generate_generic_name(statement)
132 | 
133 |     def _generate_short_hash(self, text: str) -> str:
134 |         """Generate a short hash from text for use in migration names."""
135 |         hash_object = hashlib.md5(text.encode())
136 |         return hash_object.hexdigest()[:8]  # First 8 chars of MD5 hash
137 | 
138 |     def _generate_ddl_name(self, statement: ValidatedStatement) -> str:
139 |         """
140 |         Generate a name for DDL statements (CREATE, ALTER, DROP).
141 |         Format: {command}_{object_type}_{schema}_{object_name}
142 |         Examples:
143 |         - create_table_public_users
144 |         - alter_function_auth_authenticate
145 |         - drop_index_public_users_email_idx
146 |         """
147 |         command = statement.command.value.lower()
148 |         schema = statement.schema_name.lower() if statement.schema_name else "public"
149 | 
150 |         # Extract object type and name with enhanced detection
151 |         object_type = "object"  # Default fallback
152 |         object_name = "unknown"  # Default fallback
153 | 
154 |         # Enhanced object type detection based on command
155 |         if statement.object_type:
156 |             object_type = statement.object_type.lower()
157 | 
158 |             # Handle specific object types
159 |             if object_type == "table" and statement.query:
160 |                 object_name = self._extract_table_name(statement.query)
161 |             elif (object_type == "function" or object_type == "procedure") and statement.query:
162 |                 object_name = self._extract_function_name(statement.query)
163 |             elif object_type == "trigger" and statement.query:
164 |                 object_name = self._extract_trigger_name(statement.query)
165 |             elif object_type == "index" and statement.query:
166 |                 object_name = self._extract_index_name(statement.query)
167 |             elif object_type == "view" and statement.query:
168 |                 object_name = self._extract_view_name(statement.query)
169 |             elif object_type == "materialized_view" and statement.query:
170 |                 object_name = self._extract_materialized_view_name(statement.query)
171 |             elif object_type == "sequence" and statement.query:
172 |                 object_name = self._extract_sequence_name(statement.query)
173 |             elif object_type == "constraint" and statement.query:
174 |                 object_name = self._extract_constraint_name(statement.query)
175 |             elif object_type == "foreign_table" and statement.query:
176 |                 object_name = self._extract_foreign_table_name(statement.query)
177 |             elif object_type == "extension" and statement.query:
178 |                 object_name = self._extract_extension_name(statement.query)
179 |             elif object_type == "type" and statement.query:
180 |                 object_name = self._extract_type_name(statement.query)
181 |             elif statement.query:
182 |                 # For other object types, use a generic extraction
183 |                 object_name = self._extract_generic_object_name(statement.query)
184 | 
185 |         # Combine parts into a descriptive name
186 |         name = f"{command}_{object_type}_{schema}_{object_name}"
187 |         return self.sanitize_name(name)
188 | 
189 |     def _generate_dml_name(self, statement: ValidatedStatement) -> str:
190 |         """
191 |         Generate a name for DML statements (INSERT, UPDATE, DELETE).
192 |         Format: {command}_{schema}_{table_name}
193 |         Examples:
194 |         - insert_public_users
195 |         - update_auth_users
196 |         - delete_public_logs
197 |         """
198 |         command = statement.command.value.lower()
199 |         schema = statement.schema_name.lower() if statement.schema_name else "public"
200 | 
201 |         # Extract table name
202 |         table_name = "unknown"
203 |         if statement.query:
204 |             table_name = self._extract_table_name(statement.query) or "unknown"
205 | 
206 |         # For UPDATE and DELETE, add what's being modified if possible
207 |         if command == "update" and statement.query:
208 |             # Try to extract column names being updated
209 |             columns = self._extract_update_columns(statement.query)
210 |             if columns:
211 |                 return self.sanitize_name(f"{command}_{columns}_in_{schema}_{table_name}")
212 | 
213 |         # Default format
214 |         name = f"{command}_{schema}_{table_name}"
215 |         return self.sanitize_name(name)
216 | 
217 |     def _generate_dcl_name(self, statement: ValidatedStatement) -> str:
218 |         """
219 |         Generate a name for DCL statements (GRANT, REVOKE).
220 |         Format: {command}_{privilege}_{schema}_{object_name}
221 |         Examples:
222 |         - grant_select_public_users
223 |         - revoke_all_public_items
224 |         """
225 |         command = statement.command.value.lower()
226 |         schema = statement.schema_name.lower() if statement.schema_name else "public"
227 | 
228 |         # Extract privilege and object name
229 |         privilege = "privilege"
230 |         object_name = "unknown"
231 | 
232 |         if statement.query:
233 |             privilege = self._extract_privilege(statement.query) or "privilege"
234 |             object_name = self._extract_dcl_object_name(statement.query) or "unknown"
235 | 
236 |         name = f"{command}_{privilege}_{schema}_{object_name}"
237 |         return self.sanitize_name(name)
238 | 
239 |     def _generate_generic_name(self, statement: ValidatedStatement) -> str:
240 |         """
241 |         Generate a name for other statement types.
242 |         Format: {command}_{schema}_{object_type}
243 |         """
244 |         command = statement.command.value.lower()
245 |         schema = statement.schema_name.lower() if statement.schema_name else "public"
246 |         object_type = statement.object_type.lower() if statement.object_type else "object"
247 | 
248 |         name = f"{command}_{schema}_{object_type}"
249 |         return self.sanitize_name(name)
250 | 
251 |     # Helper methods for extracting specific parts from SQL queries
252 | 
253 |     def _extract_table_name(self, query: str) -> str:
254 |         """Extract table name from a query."""
255 |         if not query:
256 |             return "unknown"
257 | 
258 |         # Simple regex-based extraction for demonstration
259 |         # In a real implementation, this would use more sophisticated parsing
260 |         import re
261 | 
262 |         # For CREATE TABLE
263 |         match = re.search(r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE)
264 |         if match:
265 |             return match.group(2)
266 | 
267 |         # For ALTER TABLE
268 |         match = re.search(r"ALTER\s+TABLE\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE)
269 |         if match:
270 |             return match.group(2)
271 | 
272 |         # For DROP TABLE
273 |         match = re.search(r"DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE)
274 |         if match:
275 |             return match.group(2)
276 | 
277 |         # For INSERT, UPDATE, DELETE
278 |         match = re.search(r"(?:INSERT\s+INTO|UPDATE|DELETE\s+FROM)\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE)
279 |         if match:
280 |             return match.group(2)
281 | 
282 |         return "unknown"
283 | 
284 |     def _extract_function_name(self, query: str) -> str:
285 |         """Extract function name from a query."""
286 |         if not query:
287 |             return "unknown"
288 | 
289 |         import re
290 | 
291 |         match = re.search(
292 |             r"(?:CREATE|ALTER|DROP)\s+(?:OR\s+REPLACE\s+)?FUNCTION\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE
293 |         )
294 |         if match:
295 |             return match.group(2)
296 | 
297 |         return "unknown"
298 | 
299 |     def _extract_trigger_name(self, query: str) -> str:
300 |         """Extract trigger name from a query."""
301 |         if not query:
302 |             return "unknown"
303 | 
304 |         import re
305 | 
306 |         match = re.search(r"(?:CREATE|ALTER|DROP)\s+TRIGGER\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)", query, re.IGNORECASE)
307 |         if match:
308 |             return match.group(1)
309 | 
310 |         return "unknown"
311 | 
312 |     def _extract_view_name(self, query: str) -> str:
313 |         """Extract view name from a query."""
314 |         if not query:
315 |             return "unknown"
316 | 
317 |         import re
318 | 
319 |         match = re.search(r"(?:CREATE|ALTER|DROP)\s+(?:OR\s+REPLACE\s+)?VIEW\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE)
320 |         if match:
321 |             return match.group(2)
322 | 
323 |         return "unknown"
324 | 
325 |     def _extract_index_name(self, query: str) -> str:
326 |         """Extract index name from a query."""
327 |         if not query:
328 |             return "unknown"
329 | 
330 |         import re
331 | 
332 |         match = re.search(r"(?:CREATE|DROP)\s+INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE)
333 |         if match:
334 |             return match.group(2)
335 | 
336 |         return "unknown"
337 | 
338 |     def _extract_sequence_name(self, query: str) -> str:
339 |         """Extract sequence name from a query."""
340 |         if not query:
341 |             return "unknown"
342 | 
343 |         import re
344 | 
345 |         match = re.search(
346 |             r"(?:CREATE|ALTER|DROP)\s+SEQUENCE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE
347 |         )
348 |         if match:
349 |             return match.group(2)
350 | 
351 |         return "unknown"
352 | 
353 |     def _extract_constraint_name(self, query: str) -> str:
354 |         """Extract constraint name from a query."""
355 |         if not query:
356 |             return "unknown"
357 | 
358 |         import re
359 | 
360 |         match = re.search(r"CONSTRAINT\s+(\w+)", query, re.IGNORECASE)
361 |         if match:
362 |             return match.group(1)
363 | 
364 |         return "unknown"
365 | 
366 |     def _extract_update_columns(self, query: str) -> str:
367 |         """Extract columns being updated in an UPDATE statement."""
368 |         if not query:
369 |             return ""
370 | 
371 |         import re
372 | 
373 |         # This is a simplified approach - a real implementation would use proper SQL parsing
374 |         match = re.search(r"UPDATE\s+(?:\w+\.)?(?:\w+)\s+SET\s+([\w\s,=]+)\s+WHERE", query, re.IGNORECASE)
375 |         if match:
376 |             # Extract column names from the SET clause
377 |             set_clause = match.group(1)
378 |             columns = re.findall(r"(\w+)\s*=", set_clause)
379 |             if columns and len(columns) <= 3:  # Limit to 3 columns to keep name reasonable
380 |                 return "_".join(columns)
381 |             elif columns:
382 |                 return f"{columns[0]}_and_others"
383 | 
384 |         return ""
385 | 
386 |     def _extract_privilege(self, query: str) -> str:
387 |         """Extract privilege from a GRANT or REVOKE statement."""
388 |         if not query:
389 |             return "privilege"
390 | 
391 |         import re
392 | 
393 |         match = re.search(r"(?:GRANT|REVOKE)\s+([\w\s,]+)\s+ON", query, re.IGNORECASE)
394 |         if match:
395 |             privileges = match.group(1).strip().lower()
396 |             if "all" in privileges:
397 |                 return "all"
398 |             elif "select" in privileges:
399 |                 return "select"
400 |             elif "insert" in privileges:
401 |                 return "insert"
402 |             elif "update" in privileges:
403 |                 return "update"
404 |             elif "delete" in privileges:
405 |                 return "delete"
406 | 
407 |         return "privilege"
408 | 
409 |     def _extract_dcl_object_name(self, query: str) -> str:
410 |         """Extract object name from a GRANT or REVOKE statement."""
411 |         if not query:
412 |             return "unknown"
413 | 
414 |         import re
415 | 
416 |         match = re.search(r"ON\s+(?:TABLE\s+)?(?:(\w+)\.)?(\w+)", query, re.IGNORECASE)
417 |         if match:
418 |             return match.group(2)
419 | 
420 |         return "unknown"
421 | 
422 |     def _extract_generic_object_name(self, query: str) -> str:
423 |         """Extract a generic object name when specific extractors don't apply."""
424 |         if not query:
425 |             return "unknown"
426 | 
427 |         import re
428 | 
429 |         # Look for common patterns of object names in SQL
430 |         patterns = [
431 |             r"(?:CREATE|ALTER|DROP)\s+(?:\w+\s+)+(?:(\w+)\.)?(\w+)",  # General DDL pattern
432 |             r"ON\s+(?:(\w+)\.)?(\w+)",  # ON clause
433 |             r"FROM\s+(?:(\w+)\.)?(\w+)",  # FROM clause
434 |             r"INTO\s+(?:(\w+)\.)?(\w+)",  # INTO clause
435 |         ]
436 | 
437 |         for pattern in patterns:
438 |             match = re.search(pattern, query, re.IGNORECASE)
439 |             if match and match.group(2):
440 |                 return match.group(2)
441 | 
442 |         return "unknown"
443 | 
444 |     def _extract_materialized_view_name(self, query: str) -> str:
445 |         """Extract materialized view name from a query."""
446 |         if not query:
447 |             return "unknown"
448 | 
449 |         import re
450 | 
451 |         match = re.search(
452 |             r"(?:CREATE|ALTER|DROP|REFRESH)\s+(?:MATERIALIZED\s+VIEW)\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)",
453 |             query,
454 |             re.IGNORECASE,
455 |         )
456 |         if match:
457 |             return match.group(2)
458 | 
459 |         return "unknown"
460 | 
461 |     def _extract_foreign_table_name(self, query: str) -> str:
462 |         """Extract foreign table name from a query."""
463 |         if not query:
464 |             return "unknown"
465 | 
466 |         import re
467 | 
468 |         match = re.search(
469 |             r"(?:CREATE|ALTER|DROP)\s+(?:FOREIGN\s+TABLE)\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)",
470 |             query,
471 |             re.IGNORECASE,
472 |         )
473 |         if match:
474 |             return match.group(2)
475 | 
476 |         return "unknown"
477 | 
478 |     def _extract_extension_name(self, query: str) -> str:
479 |         """Extract extension name from a query."""
480 |         if not query:
481 |             return "unknown"
482 | 
483 |         import re
484 | 
485 |         match = re.search(r"(?:CREATE|ALTER|DROP)\s+EXTENSION\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)", query, re.IGNORECASE)
486 |         if match:
487 |             return match.group(1)
488 | 
489 |         return "unknown"
490 | 
491 |     def _extract_type_name(self, query: str) -> str:
492 |         """Extract custom type name from a query."""
493 |         if not query:
494 |             return "unknown"
495 | 
496 |         import re
497 | 
498 |         # For ENUM types
499 |         match = re.search(r"(?:CREATE|ALTER|DROP)\s+TYPE\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE)
500 |         if match:
501 |             return match.group(2)
502 | 
503 |         # For DOMAIN types
504 |         match = re.search(r"(?:CREATE|ALTER|DROP)\s+DOMAIN\s+(?:(\w+)\.)?(\w+)", query, re.IGNORECASE)
505 |         if match:
506 |             return match.group(2)
507 | 
508 |         return "unknown"
509 | 
510 |     def generate_query_timestamp(self) -> str:
511 |         """
512 |         Generate a timestamp for a migration script in the format YYYYMMDDHHMMSS.
513 | 
514 |         Returns:
515 |             str: Timestamp string
516 |         """
517 |         now = datetime.datetime.now()
518 |         return now.strftime("%Y%m%d%H%M%S")
519 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/services/safety/safety_configs.py:
--------------------------------------------------------------------------------

```python
  1 | import re
  2 | from abc import ABC, abstractmethod
  3 | from enum import Enum
  4 | from typing import Any, Generic, TypeVar
  5 | 
  6 | from supabase_mcp.services.database.sql.models import (
  7 |     QueryValidationResults,
  8 |     SQLQueryCategory,
  9 | )
 10 | from supabase_mcp.services.safety.models import OperationRiskLevel, SafetyMode
 11 | 
 12 | T = TypeVar("T")
 13 | 
 14 | 
 15 | class SafetyConfigBase(Generic[T], ABC):
 16 |     """Abstract base class for all SafetyConfig classes of specific clients.
 17 | 
 18 |     Provides methods:
 19 |     - register safety configuration
 20 |     - to get / set safety level
 21 |     - check safety level of operation
 22 |     """
 23 | 
 24 |     @abstractmethod
 25 |     def get_risk_level(self, operation: T) -> OperationRiskLevel:
 26 |         """Get the risk level for an operation.
 27 | 
 28 |         Args:
 29 |             operation: The operation to check
 30 | 
 31 |         Returns:
 32 |             The risk level for the operation
 33 |         """
 34 |         pass
 35 | 
 36 |     def is_operation_allowed(self, risk_level: OperationRiskLevel, mode: SafetyMode) -> bool:
 37 |         """Check if an operation is allowed based on its risk level and the current safety mode.
 38 | 
 39 |         Args:
 40 |             risk_level: The risk level of the operation
 41 |             mode: The current safety mode
 42 | 
 43 |         Returns:
 44 |             True if the operation is allowed, False otherwise
 45 |         """
 46 |         # LOW risk operations are always allowed
 47 |         if risk_level == OperationRiskLevel.LOW:
 48 |             return True
 49 | 
 50 |         # MEDIUM risk operations are allowed only in UNSAFE mode
 51 |         if risk_level == OperationRiskLevel.MEDIUM:
 52 |             return mode == SafetyMode.UNSAFE
 53 | 
 54 |         # HIGH risk operations are allowed only in UNSAFE mode with confirmation
 55 |         if risk_level == OperationRiskLevel.HIGH:
 56 |             return mode == SafetyMode.UNSAFE
 57 | 
 58 |         # EXTREME risk operations are never allowed
 59 |         return False
 60 | 
 61 |     def needs_confirmation(self, risk_level: OperationRiskLevel) -> bool:
 62 |         """Check if an operation needs confirmation based on its risk level.
 63 | 
 64 |         Args:
 65 |             risk_level: The risk level of the operation
 66 | 
 67 |         Returns:
 68 |             True if the operation needs confirmation, False otherwise
 69 |         """
 70 |         # Only HIGH and EXTREME risk operations require confirmation
 71 |         return risk_level >= OperationRiskLevel.HIGH
 72 | 
 73 | 
 74 | # ========
 75 | # API Safety Config
 76 | # ========
 77 | 
 78 | 
 79 | class HTTPMethod(str, Enum):
 80 |     """HTTP methods used in API operations."""
 81 | 
 82 |     GET = "GET"
 83 |     POST = "POST"
 84 |     PUT = "PUT"
 85 |     PATCH = "PATCH"
 86 |     DELETE = "DELETE"
 87 |     HEAD = "HEAD"
 88 |     OPTIONS = "OPTIONS"
 89 | 
 90 | 
 91 | class APISafetyConfig(SafetyConfigBase[tuple[str, str, dict[str, Any], dict[str, Any], dict[str, Any]]]):
 92 |     """Safety configuration for API operations.
 93 | 
 94 |     The operation type is a tuple of (method, path).
 95 |     """
 96 | 
 97 |     # Maps risk levels to operations (method + path patterns)
 98 |     PATH_SAFETY_CONFIG = {
 99 |         OperationRiskLevel.EXTREME: {
100 |             HTTPMethod.DELETE: [
101 |                 "/v1/projects/{ref}",  # Delete project.  Irreversible, complete data loss.
102 |             ]
103 |         },
104 |         OperationRiskLevel.HIGH: {
105 |             HTTPMethod.DELETE: [
106 |                 "/v1/projects/{ref}/branches/{branch_id}",  # Delete a database branch.  Data loss on branch.
107 |                 "/v1/projects/{ref}/branches",  # Disables preview branching. Disruptive to development workflows.
108 |                 "/v1/projects/{ref}/custom-hostname",  # Deletes custom hostname config.  Can break production access.
109 |                 "/v1/projects/{ref}/vanity-subdomain",  # Deletes vanity subdomain config.  Breaks vanity URL access.
110 |                 "/v1/projects/{ref}/network-bans",  # Remove network bans (can expose database to wider network).
111 |                 "/v1/projects/{ref}/secrets",  # Bulk delete secrets. Can break application functionality if critical secrets are removed.
112 |                 "/v1/projects/{ref}/functions/{function_slug}",  # Delete function.  Breaks functionality relying on the function.
113 |                 "/v1/projects/{ref}/api-keys/{id}",  # Delete api key.  Can break API access.
114 |                 "/v1/projects/{ref}/config/auth/sso/providers/{provider_id}",  # Delete SSO Provider.  Breaks SSO login.
115 |                 "/v1/projects/{ref}/config/auth/signing-keys/{id}",  # Delete signing key. Can break JWT verification.
116 |             ],
117 |             HTTPMethod.POST: [
118 |                 "/v1/projects/{ref}/pause",  # Pause project - Impacts production, database becomes unavailable.
119 |                 "/v1/projects/{ref}/restore",  # Restore project - Can overwrite existing data with backup.
120 |                 "/v1/projects/{ref}/upgrade",  # Upgrades the project's Postgres version - potential downtime/compatibility issues.
121 |                 "/v1/projects/{ref}/read-replicas/remove",  # Remove a read replica.  Can impact read scalability.
122 |                 "/v1/projects/{ref}/restore/cancel",  # Cancels the given project restoration. Can leave project in inconsistent state.
123 |                 "/v1/projects/{ref}/readonly/temporary-disable",  # Disables readonly mode. Allows potentially destructive operations.
124 |             ],
125 |         },
126 |         OperationRiskLevel.MEDIUM: {
127 |             HTTPMethod.POST: [
128 |                 "/v1/projects",  # Create project.  Significant infrastructure change.
129 |                 "/v1/organizations",  # Create org. Significant infrastructure change.
130 |                 "/v1/projects/{ref}/branches",  # Create a database branch.  Could potentially impact production if misused.
131 |                 "/v1/projects/{ref}/branches/{branch_id}/push",  # Push a database branch.  Could overwrite production data if pushed to the wrong branch.
132 |                 "/v1/projects/{ref}/branches/{branch_id}/reset",  # Reset a database branch. Data loss on the branch.
133 |                 "/v1/projects/{ref}/custom-hostname/initialize",  # Updates custom hostname configuration, potentially breaking existing config.
134 |                 "/v1/projects/{ref}/custom-hostname/reverify",  # Attempts to verify DNS configuration.  Could disrupt custom hostname if misconfigured.
135 |                 "/v1/projects/{ref}/custom-hostname/activate",  # Activates custom hostname. Could lead to downtime during switchover.
136 |                 "/v1/projects/{ref}/network-bans/retrieve",  # Gets project's network bans. Information disclosure, though less risky than removing bans.
137 |                 "/v1/projects/{ref}/network-restrictions/apply",  # Updates project's network restrictions. Could block legitimate access if misconfigured.
138 |                 "/v1/projects/{ref}/secrets",  # Bulk create secrets.  Could overwrite existing secrets if names collide.
139 |                 "/v1/projects/{ref}/upgrade/status",  # get status for upgrade
140 |                 "/v1/projects/{ref}/database/webhooks/enable",  # Enables Database Webhooks.  Could expose data if webhooks are misconfigured.
141 |                 "/v1/projects/{ref}/functions",  # Create a function (deprecated).
142 |                 "/v1/projects/{ref}/functions/deploy",  # Deploy a function. Could break functionality if deployed code has errors.
143 |                 "/v1/projects/{ref}/config/auth/sso/providers",  # Create SSO provider.  Could impact authentication if misconfigured.
144 |                 "/v1/projects/{ref}/database/backups/restore-pitr",  # Restore a PITR backup.  Can overwrite data.
145 |                 "/v1/projects/{ref}/read-replicas/setup",  # Setup a read replica
146 |                 "/v1/projects/{ref}/database/query",  # Run SQL query.  *Crucially*, this allows arbitrary SQL, including `DROP TABLE`, `DELETE`, etc.
147 |                 "/v1/projects/{ref}/config/auth/signing-keys",  # Create a new signing key, requires key rotation.
148 |                 "/v1/oauth/token",  # Exchange auth code for user's access token. Security-sensitive.
149 |                 "/v1/oauth/revoke",  # Revoke oauth app authorization.  Can break application access.
150 |                 "/v1/projects/{ref}/api-keys",  # Create an API key
151 |             ],
152 |             HTTPMethod.PATCH: [
153 |                 "/v1/projects/{ref}/config/auth",  # Auth config.  Could lock users out or introduce vulnerabilities if misconfigured.
154 |                 "/v1/projects/{ref}/config/database/pooler",  # Connection pooling changes.  Can impact database performance.
155 |                 "/v1/projects/{ref}/postgrest",  # Update Postgrest config.  Can impact API behavior.
156 |                 "/v1/projects/{ref}/functions/{function_slug}",  # Updates a function.  Can break functionality.
157 |                 "/v1/projects/{ref}/config/storage",  # Update Storage config.  Can change file size limits, etc.
158 |                 "/v1/branches/{branch_id}",  # Update database branch config.
159 |                 "/v1/projects/{ref}/api-keys/{id}",  # Updates a API key
160 |                 "/v1/projects/{ref}/config/auth/signing-keys/{id}",  # updates signing key.
161 |             ],
162 |             HTTPMethod.PUT: [
163 |                 "/v1/projects/{ref}/config/database/postgres",  # Postgres config changes.  Can significantly impact database performance/behavior.
164 |                 "/v1/projects/{ref}/pgsodium",  # Update pgsodium config.  *Critical*: Updating the `root_key` can cause data loss.
165 |                 "/v1/projects/{ref}/ssl-enforcement",  # Update SSL enforcement config.  Could break access if misconfigured.
166 |                 "/v1/projects/{ref}/functions",  # Bulk update Edge Functions. Could break multiple functions at once.
167 |                 "/v1/projects/{ref}/config/auth/sso/providers/{provider_id}",  # Update sso provider.
168 |             ],
169 |         },
170 |     }
171 | 
172 |     def get_risk_level(
173 |         self, operation: tuple[str, str, dict[str, Any], dict[str, Any], dict[str, Any]]
174 |     ) -> OperationRiskLevel:
175 |         """Get the risk level for an API operation.
176 | 
177 |         Args:
178 |             operation: Tuple of (method, path)
179 | 
180 |         Returns:
181 |             The risk level for the operation
182 |         """
183 |         method, path, _, _, _ = operation
184 | 
185 |         # Check each risk level from highest to lowest
186 |         for risk_level in sorted(self.PATH_SAFETY_CONFIG.keys(), reverse=True):
187 |             if self._path_matches_risk_level(method, path, risk_level):
188 |                 return risk_level
189 | 
190 |         # Default to low risk
191 |         return OperationRiskLevel.LOW
192 | 
193 |     def _path_matches_risk_level(self, method: str, path: str, risk_level: OperationRiskLevel) -> bool:
194 |         """Check if the method and path match any pattern for the given risk level."""
195 |         patterns = self.PATH_SAFETY_CONFIG.get(risk_level, {})
196 | 
197 |         if method not in patterns:
198 |             return False
199 | 
200 |         for pattern in patterns[method]:
201 |             # Convert placeholder pattern to regex
202 |             regex_pattern = self._convert_pattern_to_regex(pattern)
203 |             if re.match(regex_pattern, path):
204 |                 return True
205 | 
206 |         return False
207 | 
208 |     def _convert_pattern_to_regex(self, pattern: str) -> str:
209 |         """Convert a placeholder pattern to a regex pattern.
210 | 
211 |         Replaces placeholders like {ref} with regex patterns for matching.
212 |         """
213 |         # Replace common placeholders with regex patterns
214 |         regex_pattern = pattern
215 |         regex_pattern = regex_pattern.replace("{ref}", r"[^/]+")
216 |         regex_pattern = regex_pattern.replace("{id}", r"[^/]+")
217 |         regex_pattern = regex_pattern.replace("{slug}", r"[^/]+")
218 |         regex_pattern = regex_pattern.replace("{table}", r"[^/]+")
219 |         regex_pattern = regex_pattern.replace("{branch_id}", r"[^/]+")
220 |         regex_pattern = regex_pattern.replace("{function_slug}", r"[^/]+")
221 | 
222 |         # Add end anchor to ensure full path matching
223 |         if not regex_pattern.endswith("$"):
224 |             regex_pattern += "$"
225 | 
226 |         return regex_pattern
227 | 
228 | 
229 | # ========
230 | # SQL Safety Config
231 | # ========
232 | 
233 | 
234 | class SQLSafetyConfig(SafetyConfigBase[QueryValidationResults]):
235 |     """Safety configuration for SQL operations."""
236 | 
237 |     STATEMENT_CONFIG = {
238 |         # DQL - all LOW risk, no migrations
239 |         "SelectStmt": {
240 |             "category": SQLQueryCategory.DQL,
241 |             "risk_level": OperationRiskLevel.LOW,
242 |             "needs_migration": False,
243 |         },
244 |         "ExplainStmt": {
245 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
246 |             "risk_level": OperationRiskLevel.LOW,
247 |             "needs_migration": False,
248 |         },
249 |         # DML - all MEDIUM risk, no migrations
250 |         "InsertStmt": {
251 |             "category": SQLQueryCategory.DML,
252 |             "risk_level": OperationRiskLevel.MEDIUM,
253 |             "needs_migration": False,
254 |         },
255 |         "UpdateStmt": {
256 |             "category": SQLQueryCategory.DML,
257 |             "risk_level": OperationRiskLevel.MEDIUM,
258 |             "needs_migration": False,
259 |         },
260 |         "DeleteStmt": {
261 |             "category": SQLQueryCategory.DML,
262 |             "risk_level": OperationRiskLevel.MEDIUM,
263 |             "needs_migration": False,
264 |         },
265 |         "MergeStmt": {
266 |             "category": SQLQueryCategory.DML,
267 |             "risk_level": OperationRiskLevel.MEDIUM,
268 |             "needs_migration": False,
269 |         },
270 |         # DDL - mix of MEDIUM and HIGH risk, need migrations
271 |         "CreateStmt": {
272 |             "category": SQLQueryCategory.DDL,
273 |             "risk_level": OperationRiskLevel.MEDIUM,
274 |             "needs_migration": True,
275 |         },
276 |         "CreateTableAsStmt": {
277 |             "category": SQLQueryCategory.DDL,
278 |             "risk_level": OperationRiskLevel.MEDIUM,
279 |             "needs_migration": True,
280 |         },
281 |         "CreateSchemaStmt": {
282 |             "category": SQLQueryCategory.DDL,
283 |             "risk_level": OperationRiskLevel.MEDIUM,
284 |             "needs_migration": True,
285 |         },
286 |         "CreateExtensionStmt": {
287 |             "category": SQLQueryCategory.DDL,
288 |             "risk_level": OperationRiskLevel.MEDIUM,
289 |             "needs_migration": True,
290 |         },
291 |         "AlterTableStmt": {
292 |             "category": SQLQueryCategory.DDL,
293 |             "risk_level": OperationRiskLevel.MEDIUM,
294 |             "needs_migration": True,
295 |         },
296 |         "AlterDomainStmt": {
297 |             "category": SQLQueryCategory.DDL,
298 |             "risk_level": OperationRiskLevel.MEDIUM,
299 |             "needs_migration": True,
300 |         },
301 |         "CreateFunctionStmt": {
302 |             "category": SQLQueryCategory.DDL,
303 |             "risk_level": OperationRiskLevel.MEDIUM,
304 |             "needs_migration": True,
305 |         },
306 |         "IndexStmt": {  # CREATE INDEX
307 |             "category": SQLQueryCategory.DDL,
308 |             "risk_level": OperationRiskLevel.MEDIUM,
309 |             "needs_migration": True,
310 |         },
311 |         "CreateTrigStmt": {
312 |             "category": SQLQueryCategory.DDL,
313 |             "risk_level": OperationRiskLevel.MEDIUM,
314 |             "needs_migration": True,
315 |         },
316 |         "ViewStmt": {  # CREATE VIEW
317 |             "category": SQLQueryCategory.DDL,
318 |             "risk_level": OperationRiskLevel.MEDIUM,
319 |             "needs_migration": True,
320 |         },
321 |         "CommentStmt": {
322 |             "category": SQLQueryCategory.DDL,
323 |             "risk_level": OperationRiskLevel.MEDIUM,
324 |             "needs_migration": True,
325 |         },
326 |         # Additional DDL statements
327 |         "CreateEnumStmt": {  # CREATE TYPE ... AS ENUM
328 |             "category": SQLQueryCategory.DDL,
329 |             "risk_level": OperationRiskLevel.MEDIUM,
330 |             "needs_migration": True,
331 |         },
332 |         "CreateTypeStmt": {  # CREATE TYPE (composite)
333 |             "category": SQLQueryCategory.DDL,
334 |             "risk_level": OperationRiskLevel.MEDIUM,
335 |             "needs_migration": True,
336 |         },
337 |         "CreateDomainStmt": {  # CREATE DOMAIN
338 |             "category": SQLQueryCategory.DDL,
339 |             "risk_level": OperationRiskLevel.MEDIUM,
340 |             "needs_migration": True,
341 |         },
342 |         "CreateSeqStmt": {  # CREATE SEQUENCE
343 |             "category": SQLQueryCategory.DDL,
344 |             "risk_level": OperationRiskLevel.MEDIUM,
345 |             "needs_migration": True,
346 |         },
347 |         "CreateForeignTableStmt": {  # CREATE FOREIGN TABLE
348 |             "category": SQLQueryCategory.DDL,
349 |             "risk_level": OperationRiskLevel.MEDIUM,
350 |             "needs_migration": True,
351 |         },
352 |         "CreatePolicyStmt": {  # CREATE POLICY
353 |             "category": SQLQueryCategory.DDL,
354 |             "risk_level": OperationRiskLevel.MEDIUM,
355 |             "needs_migration": True,
356 |         },
357 |         "CreateCastStmt": {  # CREATE CAST
358 |             "category": SQLQueryCategory.DDL,
359 |             "risk_level": OperationRiskLevel.MEDIUM,
360 |             "needs_migration": True,
361 |         },
362 |         "CreateOpClassStmt": {  # CREATE OPERATOR CLASS
363 |             "category": SQLQueryCategory.DDL,
364 |             "risk_level": OperationRiskLevel.MEDIUM,
365 |             "needs_migration": True,
366 |         },
367 |         "CreateOpFamilyStmt": {  # CREATE OPERATOR FAMILY
368 |             "category": SQLQueryCategory.DDL,
369 |             "risk_level": OperationRiskLevel.MEDIUM,
370 |             "needs_migration": True,
371 |         },
372 |         "AlterEnumStmt": {  # ALTER TYPE ... ADD VALUE
373 |             "category": SQLQueryCategory.DDL,
374 |             "risk_level": OperationRiskLevel.MEDIUM,
375 |             "needs_migration": True,
376 |         },
377 |         "AlterSeqStmt": {  # ALTER SEQUENCE
378 |             "category": SQLQueryCategory.DDL,
379 |             "risk_level": OperationRiskLevel.MEDIUM,
380 |             "needs_migration": True,
381 |         },
382 |         "AlterOwnerStmt": {  # ALTER ... OWNER TO
383 |             "category": SQLQueryCategory.DDL,
384 |             "risk_level": OperationRiskLevel.MEDIUM,
385 |             "needs_migration": True,
386 |         },
387 |         "AlterObjectSchemaStmt": {  # ALTER ... SET SCHEMA
388 |             "category": SQLQueryCategory.DDL,
389 |             "risk_level": OperationRiskLevel.MEDIUM,
390 |             "needs_migration": True,
391 |         },
392 |         "RenameStmt": {  # RENAME operations
393 |             "category": SQLQueryCategory.DDL,
394 |             "risk_level": OperationRiskLevel.MEDIUM,
395 |             "needs_migration": True,
396 |         },
397 |         # DESTRUCTIVE DDL - HIGH risk, need migrations and confirmation
398 |         "DropStmt": {
399 |             "category": SQLQueryCategory.DDL,
400 |             "risk_level": OperationRiskLevel.HIGH,
401 |             "needs_migration": True,
402 |         },
403 |         "TruncateStmt": {
404 |             "category": SQLQueryCategory.DDL,
405 |             "risk_level": OperationRiskLevel.HIGH,
406 |             "needs_migration": True,
407 |         },
408 |         # DCL - MEDIUM risk, need migrations
409 |         "GrantStmt": {
410 |             "category": SQLQueryCategory.DCL,
411 |             "risk_level": OperationRiskLevel.MEDIUM,
412 |             "needs_migration": True,
413 |         },
414 |         "GrantRoleStmt": {
415 |             "category": SQLQueryCategory.DCL,
416 |             "risk_level": OperationRiskLevel.MEDIUM,
417 |             "needs_migration": True,
418 |         },
419 |         "RevokeStmt": {
420 |             "category": SQLQueryCategory.DCL,
421 |             "risk_level": OperationRiskLevel.MEDIUM,
422 |             "needs_migration": True,
423 |         },
424 |         "RevokeRoleStmt": {
425 |             "category": SQLQueryCategory.DCL,
426 |             "risk_level": OperationRiskLevel.MEDIUM,
427 |             "needs_migration": True,
428 |         },
429 |         "CreateRoleStmt": {
430 |             "category": SQLQueryCategory.DCL,
431 |             "risk_level": OperationRiskLevel.MEDIUM,
432 |             "needs_migration": True,
433 |         },
434 |         "AlterRoleStmt": {
435 |             "category": SQLQueryCategory.DCL,
436 |             "risk_level": OperationRiskLevel.MEDIUM,
437 |             "needs_migration": True,
438 |         },
439 |         "DropRoleStmt": {
440 |             "category": SQLQueryCategory.DCL,
441 |             "risk_level": OperationRiskLevel.HIGH,
442 |             "needs_migration": True,
443 |         },
444 |         # TCL - LOW risk, no migrations
445 |         "TransactionStmt": {
446 |             "category": SQLQueryCategory.TCL,
447 |             "risk_level": OperationRiskLevel.LOW,
448 |             "needs_migration": False,
449 |         },
450 |         # PostgreSQL-specific
451 |         "VacuumStmt": {
452 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
453 |             "risk_level": OperationRiskLevel.MEDIUM,
454 |             "needs_migration": False,
455 |         },
456 |         "AnalyzeStmt": {
457 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
458 |             "risk_level": OperationRiskLevel.LOW,
459 |             "needs_migration": False,
460 |         },
461 |         "ClusterStmt": {
462 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
463 |             "risk_level": OperationRiskLevel.MEDIUM,
464 |             "needs_migration": False,
465 |         },
466 |         "CheckPointStmt": {
467 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
468 |             "risk_level": OperationRiskLevel.MEDIUM,
469 |             "needs_migration": False,
470 |         },
471 |         "PrepareStmt": {
472 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
473 |             "risk_level": OperationRiskLevel.LOW,
474 |             "needs_migration": False,
475 |         },
476 |         "ExecuteStmt": {
477 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
478 |             "risk_level": OperationRiskLevel.MEDIUM,  # Could be LOW or MEDIUM based on prepared statement
479 |             "needs_migration": False,
480 |         },
481 |         "DeallocateStmt": {
482 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
483 |             "risk_level": OperationRiskLevel.LOW,
484 |             "needs_migration": False,
485 |         },
486 |         "ListenStmt": {
487 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
488 |             "risk_level": OperationRiskLevel.LOW,
489 |             "needs_migration": False,
490 |         },
491 |         "NotifyStmt": {
492 |             "category": SQLQueryCategory.POSTGRES_SPECIFIC,
493 |             "risk_level": OperationRiskLevel.MEDIUM,
494 |             "needs_migration": False,
495 |         },
496 |     }
497 | 
498 |     # Functions for more complex determinations
499 |     def classify_statement(self, stmt_type: str, stmt_node: Any) -> dict[str, Any]:
500 |         """Get classification rules for a given statement type from our config."""
501 |         config = self.STATEMENT_CONFIG.get(
502 |             stmt_type,
503 |             # if not found - default to MEDIUM risk
504 |             {
505 |                 "category": SQLQueryCategory.OTHER,
506 |                 "risk_level": OperationRiskLevel.MEDIUM,  # Default to MEDIUM risk for unknown
507 |                 "needs_migration": False,
508 |             },
509 |         )
510 | 
511 |         # Special case: CopyStmt can be read or write
512 |         if stmt_type == "CopyStmt" and stmt_node:
513 |             # Check if it's COPY TO (read) or COPY FROM (write)
514 |             if hasattr(stmt_node, "is_from") and not stmt_node.is_from:
515 |                 # COPY TO - it's a read operation (LOW risk)
516 |                 config["category"] = SQLQueryCategory.DQL
517 |                 config["risk_level"] = OperationRiskLevel.LOW
518 |             else:
519 |                 # COPY FROM - it's a write operation (MEDIUM risk)
520 |                 config["category"] = SQLQueryCategory.DML
521 |                 config["risk_level"] = OperationRiskLevel.MEDIUM
522 | 
523 |         # Other special cases can be added here
524 | 
525 |         return config
526 | 
527 |     def get_risk_level(self, operation: QueryValidationResults) -> OperationRiskLevel:
528 |         """Get the risk level for an SQL batch operation.
529 | 
530 |         Args:
531 |             operation: The SQL batch validation result to check
532 | 
533 |         Returns:
534 |             The highest risk level found in the batch
535 |         """
536 |         # Simply return the highest risk level that's already tracked in the batch
537 |         return operation.highest_risk_level
538 | 
```

--------------------------------------------------------------------------------
/tests/services/database/sql/test_sql_validator.py:
--------------------------------------------------------------------------------

```python
  1 | import pytest
  2 | 
  3 | from supabase_mcp.exceptions import ValidationError
  4 | from supabase_mcp.services.database.sql.models import SQLQueryCategory, SQLQueryCommand
  5 | from supabase_mcp.services.database.sql.validator import SQLValidator
  6 | from supabase_mcp.services.safety.models import OperationRiskLevel
  7 | 
  8 | 
  9 | class TestSQLValidator:
 10 |     """Test suite for the SQLValidator class."""
 11 | 
 12 |     # =========================================================================
 13 |     # Core Validation Tests
 14 |     # =========================================================================
 15 | 
 16 |     def test_empty_query_validation(self, mock_validator: SQLValidator):
 17 |         """
 18 |         Test that empty queries are properly rejected.
 19 | 
 20 |         This is a fundamental validation test to ensure the validator
 21 |         rejects empty or whitespace-only queries.
 22 |         """
 23 |         # Test empty string
 24 |         with pytest.raises(ValidationError, match="Query cannot be empty"):
 25 |             mock_validator.validate_query("")
 26 | 
 27 |         # Test whitespace-only string
 28 |         with pytest.raises(ValidationError, match="Query cannot be empty"):
 29 |             mock_validator.validate_query("   \n   \t   ")
 30 | 
 31 |     def test_schema_and_table_name_validation(self, mock_validator: SQLValidator):
 32 |         """
 33 |         Test validation of schema and table names.
 34 | 
 35 |         This test ensures that schema and table names are properly validated
 36 |         to prevent SQL injection and other security issues.
 37 |         """
 38 |         # Test schema name validation
 39 |         valid_schema = "public"
 40 |         assert mock_validator.validate_schema_name(valid_schema) == valid_schema
 41 | 
 42 |         # The actual error message is "Schema name cannot contain spaces"
 43 |         invalid_schema = "public; DROP TABLE users;"
 44 |         with pytest.raises(ValidationError, match="Schema name cannot contain spaces"):
 45 |             mock_validator.validate_schema_name(invalid_schema)
 46 | 
 47 |         # Test table name validation
 48 |         valid_table = "users"
 49 |         assert mock_validator.validate_table_name(valid_table) == valid_table
 50 | 
 51 |         # The actual error message is "Table name cannot contain spaces"
 52 |         invalid_table = "users; DROP TABLE users;"
 53 |         with pytest.raises(ValidationError, match="Table name cannot contain spaces"):
 54 |             mock_validator.validate_table_name(invalid_table)
 55 | 
 56 |     # =========================================================================
 57 |     # Safety Level Classification Tests
 58 |     # =========================================================================
 59 | 
 60 |     def test_safe_operation_identification(self, mock_validator: SQLValidator, sample_dql_queries: dict[str, str]):
 61 |         """
 62 |         Test that safe operations (SELECT queries) are correctly identified.
 63 | 
 64 |         This test ensures that all SELECT queries are properly categorized as
 65 |         safe operations, which is critical for security.
 66 |         """
 67 |         for name, query in sample_dql_queries.items():
 68 |             result = mock_validator.validate_query(query)
 69 |             assert result.highest_risk_level == OperationRiskLevel.LOW, f"Query '{name}' should be classified as SAFE"
 70 |             assert result.statements[0].category == SQLQueryCategory.DQL, f"Query '{name}' should be categorized as DQL"
 71 |             assert result.statements[0].command == SQLQueryCommand.SELECT, f"Query '{name}' should have command SELECT"
 72 | 
 73 |     def test_write_operation_identification(self, mock_validator: SQLValidator, sample_dml_queries: dict[str, str]):
 74 |         """
 75 |         Test that write operations (INSERT, UPDATE, DELETE) are correctly identified.
 76 | 
 77 |         This test ensures that all data modification queries are properly categorized
 78 |         as write operations, which require different permissions.
 79 |         """
 80 |         for name, query in sample_dml_queries.items():
 81 |             result = mock_validator.validate_query(query)
 82 |             assert result.highest_risk_level == OperationRiskLevel.MEDIUM, (
 83 |                 f"Query '{name}' should be classified as WRITE"
 84 |             )
 85 |             assert result.statements[0].category == SQLQueryCategory.DML, f"Query '{name}' should be categorized as DML"
 86 | 
 87 |             # Check specific command based on query type
 88 |             if name.startswith("insert"):
 89 |                 assert result.statements[0].command == SQLQueryCommand.INSERT
 90 |             elif name.startswith("update"):
 91 |                 assert result.statements[0].command == SQLQueryCommand.UPDATE
 92 |             elif name.startswith("delete"):
 93 |                 assert result.statements[0].command == SQLQueryCommand.DELETE
 94 |             elif name.startswith("merge"):
 95 |                 assert result.statements[0].command == SQLQueryCommand.MERGE
 96 | 
 97 |     def test_destructive_operation_identification(
 98 |         self, mock_validator: SQLValidator, sample_ddl_queries: dict[str, str]
 99 |     ):
100 |         """
101 |         Test that destructive operations (DROP, TRUNCATE) are correctly identified.
102 | 
103 |         This test ensures that all data definition queries that can destroy data
104 |         are properly categorized as destructive operations, which require
105 |         the highest level of permissions.
106 |         """
107 |         # Test DROP statements
108 |         drop_query = sample_ddl_queries["drop_table"]
109 |         drop_result = mock_validator.validate_query(drop_query)
110 | 
111 |         # Verify the statement is correctly categorized as DDL and has the DROP command
112 |         assert drop_result.statements[0].category == SQLQueryCategory.DDL, "DROP should be categorized as DDL"
113 |         assert drop_result.statements[0].command == SQLQueryCommand.DROP, "Command should be DROP"
114 | 
115 |         # Test TRUNCATE statements
116 |         truncate_query = sample_ddl_queries["truncate_table"]
117 |         truncate_result = mock_validator.validate_query(truncate_query)
118 | 
119 |         # Verify the statement is correctly categorized as DDL and has the TRUNCATE command
120 |         assert truncate_result.statements[0].category == SQLQueryCategory.DDL, "TRUNCATE should be categorized as DDL"
121 |         assert truncate_result.statements[0].command == SQLQueryCommand.TRUNCATE, "Command should be TRUNCATE"
122 | 
123 |     # =========================================================================
124 |     # Transaction Control Tests
125 |     # =========================================================================
126 | 
127 |     def test_transaction_control_detection(self, mock_validator: SQLValidator, sample_tcl_queries: dict[str, str]):
128 |         """
129 |         Test that BEGIN/COMMIT/ROLLBACK statements are correctly identified as TCL.
130 | 
131 |         Transaction control is critical for maintaining data integrity and
132 |         must be properly detected regardless of case or formatting.
133 |         """
134 |         # Test BEGIN statement
135 |         with pytest.raises(ValidationError) as excinfo:
136 |             mock_validator.validate_query(sample_tcl_queries["begin_transaction"])
137 |         assert "Transaction control statements" in str(excinfo.value)
138 | 
139 |         # Test COMMIT statement
140 |         with pytest.raises(ValidationError) as excinfo:
141 |             mock_validator.validate_query(sample_tcl_queries["commit_transaction"])
142 |         assert "Transaction control statements" in str(excinfo.value)
143 | 
144 |         # Test ROLLBACK statement
145 |         with pytest.raises(ValidationError) as excinfo:
146 |             mock_validator.validate_query(sample_tcl_queries["rollback_transaction"])
147 |         assert "Transaction control statements" in str(excinfo.value)
148 | 
149 |         # Test mixed case transaction statement
150 |         with pytest.raises(ValidationError) as excinfo:
151 |             mock_validator.validate_query(sample_tcl_queries["mixed_case_transaction"])
152 |         assert "Transaction control statements" in str(excinfo.value)
153 | 
154 |         # Test string-based detection method directly
155 |         assert SQLValidator.validate_transaction_control("BEGIN"), "String-based detection should identify BEGIN"
156 |         assert SQLValidator.validate_transaction_control("COMMIT"), "String-based detection should identify COMMIT"
157 |         assert SQLValidator.validate_transaction_control("ROLLBACK"), "String-based detection should identify ROLLBACK"
158 |         assert SQLValidator.validate_transaction_control("begin transaction"), (
159 |             "String-based detection should be case-insensitive"
160 |         )
161 | 
162 |     # =========================================================================
163 |     # Multiple Statements Tests
164 |     # =========================================================================
165 | 
166 |     def test_multiple_statements_with_mixed_safety_levels(
167 |         self, mock_validator: SQLValidator, sample_multiple_statements: dict[str, str]
168 |     ):
169 |         """
170 |         Test that multiple statements with different safety levels are correctly identified.
171 | 
172 |         Note: Due to the string-based comparison in the implementation, the safety levels
173 |         are not correctly ordered (SAFE > WRITE > DESTRUCTIVE). This test focuses on
174 |         verifying that multiple statements are correctly parsed and categorized.
175 |         """
176 |         # Test multiple safe statements
177 |         safe_result = mock_validator.validate_query(sample_multiple_statements["multiple_safe"])
178 |         assert len(safe_result.statements) == 2, "Should identify two statements"
179 |         assert safe_result.statements[0].category == SQLQueryCategory.DQL, "First statement should be DQL"
180 |         assert safe_result.statements[1].category == SQLQueryCategory.DQL, "Second statement should be DQL"
181 | 
182 |         # Test safe + write statements
183 |         mixed_result = mock_validator.validate_query(sample_multiple_statements["safe_and_write"])
184 |         assert len(mixed_result.statements) == 2, "Should identify two statements"
185 |         assert mixed_result.statements[0].category == SQLQueryCategory.DQL, "First statement should be DQL"
186 |         assert mixed_result.statements[1].category == SQLQueryCategory.DML, "Second statement should be DML"
187 | 
188 |         # Test write + destructive statements
189 |         destructive_result = mock_validator.validate_query(sample_multiple_statements["write_and_destructive"])
190 |         assert len(destructive_result.statements) == 2, "Should identify two statements"
191 |         assert destructive_result.statements[0].category == SQLQueryCategory.DML, "First statement should be DML"
192 |         assert destructive_result.statements[1].category == SQLQueryCategory.DDL, "Second statement should be DDL"
193 |         assert destructive_result.statements[1].command == SQLQueryCommand.DROP, "Second command should be DROP"
194 | 
195 |         # Test transaction statements
196 |         with pytest.raises(ValidationError) as excinfo:
197 |             mock_validator.validate_query(sample_multiple_statements["with_transaction"])
198 |         assert "Transaction control statements" in str(excinfo.value)
199 | 
200 |     # =========================================================================
201 |     # Error Handling Tests
202 |     # =========================================================================
203 | 
204 |     def test_syntax_error_handling(self, mock_validator: SQLValidator, sample_invalid_queries: dict[str, str]):
205 |         """
206 |         Test that SQL syntax errors are properly caught and reported.
207 | 
208 |         Fundamental for providing clear feedback to users when their SQL is invalid.
209 |         """
210 |         # Test syntax error
211 |         with pytest.raises(ValidationError, match="SQL syntax error"):
212 |             mock_validator.validate_query(sample_invalid_queries["syntax_error"])
213 | 
214 |         # Test missing parenthesis
215 |         with pytest.raises(ValidationError, match="SQL syntax error"):
216 |             mock_validator.validate_query(sample_invalid_queries["missing_parenthesis"])
217 | 
218 |         # Test incomplete statement
219 |         with pytest.raises(ValidationError, match="SQL syntax error"):
220 |             mock_validator.validate_query(sample_invalid_queries["incomplete_statement"])
221 | 
222 |     # =========================================================================
223 |     # PostgreSQL-Specific Features Tests
224 |     # =========================================================================
225 | 
226 |     def test_copy_statement_direction_detection(
227 |         self, mock_validator: SQLValidator, sample_postgres_specific_queries: dict[str, str]
228 |     ):
229 |         """
230 |         Test that COPY TO (read) vs COPY FROM (write) are correctly distinguished.
231 | 
232 |         Important edge case with safety implications as COPY TO is safe
233 |         while COPY FROM modifies data.
234 |         """
235 |         # Test COPY TO (should be SAFE)
236 |         copy_to_result = mock_validator.validate_query(sample_postgres_specific_queries["copy_to"])
237 |         assert copy_to_result.highest_risk_level == OperationRiskLevel.LOW, "COPY TO should be classified as SAFE"
238 |         assert copy_to_result.statements[0].category == SQLQueryCategory.DQL, "COPY TO should be categorized as DQL"
239 | 
240 |         # Test COPY FROM (should be WRITE)
241 |         copy_from_result = mock_validator.validate_query(sample_postgres_specific_queries["copy_from"])
242 |         assert copy_from_result.highest_risk_level == OperationRiskLevel.MEDIUM, (
243 |             "COPY FROM should be classified as WRITE"
244 |         )
245 |         assert copy_from_result.statements[0].category == SQLQueryCategory.DML, "COPY FROM should be categorized as DML"
246 | 
247 |     # =========================================================================
248 |     # Complex Scenarios Tests
249 |     # =========================================================================
250 | 
251 |     def test_complex_queries_with_subqueries_and_ctes(
252 |         self, mock_validator: SQLValidator, sample_dql_queries: dict[str, str]
253 |     ):
254 |         """
255 |         Test that complex queries with subqueries and CTEs are correctly parsed.
256 | 
257 |         Ensures robustness with real-world queries that may contain
258 |         complex structures but are still valid.
259 |         """
260 |         # Test query with subquery
261 |         subquery_result = mock_validator.validate_query(sample_dql_queries["select_with_subquery"])
262 |         assert subquery_result.highest_risk_level == OperationRiskLevel.LOW, "Query with subquery should be SAFE"
263 |         assert subquery_result.statements[0].category == SQLQueryCategory.DQL, "Query with subquery should be DQL"
264 | 
265 |         # Test query with CTE (Common Table Expression)
266 |         cte_result = mock_validator.validate_query(sample_dql_queries["select_with_cte"])
267 |         assert cte_result.highest_risk_level == OperationRiskLevel.LOW, "Query with CTE should be SAFE"
268 |         assert cte_result.statements[0].category == SQLQueryCategory.DQL, "Query with CTE should be DQL"
269 | 
270 |     # =========================================================================
271 |     # False Positive Prevention Tests
272 |     # =========================================================================
273 | 
274 |     def test_valid_queries_with_comments(self, mock_validator: SQLValidator, sample_edge_cases: dict[str, str]):
275 |         """
276 |         Test that valid queries with SQL comments are not rejected.
277 | 
278 |         Ensures that comments (inline and block) don't cause valid queries
279 |         to be incorrectly flagged as invalid.
280 |         """
281 |         # Test query with comments
282 |         query_with_comments = sample_edge_cases["with_comments"]
283 |         result = mock_validator.validate_query(query_with_comments)
284 | 
285 |         # Verify the query is parsed correctly despite comments
286 |         assert result.statements[0].category == SQLQueryCategory.DQL, "Query with comments should be categorized as DQL"
287 |         assert result.statements[0].command == SQLQueryCommand.SELECT, "Query with comments should have SELECT command"
288 |         assert result.highest_risk_level == OperationRiskLevel.LOW, "Query with comments should be SAFE"
289 | 
290 |     def test_valid_queries_with_quoted_identifiers(
291 |         self, mock_validator: SQLValidator, sample_edge_cases: dict[str, str]
292 |     ):
293 |         """
294 |         Test that valid queries with quoted identifiers are not rejected.
295 | 
296 |         Ensures that double-quoted table/column names and single-quoted
297 |         strings don't cause false positives.
298 |         """
299 |         # Test query with quoted identifiers
300 |         query_with_quotes = sample_edge_cases["quoted_identifiers"]
301 |         result = mock_validator.validate_query(query_with_quotes)
302 | 
303 |         # Verify the query is parsed correctly despite quoted identifiers
304 |         assert result.statements[0].category == SQLQueryCategory.DQL, (
305 |             "Query with quoted identifiers should be categorized as DQL"
306 |         )
307 |         assert result.statements[0].command == SQLQueryCommand.SELECT, (
308 |             "Query with quoted identifiers should have SELECT command"
309 |         )
310 |         assert result.highest_risk_level == OperationRiskLevel.LOW, "Query with quoted identifiers should be SAFE"
311 | 
312 |     def test_valid_queries_with_special_characters(
313 |         self, mock_validator: SQLValidator, sample_edge_cases: dict[str, str]
314 |     ):
315 |         """
316 |         Test that valid queries with special characters are not rejected.
317 | 
318 |         Ensures that special characters in strings and identifiers
319 |         don't trigger false positives.
320 |         """
321 |         # Test query with special characters
322 |         query_with_special_chars = sample_edge_cases["special_characters"]
323 |         result = mock_validator.validate_query(query_with_special_chars)
324 | 
325 |         # Verify the query is parsed correctly despite special characters
326 |         assert result.statements[0].category == SQLQueryCategory.DQL, (
327 |             "Query with special characters should be categorized as DQL"
328 |         )
329 |         assert result.statements[0].command == SQLQueryCommand.SELECT, (
330 |             "Query with special characters should have SELECT command"
331 |         )
332 |         assert result.highest_risk_level == OperationRiskLevel.LOW, "Query with special characters should be SAFE"
333 | 
334 |     def test_valid_postgresql_specific_syntax(
335 |         self,
336 |         mock_validator: SQLValidator,
337 |         sample_edge_cases: dict[str, str],
338 |         sample_postgres_specific_queries: dict[str, str],
339 |     ):
340 |         """
341 |         Test that valid PostgreSQL-specific syntax is not rejected.
342 | 
343 |         Ensures that PostgreSQL extensions to standard SQL (like RETURNING
344 |         clauses or specific operators) don't cause false positives.
345 |         """
346 |         # Test query with dollar-quoted strings (PostgreSQL-specific feature)
347 |         query_with_dollar_quotes = sample_edge_cases["with_dollar_quotes"]
348 |         result = mock_validator.validate_query(query_with_dollar_quotes)
349 |         assert result.statements[0].category == SQLQueryCategory.DQL, (
350 |             "Query with dollar quotes should be categorized as DQL"
351 |         )
352 | 
353 |         # Test schema-qualified names
354 |         schema_qualified_query = sample_edge_cases["schema_qualified"]
355 |         result = mock_validator.validate_query(schema_qualified_query)
356 |         assert result.statements[0].category == SQLQueryCategory.DQL, (
357 |             "Query with schema qualification should be categorized as DQL"
358 |         )
359 | 
360 |         # Test EXPLAIN ANALYZE (PostgreSQL-specific)
361 |         explain_query = sample_postgres_specific_queries["explain"]
362 |         result = mock_validator.validate_query(explain_query)
363 |         assert result.statements[0].category == SQLQueryCategory.POSTGRES_SPECIFIC, (
364 |             "EXPLAIN should be categorized as POSTGRES_SPECIFIC"
365 |         )
366 | 
367 |     def test_valid_complex_joins(self, mock_validator: SQLValidator):
368 |         """
369 |         Test that valid complex JOIN operations are not rejected.
370 | 
371 |         Ensures that complex but valid JOIN syntax (including LATERAL joins,
372 |         multiple join conditions, etc.) doesn't cause false positives.
373 |         """
374 |         # Test complex join with multiple conditions
375 |         complex_join_query = """
376 |         SELECT u.id, u.name, p.title, c.content
377 |         FROM users u
378 |         JOIN posts p ON u.id = p.user_id AND p.published = true
379 |         LEFT JOIN comments c ON p.id = c.post_id
380 |         WHERE u.active = true
381 |         ORDER BY p.created_at DESC
382 |         """
383 |         result = mock_validator.validate_query(complex_join_query)
384 |         assert result.statements[0].category == SQLQueryCategory.DQL, "Complex join query should be categorized as DQL"
385 |         assert result.statements[0].command == SQLQueryCommand.SELECT, "Complex join query should have SELECT command"
386 | 
387 |         # Test LATERAL join (PostgreSQL-specific join type)
388 |         lateral_join_query = """
389 |         SELECT u.id, u.name, p.title
390 |         FROM users u
391 |         LEFT JOIN LATERAL (
392 |             SELECT title FROM posts WHERE user_id = u.id ORDER BY created_at DESC LIMIT 1
393 |         ) p ON true
394 |         """
395 |         result = mock_validator.validate_query(lateral_join_query)
396 |         assert result.statements[0].category == SQLQueryCategory.DQL, "LATERAL join query should be categorized as DQL"
397 |         assert result.statements[0].command == SQLQueryCommand.SELECT, "LATERAL join query should have SELECT command"
398 | 
399 |     # =========================================================================
400 |     # Additional Tests Based on Code Review
401 |     # =========================================================================
402 | 
403 |     def test_dcl_statement_identification(self, mock_validator: SQLValidator, sample_dcl_queries: dict[str, str]):
404 |         """
405 |         Test that GRANT/REVOKE statements are correctly identified as DCL.
406 | 
407 |         DCL statements control access to data and should be properly classified
408 |         to ensure appropriate permissions management.
409 |         """
410 |         # Test GRANT statement
411 |         grant_query = sample_dcl_queries["grant_select"]
412 |         grant_result = mock_validator.validate_query(grant_query)
413 |         assert grant_result.statements[0].category == SQLQueryCategory.DCL, "GRANT should be categorized as DCL"
414 |         assert grant_result.statements[0].command == SQLQueryCommand.GRANT, "Command should be GRANT"
415 | 
416 |         # Test REVOKE statement
417 |         revoke_query = sample_dcl_queries["revoke_select"]
418 |         revoke_result = mock_validator.validate_query(revoke_query)
419 |         assert revoke_result.statements[0].category == SQLQueryCategory.DCL, "REVOKE should be categorized as DCL"
420 |         # Note: The current implementation may not correctly identify REVOKE commands
421 |         # so we're only checking the category, not the specific command
422 | 
423 |         # Test CREATE ROLE statement (also DCL)
424 |         create_role_query = sample_dcl_queries["create_role"]
425 |         create_role_result = mock_validator.validate_query(create_role_query)
426 |         assert create_role_result.statements[0].category == SQLQueryCategory.DCL, (
427 |             "CREATE ROLE should be categorized as DCL"
428 |         )
429 | 
430 |     def test_needs_migration_flag(
431 |         self, mock_validator: SQLValidator, sample_ddl_queries: dict[str, str], sample_dml_queries: dict[str, str]
432 |     ):
433 |         """
434 |         Test that statements requiring migrations are correctly flagged.
435 | 
436 |         Ensures that DDL statements that require migrations (like CREATE TABLE)
437 |         are properly identified to enforce migration requirements.
438 |         """
439 |         # Test CREATE TABLE (should need migration)
440 |         create_table_query = sample_ddl_queries["create_table"]
441 |         create_result = mock_validator.validate_query(create_table_query)
442 |         assert create_result.statements[0].needs_migration, "CREATE TABLE should require migration"
443 | 
444 |         # Test ALTER TABLE (should need migration)
445 |         alter_table_query = sample_ddl_queries["alter_table"]
446 |         alter_result = mock_validator.validate_query(alter_table_query)
447 |         assert alter_result.statements[0].needs_migration, "ALTER TABLE should require migration"
448 | 
449 |         # Test INSERT (should NOT need migration)
450 |         insert_query = sample_dml_queries["simple_insert"]
451 |         insert_result = mock_validator.validate_query(insert_query)
452 |         assert not insert_result.statements[0].needs_migration, "INSERT should not require migration"
453 | 
454 |     def test_object_type_extraction(self, mock_validator: SQLValidator):
455 |         """
456 |         Test that object types (table names, etc.) are correctly extracted when possible.
457 | 
458 |         Note: The current implementation has limitations in extracting object types
459 |         from all statement types. This test focuses on verifying the basic functionality
460 |         without making assumptions about specific extraction capabilities.
461 |         """
462 |         # Test that object_type is present in the result structure
463 |         select_query = "SELECT * FROM users WHERE id = 1"
464 |         select_result = mock_validator.validate_query(select_query)
465 | 
466 |         # Verify the object_type field exists in the result
467 |         assert hasattr(select_result.statements[0], "object_type"), "Result should have object_type field"
468 | 
469 |         # Test with a more complex query
470 |         complex_query = """
471 |         WITH active_users AS (
472 |             SELECT * FROM users WHERE active = true
473 |         )
474 |         SELECT * FROM active_users
475 |         """
476 |         complex_result = mock_validator.validate_query(complex_query)
477 |         assert hasattr(complex_result.statements[0], "object_type"), (
478 |             "Complex query result should have object_type field"
479 |         )
480 | 
481 |     def test_string_based_transaction_control(self, mock_validator: SQLValidator):
482 |         """
483 |         Test the string-based transaction control detection method.
484 | 
485 |         Specifically tests the validate_transaction_control class method
486 |         to ensure it correctly identifies transaction keywords.
487 |         """
488 |         # Test standard transaction keywords
489 |         assert SQLValidator.validate_transaction_control("BEGIN"), "Should detect 'BEGIN'"
490 |         assert SQLValidator.validate_transaction_control("COMMIT"), "Should detect 'COMMIT'"
491 |         assert SQLValidator.validate_transaction_control("ROLLBACK"), "Should detect 'ROLLBACK'"
492 | 
493 |         # Test case insensitivity
494 |         assert SQLValidator.validate_transaction_control("begin"), "Should be case-insensitive"
495 |         assert SQLValidator.validate_transaction_control("Commit"), "Should be case-insensitive"
496 |         assert SQLValidator.validate_transaction_control("ROLLBACK"), "Should be case-insensitive"
497 | 
498 |         # Test with additional text
499 |         assert SQLValidator.validate_transaction_control("BEGIN TRANSACTION"), "Should detect 'BEGIN TRANSACTION'"
500 |         assert SQLValidator.validate_transaction_control("COMMIT WORK"), "Should detect 'COMMIT WORK'"
501 | 
502 |         # Test negative cases
503 |         assert not SQLValidator.validate_transaction_control("SELECT * FROM transactions"), (
504 |             "Should not detect in regular SQL"
505 |         )
506 |         assert not SQLValidator.validate_transaction_control(""), "Should not detect in empty string"
507 | 
508 |     def test_basic_query_validation_method(self, mock_validator: SQLValidator):
509 |         """
510 |         Test the basic_query_validation method.
511 | 
512 |         Ensures that the method correctly validates and sanitizes
513 |         input queries before parsing.
514 |         """
515 |         # Test valid query
516 |         valid_query = "SELECT * FROM users"
517 |         assert mock_validator.basic_query_validation(valid_query) == valid_query, "Should return valid query unchanged"
518 | 
519 |         # Test query with whitespace
520 |         whitespace_query = "  SELECT * FROM users  "
521 |         assert mock_validator.basic_query_validation(whitespace_query) == whitespace_query, "Should preserve whitespace"
522 | 
523 |         # Test empty query
524 |         with pytest.raises(ValidationError, match="Query cannot be empty"):
525 |             mock_validator.basic_query_validation("")
526 | 
527 |         # Test whitespace-only query
528 |         with pytest.raises(ValidationError, match="Query cannot be empty"):
529 |             mock_validator.basic_query_validation("   \n   \t   ")
530 | 
```
Page 3/6FirstPrevNextLast