This is page 2 of 5. Use http://codebase.md/alexander-zuev/supabase-mcp-server?lines=false&page={x} to view the full context. # Directory Structure ``` ├── .claude │ └── settings.local.json ├── .dockerignore ├── .env.example ├── .env.test.example ├── .github │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.md │ │ ├── feature_request.md │ │ └── roadmap_item.md │ ├── PULL_REQUEST_TEMPLATE.md │ └── workflows │ ├── ci.yaml │ ├── docs │ │ └── release-checklist.md │ └── publish.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── CHANGELOG.MD ├── codecov.yml ├── CONTRIBUTING.MD ├── Dockerfile ├── LICENSE ├── llms-full.txt ├── pyproject.toml ├── README.md ├── smithery.yaml ├── supabase_mcp │ ├── __init__.py │ ├── clients │ │ ├── api_client.py │ │ ├── base_http_client.py │ │ ├── management_client.py │ │ └── sdk_client.py │ ├── core │ │ ├── __init__.py │ │ ├── container.py │ │ └── feature_manager.py │ ├── exceptions.py │ ├── logger.py │ ├── main.py │ ├── services │ │ ├── __init__.py │ │ ├── api │ │ │ ├── __init__.py │ │ │ ├── api_manager.py │ │ │ ├── spec_manager.py │ │ │ └── specs │ │ │ └── api_spec.json │ │ ├── database │ │ │ ├── __init__.py │ │ │ ├── migration_manager.py │ │ │ ├── postgres_client.py │ │ │ ├── query_manager.py │ │ │ └── sql │ │ │ ├── loader.py │ │ │ ├── models.py │ │ │ ├── queries │ │ │ │ ├── create_migration.sql │ │ │ │ ├── get_migrations.sql │ │ │ │ ├── get_schemas.sql │ │ │ │ ├── get_table_schema.sql │ │ │ │ ├── get_tables.sql │ │ │ │ ├── init_migrations.sql │ │ │ │ └── logs │ │ │ │ ├── auth_logs.sql │ │ │ │ ├── cron_logs.sql │ │ │ │ ├── edge_logs.sql │ │ │ │ ├── function_edge_logs.sql │ │ │ │ ├── pgbouncer_logs.sql │ │ │ │ ├── postgres_logs.sql │ │ │ │ ├── postgrest_logs.sql │ │ │ │ ├── realtime_logs.sql │ │ │ │ ├── storage_logs.sql │ │ │ │ └── supavisor_logs.sql │ │ │ └── validator.py │ │ ├── logs │ │ │ ├── __init__.py │ │ │ └── log_manager.py │ │ ├── safety │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ ├── safety_configs.py │ │ │ └── safety_manager.py │ │ └── sdk │ │ ├── __init__.py │ │ ├── auth_admin_models.py │ │ └── auth_admin_sdk_spec.py │ ├── settings.py │ └── tools │ ├── __init__.py │ ├── descriptions │ │ ├── api_tools.yaml │ │ ├── database_tools.yaml │ │ ├── logs_and_analytics_tools.yaml │ │ ├── safety_tools.yaml │ │ └── sdk_tools.yaml │ ├── manager.py │ └── registry.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── services │ │ ├── __init__.py │ │ ├── api │ │ │ ├── __init__.py │ │ │ ├── test_api_client.py │ │ │ ├── test_api_manager.py │ │ │ └── test_spec_manager.py │ │ ├── database │ │ │ ├── sql │ │ │ │ ├── __init__.py │ │ │ │ ├── conftest.py │ │ │ │ ├── test_loader.py │ │ │ │ ├── test_sql_validator_integration.py │ │ │ │ └── test_sql_validator.py │ │ │ ├── test_migration_manager.py │ │ │ ├── test_postgres_client.py │ │ │ └── test_query_manager.py │ │ ├── logs │ │ │ └── test_log_manager.py │ │ ├── safety │ │ │ ├── test_api_safety_config.py │ │ │ ├── test_safety_manager.py │ │ │ └── test_sql_safety_config.py │ │ └── sdk │ │ ├── test_auth_admin_models.py │ │ └── test_sdk_client.py │ ├── test_container.py │ ├── test_main.py │ ├── test_settings.py │ ├── test_tool_manager.py │ ├── test_tools_integration.py.bak │ └── test_tools.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /supabase_mcp/services/database/query_manager.py: -------------------------------------------------------------------------------- ```python from supabase_mcp.exceptions import OperationNotAllowedError from supabase_mcp.logger import logger from supabase_mcp.services.database.migration_manager import MigrationManager from supabase_mcp.services.database.postgres_client import PostgresClient, QueryResult from supabase_mcp.services.database.sql.loader import SQLLoader from supabase_mcp.services.database.sql.models import QueryValidationResults from supabase_mcp.services.database.sql.validator import SQLValidator from supabase_mcp.services.safety.models import ClientType, SafetyMode from supabase_mcp.services.safety.safety_manager import SafetyManager class QueryManager: """ Manages SQL query execution with validation and migration handling. This class is responsible for: 1. Validating SQL queries for safety 2. Executing queries through the database client 3. Managing migrations for queries that require them 4. Loading SQL queries from files It acts as a central point for all SQL operations, ensuring consistent validation and execution patterns. """ def __init__( self, postgres_client: PostgresClient, safety_manager: SafetyManager, sql_validator: SQLValidator | None = None, migration_manager: MigrationManager | None = None, sql_loader: SQLLoader | None = None, ): """ Initialize the QueryManager. Args: postgres_client: The database client to use for executing queries safety_manager: The safety manager to use for validating operations sql_validator: Optional SQL validator to use migration_manager: Optional migration manager to use sql_loader: Optional SQL loader to use """ self.db_client = postgres_client self.safety_manager = safety_manager self.validator = sql_validator or SQLValidator() self.sql_loader = sql_loader or SQLLoader() self.migration_manager = migration_manager or MigrationManager(loader=self.sql_loader) def check_readonly(self) -> bool: """Returns true if current safety mode is SAFE.""" result = self.safety_manager.get_safety_mode(ClientType.DATABASE) == SafetyMode.SAFE logger.debug(f"Check readonly result: {result}") return result async def handle_query(self, query: str, has_confirmation: bool = False, migration_name: str = "") -> QueryResult: """ Handle a SQL query with validation and potential migration. Uses migration name, if provided. This method: 1. Validates the query for safety 2. Checks if the query requires migration 3. Handles migration if needed 4. Executes the query Args: query: SQL query to execute params: Query parameters has_confirmation: Whether the operation has been confirmed by the user Returns: QueryResult: The result of the query execution Raises: OperationNotAllowedError: If the query is not allowed in the current safety mode ConfirmationRequiredError: If the query requires confirmation and has_confirmation is False """ # 1. Run through the validator validated_query = self.validator.validate_query(query) # 2. Ensure execution is allowed self.safety_manager.validate_operation(ClientType.DATABASE, validated_query, has_confirmation) logger.debug(f"Operation with risk level {validated_query.highest_risk_level} validated successfully") # 3. Handle migration if needed await self.handle_migration(validated_query, query, migration_name) # 4. Execute the query return await self.handle_query_execution(validated_query) async def handle_query_execution(self, validated_query: QueryValidationResults) -> QueryResult: """ Handle query execution with validation and potential migration. This method: 1. Checks the readonly mode 2. Executes the query 3. Returns the result Args: validated_query: The validation result query: The original query Returns: QueryResult: The result of the query execution """ readonly = self.check_readonly() result = await self.db_client.execute_query(validated_query, readonly) logger.debug(f"Query result: {result}") return result async def handle_migration( self, validation_result: QueryValidationResults, original_query: str, migration_name: str = "" ) -> None: """ Handle migration for a query that requires it. Args: validation_result: The validation result query: The original query migration_name: Migration name to use, if provided """ # 1. Check if migration is needed if not validation_result.needs_migration(): logger.debug("No migration needed for this query") return # 2. Prepare migration query migration_query, name = self.migration_manager.prepare_migration_query( validation_result, original_query, migration_name ) logger.debug("Migration query prepared") # 3. Execute migration query try: # First, ensure the migration schema exists await self.init_migration_schema() # Then execute the migration query migration_validation = self.validator.validate_query(migration_query) await self.db_client.execute_query(migration_validation, readonly=False) logger.info(f"Migration '{name}' executed successfully") except Exception as e: logger.debug(f"Migration failure details: {str(e)}") # We don't want to fail the main query if migration fails # Just log the error and continue logger.warning(f"Failed to record migration '{name}': {e}") async def init_migration_schema(self) -> None: """Initialize the migrations schema and table if they don't exist.""" try: # Get the initialization query init_query = self.sql_loader.get_init_migrations_query() # Validate and execute it init_validation = self.validator.validate_query(init_query) await self.db_client.execute_query(init_validation, readonly=False) logger.debug("Migrations schema initialized successfully") except Exception as e: logger.warning(f"Failed to initialize migrations schema: {e}") async def handle_confirmation(self, confirmation_id: str) -> QueryResult: """ Handle a confirmed operation using its confirmation ID. This method retrieves the stored operation and passes it to handle_query. Args: confirmation_id: The unique ID of the confirmation to process Returns: QueryResult: The result of the query execution """ # Get the stored operation operation = self.safety_manager.get_stored_operation(confirmation_id) if not operation: raise OperationNotAllowedError(f"Invalid or expired confirmation ID: {confirmation_id}") # Get the query from the operation query = operation.original_query logger.debug(f"Processing confirmed operation with ID {confirmation_id}") # Call handle_query with the query and has_confirmation=True return await self.handle_query(query, has_confirmation=True) def get_schemas_query(self) -> str: """Get a query to list all schemas.""" return self.sql_loader.get_schemas_query() def get_tables_query(self, schema_name: str) -> str: """Get a query to list all tables in a schema.""" return self.sql_loader.get_tables_query(schema_name) def get_table_schema_query(self, schema_name: str, table: str) -> str: """Get a query to get the schema of a table.""" return self.sql_loader.get_table_schema_query(schema_name, table) def get_migrations_query( self, limit: int = 50, offset: int = 0, name_pattern: str = "", include_full_queries: bool = False ) -> str: """Get a query to list migrations.""" return self.sql_loader.get_migrations_query( limit=limit, offset=offset, name_pattern=name_pattern, include_full_queries=include_full_queries ) ``` -------------------------------------------------------------------------------- /supabase_mcp/clients/management_client.py: -------------------------------------------------------------------------------- ```python from __future__ import annotations from json.decoder import JSONDecodeError from typing import Any import httpx from httpx import Request, Response from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential from supabase_mcp.exceptions import ( APIClientError, APIConnectionError, APIResponseError, APIServerError, UnexpectedError, ) from supabase_mcp.logger import logger from supabase_mcp.settings import Settings # Helper function for retry decorator to safely log exceptions def log_retry_attempt(retry_state: RetryCallState) -> None: """Log retry attempts with exception details if available.""" exception = retry_state.outcome.exception() if retry_state.outcome and retry_state.outcome.failed else None exception_str = str(exception) if exception else "Unknown error" logger.warning(f"Network error, retrying ({retry_state.attempt_number}/3): {exception_str}") class ManagementAPIClient: """ Client for Supabase Management API. Handles low-level HTTP requests to the Supabase Management API. """ def __init__(self, settings: Settings) -> None: """Initialize the API client with default settings.""" self.settings = settings self.client = self.create_httpx_client(settings) logger.info("✔️ Management API client initialized successfully") def create_httpx_client(self, settings: Settings) -> httpx.AsyncClient: """Create and configure an httpx client for API requests.""" headers = { "Authorization": f"Bearer {settings.supabase_access_token}", "Content-Type": "application/json", } return httpx.AsyncClient( base_url=settings.supabase_api_url, headers=headers, timeout=30.0, ) def prepare_request( self, method: str, path: str, request_params: dict[str, Any] | None = None, request_body: dict[str, Any] | None = None, ) -> Request: """ Prepare an HTTP request to the Supabase Management API. Args: method: HTTP method (GET, POST, etc.) path: API path request_params: Query parameters request_body: Request body Returns: Prepared httpx.Request object Raises: APIClientError: If request preparation fails """ try: return self.client.build_request(method=method, url=path, params=request_params, json=request_body) except Exception as e: raise APIClientError( message=f"Failed to build request: {str(e)}", status_code=None, ) from e @retry( retry=retry_if_exception_type(httpx.NetworkError), # This includes ConnectError and TimeoutException stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), reraise=True, # Ensure the original exception is raised before_sleep=log_retry_attempt, ) async def send_request(self, request: Request) -> Response: """ Send an HTTP request with retry logic for transient errors. Args: request: Prepared httpx.Request object Returns: httpx.Response object Raises: APIConnectionError: For connection issues APIClientError: For other request errors """ try: return await self.client.send(request) except httpx.NetworkError as e: # All NetworkErrors will be retried by the decorator # This will only be reached after all retries are exhausted logger.error(f"Network error after all retry attempts: {str(e)}") raise APIConnectionError( message=f"Network error after 3 retry attempts: {str(e)}", status_code=None, ) from e except Exception as e: # Other exceptions won't be retried raise APIClientError( message=f"Request failed: {str(e)}", status_code=None, ) from e def parse_response(self, response: Response) -> dict[str, Any]: """ Parse an HTTP response as JSON. Args: response: httpx.Response object Returns: Parsed response body as dictionary Raises: APIResponseError: If response cannot be parsed as JSON """ if not response.content: return {} try: return response.json() except JSONDecodeError as e: raise APIResponseError( message=f"Failed to parse response as JSON: {str(e)}", status_code=response.status_code, response_body={"raw_content": response.text}, ) from e def handle_error_response(self, response: Response, parsed_body: dict[str, Any] | None = None) -> None: """ Handle error responses based on status code. Args: response: httpx.Response object parsed_body: Parsed response body if available Raises: APIClientError: For client errors (4xx) APIServerError: For server errors (5xx) UnexpectedError: For unexpected status codes """ # Extract error message error_message = f"API request failed: {response.status_code}" if parsed_body and "message" in parsed_body: error_message = parsed_body["message"] # Determine error type based on status code if 400 <= response.status_code < 500: raise APIClientError( message=error_message, status_code=response.status_code, response_body=parsed_body, ) elif response.status_code >= 500: raise APIServerError( message=error_message, status_code=response.status_code, response_body=parsed_body, ) else: # This should not happen, but just in case raise UnexpectedError( message=f"Unexpected status code: {response.status_code}", status_code=response.status_code, response_body=parsed_body, ) async def execute_request( self, method: str, path: str, request_params: dict[str, Any] | None = None, request_body: dict[str, Any] | None = None, ) -> dict[str, Any]: """ Execute an HTTP request to the Supabase Management API. Args: method: HTTP method (GET, POST, etc.) path: API path request_params: Query parameters request_body: Request body Returns: API response as a dictionary Raises: APIClientError: For client errors (4xx) APIConnectionError: For connection issues APIResponseError: For response parsing errors UnexpectedError: For unexpected errors """ # Check if access token is available if not self.settings.supabase_access_token: raise APIClientError( "Supabase access token is not configured. Set SUPABASE_ACCESS_TOKEN environment variable to use Management API tools." ) # Log detailed request information logger.info(f"API Client: Executing {method} request to {path}") if request_params: logger.debug(f"Request params: {request_params}") if request_body: logger.debug(f"Request body: {request_body}") # Prepare request request = self.prepare_request(method, path, request_params, request_body) # Send request response = await self.send_request(request) # Parse response (for both success and error cases) parsed_body = self.parse_response(response) # Check if successful if not response.is_success: logger.warning(f"Request failed: {method} {path} - Status {response.status_code}") self.handle_error_response(response, parsed_body) # Log success and return logger.info(f"Request successful: {method} {path} - Status {response.status_code}") return parsed_body async def close(self) -> None: """Close the HTTP client and release resources.""" if self.client: await self.client.aclose() logger.info("HTTP API client closed") ``` -------------------------------------------------------------------------------- /supabase_mcp/services/api/spec_manager.py: -------------------------------------------------------------------------------- ```python import json from enum import Enum from pathlib import Path from typing import Any import httpx from supabase_mcp.logger import logger # Constants SPEC_URL = "https://api.supabase.com/api/v1-json" LOCAL_SPEC_PATH = Path(__file__).parent / "specs" / "api_spec.json" class ApiDomain(str, Enum): """Enum of all possible domains in the Supabase Management API.""" ANALYTICS = "Analytics" AUTH = "Auth" DATABASE = "Database" DOMAINS = "Domains" EDGE_FUNCTIONS = "Edge Functions" ENVIRONMENTS = "Environments" OAUTH = "OAuth" ORGANIZATIONS = "Organizations" PROJECTS = "Projects" REST = "Rest" SECRETS = "Secrets" STORAGE = "Storage" @classmethod def list(cls) -> list[str]: """Return a list of all domain values.""" return [domain.value for domain in cls] class ApiSpecManager: """ Manages the OpenAPI specification for the Supabase Management API. Handles spec loading, caching, and validation. """ def __init__(self) -> None: self.spec: dict[str, Any] | None = None self._paths_cache: dict[str, dict[str, str]] | None = None self._domains_cache: list[str] | None = None async def _fetch_remote_spec(self) -> dict[str, Any] | None: """ Fetch latest OpenAPI spec from Supabase API. Returns None if fetch fails. """ try: async with httpx.AsyncClient() as client: response = await client.get(SPEC_URL) if response.status_code == 200: return response.json() logger.warning(f"Failed to fetch API spec: {response.status_code}") return None except Exception as e: logger.warning(f"Error fetching API spec: {e}") return None def _load_local_spec(self) -> dict[str, Any]: """ Load OpenAPI spec from local file. This is our fallback spec shipped with the server. """ try: with open(LOCAL_SPEC_PATH) as f: return json.load(f) except FileNotFoundError: logger.error(f"Local spec not found at {LOCAL_SPEC_PATH}") raise except json.JSONDecodeError as e: logger.error(f"Invalid JSON in local spec: {e}") raise async def get_spec(self) -> dict[str, Any]: """Retrieve the enriched spec.""" if self.spec is None: raw_spec = await self._fetch_remote_spec() if not raw_spec: # If remote fetch fails, use our fallback spec logger.info("Using fallback API spec") raw_spec = self._load_local_spec() self.spec = raw_spec return self.spec def get_all_paths_and_methods(self) -> dict[str, dict[str, str]]: """ Returns a dictionary of all paths and their methods with operation IDs. Returns: Dict[str, Dict[str, str]]: {path: {method: operationId}} """ if self._paths_cache is None: self._build_caches() return self._paths_cache or {} def get_paths_and_methods_by_domain(self, domain: str) -> dict[str, dict[str, str]]: """ Returns paths and methods within a specific domain (tag). Args: domain (str): The domain name (e.g., "Auth", "Projects"). Returns: Dict[str, Dict[str, str]]: {path: {method: operationId}} """ if self._paths_cache is None: self._build_caches() # Validate domain using enum try: valid_domain = ApiDomain(domain).value except ValueError as e: raise ValueError(f"Invalid domain: {domain}") from e domain_paths: dict[str, dict[str, str]] = {} if self.spec: for path, methods in self.spec.get("paths", {}).items(): for method, details in methods.items(): if valid_domain in details.get("tags", []): if path not in domain_paths: domain_paths[path] = {} domain_paths[path][method] = details.get("operationId", "") return domain_paths def get_all_domains(self) -> list[str]: """ Returns a list of all available domains (tags). Returns: List[str]: List of domain names. """ if self._domains_cache is None: self._build_caches() return self._domains_cache or [] def get_spec_for_path_and_method(self, path: str, method: str) -> dict[str, Any] | None: """ Returns the full specification for a given path and HTTP method. Args: path (str): The API path (e.g., "/v1/projects"). method (str): The HTTP method (e.g., "get", "post"). Returns: Optional[Dict[str, Any]]: The full spec for the operation, or None if not found. """ if self.spec is None: return None path_spec = self.spec.get("paths", {}).get(path) if path_spec: return path_spec.get(method.lower()) # Ensure lowercase method return None def get_spec_part(self, part: str, *args: str | int) -> Any: """ Safely retrieves a nested part of the OpenAPI spec. Args: part: The top-level key (e.g., 'paths', 'components'). *args: Subsequent keys or indices to traverse the spec. Returns: The value at the specified location in the spec, or None if not found. """ if self.spec is None: return None current = self.spec.get(part) for key in args: if isinstance(current, dict) and key in current: current = current[key] elif isinstance(current, list) and isinstance(key, int) and 0 <= key < len(current): current = current[key] else: return None # Key not found or invalid index return current def _build_caches(self) -> None: """ Build internal caches for faster lookups. This populates _paths_cache and _domains_cache. """ if self.spec is None: logger.error("Cannot build caches: OpenAPI spec not loaded") return # Build paths cache paths_cache: dict[str, dict[str, str]] = {} domains_set = set() for path, methods in self.spec.get("paths", {}).items(): for method, details in methods.items(): # Add to paths cache if path not in paths_cache: paths_cache[path] = {} paths_cache[path][method] = details.get("operationId", "") # Collect domains (tags) for tag in details.get("tags", []): domains_set.add(tag) self._paths_cache = paths_cache self._domains_cache = sorted(list(domains_set)) # Example usage (assuming you have an instance of ApiSpecManager called 'spec_manager'): async def main() -> None: """Test function to demonstrate ApiSpecManager functionality.""" # Create a new instance of ApiSpecManager spec_manager = ApiSpecManager() # Load the spec await spec_manager.get_spec() # Print the path to help debug print(f"Looking for spec at: {LOCAL_SPEC_PATH}") # 1. Get all domains all_domains = spec_manager.get_all_domains() print("\nAll Domains:") print(all_domains) # 2. Get all paths and methods all_paths = spec_manager.get_all_paths_and_methods() print("\nAll Paths and Methods (sample):") # Just print a few to avoid overwhelming output for i, (path, methods) in enumerate(all_paths.items()): if i >= 5: # Limit to 5 paths break print(f" {path}:") for method, operation_id in methods.items(): print(f" {method}: {operation_id}") # 3. Get paths and methods for the "Edge Functions" domain edge_paths = spec_manager.get_paths_and_methods_by_domain("Edge Functions") print("\nEdge Functions Paths and Methods:") for path, methods in edge_paths.items(): print(f" {path}:") for method, operation_id in methods.items(): print(f" {method}: {operation_id}") # 4. Get the full spec for a specific path and method path = "/v1/projects/{ref}/functions" method = "GET" full_spec = spec_manager.get_spec_for_path_and_method(path, method) print(f"\nFull Spec for {method} {path}:") if full_spec: print(json.dumps(full_spec, indent=2)[:500] + "...") # Truncate for readability else: print("Spec not found for this path/method") if __name__ == "__main__": import asyncio asyncio.run(main()) ``` -------------------------------------------------------------------------------- /supabase_mcp/tools/registry.py: -------------------------------------------------------------------------------- ```python from typing import Any, Literal from mcp.server.fastmcp import FastMCP from supabase_mcp.core.container import ServicesContainer from supabase_mcp.services.database.postgres_client import QueryResult from supabase_mcp.tools.manager import ToolName class ToolRegistry: """Responsible for registering tools with the MCP server""" def __init__(self, mcp: FastMCP, services_container: ServicesContainer): self.mcp = mcp self.services_container = services_container def register_tools(self) -> FastMCP: """Register all tools with the MCP server""" mcp = self.mcp services_container = self.services_container tool_manager = services_container.tool_manager feature_manager = services_container.feature_manager @mcp.tool(description=tool_manager.get_description(ToolName.GET_SCHEMAS)) # type: ignore async def get_schemas() -> QueryResult: """List all database schemas with their sizes and table counts.""" return await feature_manager.execute_tool(ToolName.GET_SCHEMAS, services_container=services_container) @mcp.tool(description=tool_manager.get_description(ToolName.GET_TABLES)) # type: ignore async def get_tables(schema_name: str) -> QueryResult: """List all tables, foreign tables, and views in a schema with their sizes, row counts, and metadata.""" return await feature_manager.execute_tool( ToolName.GET_TABLES, services_container=services_container, schema_name=schema_name ) @mcp.tool(description=tool_manager.get_description(ToolName.GET_TABLE_SCHEMA)) # type: ignore async def get_table_schema(schema_name: str, table: str) -> QueryResult: """Get detailed table structure including columns, keys, and relationships.""" return await feature_manager.execute_tool( ToolName.GET_TABLE_SCHEMA, services_container=services_container, schema_name=schema_name, table=table, ) @mcp.tool(description=tool_manager.get_description(ToolName.EXECUTE_POSTGRESQL)) # type: ignore async def execute_postgresql(query: str, migration_name: str = "") -> QueryResult: """Execute PostgreSQL statements against your Supabase database.""" return await feature_manager.execute_tool( ToolName.EXECUTE_POSTGRESQL, services_container=services_container, query=query, migration_name=migration_name, ) @mcp.tool(description=tool_manager.get_description(ToolName.RETRIEVE_MIGRATIONS)) # type: ignore async def retrieve_migrations( limit: int = 50, offset: int = 0, name_pattern: str = "", include_full_queries: bool = False, ) -> QueryResult: """Retrieve a list of all migrations a user has from Supabase. SAFETY: This is a low-risk read operation that can be executed in SAFE mode. """ result = await feature_manager.execute_tool( ToolName.RETRIEVE_MIGRATIONS, services_container=services_container, limit=limit, offset=offset, name_pattern=name_pattern, include_full_queries=include_full_queries, ) return QueryResult.model_validate(result) @mcp.tool(description=tool_manager.get_description(ToolName.SEND_MANAGEMENT_API_REQUEST)) # type: ignore async def send_management_api_request( method: str, path: str, path_params: dict[str, str], request_params: dict[str, Any], request_body: dict[str, Any], ) -> dict[str, Any]: """Execute a Supabase Management API request.""" return await feature_manager.execute_tool( ToolName.SEND_MANAGEMENT_API_REQUEST, services_container=services_container, method=method, path=path, path_params=path_params, request_params=request_params, request_body=request_body, ) @mcp.tool(description=tool_manager.get_description(ToolName.GET_MANAGEMENT_API_SPEC)) # type: ignore async def get_management_api_spec(params: dict[str, Any] = {}) -> dict[str, Any]: """Get the Supabase Management API specification. This tool can be used in four different ways (and then some ;)): 1. Without parameters: Returns all domains (default) 2. With path and method: Returns the full specification for a specific API endpoint 3. With domain only: Returns all paths and methods within that domain 4. With all_paths=True: Returns all paths and methods Args: params: Dictionary containing optional parameters: - path: Optional API path (e.g., "/v1/projects/{ref}/functions") - method: Optional HTTP method (e.g., "GET", "POST") - domain: Optional domain/tag name (e.g., "Auth", "Storage") - all_paths: If True, returns all paths and methods Returns: API specification based on the provided parameters """ return await feature_manager.execute_tool( ToolName.GET_MANAGEMENT_API_SPEC, services_container=services_container, params=params ) @mcp.tool(description=tool_manager.get_description(ToolName.GET_AUTH_ADMIN_METHODS_SPEC)) # type: ignore async def get_auth_admin_methods_spec() -> dict[str, Any]: """Get Python SDK methods specification for Auth Admin.""" return await feature_manager.execute_tool( ToolName.GET_AUTH_ADMIN_METHODS_SPEC, services_container=services_container ) @mcp.tool(description=tool_manager.get_description(ToolName.CALL_AUTH_ADMIN_METHOD)) # type: ignore async def call_auth_admin_method(method: str, params: dict[str, Any]) -> dict[str, Any]: """Call an Auth Admin method from Supabase Python SDK.""" return await feature_manager.execute_tool( ToolName.CALL_AUTH_ADMIN_METHOD, services_container=services_container, method=method, params=params ) @mcp.tool(description=tool_manager.get_description(ToolName.LIVE_DANGEROUSLY)) # type: ignore async def live_dangerously( service: Literal["api", "database"], enable_unsafe_mode: bool = False ) -> dict[str, Any]: """ Toggle between safe and unsafe operation modes for API or Database services. This function controls the safety level for operations, allowing you to: - Enable write operations for the database (INSERT, UPDATE, DELETE, schema changes) - Enable state-changing operations for the Management API """ return await feature_manager.execute_tool( ToolName.LIVE_DANGEROUSLY, services_container=services_container, service=service, enable_unsafe_mode=enable_unsafe_mode, ) @mcp.tool(description=tool_manager.get_description(ToolName.CONFIRM_DESTRUCTIVE_OPERATION)) # type: ignore async def confirm_destructive_operation( operation_type: Literal["api", "database"], confirmation_id: str, user_confirmation: bool = False ) -> QueryResult | dict[str, Any]: """Execute a destructive operation after confirmation. Use this only after reviewing the risks with the user.""" return await feature_manager.execute_tool( ToolName.CONFIRM_DESTRUCTIVE_OPERATION, services_container=services_container, operation_type=operation_type, confirmation_id=confirmation_id, user_confirmation=user_confirmation, ) @mcp.tool(description=tool_manager.get_description(ToolName.RETRIEVE_LOGS)) # type: ignore async def retrieve_logs( collection: str, limit: int = 20, hours_ago: int = 1, filters: list[dict[str, Any]] = [], search: str = "", custom_query: str = "", ) -> dict[str, Any]: """Retrieve logs from your Supabase project's services for debugging and monitoring.""" return await feature_manager.execute_tool( ToolName.RETRIEVE_LOGS, services_container=services_container, collection=collection, limit=limit, hours_ago=hours_ago, filters=filters, search=search, custom_query=custom_query, ) return mcp ``` -------------------------------------------------------------------------------- /tests/services/database/test_query_manager.py: -------------------------------------------------------------------------------- ```python from unittest.mock import AsyncMock, MagicMock import pytest from supabase_mcp.exceptions import SafetyError from supabase_mcp.services.database.query_manager import QueryManager from supabase_mcp.services.database.sql.loader import SQLLoader from supabase_mcp.services.database.sql.validator import ( QueryValidationResults, SQLQueryCategory, SQLQueryCommand, SQLValidator, ValidatedStatement, ) from supabase_mcp.services.safety.models import ClientType, OperationRiskLevel @pytest.mark.asyncio(loop_scope="module") class TestQueryManager: """Tests for the Query Manager.""" @pytest.mark.unit async def test_query_execution(self, mock_query_manager: QueryManager): """Test query execution through the Query Manager.""" query_manager = mock_query_manager # Ensure validator and safety_manager are proper mocks query_manager.validator = MagicMock() query_manager.safety_manager = MagicMock() # Create a mock validation result for a SELECT query validated_statement = ValidatedStatement( category=SQLQueryCategory.DQL, command=SQLQueryCommand.SELECT, risk_level=OperationRiskLevel.LOW, query="SELECT * FROM users", needs_migration=False, object_type="TABLE", schema_name="public", ) validation_result = QueryValidationResults( statements=[validated_statement], highest_risk_level=OperationRiskLevel.LOW, has_transaction_control=False, original_query="SELECT * FROM users", ) # Make the validator return our mock validation result query_manager.validator.validate_query.return_value = validation_result # Make the db_client return a mock query result mock_query_result = MagicMock() query_manager.db_client.execute_query = AsyncMock(return_value=mock_query_result) # Execute a query query = "SELECT * FROM users" result = await query_manager.handle_query(query) # Verify the validator was called with the query query_manager.validator.validate_query.assert_called_once_with(query) # Verify the db_client was called with the validation result query_manager.db_client.execute_query.assert_called_once_with(validation_result, False) # Verify the result is what we expect assert result == mock_query_result @pytest.mark.asyncio @pytest.mark.unit async def test_safety_validation_blocks_dangerous_query(self, mock_query_manager: QueryManager): """Test that the safety validation blocks dangerous queries.""" # Create a query manager with the mock dependencies query_manager = mock_query_manager # Ensure validator and safety_manager are proper mocks query_manager.validator = MagicMock() query_manager.safety_manager = MagicMock() # Create a mock validation result for a DROP TABLE query validated_statement = ValidatedStatement( category=SQLQueryCategory.DDL, command=SQLQueryCommand.DROP, risk_level=OperationRiskLevel.EXTREME, query="DROP TABLE users", needs_migration=False, object_type="TABLE", schema_name="public", ) validation_result = QueryValidationResults( statements=[validated_statement], highest_risk_level=OperationRiskLevel.EXTREME, has_transaction_control=False, original_query="DROP TABLE users", ) # Make the validator return our mock validation result query_manager.validator.validate_query.return_value = validation_result # Make the safety manager raise a SafetyError error_message = "Operation not allowed in SAFE mode" query_manager.safety_manager.validate_operation.side_effect = SafetyError(error_message) # Execute a query - should raise a SafetyError query = "DROP TABLE users" with pytest.raises(SafetyError) as excinfo: await query_manager.handle_query(query) # Verify the error message assert error_message in str(excinfo.value) # Verify the validator was called with the query query_manager.validator.validate_query.assert_called_once_with(query) # Verify the safety manager was called with the validation result query_manager.safety_manager.validate_operation.assert_called_once_with( ClientType.DATABASE, validation_result, False ) # Verify the db_client was not called query_manager.db_client.execute_query.assert_not_called() @pytest.mark.unit async def test_get_migrations_query(self, query_manager_integration: QueryManager): """Test that get_migrations_query returns a valid query string.""" # Test with default parameters query = query_manager_integration.get_migrations_query() assert isinstance(query, str) assert "supabase_migrations.schema_migrations" in query assert "LIMIT 50" in query # Test with custom parameters custom_query = query_manager_integration.get_migrations_query( limit=10, offset=5, name_pattern="test", include_full_queries=True ) assert isinstance(custom_query, str) assert "supabase_migrations.schema_migrations" in custom_query assert "LIMIT 10" in custom_query assert "OFFSET 5" in custom_query assert "name ILIKE" in custom_query assert "statements" in custom_query # Should include statements column when include_full_queries=True @pytest.mark.unit async def test_init_migration_schema(self): """Test that init_migration_schema initializes the migration schema correctly.""" # Create minimal mocks postgres_client = MagicMock() postgres_client.execute_query = AsyncMock() safety_manager = MagicMock() # Create a real SQLLoader and SQLValidator sql_loader = SQLLoader() sql_validator = SQLValidator() # Create the QueryManager with minimal mocking query_manager = QueryManager( postgres_client=postgres_client, safety_manager=safety_manager, sql_validator=sql_validator, sql_loader=sql_loader, ) # Call the method await query_manager.init_migration_schema() # Verify that the SQL loader was used to get the init migrations query # and that the query was executed assert postgres_client.execute_query.called # Get the arguments that execute_query was called with call_args = postgres_client.execute_query.call_args assert call_args is not None # The first argument should be a QueryValidationResults object args, _ = call_args # Use _ to ignore unused kwargs assert len(args) > 0 validation_result = args[0] assert isinstance(validation_result, QueryValidationResults) # Check that the validation result contains the expected SQL init_query = sql_loader.get_init_migrations_query() assert any(stmt.query and stmt.query in init_query for stmt in validation_result.statements) @pytest.mark.unit async def test_handle_migration(self): """Test that handle_migration correctly handles migrations when needed.""" # Create minimal mocks postgres_client = MagicMock() postgres_client.execute_query = AsyncMock() safety_manager = MagicMock() # Create a real SQLLoader sql_loader = SQLLoader() # Create a mock MigrationManager migration_manager = MagicMock() migration_query = "INSERT INTO _migrations.migrations (name) VALUES ('test_migration')" migration_name = "test_migration" migration_manager.prepare_migration_query.return_value = (migration_query, migration_name) # Create a real SQLValidator sql_validator = SQLValidator() # Create the QueryManager with minimal mocking query_manager = QueryManager( postgres_client=postgres_client, safety_manager=safety_manager, sql_validator=sql_validator, sql_loader=sql_loader, migration_manager=migration_manager, ) # Create a validation result that needs migration validated_statement = ValidatedStatement( category=SQLQueryCategory.DDL, command=SQLQueryCommand.CREATE, risk_level=OperationRiskLevel.MEDIUM, query="CREATE TABLE test (id INT)", needs_migration=True, object_type="TABLE", schema_name="public", ) validation_result = QueryValidationResults( statements=[validated_statement], highest_risk_level=OperationRiskLevel.MEDIUM, has_transaction_control=False, original_query="CREATE TABLE test (id INT)", ) # Call the method await query_manager.handle_migration(validation_result, "CREATE TABLE test (id INT)", "test_migration") # Verify that the migration manager was called to prepare the migration query migration_manager.prepare_migration_query.assert_called_once_with( validation_result, "CREATE TABLE test (id INT)", "test_migration" ) # Verify that execute_query was called at least twice # Once for init_migration_schema and once for the migration query assert postgres_client.execute_query.call_count >= 2 ``` -------------------------------------------------------------------------------- /supabase_mcp/services/safety/safety_manager.py: -------------------------------------------------------------------------------- ```python import time import uuid from typing import Any, Optional from supabase_mcp.exceptions import ConfirmationRequiredError, OperationNotAllowedError from supabase_mcp.logger import logger from supabase_mcp.services.safety.models import ClientType, SafetyMode from supabase_mcp.services.safety.safety_configs import APISafetyConfig, SafetyConfigBase, SQLSafetyConfig class SafetyManager: """A singleton service that maintains current safety state. Provides methods to: - Get/set safety modes for different clients - Register safety configurations - Check if operations are allowed Serves as the central point for safety decisions""" _instance: Optional["SafetyManager"] = None def __init__(self) -> None: """Initialize the safety manager with default safety modes.""" self._safety_modes: dict[ClientType, SafetyMode] = { ClientType.DATABASE: SafetyMode.SAFE, ClientType.API: SafetyMode.SAFE, } self._safety_configs: dict[ClientType, SafetyConfigBase[Any]] = {} self._pending_confirmations: dict[str, dict[str, Any]] = {} self._confirmation_expiry = 300 # 5 minutes in seconds @classmethod def get_instance(cls) -> "SafetyManager": """Get the singleton instance of the safety manager.""" if cls._instance is None: cls._instance = SafetyManager() return cls._instance def register_safety_configs(self) -> bool: """Register all safety configurations with the SafetyManager. Returns: bool: True if all configurations were registered successfully """ # Register SQL safety config sql_config = SQLSafetyConfig() self.register_config(ClientType.DATABASE, sql_config) # Register API safety config api_config = APISafetyConfig() self.register_config(ClientType.API, api_config) logger.info("✓ Safety configurations registered successfully") return True def register_config(self, client_type: ClientType, config: SafetyConfigBase[Any]) -> None: """Register a safety configuration for a client type. Args: client_type: The client type to register the configuration for config: The safety configuration for the client """ self._safety_configs[client_type] = config def get_safety_mode(self, client_type: ClientType) -> SafetyMode: """Get the current safety mode for a client type. Args: client_type: The client type to get the safety mode for Returns: The current safety mode for the client type """ if client_type not in self._safety_modes: logger.warning(f"No safety mode registered for {client_type}, defaulting to SAFE") return SafetyMode.SAFE return self._safety_modes[client_type] def set_safety_mode(self, client_type: ClientType, mode: SafetyMode) -> None: """Set the safety mode for a client type. Args: client_type: The client type to set the safety mode for mode: The safety mode to set """ self._safety_modes[client_type] = mode logger.debug(f"Set safety mode for {client_type} to {mode}") def validate_operation( self, client_type: ClientType, operation: Any, has_confirmation: bool = False, ) -> None: """Validate if an operation is allowed for a client type. This method will raise appropriate exceptions if the operation is not allowed or requires confirmation. Args: client_type: The client type to check the operation for operation: The operation to check has_confirmation: Whether the operation has been confirmed by the user Raises: OperationNotAllowedError: If the operation is not allowed in the current safety mode ConfirmationRequiredError: If the operation requires confirmation and has_confirmation is False """ # Get the current safety mode and config mode = self.get_safety_mode(client_type) config = self._safety_configs.get(client_type) if not config: message = f"No safety configuration registered for {client_type}" logger.warning(message) raise OperationNotAllowedError(message) # Get the risk level for the operation risk_level = config.get_risk_level(operation) logger.debug(f"Operation risk level: {risk_level}") # Check if the operation is allowed in the current mode is_allowed = config.is_operation_allowed(risk_level, mode) if not is_allowed: message = f"Operation with risk level {risk_level} is not allowed in {mode} mode" logger.debug(f"Operation with risk level {risk_level} not allowed in {mode} mode") raise OperationNotAllowedError(message) # Check if the operation needs confirmation needs_confirmation = config.needs_confirmation(risk_level) if needs_confirmation and not has_confirmation: # Store the operation for later confirmation confirmation_id = self._store_confirmation(client_type, operation, risk_level) message = ( f"Operation with risk level {risk_level} requires explicit user confirmation.\n\n" f"WHAT HAPPENED: This high-risk operation was rejected for safety reasons.\n" f"WHAT TO DO: 1. Review the operation with the user and explain the risks\n" f" 2. If the user approves, use the confirmation tool with this ID: {confirmation_id}\n\n" f'CONFIRMATION COMMAND: confirm_destructive_postgresql(confirmation_id="{confirmation_id}", user_confirmation=True)' ) logger.debug( f"Operation with risk level {risk_level} requires confirmation, stored with ID {confirmation_id}" ) raise ConfirmationRequiredError(message) logger.debug(f"Operation with risk level {risk_level} allowed in {mode} mode") def _store_confirmation(self, client_type: ClientType, operation: Any, risk_level: int) -> str: """Store an operation that needs confirmation. Args: client_type: The client type the operation is for operation: The operation to store risk_level: The risk level of the operation Returns: A unique confirmation ID """ # Generate a unique ID confirmation_id = f"conf_{uuid.uuid4().hex[:8]}" # Store the operation with metadata self._pending_confirmations[confirmation_id] = { "operation": operation, "client_type": client_type, "risk_level": risk_level, "timestamp": time.time(), } # Clean up expired confirmations self._cleanup_expired_confirmations() return confirmation_id def _get_confirmation(self, confirmation_id: str) -> dict[str, Any] | None: """Retrieve a stored confirmation by ID. Args: confirmation_id: The ID of the confirmation to retrieve Returns: The stored confirmation data or None if not found or expired """ # Clean up expired confirmations first self._cleanup_expired_confirmations() # Return the stored confirmation if it exists return self._pending_confirmations.get(confirmation_id) def _cleanup_expired_confirmations(self) -> None: """Remove expired confirmations from storage.""" current_time = time.time() expired_ids = [ conf_id for conf_id, data in self._pending_confirmations.items() if current_time - data["timestamp"] > self._confirmation_expiry ] for conf_id in expired_ids: logger.debug(f"Removing expired confirmation with ID {conf_id}") del self._pending_confirmations[conf_id] def get_stored_operation(self, confirmation_id: str) -> Any | None: """Get a stored operation by its confirmation ID. Args: confirmation_id: The confirmation ID to get the operation for Returns: The stored operation, or None if not found """ confirmation = self._get_confirmation(confirmation_id) if confirmation is None: return None return confirmation.get("operation") def get_operations_by_risk_level( self, risk_level: str, client_type: ClientType = ClientType.DATABASE ) -> dict[str, list[str]]: """Get operations for a specific risk level. Args: risk_level: The risk level to get operations for client_type: The client type to get operations for Returns: A dictionary mapping HTTP methods to lists of paths """ # Get the config for the specified client type config = self._safety_configs.get(client_type) if not config or not hasattr(config, "PATH_SAFETY_CONFIG"): return {} # Get the operations for this risk level risk_config = getattr(config, "PATH_SAFETY_CONFIG", {}) if risk_level in risk_config: return risk_config[risk_level] def get_current_mode(self, client_type: ClientType) -> str: """Get the current safety mode as a string. Args: client_type: The client type to get the mode for Returns: The current safety mode as a string """ mode = self.get_safety_mode(client_type) return str(mode) @classmethod def reset(cls) -> None: """Reset the singleton instance cleanly. This closes any open connections and resets the singleton instance. """ if cls._instance is not None: cls._instance = None logger.info("SafetyManager instance reset complete") ``` -------------------------------------------------------------------------------- /supabase_mcp/clients/sdk_client.py: -------------------------------------------------------------------------------- ```python from __future__ import annotations from typing import Any, TypeVar from pydantic import BaseModel, ValidationError from supabase import AsyncClient, create_async_client from supabase.lib.client_options import AsyncClientOptions from supabase_mcp.exceptions import PythonSDKError from supabase_mcp.logger import logger from supabase_mcp.services.sdk.auth_admin_models import ( PARAM_MODELS, CreateUserParams, DeleteFactorParams, DeleteUserParams, GenerateLinkParams, GetUserByIdParams, InviteUserByEmailParams, ListUsersParams, UpdateUserByIdParams, ) from supabase_mcp.services.sdk.auth_admin_sdk_spec import get_auth_admin_methods_spec from supabase_mcp.settings import Settings T = TypeVar("T", bound=BaseModel) class IncorrectSDKParamsError(PythonSDKError): """Error raised when the parameters passed to the SDK are incorrect.""" pass class SupabaseSDKClient: """Supabase Python SDK client, which exposes functionality related to Auth admin of the Python SDK.""" _instance: SupabaseSDKClient | None = None def __init__( self, settings: Settings | None = None, project_ref: str | None = None, service_role_key: str | None = None, ): self.client: AsyncClient | None = None self.settings = settings self.project_ref = settings.supabase_project_ref if settings else project_ref self.service_role_key = settings.supabase_service_role_key if settings else service_role_key self.supabase_url = self.get_supabase_url() logger.info(f"✔️ Supabase SDK client initialized successfully for project {self.project_ref}") def get_supabase_url(self) -> str: """Returns the Supabase URL based on the project reference""" if not self.project_ref: raise PythonSDKError("Project reference is not set") if self.project_ref.startswith("127.0.0.1"): # Return the default Supabase API URL return "http://127.0.0.1:54321" return f"https://{self.project_ref}.supabase.co" @classmethod def create( cls, settings: Settings | None = None, project_ref: str | None = None, service_role_key: str | None = None, ) -> SupabaseSDKClient: if cls._instance is None: cls._instance = cls(settings, project_ref, service_role_key) return cls._instance @classmethod def get_instance( cls, settings: Settings | None = None, project_ref: str | None = None, service_role_key: str | None = None, ) -> SupabaseSDKClient: """Returns the singleton instance""" if cls._instance is None: cls.create(settings, project_ref, service_role_key) return cls._instance async def create_client(self) -> AsyncClient: """Creates a new Supabase client""" try: client = await create_async_client( self.supabase_url, self.service_role_key, options=AsyncClientOptions( auto_refresh_token=False, persist_session=False, ), ) return client except Exception as e: logger.error(f"Error creating Supabase client: {e}") raise PythonSDKError(f"Error creating Supabase client: {e}") from e async def get_client(self) -> AsyncClient: """Returns the Supabase client""" if not self.client: self.client = await self.create_client() logger.info(f"Created Supabase SDK client for project {self.project_ref}") return self.client async def close(self) -> None: """Reset the client reference to allow garbage collection.""" self.client = None logger.info("Supabase SDK client reference cleared") def return_python_sdk_spec(self) -> dict: """Returns the Python SDK spec""" return get_auth_admin_methods_spec() def _validate_params(self, method: str, params: dict, param_model_cls: type[T]) -> T: """Validate parameters using the appropriate Pydantic model""" try: return param_model_cls.model_validate(params) except ValidationError as e: raise PythonSDKError(f"Invalid parameters for method {method}: {str(e)}") from e async def _get_user_by_id(self, params: GetUserByIdParams) -> dict: """Get user by ID implementation""" self.client = await self.get_client() admin_auth_client = self.client.auth.admin result = await admin_auth_client.get_user_by_id(params.uid) return result async def _list_users(self, params: ListUsersParams) -> dict: """List users implementation""" self.client = await self.get_client() admin_auth_client = self.client.auth.admin result = await admin_auth_client.list_users(page=params.page, per_page=params.per_page) return result async def _create_user(self, params: CreateUserParams) -> dict: """Create user implementation""" self.client = await self.get_client() admin_auth_client = self.client.auth.admin user_data = params.model_dump(exclude_none=True) result = await admin_auth_client.create_user(user_data) return result async def _delete_user(self, params: DeleteUserParams) -> dict: """Delete user implementation""" self.client = await self.get_client() admin_auth_client = self.client.auth.admin result = await admin_auth_client.delete_user(params.id, should_soft_delete=params.should_soft_delete) return result async def _invite_user_by_email(self, params: InviteUserByEmailParams) -> dict: """Invite user by email implementation""" self.client = await self.get_client() admin_auth_client = self.client.auth.admin options = params.options if params.options else {} result = await admin_auth_client.invite_user_by_email(params.email, options) return result async def _generate_link(self, params: GenerateLinkParams) -> dict: """Generate link implementation""" self.client = await self.get_client() admin_auth_client = self.client.auth.admin # Create a params dictionary as expected by the SDK params_dict = params.model_dump(exclude_none=True) try: # The SDK expects a single 'params' parameter containing all the fields result = await admin_auth_client.generate_link(params=params_dict) return result except TypeError as e: # Catch parameter errors and provide a more helpful message error_msg = str(e) if "unexpected keyword argument" in error_msg: raise IncorrectSDKParamsError( f"Incorrect parameters for generate_link: {error_msg}. " f"Please check the SDK specification for the correct parameter structure." ) from e raise async def _update_user_by_id(self, params: UpdateUserByIdParams) -> dict: """Update user by ID implementation""" self.client = await self.get_client() admin_auth_client = self.client.auth.admin uid = params.uid attributes = params.attributes.model_dump(exclude={"uid"}, exclude_none=True) result = await admin_auth_client.update_user_by_id(uid, attributes) return result async def _delete_factor(self, params: DeleteFactorParams) -> dict: """Delete factor implementation""" # This method is not implemented in the Supabase SDK yet raise NotImplementedError("The delete_factor method is not implemented in the Supabase SDK yet") async def call_auth_admin_method(self, method: str, params: dict[str, Any]) -> Any: """Calls a method of the Python SDK client""" # Check if service role key is available if not self.service_role_key: raise PythonSDKError( "Supabase service role key is not configured. Set SUPABASE_SERVICE_ROLE_KEY environment variable to use Auth Admin tools." ) if not self.client: self.client = await self.get_client() if not self.client: raise PythonSDKError("Python SDK client not initialized") # Validate method exists if method not in PARAM_MODELS: available_methods = ", ".join(PARAM_MODELS.keys()) raise PythonSDKError(f"Unknown method: {method}. Available methods: {available_methods}") # Get the appropriate model class and validate parameters param_model_cls = PARAM_MODELS[method] validated_params = self._validate_params(method, params, param_model_cls) # Method dispatch using a dictionary of method implementations method_handlers = { "get_user_by_id": self._get_user_by_id, "list_users": self._list_users, "create_user": self._create_user, "delete_user": self._delete_user, "invite_user_by_email": self._invite_user_by_email, "generate_link": self._generate_link, "update_user_by_id": self._update_user_by_id, "delete_factor": self._delete_factor, } # Call the appropriate method handler try: handler = method_handlers.get(method) if not handler: raise PythonSDKError(f"Method {method} is not implemented") logger.debug(f"Python SDK request params: {validated_params}") return await handler(validated_params) except Exception as e: if isinstance(e, IncorrectSDKParamsError): # Re-raise our custom error without wrapping it raise e logger.error(f"Error calling {method}: {e}") raise PythonSDKError(f"Error calling {method}: {str(e)}") from e @classmethod def reset(cls) -> None: """Reset the singleton instance cleanly. This closes any open connections and resets the singleton instance. """ if cls._instance is not None: cls._instance = None logger.info("SupabaseSDKClient instance reset complete") ``` -------------------------------------------------------------------------------- /tests/services/sdk/test_auth_admin_models.py: -------------------------------------------------------------------------------- ```python import pytest from pydantic import ValidationError from supabase_mcp.services.sdk.auth_admin_models import ( PARAM_MODELS, AdminUserAttributes, CreateUserParams, DeleteFactorParams, DeleteUserParams, GenerateLinkParams, GetUserByIdParams, InviteUserByEmailParams, ListUsersParams, UpdateUserByIdParams, UserMetadata, ) class TestModelConversion: """Test conversion from JSON data to models and validation""" def test_get_user_by_id_conversion(self): """Test conversion of get_user_by_id JSON data""" # Valid payload valid_payload = {"uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"} params = GetUserByIdParams.model_validate(valid_payload) assert params.uid == valid_payload["uid"] # Invalid payload (missing required uid) invalid_payload = {} with pytest.raises(ValidationError) as excinfo: GetUserByIdParams.model_validate(invalid_payload) assert "uid" in str(excinfo.value) def test_list_users_conversion(self): """Test conversion of list_users JSON data""" # Valid payload with custom values valid_payload = {"page": 2, "per_page": 20} params = ListUsersParams.model_validate(valid_payload) assert params.page == valid_payload["page"] assert params.per_page == valid_payload["per_page"] # Valid payload with defaults empty_payload = {} params = ListUsersParams.model_validate(empty_payload) assert params.page == 1 assert params.per_page == 50 # Invalid payload (non-integer values) invalid_payload = {"page": "not-a-number", "per_page": "also-not-a-number"} with pytest.raises(ValidationError) as excinfo: ListUsersParams.model_validate(invalid_payload) assert "page" in str(excinfo.value) def test_create_user_conversion(self): """Test conversion of create_user JSON data""" # Valid payload with email valid_payload = { "email": "[email protected]", "password": "secure-password", "email_confirm": True, "user_metadata": UserMetadata(email="[email protected]"), } params = CreateUserParams.model_validate(valid_payload) assert params.email == valid_payload["email"] assert params.password == valid_payload["password"] assert params.email_confirm is True assert params.user_metadata == valid_payload["user_metadata"] # Valid payload with phone valid_phone_payload = { "phone": "+1234567890", "password": "secure-password", "phone_confirm": True, } params = CreateUserParams.model_validate(valid_phone_payload) assert params.phone == valid_phone_payload["phone"] assert params.password == valid_phone_payload["password"] assert params.phone_confirm is True # Invalid payload (missing both email and phone) invalid_payload = {"password": "secure-password"} with pytest.raises(ValidationError) as excinfo: CreateUserParams.model_validate(invalid_payload) assert "Either email or phone must be provided" in str(excinfo.value) def test_delete_user_conversion(self): """Test conversion of delete_user JSON data""" # Valid payload with custom values valid_payload = {"id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "should_soft_delete": True} params = DeleteUserParams.model_validate(valid_payload) assert params.id == valid_payload["id"] assert params.should_soft_delete is True # Valid payload with defaults valid_payload = {"id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"} params = DeleteUserParams.model_validate(valid_payload) assert params.id == valid_payload["id"] assert params.should_soft_delete is False # Invalid payload (missing id) invalid_payload = {"should_soft_delete": True} with pytest.raises(ValidationError) as excinfo: DeleteUserParams.model_validate(invalid_payload) assert "id" in str(excinfo.value) def test_invite_user_by_email_conversion(self): """Test conversion of invite_user_by_email JSON data""" # Valid payload with options valid_payload = { "email": "[email protected]", "options": {"data": {"name": "Invited User"}, "redirect_to": "https://example.com/welcome"}, } params = InviteUserByEmailParams.model_validate(valid_payload) assert params.email == valid_payload["email"] assert params.options == valid_payload["options"] # Valid payload without options valid_payload = {"email": "[email protected]"} params = InviteUserByEmailParams.model_validate(valid_payload) assert params.email == valid_payload["email"] assert params.options is None # Invalid payload (missing email) invalid_payload = {"options": {"data": {"name": "Invited User"}}} with pytest.raises(ValidationError) as excinfo: InviteUserByEmailParams.model_validate(invalid_payload) assert "email" in str(excinfo.value) def test_generate_link_conversion(self): """Test conversion of generate_link JSON data""" # Valid signup link payload valid_signup_payload = { "type": "signup", "email": "[email protected]", "password": "secure-password", "options": {"data": {"name": "New User"}, "redirect_to": "https://example.com/welcome"}, } params = GenerateLinkParams.model_validate(valid_signup_payload) assert params.type == valid_signup_payload["type"] assert params.email == valid_signup_payload["email"] assert params.password == valid_signup_payload["password"] assert params.options == valid_signup_payload["options"] # Valid email_change link payload valid_email_change_payload = { "type": "email_change_current", "email": "[email protected]", "new_email": "[email protected]", } params = GenerateLinkParams.model_validate(valid_email_change_payload) assert params.type == valid_email_change_payload["type"] assert params.email == valid_email_change_payload["email"] assert params.new_email == valid_email_change_payload["new_email"] # Invalid payload (missing password for signup) invalid_signup_payload = { "type": "signup", "email": "[email protected]", } with pytest.raises(ValidationError) as excinfo: GenerateLinkParams.model_validate(invalid_signup_payload) assert "Password is required for signup links" in str(excinfo.value) # Invalid payload (missing new_email for email_change) invalid_email_change_payload = { "type": "email_change_current", "email": "[email protected]", } with pytest.raises(ValidationError) as excinfo: GenerateLinkParams.model_validate(invalid_email_change_payload) assert "new_email is required for email change links" in str(excinfo.value) # Invalid payload (invalid type) invalid_type_payload = { "type": "invalid-type", "email": "[email protected]", } with pytest.raises(ValidationError) as excinfo: GenerateLinkParams.model_validate(invalid_type_payload) assert "type" in str(excinfo.value) def test_update_user_by_id_conversion(self): """Test conversion of update_user_by_id JSON data""" # Valid payload valid_payload = { "uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "attributes": AdminUserAttributes(email="[email protected]", email_verified=True), } params = UpdateUserByIdParams.model_validate(valid_payload) assert params.uid == valid_payload["uid"] assert params.attributes == valid_payload["attributes"] # Invalid payload (incorrect metadata and missing uids) invalid_payload = { "email": "[email protected]", "user_metadata": {"name": "Updated User"}, } with pytest.raises(ValidationError) as excinfo: UpdateUserByIdParams.model_validate(invalid_payload) assert "uid" in str(excinfo.value) def test_delete_factor_conversion(self): """Test conversion of delete_factor JSON data""" # Valid payload valid_payload = { "user_id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "id": "totp-factor-id-123", } params = DeleteFactorParams.model_validate(valid_payload) assert params.user_id == valid_payload["user_id"] assert params.id == valid_payload["id"] # Invalid payload (missing user_id) invalid_payload = { "id": "totp-factor-id-123", } with pytest.raises(ValidationError) as excinfo: DeleteFactorParams.model_validate(invalid_payload) assert "user_id" in str(excinfo.value) # Invalid payload (missing id) invalid_payload = { "user_id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", } with pytest.raises(ValidationError) as excinfo: DeleteFactorParams.model_validate(invalid_payload) assert "id" in str(excinfo.value) def test_param_models_mapping(self): """Test PARAM_MODELS mapping functionality""" # Test that all methods have the correct corresponding model method_model_pairs = [ ("get_user_by_id", GetUserByIdParams), ("list_users", ListUsersParams), ("create_user", CreateUserParams), ("delete_user", DeleteUserParams), ("invite_user_by_email", InviteUserByEmailParams), ("generate_link", GenerateLinkParams), ("update_user_by_id", UpdateUserByIdParams), ("delete_factor", DeleteFactorParams), ] for method, expected_model in method_model_pairs: assert method in PARAM_MODELS assert PARAM_MODELS[method] == expected_model # Test actual validation of data through PARAM_MODELS mapping method = "create_user" model_class = PARAM_MODELS[method] valid_payload = {"email": "[email protected]", "password": "secure-password"} params = model_class.model_validate(valid_payload) assert params.email == valid_payload["email"] assert params.password == valid_payload["password"] ``` -------------------------------------------------------------------------------- /tests/test_tools.py: -------------------------------------------------------------------------------- ```python """Unit tests for tools - no external dependencies.""" import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest from mcp.server.fastmcp import FastMCP from supabase_mcp.core.container import ServicesContainer from supabase_mcp.exceptions import ConfirmationRequiredError, OperationNotAllowedError, PythonSDKError from supabase_mcp.services.database.postgres_client import QueryResult, StatementResult from supabase_mcp.services.safety.models import ClientType, OperationRiskLevel, SafetyMode @pytest.mark.asyncio class TestDatabaseToolsUnit: """Unit tests for database tools.""" @pytest.fixture def mock_container(self): """Create a mock container with all necessary services.""" container = MagicMock(spec=ServicesContainer) # Mock query manager container.query_manager = MagicMock() container.query_manager.handle_query = AsyncMock() container.query_manager.get_schemas_query = MagicMock(return_value="SELECT * FROM schemas") container.query_manager.get_tables_query = MagicMock(return_value="SELECT * FROM tables") container.query_manager.get_table_schema_query = MagicMock(return_value="SELECT * FROM columns") # Mock safety manager container.safety_manager = MagicMock() container.safety_manager.check_permission = MagicMock() container.safety_manager.is_unsafe_mode = MagicMock(return_value=False) return container async def test_get_schemas_returns_query_result(self, mock_container): """Test that get_schemas returns proper QueryResult.""" # Setup mock response mock_result = QueryResult(results=[ StatementResult(rows=[ {"schema_name": "public", "total_size": "100MB", "table_count": 10}, {"schema_name": "auth", "total_size": "50MB", "table_count": 5} ]) ]) mock_container.query_manager.handle_query.return_value = mock_result # Execute query = mock_container.query_manager.get_schemas_query() result = await mock_container.query_manager.handle_query(query) # Verify assert isinstance(result, QueryResult) assert len(result.results[0].rows) == 2 assert result.results[0].rows[0]["schema_name"] == "public" async def test_get_tables_with_schema_filter(self, mock_container): """Test that get_tables properly filters by schema.""" # Setup mock_result = QueryResult(results=[ StatementResult(rows=[ {"table_name": "users", "table_type": "BASE TABLE", "row_count": 100, "size_bytes": 1024} ]) ]) mock_container.query_manager.handle_query.return_value = mock_result # Execute query = mock_container.query_manager.get_tables_query("public") result = await mock_container.query_manager.handle_query(query) # Verify mock_container.query_manager.get_tables_query.assert_called_with("public") assert result.results[0].rows[0]["table_name"] == "users" async def test_unsafe_query_blocked_in_safe_mode(self, mock_container): """Test that unsafe queries are blocked in safe mode.""" # Setup mock_container.safety_manager.is_unsafe_mode.return_value = False mock_container.safety_manager.check_permission.side_effect = OperationNotAllowedError( "DROP operations are not allowed in safe mode" ) # Execute & Verify with pytest.raises(OperationNotAllowedError): mock_container.safety_manager.check_permission( ClientType.DATABASE, OperationRiskLevel.HIGH ) @pytest.mark.asyncio class TestAPIToolsUnit: """Unit tests for API tools.""" @pytest.fixture def mock_container(self): """Create a mock container with API services.""" container = MagicMock(spec=ServicesContainer) # Mock API manager container.api_manager = MagicMock() container.api_manager.send_request = AsyncMock() container.api_manager.spec_manager = MagicMock() container.api_manager.spec_manager.get_full_spec = MagicMock(return_value={"paths": {}}) # Mock safety manager container.safety_manager = MagicMock() container.safety_manager.check_permission = MagicMock() container.safety_manager.is_unsafe_mode = MagicMock(return_value=False) return container async def test_api_request_success(self, mock_container): """Test successful API request.""" # Setup mock_response = {"id": "123", "name": "Test Project"} mock_container.api_manager.send_request.return_value = mock_response # Execute result = await mock_container.api_manager.send_request( "GET", "/v1/projects", {} ) # Verify assert result["id"] == "123" assert result["name"] == "Test Project" async def test_api_spec_retrieval(self, mock_container): """Test API spec retrieval.""" # Setup expected_spec = { "paths": { "/v1/projects": { "get": {"summary": "List projects"} } } } mock_container.api_manager.spec_manager.get_full_spec.return_value = expected_spec # Execute spec = mock_container.api_manager.spec_manager.get_full_spec() # Verify assert "paths" in spec assert "/v1/projects" in spec["paths"] async def test_medium_risk_api_blocked_in_safe_mode(self, mock_container): """Test that medium risk API operations are blocked in safe mode.""" # Setup mock_container.safety_manager.check_permission.side_effect = ConfirmationRequiredError( "This operation requires confirmation", {"method": "POST", "path": "/v1/projects"} ) # Execute & Verify with pytest.raises(ConfirmationRequiredError) as exc_info: mock_container.safety_manager.check_permission( ClientType.API, OperationRiskLevel.MEDIUM ) assert "requires confirmation" in str(exc_info.value) @pytest.mark.asyncio class TestAuthToolsUnit: """Unit tests for auth tools.""" @pytest.fixture def mock_container(self): """Create a mock container with SDK client.""" container = MagicMock(spec=ServicesContainer) # Mock SDK client container.sdk_client = MagicMock() container.sdk_client.call_auth_admin_method = AsyncMock() container.sdk_client.return_python_sdk_spec = MagicMock(return_value={ "methods": ["list_users", "create_user", "delete_user"] }) return container async def test_list_users_success(self, mock_container): """Test listing users successfully.""" # Setup mock_users = [ {"id": "user1", "email": "[email protected]"}, {"id": "user2", "email": "[email protected]"} ] mock_container.sdk_client.call_auth_admin_method.return_value = mock_users # Execute result = await mock_container.sdk_client.call_auth_admin_method( "list_users", {"page": 1, "per_page": 10} ) # Verify assert len(result) == 2 assert result[0]["email"] == "[email protected]" async def test_invalid_method_raises_error(self, mock_container): """Test that invalid method names raise errors.""" # Setup mock_container.sdk_client.call_auth_admin_method.side_effect = PythonSDKError( "Unknown method: invalid_method" ) # Execute & Verify with pytest.raises(PythonSDKError) as exc_info: await mock_container.sdk_client.call_auth_admin_method( "invalid_method", {} ) assert "Unknown method" in str(exc_info.value) async def test_create_user_validation(self, mock_container): """Test user creation with validation.""" # Setup new_user = {"id": str(uuid.uuid4()), "email": "[email protected]"} mock_container.sdk_client.call_auth_admin_method.return_value = {"user": new_user} # Execute result = await mock_container.sdk_client.call_auth_admin_method( "create_user", {"email": "[email protected]", "password": "TestPass123!"} ) # Verify assert result["user"]["email"] == "[email protected]" mock_container.sdk_client.call_auth_admin_method.assert_called_once() @pytest.mark.asyncio class TestSafetyToolsUnit: """Unit tests for safety tools - these already work well.""" @pytest.fixture def mock_container(self): """Create a mock container with safety manager.""" container = MagicMock(spec=ServicesContainer) # Mock safety manager with proper methods container.safety_manager = MagicMock() container.safety_manager.set_unsafe_mode = MagicMock() container.safety_manager.get_mode = MagicMock(return_value=SafetyMode.SAFE) container.safety_manager.confirm_operation = MagicMock() container.safety_manager.is_unsafe_mode = MagicMock(return_value=False) return container async def test_live_dangerously_enables_unsafe_mode(self, mock_container): """Test that live_dangerously enables unsafe mode.""" # Execute mock_container.safety_manager.set_unsafe_mode(ClientType.DATABASE, True) # Verify mock_container.safety_manager.set_unsafe_mode.assert_called_with(ClientType.DATABASE, True) async def test_confirm_operation_stores_confirmation(self, mock_container): """Test that confirm operation stores the confirmation.""" # Setup confirmation_id = str(uuid.uuid4()) # Execute mock_container.safety_manager.confirm_operation(confirmation_id) # Verify mock_container.safety_manager.confirm_operation.assert_called_with(confirmation_id) async def test_safety_mode_switching(self, mock_container): """Test switching between safe and unsafe modes.""" # Test enabling unsafe mode mock_container.safety_manager.set_unsafe_mode(ClientType.API, True) mock_container.safety_manager.set_unsafe_mode.assert_called_with(ClientType.API, True) # Test disabling unsafe mode mock_container.safety_manager.set_unsafe_mode(ClientType.API, False) mock_container.safety_manager.set_unsafe_mode.assert_called_with(ClientType.API, False) ``` -------------------------------------------------------------------------------- /supabase_mcp/core/feature_manager.py: -------------------------------------------------------------------------------- ```python from typing import TYPE_CHECKING, Any, Literal from supabase_mcp.clients.api_client import ApiClient from supabase_mcp.exceptions import APIError, ConfirmationRequiredError, FeatureAccessError, FeatureTemporaryError from supabase_mcp.logger import logger from supabase_mcp.services.database.postgres_client import QueryResult from supabase_mcp.services.safety.models import ClientType, SafetyMode from supabase_mcp.tools.manager import ToolName if TYPE_CHECKING: from supabase_mcp.core.container import ServicesContainer class FeatureManager: """Service for managing features, access to them and their configuration.""" def __init__(self, api_client: ApiClient): """Initialize the feature service. Args: api_client: Client for communicating with the API """ self.api_client = api_client async def check_feature_access(self, feature_name: str) -> None: """Check if the user has access to a feature. Args: feature_name: Name of the feature to check Raises: FeatureAccessError: If the user doesn't have access to the feature """ try: # Use the API client to check feature access response = await self.api_client.check_feature_access(feature_name) # If access is not granted, raise an exception if not response.access_granted: logger.info(f"Feature access denied: {feature_name}") raise FeatureAccessError(feature_name) logger.debug(f"Feature access granted: {feature_name}") except APIError as e: logger.error(f"API error checking feature access: {feature_name} - {e}") raise FeatureTemporaryError(feature_name, e.status_code, e.response_body) from e except Exception as e: if not isinstance(e, FeatureAccessError): logger.error(f"Unexpected error checking feature access: {feature_name} - {e}") raise FeatureTemporaryError(feature_name) from e raise async def execute_tool(self, tool_name: ToolName, services_container: "ServicesContainer", **kwargs: Any) -> Any: """Execute a tool with feature access check. Args: tool_name: Name of the tool to execute services_container: Container with all services **kwargs: Arguments to pass to the tool Returns: Result of the tool execution """ # Check feature access await self.check_feature_access(tool_name.value) # Execute the appropriate tool based on name if tool_name == ToolName.GET_SCHEMAS: return await self.get_schemas(services_container) elif tool_name == ToolName.GET_TABLES: return await self.get_tables(services_container, **kwargs) elif tool_name == ToolName.GET_TABLE_SCHEMA: return await self.get_table_schema(services_container, **kwargs) elif tool_name == ToolName.EXECUTE_POSTGRESQL: return await self.execute_postgresql(services_container, **kwargs) elif tool_name == ToolName.RETRIEVE_MIGRATIONS: return await self.retrieve_migrations(services_container, **kwargs) elif tool_name == ToolName.SEND_MANAGEMENT_API_REQUEST: return await self.send_management_api_request(services_container, **kwargs) elif tool_name == ToolName.GET_MANAGEMENT_API_SPEC: return await self.get_management_api_spec(services_container, **kwargs) elif tool_name == ToolName.GET_AUTH_ADMIN_METHODS_SPEC: return await self.get_auth_admin_methods_spec(services_container) elif tool_name == ToolName.CALL_AUTH_ADMIN_METHOD: return await self.call_auth_admin_method(services_container, **kwargs) elif tool_name == ToolName.LIVE_DANGEROUSLY: return await self.live_dangerously(services_container, **kwargs) elif tool_name == ToolName.CONFIRM_DESTRUCTIVE_OPERATION: return await self.confirm_destructive_operation(services_container, **kwargs) elif tool_name == ToolName.RETRIEVE_LOGS: return await self.retrieve_logs(services_container, **kwargs) else: raise ValueError(f"Unknown tool: {tool_name}") async def get_schemas(self, container: "ServicesContainer") -> QueryResult: """List all database schemas with their sizes and table counts.""" query_manager = container.query_manager query = query_manager.get_schemas_query() return await query_manager.handle_query(query) async def get_tables(self, container: "ServicesContainer", schema_name: str) -> QueryResult: """List all tables, foreign tables, and views in a schema with their sizes, row counts, and metadata.""" query_manager = container.query_manager query = query_manager.get_tables_query(schema_name) return await query_manager.handle_query(query) async def get_table_schema(self, container: "ServicesContainer", schema_name: str, table: str) -> QueryResult: """Get detailed table structure including columns, keys, and relationships.""" query_manager = container.query_manager query = query_manager.get_table_schema_query(schema_name, table) return await query_manager.handle_query(query) async def execute_postgresql( self, container: "ServicesContainer", query: str, migration_name: str = "" ) -> QueryResult: """Execute PostgreSQL statements against your Supabase database.""" query_manager = container.query_manager return await query_manager.handle_query(query, has_confirmation=False, migration_name=migration_name) async def retrieve_migrations( self, container: "ServicesContainer", limit: int = 50, offset: int = 0, name_pattern: str = "", include_full_queries: bool = False, ) -> QueryResult: """Retrieve a list of all migrations a user has from Supabase.""" query_manager = container.query_manager query = query_manager.get_migrations_query( limit=limit, offset=offset, name_pattern=name_pattern, include_full_queries=include_full_queries ) return await query_manager.handle_query(query) async def send_management_api_request( self, container: "ServicesContainer", method: str, path: str, path_params: dict[str, str], request_params: dict[str, Any], request_body: dict[str, Any], ) -> dict[str, Any]: """Execute a Supabase Management API request.""" api_manager = container.api_manager return await api_manager.execute_request(method, path, path_params, request_params, request_body) async def get_management_api_spec( self, container: "ServicesContainer", params: dict[str, Any] = {} ) -> dict[str, Any]: """Get the Supabase Management API specification.""" path = params.get("path") method = params.get("method") domain = params.get("domain") all_paths = params.get("all_paths", False) logger.debug( f"Getting management API spec with path: {path}, method: {method}, domain: {domain}, all_paths: {all_paths}" ) api_manager = container.api_manager return await api_manager.handle_spec_request(path, method, domain, all_paths) async def get_auth_admin_methods_spec(self, container: "ServicesContainer") -> dict[str, Any]: """Get Python SDK methods specification for Auth Admin.""" sdk_client = container.sdk_client return sdk_client.return_python_sdk_spec() async def call_auth_admin_method( self, container: "ServicesContainer", method: str, params: dict[str, Any] ) -> dict[str, Any]: """Call an Auth Admin method from Supabase Python SDK.""" sdk_client = container.sdk_client return await sdk_client.call_auth_admin_method(method, params) async def live_dangerously( self, container: "ServicesContainer", service: Literal["api", "database"], enable_unsafe_mode: bool = False ) -> dict[str, Any]: """ Toggle between safe and unsafe operation modes for API or Database services. This function controls the safety level for operations, allowing you to: - Enable write operations for the database (INSERT, UPDATE, DELETE, schema changes) - Enable state-changing operations for the Management API """ safety_manager = container.safety_manager if service == "api": # Set the safety mode in the safety manager new_mode = SafetyMode.UNSAFE if enable_unsafe_mode else SafetyMode.SAFE safety_manager.set_safety_mode(ClientType.API, new_mode) # Return the actual mode that was set return {"service": "api", "mode": safety_manager.get_safety_mode(ClientType.API)} elif service == "database": # Set the safety mode in the safety manager new_mode = SafetyMode.UNSAFE if enable_unsafe_mode else SafetyMode.SAFE safety_manager.set_safety_mode(ClientType.DATABASE, new_mode) # Return the actual mode that was set return {"service": "database", "mode": safety_manager.get_safety_mode(ClientType.DATABASE)} async def confirm_destructive_operation( self, container: "ServicesContainer", operation_type: Literal["api", "database"], confirmation_id: str, user_confirmation: bool = False, ) -> QueryResult | dict[str, Any]: """Execute a destructive operation after confirmation. Use this only after reviewing the risks with the user.""" api_manager = container.api_manager query_manager = container.query_manager if not user_confirmation: raise ConfirmationRequiredError("Destructive operation requires explicit user confirmation.") if operation_type == "api": return await api_manager.handle_confirmation(confirmation_id) elif operation_type == "database": return await query_manager.handle_confirmation(confirmation_id) async def retrieve_logs( self, container: "ServicesContainer", collection: str, limit: int = 20, hours_ago: int = 1, filters: list[dict[str, Any]] = [], search: str = "", custom_query: str = "", ) -> dict[str, Any]: """Retrieve logs from your Supabase project's services for debugging and monitoring.""" logger.info( f"Tool called: retrieve_logs(collection={collection}, limit={limit}, hours_ago={hours_ago}, filters={filters}, search={search}, custom_query={'<custom>' if custom_query else None})" ) api_manager = container.api_manager result = await api_manager.retrieve_logs( collection=collection, limit=limit, hours_ago=hours_ago, filters=filters, search=search, custom_query=custom_query, ) logger.info(f"Tool completed: retrieve_logs - Retrieved log entries for collection={collection}") return result ``` -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- ```python import os from collections.abc import AsyncGenerator, Generator from pathlib import Path from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest import pytest_asyncio from dotenv import load_dotenv from mcp.server.fastmcp import FastMCP from supabase_mcp.clients.management_client import ManagementAPIClient from supabase_mcp.clients.sdk_client import SupabaseSDKClient from supabase_mcp.core.container import ServicesContainer from supabase_mcp.logger import logger from supabase_mcp.services.api.api_manager import SupabaseApiManager from supabase_mcp.services.api.spec_manager import ApiSpecManager from supabase_mcp.services.database.migration_manager import MigrationManager from supabase_mcp.services.database.postgres_client import PostgresClient from supabase_mcp.services.database.query_manager import QueryManager from supabase_mcp.services.database.sql.loader import SQLLoader from supabase_mcp.services.database.sql.validator import SQLValidator from supabase_mcp.services.safety.safety_manager import SafetyManager from supabase_mcp.settings import Settings, find_config_file from supabase_mcp.tools import ToolManager from supabase_mcp.tools.registry import ToolRegistry # ====================== # Environment Fixtures # ====================== @pytest.fixture def clean_environment() -> Generator[None, None, None]: """Fixture to provide a clean environment without any Supabase-related env vars.""" # Save original environment original_env = dict(os.environ) # Remove all Supabase-related environment variables for key in list(os.environ.keys()): if key.startswith("SUPABASE_"): del os.environ[key] yield # Restore original environment os.environ.clear() os.environ.update(original_env) def load_test_env() -> dict[str, str | None]: """Load test environment variables from .env.test file""" env_test_path = Path(__file__).parent.parent / ".env.test" if not env_test_path.exists(): raise FileNotFoundError(f"Test environment file not found at {env_test_path}") load_dotenv(env_test_path) return { "SUPABASE_PROJECT_REF": os.getenv("SUPABASE_PROJECT_REF"), "SUPABASE_DB_PASSWORD": os.getenv("SUPABASE_DB_PASSWORD"), "SUPABASE_SERVICE_ROLE_KEY": os.getenv("SUPABASE_SERVICE_ROLE_KEY"), "SUPABASE_ACCESS_TOKEN": os.getenv("SUPABASE_ACCESS_TOKEN"), } @pytest.fixture(scope="session") def settings_integration() -> Settings: """Fixture providing settings for integration tests. This fixture loads settings from environment variables or .env.test file. Uses session scope since settings don't change during tests. """ return Settings.with_config(find_config_file(".env.test")) @pytest.fixture def mock_validator() -> SQLValidator: """Fixture providing a mock SQLValidator for integration tests.""" return SQLValidator() @pytest.fixture def settings_integration_custom_env() -> Generator[Settings, None, None]: """Fixture that provides Settings instance for integration tests using .env.test""" # Load custom environment variables test_env = load_test_env() original_env = dict(os.environ) # Set up test environment for key, value in test_env.items(): if value is not None: os.environ[key] = value # Create fresh settings instance settings = Settings() logger.info(f"Custom connection settings initialized: {settings}") yield settings # Restore original environment os.environ.clear() os.environ.update(original_env) # ====================== # Service Fixtures # ====================== @pytest_asyncio.fixture(scope="module") async def postgres_client_integration(settings_integration: Settings) -> AsyncGenerator[PostgresClient, None]: # Reset before creation await PostgresClient.reset() # Create a client client = PostgresClient(settings=settings_integration) try: yield client finally: await client.close() @pytest_asyncio.fixture(scope="module") async def spec_manager_integration() -> AsyncGenerator[ApiSpecManager, None]: """Fixture providing an ApiSpecManager instance for tests.""" manager = ApiSpecManager() yield manager @pytest_asyncio.fixture(scope="module") async def api_client_integration(settings_integration: Settings) -> AsyncGenerator[ManagementAPIClient, None]: # We don't need to reset since it's not a singleton client = ManagementAPIClient(settings=settings_integration) try: yield client finally: await client.close() @pytest_asyncio.fixture(scope="module") async def sdk_client_integration(settings_integration: Settings) -> AsyncGenerator[SupabaseSDKClient, None]: """Fixture providing a SupabaseSDKClient instance for tests. Uses function scope to ensure a fresh client for each test. """ client = SupabaseSDKClient.get_instance(settings=settings_integration) try: yield client finally: # Reset the singleton to ensure a fresh client for the next test SupabaseSDKClient.reset() @pytest.fixture(scope="module") def safety_manager_integration() -> SafetyManager: """Fixture providing a safety manager for integration tests.""" # Reset the safety manager singleton SafetyManager.reset() # Create a new safety manager safety_manager = SafetyManager.get_instance() safety_manager.register_safety_configs() return safety_manager @pytest.fixture(scope="module") def tool_manager_integration() -> ToolManager: """Fixture providing a tool manager for integration tests.""" # Reset the tool manager singleton ToolManager.reset() return ToolManager.get_instance() @pytest.fixture(scope="module") def query_manager_integration( postgres_client_integration: PostgresClient, safety_manager_integration: SafetyManager, ) -> QueryManager: """Fixture providing a query manager for integration tests.""" query_manager = QueryManager( postgres_client=postgres_client_integration, safety_manager=safety_manager_integration, ) return query_manager @pytest.fixture(scope="module") def mock_api_manager() -> SupabaseApiManager: """Fixture providing a properly mocked API manager for unit tests.""" # Create mock dependencies mock_client = MagicMock() mock_safety_manager = MagicMock() mock_spec_manager = MagicMock() # Create the API manager with proper constructor arguments api_manager = SupabaseApiManager(api_client=mock_client, safety_manager=mock_safety_manager) # Add the spec_manager attribute api_manager.spec_manager = mock_spec_manager return api_manager @pytest.fixture def mock_query_manager() -> QueryManager: """Fixture providing a properly mocked Query manager for unit tests.""" # Create mock dependencies mock_safety_manager = MagicMock() mock_postgres_client = MagicMock() mock_validator = MagicMock() # Create the Query manager with proper constructor arguments query_manager = QueryManager( postgres_client=mock_postgres_client, safety_manager=mock_safety_manager, ) # Replace the validator with a mock query_manager.validator = mock_validator # Store the postgres client as an attribute for tests to access query_manager.db_client = mock_postgres_client # Make execute_query_async an AsyncMock query_manager.db_client.execute_query_async = AsyncMock() return query_manager @pytest_asyncio.fixture(scope="module") async def api_manager_integration( api_client_integration: ManagementAPIClient, safety_manager_integration: SafetyManager, ) -> AsyncGenerator[SupabaseApiManager, None]: """Fixture providing an API manager for integration tests.""" # Create a new API manager api_manager = SupabaseApiManager.get_instance( api_client=api_client_integration, safety_manager=safety_manager_integration, ) try: yield api_manager finally: # Reset the API manager singleton SupabaseApiManager.reset() # ====================== # Mock MCP Server # ====================== @pytest.fixture def mock_mcp_server() -> Any: """Fixture providing a mock MCP server for integration tests.""" # Create a simple mock MCP server that mimics the FastMCP interface class MockMCP: def __init__(self) -> None: self.tools: dict[str, Any] = {} self.name = "mock_mcp" def register_tool(self, name: str, func: Any, **kwargs: Any) -> None: """Register a tool with the MCP server.""" self.tools[name] = func def run(self) -> None: """Mock run method.""" pass return MockMCP() @pytest.fixture(scope="module") def mock_mcp_server_integration() -> Any: """Fixture providing a mock MCP server for integration tests.""" return FastMCP(name="supabase") # ====================== # Container Fixture # ====================== @pytest.fixture(scope="module") def container_integration( postgres_client_integration: PostgresClient, api_client_integration: ManagementAPIClient, sdk_client_integration: SupabaseSDKClient, api_manager_integration: SupabaseApiManager, safety_manager_integration: SafetyManager, query_manager_integration: QueryManager, tool_manager_integration: ToolManager, mock_mcp_server_integration: FastMCP, ) -> ServicesContainer: """Fixture providing a basic Container for integration tests. This container includes all services needed for integration testing, but is not initialized. """ # Create a new container with all the services container = ServicesContainer( mcp_server=mock_mcp_server_integration, postgres_client=postgres_client_integration, api_client=api_client_integration, sdk_client=sdk_client_integration, api_manager=api_manager_integration, safety_manager=safety_manager_integration, query_manager=query_manager_integration, tool_manager=tool_manager_integration, ) logger.info("✓ Integration container created successfully.") return container @pytest.fixture(scope="module") def initialized_container_integration( container_integration: ServicesContainer, settings_integration: Settings, ) -> ServicesContainer: """Fixture providing a fully initialized Container for integration tests. This container is initialized with all services and ready to use. """ container_integration.initialize_services(settings_integration) logger.info("✓ Integration container initialized successfully.") return container_integration @pytest.fixture(scope="module") def tools_registry_integration( initialized_container_integration: ServicesContainer, ) -> ServicesContainer: """Fixture providing a Container with tools registered for integration tests. This container has all tools registered with the MCP server. """ container = initialized_container_integration mcp_server = container.mcp_server registry = ToolRegistry(mcp_server, container) registry.register_tools() logger.info("✓ Tools registered with MCP server successfully.") return container @pytest.fixture def sql_loader() -> SQLLoader: """Fixture providing a SQLLoader instance for tests.""" return SQLLoader() @pytest.fixture def migration_manager(sql_loader: SQLLoader) -> MigrationManager: """Fixture providing a MigrationManager instance for tests.""" return MigrationManager(loader=sql_loader) ``` -------------------------------------------------------------------------------- /tests/services/api/test_spec_manager.py: -------------------------------------------------------------------------------- ```python import json from unittest.mock import AsyncMock, MagicMock, mock_open, patch import httpx import pytest from supabase_mcp.services.api.spec_manager import ApiSpecManager # Test data SAMPLE_SPEC = {"openapi": "3.0.0", "paths": {"/v1/test": {"get": {"operationId": "test"}}}} class TestApiSpecManager: """Integration tests for api spec manager tools.""" # Local Spec Tests def test_load_local_spec_success(self, spec_manager_integration: ApiSpecManager): """Test successful loading of local spec file""" mock_file = mock_open(read_data=json.dumps(SAMPLE_SPEC)) with patch("builtins.open", mock_file): result = spec_manager_integration._load_local_spec() assert result == SAMPLE_SPEC mock_file.assert_called_once() def test_load_local_spec_file_not_found(self, spec_manager_integration: ApiSpecManager): """Test handling of missing local spec file""" with patch("builtins.open", side_effect=FileNotFoundError), pytest.raises(FileNotFoundError): spec_manager_integration._load_local_spec() def test_load_local_spec_invalid_json(self, spec_manager_integration: ApiSpecManager): """Test handling of invalid JSON in local spec""" mock_file = mock_open(read_data="invalid json") with patch("builtins.open", mock_file), pytest.raises(json.JSONDecodeError): spec_manager_integration._load_local_spec() # Remote Spec Tests @pytest.mark.asyncio async def test_fetch_remote_spec_success(self, spec_manager_integration: ApiSpecManager): """Test successful remote spec fetch""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = SAMPLE_SPEC mock_client = AsyncMock() mock_client.get.return_value = mock_response mock_client.__aenter__.return_value = mock_client # Mock async context manager with patch("httpx.AsyncClient", return_value=mock_client): result = await spec_manager_integration._fetch_remote_spec() assert result == SAMPLE_SPEC mock_client.get.assert_called_once() @pytest.mark.asyncio async def test_fetch_remote_spec_api_error(self, spec_manager_integration: ApiSpecManager): """Test handling of API error during remote fetch""" mock_response = MagicMock() mock_response.status_code = 500 mock_client = AsyncMock() mock_client.get.return_value = mock_response mock_client.__aenter__.return_value = mock_client # Mock async context manager with patch("httpx.AsyncClient", return_value=mock_client): result = await spec_manager_integration._fetch_remote_spec() assert result is None @pytest.mark.asyncio async def test_fetch_remote_spec_network_error(self, spec_manager_integration: ApiSpecManager): """Test handling of network error during remote fetch""" mock_client = AsyncMock() mock_client.get.side_effect = httpx.NetworkError("Network error") with patch("httpx.AsyncClient", return_value=mock_client): result = await spec_manager_integration._fetch_remote_spec() assert result is None # Startup Flow Tests @pytest.mark.asyncio async def test_startup_remote_success(self, spec_manager_integration: ApiSpecManager): """Test successful startup with remote fetch""" # Reset spec to None to ensure we're testing the fetch spec_manager_integration.spec = None # Mock the fetch method to return sample spec mock_fetch = AsyncMock(return_value=SAMPLE_SPEC) with patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch): result = await spec_manager_integration.get_spec() assert result == SAMPLE_SPEC mock_fetch.assert_called_once() @pytest.mark.asyncio async def test_get_spec_remote_fail_local_fallback(self, spec_manager_integration: ApiSpecManager): """Test get_spec with remote failure and local fallback""" # Reset spec to None to ensure we're testing the fetch and fallback spec_manager_integration.spec = None # Mock fetch to fail and local to succeed mock_fetch = AsyncMock(return_value=None) mock_local = MagicMock(return_value=SAMPLE_SPEC) with ( patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch), patch.object(spec_manager_integration, "_load_local_spec", mock_local), ): result = await spec_manager_integration.get_spec() assert result == SAMPLE_SPEC mock_fetch.assert_called_once() mock_local.assert_called_once() @pytest.mark.asyncio async def test_get_spec_both_fail(self, spec_manager_integration: ApiSpecManager): """Test get_spec with both remote and local failure""" # Reset spec to None to ensure we're testing the fetch and fallback spec_manager_integration.spec = None # Mock both fetch and local to fail mock_fetch = AsyncMock(return_value=None) mock_local = MagicMock(side_effect=FileNotFoundError("Test file not found")) with ( patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch), patch.object(spec_manager_integration, "_load_local_spec", mock_local), pytest.raises(FileNotFoundError), ): await spec_manager_integration.get_spec() mock_fetch.assert_called_once() mock_local.assert_called_once() @pytest.mark.asyncio async def test_get_spec_cached(self, spec_manager_integration: ApiSpecManager): """Test that get_spec returns cached spec if available""" # Set the spec directly spec_manager_integration.spec = SAMPLE_SPEC # Mock the fetch method to verify it's not called mock_fetch = AsyncMock() with patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch): result = await spec_manager_integration.get_spec() assert result == SAMPLE_SPEC mock_fetch.assert_not_called() @pytest.mark.asyncio async def test_get_spec_not_loaded(self, spec_manager_integration: ApiSpecManager): """Test behavior when spec is not loaded but can be loaded""" # Reset spec to None spec_manager_integration.spec = None # Mock fetch to succeed mock_fetch = AsyncMock(return_value=SAMPLE_SPEC) with patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch): result = await spec_manager_integration.get_spec() assert result == SAMPLE_SPEC mock_fetch.assert_called_once() @pytest.mark.asyncio async def test_comprehensive_spec_retrieval(self, spec_manager_integration: ApiSpecManager): """ Comprehensive test of API spec retrieval and functionality. This test exactly mirrors the main() function to ensure all aspects work correctly. """ # Create a fresh instance to avoid any cached data from other tests from supabase_mcp.services.api.spec_manager import LOCAL_SPEC_PATH, ApiSpecManager spec_manager = ApiSpecManager() # Print the path being used (for debugging) print(f"\nTest is looking for spec at: {LOCAL_SPEC_PATH}") # Load the spec spec = await spec_manager.get_spec() assert spec is not None, "Spec should be loaded successfully" # 1. Test get_all_domains all_domains = spec_manager.get_all_domains() print(f"\nAll domains: {all_domains}") assert len(all_domains) > 0, "Should have at least one domain" # Verify all expected domains are present expected_domains = [ "Analytics", "Auth", "Database", "Domains", "Edge Functions", "Environments", "OAuth", "Organizations", "Projects", "Rest", "Secrets", "Storage", ] for domain in expected_domains: assert domain in all_domains, f"Domain '{domain}' should be in the list of domains" # 2. Test get_all_paths_and_methods all_paths = spec_manager.get_all_paths_and_methods() assert len(all_paths) > 0, "Should have at least one path" # Sample a few paths to verify sample_paths = list(all_paths.keys())[:5] print("\nSample paths:") for path in sample_paths: print(f" {path}:") assert path.startswith("/v1/"), f"Path {path} should start with /v1/" assert len(all_paths[path]) > 0, f"Path {path} should have at least one method" for method, operation_id in all_paths[path].items(): print(f" {method}: {operation_id}") assert method.lower() in ["get", "post", "put", "patch", "delete"], f"Method {method} should be valid" assert operation_id.startswith("v1-"), f"Operation ID {operation_id} should start with v1-" # 3. Test get_paths_and_methods_by_domain for each domain for domain in expected_domains: domain_paths = spec_manager.get_paths_and_methods_by_domain(domain) assert len(domain_paths) > 0, f"Domain {domain} should have at least one path" print(f"\n{domain} domain has {len(domain_paths)} paths") # 4. Test Edge Functions domain specifically edge_paths = spec_manager.get_paths_and_methods_by_domain("Edge Functions") print("\nEdge Functions Paths and Methods:") for path in edge_paths: print(f" {path}") for method, operation_id in edge_paths[path].items(): print(f" {method}: {operation_id}") # Verify specific Edge Functions paths exist expected_edge_paths = [ "/v1/projects/{ref}/functions", "/v1/projects/{ref}/functions/{function_slug}", "/v1/projects/{ref}/functions/deploy", ] for path in expected_edge_paths: assert path in edge_paths, f"Expected path {path} should be in Edge Functions domain" # 5. Test get_spec_for_path_and_method # Test for Edge Functions path = "/v1/projects/{ref}/functions" method = "GET" full_spec = spec_manager.get_spec_for_path_and_method(path, method) assert full_spec is not None, f"Should find spec for {method} {path}" assert "operationId" in full_spec, "Spec should include operationId" assert full_spec["operationId"] == "v1-list-all-functions", "Should have correct operationId" # Test for another domain (Auth) auth_path = "/v1/projects/{ref}/config/auth" auth_method = "GET" auth_spec = spec_manager.get_spec_for_path_and_method(auth_path, auth_method) assert auth_spec is not None, f"Should find spec for {auth_method} {auth_path}" assert "operationId" in auth_spec, "Auth spec should include operationId" # 6. Test get_spec_part # Get a specific schema schema = spec_manager.get_spec_part("components", "schemas", "FunctionResponse") assert schema is not None, "Should find FunctionResponse schema" assert "properties" in schema, "Schema should have properties" # 7. Test caching behavior # Call get_spec again - should use cached version import time start_time = time.time() await spec_manager.get_spec() end_time = time.time() assert (end_time - start_time) < 0.1, "Cached spec retrieval should be fast" ``` -------------------------------------------------------------------------------- /supabase_mcp/services/api/api_manager.py: -------------------------------------------------------------------------------- ```python from __future__ import annotations from enum import Enum from typing import Any from supabase_mcp.clients.management_client import ManagementAPIClient from supabase_mcp.logger import logger from supabase_mcp.services.api.spec_manager import ApiSpecManager from supabase_mcp.services.logs.log_manager import LogManager from supabase_mcp.services.safety.models import ClientType from supabase_mcp.services.safety.safety_manager import SafetyManager from supabase_mcp.settings import settings class PathPlaceholder(str, Enum): """Enum of all possible path placeholders in the Supabase Management API.""" REF = "ref" FUNCTION_SLUG = "function_slug" ID = "id" SLUG = "slug" BRANCH_ID = "branch_id" PROVIDER_ID = "provider_id" TPA_ID = "tpa_id" class SupabaseApiManager: """ Manages the Supabase Management API. """ _instance: SupabaseApiManager | None = None def __init__( self, api_client: ManagementAPIClient, safety_manager: SafetyManager, spec_manager: ApiSpecManager | None = None, log_manager: LogManager | None = None, ) -> None: """Initialize the API manager.""" self.spec_manager = spec_manager or ApiSpecManager() # this is so that I don't have to pass it self.client = api_client self.safety_manager = safety_manager self.log_manager = log_manager or LogManager() @classmethod def get_instance( cls, api_client: ManagementAPIClient, safety_manager: SafetyManager, spec_manager: ApiSpecManager | None = None, ) -> SupabaseApiManager: """Get the singleton instance""" if cls._instance is None: cls._instance = SupabaseApiManager(api_client, safety_manager, spec_manager) return cls._instance @classmethod def reset(cls) -> None: """Reset the singleton instance""" if cls._instance is not None: cls._instance = None logger.info("SupabaseApiManager instance reset complete") def get_safety_rules(self) -> str: """ Get safety rules with human-readable descriptions. Returns: str: Human readable safety rules explanation """ # Get safety configuration from the safety manager safety_manager = self.safety_manager # Get risk levels and operations by risk level extreme_risk_ops = safety_manager.get_operations_by_risk_level("extreme", ClientType.API) high_risk_ops = safety_manager.get_operations_by_risk_level("high", ClientType.API) medium_risk_ops = safety_manager.get_operations_by_risk_level("medium", ClientType.API) # Create human-readable explanations extreme_risk_summary = ( "\n".join([f"- {method} {path}" for method, paths in extreme_risk_ops.items() for path in paths]) if extreme_risk_ops else "None" ) high_risk_summary = ( "\n".join([f"- {method} {path}" for method, paths in high_risk_ops.items() for path in paths]) if high_risk_ops else "None" ) medium_risk_summary = ( "\n".join([f"- {method} {path}" for method, paths in medium_risk_ops.items() for path in paths]) if medium_risk_ops else "None" ) current_mode = safety_manager.get_current_mode(ClientType.API) return f"""MCP Server Safety Rules: EXTREME RISK Operations (never allowed by the server): {extreme_risk_summary} HIGH RISK Operations (require unsafe mode): {high_risk_summary} MEDIUM RISK Operations (require unsafe mode): {medium_risk_summary} All other operations are LOW RISK (always allowed). Current mode: {current_mode} In safe mode, only low risk operations are allowed. Use live_dangerously() to enable unsafe mode for medium and high risk operations. """ def replace_path_params(self, path: str, path_params: dict[str, Any] | None = None) -> str: """ Replace path parameters in the path string with actual values. This method: 1. Automatically injects the project ref from settings 2. Replaces all placeholders in the path with values from path_params 3. Validates that all placeholders are replaced Args: path: The API path with placeholders (e.g., "/v1/projects/{ref}/functions/{function_slug}") path_params: Dictionary of path parameters to replace (e.g., {"function_slug": "my-function"}) Returns: The path with all placeholders replaced Raises: ValueError: If any placeholders remain after replacement or if invalid placeholders are provided """ # Create a working copy of path_params to avoid modifying the original working_params = {} if path_params is None else path_params.copy() # Check if user provided ref and raise an error if working_params and PathPlaceholder.REF.value in working_params: raise ValueError( "Do not provide 'ref' in path_params. The project reference is automatically injected from settings. " "If you need to change the project reference, modify the environment variables instead." ) # Validate that all provided path parameters are known placeholders if working_params: for key in working_params: try: PathPlaceholder(key) except ValueError as e: raise ValueError( f"Unknown path parameter: '{key}'. Valid placeholders are: " f"{', '.join([p.value for p in PathPlaceholder])}" ) from e # Get project ref from settings and add it to working_params working_params[PathPlaceholder.REF.value] = settings.supabase_project_ref logger.info(f"Replacing path parameters in path: {working_params}") # Replace all placeholders in the path for key, value in working_params.items(): placeholder = "{" + key + "}" if placeholder in path: path = path.replace(placeholder, str(value)) logger.debug(f"Replaced {placeholder} with {value}") # Check if any placeholders remain import re remaining_placeholders = re.findall(r"\{([^}]+)\}", path) if remaining_placeholders: raise ValueError( f"Missing path parameters: {', '.join(remaining_placeholders)}. " f"Please provide values for these placeholders in the path_params dictionary." ) return path async def execute_request( self, method: str, path: str, path_params: dict[str, Any] | None = None, request_params: dict[str, Any] | None = None, request_body: dict[str, Any] | None = None, has_confirmation: bool = False, ) -> dict[str, Any]: """ Execute Management API request with safety validation. Args: method: HTTP method to use path: API path to call request_params: Query parameters to include request_body: Request body to send has_confirmation: Whether the operation has been confirmed by the user Returns: API response as a dictionary Raises: SafetyError: If the operation is not allowed by safety rules """ # Log the request with proper formatting logger.info( f"API Request: {method} {path} | Path params: {path_params or {}} | Query params: {request_params or {}} | Body: {request_body or {}}" ) # Create an operation object for validation operation = (method, path, path_params, request_params, request_body) # Use the safety manager to validate the operation logger.debug(f"Validating operation safety: {method} {path}") self.safety_manager.validate_operation(ClientType.API, operation, has_confirmation=has_confirmation) # Replace path parameters in the path string with actual values path = self.replace_path_params(path, path_params) # Execute the request using the API client return await self.client.execute_request(method, path, request_params, request_body) async def handle_confirmation(self, confirmation_id: str) -> dict[str, Any]: """Handle a confirmation request.""" # retrieve the operation from the confirmation id operation = self.safety_manager.get_stored_operation(confirmation_id) if not operation: raise ValueError("No operation found for confirmation id") # execute the operation return await self.execute_request( method=operation[0], path=operation[1], path_params=operation[2], request_params=operation[3], request_body=operation[4], has_confirmation=True, ) async def handle_spec_request( self, path: str | None = None, method: str | None = None, domain: str | None = None, all_paths: bool | None = False, ) -> dict[str, Any]: """Handle a spec request. Args: path: Optional API path method: Optional HTTP method api_domain: Optional domain/tag name full_spec: If True, returns all paths and methods Returns: API specification based on the provided parameters """ spec_manager = self.spec_manager if spec_manager is None: raise RuntimeError("API spec manager is not initialized") # Ensure spec is loaded await spec_manager.get_spec() # Option 1: Get spec for specific path and method if path and method: method = method.lower() # Normalize method to lowercase result = spec_manager.get_spec_for_path_and_method(path, method) if result is None: return {"error": f"No specification found for {method.upper()} {path}"} return result # Option 2: Get all paths and methods for a specific domain elif domain: result = spec_manager.get_paths_and_methods_by_domain(domain) if not result: # Check if the domain exists all_domains = spec_manager.get_all_domains() if domain not in all_domains: return {"error": f"Domain '{domain}' not found", "available_domains": all_domains} return {"domain": domain, "paths": result} # Option 4: Get all paths and methods elif all_paths: return {"paths": spec_manager.get_all_paths_and_methods()} # Option 3: Get all domains (default) else: return {"domains": spec_manager.get_all_domains()} async def retrieve_logs( self, collection: str, limit: int = 20, hours_ago: int | None = 1, filters: list[dict[str, Any]] | None = None, search: str | None = None, custom_query: str | None = None, ) -> dict[str, Any]: """Retrieve logs from a Supabase service. Args: collection: The log collection to query limit: Maximum number of log entries to return hours_ago: Retrieve logs from the last N hours filters: List of filter objects with field, operator, and value search: Text to search for in event messages custom_query: Complete custom SQL query to execute Returns: The query result Raises: ValueError: If the collection is unknown """ log_manager = self.log_manager # Build the SQL query using LogManager sql = log_manager.build_logs_query( collection=collection, limit=limit, hours_ago=hours_ago, filters=filters, search=search, custom_query=custom_query, ) logger.debug(f"Executing log query: {sql}") # Make the API request try: response = await self.execute_request( method="GET", path="/v1/projects/{ref}/analytics/endpoints/logs.all", path_params={}, request_params={"sql": sql}, request_body={}, ) return response except Exception as e: logger.error(f"Error retrieving logs: {e}") raise ``` -------------------------------------------------------------------------------- /tests/services/safety/test_safety_manager.py: -------------------------------------------------------------------------------- ```python import time import pytest from supabase_mcp.exceptions import ConfirmationRequiredError, OperationNotAllowedError from supabase_mcp.services.safety.models import ClientType, OperationRiskLevel, SafetyMode from supabase_mcp.services.safety.safety_configs import SafetyConfigBase from supabase_mcp.services.safety.safety_manager import SafetyManager class MockSafetyConfig(SafetyConfigBase[str]): """Mock safety configuration for testing.""" def get_risk_level(self, operation: str) -> OperationRiskLevel: """Get the risk level for an operation.""" if operation == "low_risk": return OperationRiskLevel.LOW elif operation == "medium_risk": return OperationRiskLevel.MEDIUM elif operation == "high_risk": return OperationRiskLevel.HIGH elif operation == "extreme_risk": return OperationRiskLevel.EXTREME else: return OperationRiskLevel.LOW @pytest.mark.unit class TestSafetyManager: """Unit test cases for the SafetyManager class.""" @pytest.fixture(autouse=True) def setup_and_teardown(self): """Setup and teardown for each test.""" # Reset the singleton before each test # pylint: disable=protected-access SafetyManager._instance = None # type: ignore yield # Reset the singleton after each test SafetyManager._instance = None # type: ignore def test_singleton_pattern(self): """Test that SafetyManager follows the singleton pattern.""" # Get two instances of the SafetyManager manager1 = SafetyManager.get_instance() manager2 = SafetyManager.get_instance() # Verify they are the same instance assert manager1 is manager2 # Verify that creating a new instance directly doesn't affect the singleton direct_instance = SafetyManager() assert direct_instance is not manager1 def test_register_config(self): """Test registering a safety configuration.""" manager = SafetyManager.get_instance() mock_config = MockSafetyConfig() # Register the config for DATABASE client type manager.register_config(ClientType.DATABASE, mock_config) # Verify the config was registered assert manager._safety_configs[ClientType.DATABASE] is mock_config # Test that registering a config for the same client type overwrites the previous config new_mock_config = MockSafetyConfig() manager.register_config(ClientType.DATABASE, new_mock_config) assert manager._safety_configs[ClientType.DATABASE] is new_mock_config def test_get_safety_mode_default(self): """Test getting the default safety mode for an unregistered client type.""" manager = SafetyManager.get_instance() # Create a custom client type that hasn't been registered class CustomClientType(str): pass custom_type = CustomClientType("custom") # Verify that getting a safety mode for an unregistered client type returns SafetyMode.SAFE assert manager.get_safety_mode(custom_type) == SafetyMode.SAFE # type: ignore def test_get_safety_mode_registered(self): """Test getting the safety mode for a registered client type.""" manager = SafetyManager.get_instance() # Set a safety mode for a client type manager._safety_modes[ClientType.API] = SafetyMode.UNSAFE # Verify it's returned correctly assert manager.get_safety_mode(ClientType.API) == SafetyMode.UNSAFE def test_set_safety_mode(self): """Test setting the safety mode for a client type.""" manager = SafetyManager.get_instance() # Set a safety mode for a client type manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE) # Verify it was updated assert manager._safety_modes[ClientType.DATABASE] == SafetyMode.UNSAFE # Change it back to SAFE manager.set_safety_mode(ClientType.DATABASE, SafetyMode.SAFE) # Verify it was updated again assert manager._safety_modes[ClientType.DATABASE] == SafetyMode.SAFE def test_validate_operation_allowed(self): """Test validating an operation that is allowed.""" manager = SafetyManager.get_instance() mock_config = MockSafetyConfig() # Register the config manager.register_config(ClientType.DATABASE, mock_config) # Set safety mode to SAFE manager.set_safety_mode(ClientType.DATABASE, SafetyMode.SAFE) # Validate a low risk operation (should be allowed in SAFE mode) # This should not raise an exception manager.validate_operation(ClientType.DATABASE, "low_risk") # Set safety mode to UNSAFE manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE) # Validate medium risk operation (should be allowed in UNSAFE mode) # This should not raise an exception manager.validate_operation(ClientType.DATABASE, "medium_risk") # High risk operations require confirmation, so we test with confirmation=True manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=True) def test_validate_operation_not_allowed(self): """Test validating an operation that is not allowed.""" manager = SafetyManager.get_instance() mock_config = MockSafetyConfig() # Register the config manager.register_config(ClientType.DATABASE, mock_config) # Set safety mode to SAFE manager.set_safety_mode(ClientType.DATABASE, SafetyMode.SAFE) # Validate medium risk operation (should not be allowed in SAFE mode) with pytest.raises(OperationNotAllowedError): manager.validate_operation(ClientType.DATABASE, "medium_risk") # Validate high risk operation (should not be allowed in SAFE mode) with pytest.raises(OperationNotAllowedError): manager.validate_operation(ClientType.DATABASE, "high_risk") # Validate extreme risk operation (should not be allowed in SAFE mode) with pytest.raises(OperationNotAllowedError): manager.validate_operation(ClientType.DATABASE, "extreme_risk") def test_validate_operation_requires_confirmation(self): """Test validating an operation that requires confirmation.""" manager = SafetyManager.get_instance() mock_config = MockSafetyConfig() # Register the config manager.register_config(ClientType.DATABASE, mock_config) # Set safety mode to UNSAFE manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE) # Validate high risk operation without confirmation # Should raise ConfirmationRequiredError with pytest.raises(ConfirmationRequiredError): manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=False) # Extreme risk operations are not allowed even in UNSAFE mode with pytest.raises(OperationNotAllowedError): manager.validate_operation(ClientType.DATABASE, "extreme_risk", has_confirmation=False) # Even with confirmation, extreme risk operations are not allowed with pytest.raises(OperationNotAllowedError): manager.validate_operation(ClientType.DATABASE, "extreme_risk", has_confirmation=True) def test_store_confirmation(self): """Test storing a confirmation for an operation.""" manager = SafetyManager.get_instance() # Store a confirmation confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME) # Verify that a confirmation ID is returned assert confirmation_id is not None assert confirmation_id.startswith("conf_") # Verify that the confirmation can be retrieved confirmation = manager._get_confirmation(confirmation_id) assert confirmation is not None assert confirmation["operation"] == "test_operation" assert confirmation["client_type"] == ClientType.DATABASE assert confirmation["risk_level"] == OperationRiskLevel.EXTREME assert "timestamp" in confirmation def test_get_confirmation_valid(self): """Test getting a valid confirmation.""" manager = SafetyManager.get_instance() # Store a confirmation confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME) # Retrieve the confirmation confirmation = manager._get_confirmation(confirmation_id) # Verify it matches what was stored assert confirmation is not None assert confirmation["operation"] == "test_operation" assert confirmation["client_type"] == ClientType.DATABASE assert confirmation["risk_level"] == OperationRiskLevel.EXTREME def test_get_confirmation_invalid(self): """Test getting an invalid confirmation.""" manager = SafetyManager.get_instance() # Try to retrieve a confirmation with an invalid ID confirmation = manager._get_confirmation("invalid_id") # Verify that None is returned assert confirmation is None def test_get_confirmation_expired(self): """Test getting an expired confirmation.""" manager = SafetyManager.get_instance() # Store a confirmation with a past expiration time confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME) # Manually set the timestamp to be older than the expiry time manager._pending_confirmations[confirmation_id]["timestamp"] = time.time() - manager._confirmation_expiry - 10 # Try to retrieve the confirmation confirmation = manager._get_confirmation(confirmation_id) # Verify that None is returned assert confirmation is None def test_cleanup_expired_confirmations(self): """Test cleaning up expired confirmations.""" manager = SafetyManager.get_instance() # Store multiple confirmations with different expiration times valid_id = manager._store_confirmation(ClientType.DATABASE, "valid_operation", OperationRiskLevel.EXTREME) expired_id = manager._store_confirmation(ClientType.DATABASE, "expired_operation", OperationRiskLevel.EXTREME) # Manually set the timestamp of the expired confirmation to be older than the expiry time manager._pending_confirmations[expired_id]["timestamp"] = time.time() - manager._confirmation_expiry - 10 # Call cleanup manager._cleanup_expired_confirmations() # Verify that expired confirmations are removed and valid ones remain assert valid_id in manager._pending_confirmations assert expired_id not in manager._pending_confirmations def test_get_stored_operation(self): """Test getting a stored operation.""" manager = SafetyManager.get_instance() # Store a confirmation for an operation confirmation_id = manager._store_confirmation(ClientType.DATABASE, "test_operation", OperationRiskLevel.EXTREME) # Retrieve the operation operation = manager.get_stored_operation(confirmation_id) # Verify that the retrieved operation matches the original assert operation == "test_operation" # Test with an invalid ID assert manager.get_stored_operation("invalid_id") is None def test_integration_validate_and_confirm(self): """Test the full flow of validating an operation that requires confirmation and then confirming it.""" manager = SafetyManager.get_instance() mock_config = MockSafetyConfig() # Register the config manager.register_config(ClientType.DATABASE, mock_config) # Set safety mode to UNSAFE manager.set_safety_mode(ClientType.DATABASE, SafetyMode.UNSAFE) # Try to validate a high risk operation and catch the ConfirmationRequiredError confirmation_id = None try: manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=False) except ConfirmationRequiredError as e: # Extract the confirmation ID from the error message error_message = str(e) # Find the confirmation ID in the message import re match = re.search(r"ID: (conf_[a-f0-9]+)", error_message) if match: confirmation_id = match.group(1) # Verify that we got a confirmation ID assert confirmation_id is not None # Now validate the operation again with the confirmation ID # This should not raise an exception manager.validate_operation(ClientType.DATABASE, "high_risk", has_confirmation=True) ``` -------------------------------------------------------------------------------- /supabase_mcp/services/database/sql/validator.py: -------------------------------------------------------------------------------- ```python from typing import Any from pglast.parser import ParseError, parse_sql from supabase_mcp.exceptions import ValidationError from supabase_mcp.logger import logger from supabase_mcp.services.database.sql.models import ( QueryValidationResults, SQLQueryCategory, SQLQueryCommand, ValidatedStatement, ) from supabase_mcp.services.safety.safety_configs import SQLSafetyConfig class SQLValidator: """SQL validator class that is based on pglast library. Responsible for: - SQL query syntax validation - SQL query categorization""" # Mapping from statement types to object types STATEMENT_TYPE_TO_OBJECT_TYPE = { "CreateFunctionStmt": "function", "ViewStmt": "view", "CreateTableAsStmt": "materialized_view", # When relkind is 'm', otherwise 'table' "CreateEnumStmt": "type", "CreateTypeStmt": "type", "CreateExtensionStmt": "extension", "CreateForeignTableStmt": "foreign_table", "CreatePolicyStmt": "policy", "CreateTrigStmt": "trigger", "IndexStmt": "index", "CreateStmt": "table", "AlterTableStmt": "table", "GrantStmt": "privilege", "RevokeStmt": "privilege", "CreateProcStmt": "procedure", # For CREATE PROCEDURE } def __init__(self, safety_config: SQLSafetyConfig | None = None) -> None: self.safety_config = safety_config or SQLSafetyConfig() def validate_schema_name(self, schema_name: str) -> str: """Validate schema name. Rules: - Must be a string - Cannot be empty - Cannot contain spaces or special characters """ if not schema_name.strip(): raise ValidationError("Schema name cannot be empty") if " " in schema_name: raise ValidationError("Schema name cannot contain spaces") return schema_name def validate_table_name(self, table: str) -> str: """Validate table name. Rules: - Must be a string - Cannot be empty - Cannot contain spaces or special characters """ if not table.strip(): raise ValidationError("Table name cannot be empty") if " " in table: raise ValidationError("Table name cannot contain spaces") return table def basic_query_validation(self, query: str) -> str: """Validate SQL query. Rules: - Must be a string - Cannot be empty """ if not query.strip(): raise ValidationError("Query cannot be empty") return query @classmethod def validate_transaction_control(cls, query: str) -> bool: """Check if the query contains transaction control statements. Args: query: SQL query string Returns: bool: True if the query contains any transaction control statements """ return any(x in query.upper() for x in ["BEGIN", "COMMIT", "ROLLBACK"]) def validate_query(self, sql_query: str) -> QueryValidationResults: """ Identify the type of SQL query using PostgreSQL's parser. Args: sql_query: A SQL query string to parse Returns: QueryValidationResults: A validation result object containing information about the SQL statements Raises: ValidationError: If the query is not valid or contains TCL statements """ try: # Validate raw input sql_query = self.basic_query_validation(sql_query) # Parse the SQL using PostgreSQL's parser parse_tree = parse_sql(sql_query) if parse_tree is None: logger.debug("No statements found in the query") # logger.debug(f"Parse tree generated with {parse_tree} statements") # Validate statements result = self.validate_statements(original_query=sql_query, parse_tree=parse_tree) # Check if the query contains transaction control statements and reject them for statement in result.statements: if statement.category == SQLQueryCategory.TCL: logger.warning(f"Transaction control statement detected: {statement.command}") raise ValidationError( "Transaction control statements (BEGIN, COMMIT, ROLLBACK) are not allowed. " "Queries will be automatically wrapped in transactions by the system." ) return result except ParseError as e: logger.exception(f"SQL syntax error: {str(e)}") raise ValidationError(f"SQL syntax error: {str(e)}") from e except ValidationError: # let it propagate raise except Exception as e: logger.exception(f"Unexpected error during SQL validation: {str(e)}") raise ValidationError(f"Unexpected error during SQL validation: {str(e)}") from e def _map_to_command(self, stmt_type: str) -> SQLQueryCommand: """Map a pglast statement type to our SQLQueryCommand enum.""" mapping = { # DQL Commands "SelectStmt": SQLQueryCommand.SELECT, # DML Commands "InsertStmt": SQLQueryCommand.INSERT, "UpdateStmt": SQLQueryCommand.UPDATE, "DeleteStmt": SQLQueryCommand.DELETE, "MergeStmt": SQLQueryCommand.MERGE, # DDL Commands "CreateStmt": SQLQueryCommand.CREATE, "CreateTableAsStmt": SQLQueryCommand.CREATE, "CreateSchemaStmt": SQLQueryCommand.CREATE, "CreateExtensionStmt": SQLQueryCommand.CREATE, "CreateFunctionStmt": SQLQueryCommand.CREATE, "CreateTrigStmt": SQLQueryCommand.CREATE, "ViewStmt": SQLQueryCommand.CREATE, "IndexStmt": SQLQueryCommand.CREATE, # Additional DDL Commands "CreateEnumStmt": SQLQueryCommand.CREATE, "CreateTypeStmt": SQLQueryCommand.CREATE, "CreateDomainStmt": SQLQueryCommand.CREATE, "CreateSeqStmt": SQLQueryCommand.CREATE, "CreateForeignTableStmt": SQLQueryCommand.CREATE, "CreatePolicyStmt": SQLQueryCommand.CREATE, "CreateCastStmt": SQLQueryCommand.CREATE, "CreateOpClassStmt": SQLQueryCommand.CREATE, "CreateOpFamilyStmt": SQLQueryCommand.CREATE, "AlterTableStmt": SQLQueryCommand.ALTER, "AlterDomainStmt": SQLQueryCommand.ALTER, "AlterEnumStmt": SQLQueryCommand.ALTER, "AlterSeqStmt": SQLQueryCommand.ALTER, "AlterOwnerStmt": SQLQueryCommand.ALTER, "AlterObjectSchemaStmt": SQLQueryCommand.ALTER, "DropStmt": SQLQueryCommand.DROP, "TruncateStmt": SQLQueryCommand.TRUNCATE, "CommentStmt": SQLQueryCommand.COMMENT, "RenameStmt": SQLQueryCommand.RENAME, # DCL Commands "GrantStmt": SQLQueryCommand.GRANT, "GrantRoleStmt": SQLQueryCommand.GRANT, "RevokeStmt": SQLQueryCommand.REVOKE, "RevokeRoleStmt": SQLQueryCommand.REVOKE, "CreateRoleStmt": SQLQueryCommand.CREATE, "AlterRoleStmt": SQLQueryCommand.ALTER, "DropRoleStmt": SQLQueryCommand.DROP, # TCL Commands "TransactionStmt": SQLQueryCommand.BEGIN, # Will need refinement for different transaction types # PostgreSQL-specific Commands "VacuumStmt": SQLQueryCommand.VACUUM, "ExplainStmt": SQLQueryCommand.EXPLAIN, "CopyStmt": SQLQueryCommand.COPY, "ListenStmt": SQLQueryCommand.LISTEN, "NotifyStmt": SQLQueryCommand.NOTIFY, "PrepareStmt": SQLQueryCommand.PREPARE, "ExecuteStmt": SQLQueryCommand.EXECUTE, "DeallocateStmt": SQLQueryCommand.DEALLOCATE, } # Try to map the statement type, default to UNKNOWN return mapping.get(stmt_type, SQLQueryCommand.UNKNOWN) def validate_statements(self, original_query: str, parse_tree: Any) -> QueryValidationResults: """Validate the statements in the parse tree. Args: parse_tree: The parse tree to validate Returns: SQLBatchValidationResult: A validation result object containing information about the SQL statements Raises: ValidationError: If the query is not valid """ result = QueryValidationResults(original_query=original_query) if parse_tree is None: return result try: for stmt in parse_tree: if not hasattr(stmt, "stmt"): continue stmt_node = stmt.stmt stmt_type = stmt_node.__class__.__name__ logger.debug(f"Processing statement node type: {stmt_type}") # logger.debug(f"DEBUGGING stmt_node: {stmt_node}") logger.debug(f"DEBUGGING stmt_node.stmt_location: {stmt.stmt_location}") # Extract the object type if available object_type = None schema_name = None if hasattr(stmt_node, "relation") and stmt_node.relation is not None: if hasattr(stmt_node.relation, "relname"): object_type = stmt_node.relation.relname if hasattr(stmt_node.relation, "schemaname"): schema_name = stmt_node.relation.schemaname # For statements with 'relations' list (like TRUNCATE) elif hasattr(stmt_node, "relations") and stmt_node.relations: for relation in stmt_node.relations: if hasattr(relation, "relname"): object_type = relation.relname if hasattr(relation, "schemaname"): schema_name = relation.schemaname break # Simple approach: Set object_type based on statement type if not already set if object_type is None and stmt_type in self.STATEMENT_TYPE_TO_OBJECT_TYPE: object_type = self.STATEMENT_TYPE_TO_OBJECT_TYPE[stmt_type] # Default schema to public if not set if schema_name is None: schema_name = "public" # Get classification for this statement type classification = self.safety_config.classify_statement(stmt_type, stmt_node) logger.debug( f"Statement category classified as: {classification.get('category', 'UNKNOWN')} - risk level: {classification.get('risk_level', 'UNKNOWN')}" ) logger.debug(f"DEBUGGING QUERY EXTRACTION LOCATION: {stmt.stmt_location} - {stmt.stmt_len}") # Create validation result query_result = ValidatedStatement( category=classification["category"], command=self._map_to_command(stmt_type), risk_level=classification["risk_level"], needs_migration=classification["needs_migration"], object_type=object_type, schema_name=schema_name, query=original_query[stmt.stmt_location : stmt.stmt_location + stmt.stmt_len] if hasattr(stmt, "stmt_location") and hasattr(stmt, "stmt_len") else None, ) # logger.debug(f"Isolated query: {query_result.query}") logger.debug( "Query validation result:", { "statement_category": query_result.category, "risk_level": query_result.risk_level, "needs_migration": query_result.needs_migration, "object_type": query_result.object_type, "schema_name": query_result.schema_name, "query": query_result.query, }, ) # Add result to the batch result.statements.append(query_result) # Update highest risk level if query_result.risk_level > result.highest_risk_level: result.highest_risk_level = query_result.risk_level logger.debug(f"Updated batch validation result to: {query_result.risk_level}") if len(result.statements) == 0: logger.debug("No valid statements found in the query") raise ValidationError("No queries were parsed - please check correctness of your query") logger.debug( f"Validated a total of {len(result.statements)} with the highest risk level of: {result.highest_risk_level}" ) return result except AttributeError as e: # Handle attempting to access missing attributes in the parse tree raise ValidationError(f"Error accessing parse tree structure: {str(e)}") from e except KeyError as e: # Handle missing keys in classification dictionary raise ValidationError(f"Missing classification key: {str(e)}") from e ``` -------------------------------------------------------------------------------- /tests/services/database/test_postgres_client.py: -------------------------------------------------------------------------------- ```python import asyncpg import pytest from unittest.mock import AsyncMock, MagicMock, patch from supabase_mcp.exceptions import ConnectionError, QueryError, PermissionError as SupabasePermissionError from supabase_mcp.services.database.postgres_client import PostgresClient, QueryResult, StatementResult from supabase_mcp.services.database.sql.validator import ( QueryValidationResults, SQLQueryCategory, SQLQueryCommand, ValidatedStatement, ) from supabase_mcp.services.safety.models import OperationRiskLevel from supabase_mcp.settings import Settings @pytest.mark.asyncio(loop_scope="class") class TestPostgresClient: """Unit tests for the Postgres client.""" @pytest.fixture def mock_settings(self): """Create mock settings for testing.""" settings = MagicMock(spec=Settings) settings.supabase_project_ref = "test-project-ref" settings.supabase_db_password = "test-password" settings.supabase_region = "us-east-1" settings.database_url = "postgresql://test:test@localhost:5432/test" return settings @pytest.fixture async def mock_postgres_client(self, mock_settings): """Create a mock Postgres client for testing.""" # Reset the singleton first await PostgresClient.reset() # Create client and mock execute_query directly client = PostgresClient(settings=mock_settings) return client async def test_execute_simple_select(self, mock_postgres_client: PostgresClient): """Test executing a simple SELECT query.""" # Create a simple validation result with a SELECT query query = "SELECT 1 as number;" statement = ValidatedStatement( query=query, command=SQLQueryCommand.SELECT, category=SQLQueryCategory.DQL, risk_level=OperationRiskLevel.LOW, needs_migration=False, object_type=None, schema_name=None, ) validation_result = QueryValidationResults( statements=[statement], original_query=query, highest_risk_level=OperationRiskLevel.LOW, ) # Mock the query result expected_result = QueryResult(results=[ StatementResult(rows=[{"number": 1}]) ]) with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): # Execute the query result = await mock_postgres_client.execute_query(validation_result) # Verify the result assert isinstance(result, QueryResult) assert len(result.results) == 1 assert isinstance(result.results[0], StatementResult) assert len(result.results[0].rows) == 1 assert result.results[0].rows[0]["number"] == 1 async def test_execute_multiple_statements(self, mock_postgres_client: PostgresClient): """Test executing multiple SQL statements in a single query.""" # Create validation result with multiple statements query = "SELECT 1 as first; SELECT 2 as second;" statements = [ ValidatedStatement( query="SELECT 1 as first;", command=SQLQueryCommand.SELECT, category=SQLQueryCategory.DQL, risk_level=OperationRiskLevel.LOW, needs_migration=False, object_type=None, schema_name=None, ), ValidatedStatement( query="SELECT 2 as second;", command=SQLQueryCommand.SELECT, category=SQLQueryCategory.DQL, risk_level=OperationRiskLevel.LOW, needs_migration=False, object_type=None, schema_name=None, ), ] validation_result = QueryValidationResults( statements=statements, original_query=query, highest_risk_level=OperationRiskLevel.LOW, ) # Mock the query result expected_result = QueryResult(results=[ StatementResult(rows=[{"first": 1}]), StatementResult(rows=[{"second": 2}]) ]) with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): # Execute the query result = await mock_postgres_client.execute_query(validation_result) # Verify the result assert isinstance(result, QueryResult) assert len(result.results) == 2 assert result.results[0].rows[0]["first"] == 1 assert result.results[1].rows[0]["second"] == 2 async def test_execute_query_with_parameters(self, mock_postgres_client: PostgresClient): """Test executing a query with parameters.""" query = "SELECT 'test' as name, 42 as value;" statement = ValidatedStatement( query=query, command=SQLQueryCommand.SELECT, category=SQLQueryCategory.DQL, risk_level=OperationRiskLevel.LOW, needs_migration=False, object_type=None, schema_name=None, ) validation_result = QueryValidationResults( statements=[statement], original_query=query, highest_risk_level=OperationRiskLevel.LOW, ) # Mock the query result expected_result = QueryResult(results=[ StatementResult(rows=[{"name": "test", "value": 42}]) ]) with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): # Execute the query result = await mock_postgres_client.execute_query(validation_result) # Verify the result assert isinstance(result, QueryResult) assert len(result.results) == 1 assert result.results[0].rows[0]["name"] == "test" assert result.results[0].rows[0]["value"] == 42 async def test_permission_error(self, mock_postgres_client: PostgresClient): """Test handling a permission error.""" # Create a mock error error = asyncpg.exceptions.InsufficientPrivilegeError("Permission denied") # Verify that the method raises PermissionError with the expected message with pytest.raises(SupabasePermissionError) as exc_info: await mock_postgres_client._handle_postgres_error(error) # Verify the error message assert "Access denied" in str(exc_info.value) assert "Permission denied" in str(exc_info.value) assert "live_dangerously" in str(exc_info.value) async def test_query_error(self, mock_postgres_client: PostgresClient): """Test handling a query error.""" # Create a validation result with a syntactically valid but semantically incorrect query query = "SELECT * FROM nonexistent_table;" statement = ValidatedStatement( query=query, command=SQLQueryCommand.SELECT, category=SQLQueryCategory.DQL, risk_level=OperationRiskLevel.LOW, needs_migration=False, object_type="TABLE", schema_name="public", ) validation_result = QueryValidationResults( statements=[statement], original_query=query, highest_risk_level=OperationRiskLevel.LOW, ) # Mock execute_query to raise a QueryError with patch.object(mock_postgres_client, 'execute_query', side_effect=QueryError("relation \"nonexistent_table\" does not exist")): # Execute the query - should raise a QueryError with pytest.raises(QueryError) as excinfo: await mock_postgres_client.execute_query(validation_result) # Verify the error message contains the specific error assert "nonexistent_table" in str(excinfo.value) async def test_schema_error(self, mock_postgres_client: PostgresClient): """Test handling a schema error.""" # Create a validation result with a query referencing a non-existent column query = "SELECT nonexistent_column FROM information_schema.tables;" statement = ValidatedStatement( query=query, command=SQLQueryCommand.SELECT, category=SQLQueryCategory.DQL, risk_level=OperationRiskLevel.LOW, needs_migration=False, object_type="TABLE", schema_name="information_schema", ) validation_result = QueryValidationResults( statements=[statement], original_query=query, highest_risk_level=OperationRiskLevel.LOW, ) # Mock execute_query to raise a QueryError with patch.object(mock_postgres_client, 'execute_query', side_effect=QueryError("column \"nonexistent_column\" does not exist")): # Execute the query - should raise a QueryError with pytest.raises(QueryError) as excinfo: await mock_postgres_client.execute_query(validation_result) # Verify the error message contains the specific error assert "nonexistent_column" in str(excinfo.value) async def test_write_operation(self, mock_postgres_client: PostgresClient): """Test a basic write operation (INSERT).""" # Create insert query insert_query = "INSERT INTO test_write (name) VALUES ('test_value') RETURNING id, name;" insert_statement = ValidatedStatement( query=insert_query, command=SQLQueryCommand.INSERT, category=SQLQueryCategory.DML, risk_level=OperationRiskLevel.MEDIUM, needs_migration=False, object_type="TABLE", schema_name="public", ) insert_validation = QueryValidationResults( statements=[insert_statement], original_query=insert_query, highest_risk_level=OperationRiskLevel.MEDIUM, ) # Mock the query result expected_result = QueryResult(results=[ StatementResult(rows=[{"id": 1, "name": "test_value"}]) ]) with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): # Execute the insert query result = await mock_postgres_client.execute_query(insert_validation, readonly=False) # Verify the result assert isinstance(result, QueryResult) assert len(result.results) == 1 assert result.results[0].rows[0]["name"] == "test_value" assert result.results[0].rows[0]["id"] == 1 async def test_ddl_operation(self, mock_postgres_client: PostgresClient): """Test a basic DDL operation (CREATE TABLE).""" # Create a test table create_query = "CREATE TEMPORARY TABLE test_ddl (id SERIAL PRIMARY KEY, value TEXT);" create_statement = ValidatedStatement( query=create_query, command=SQLQueryCommand.CREATE, category=SQLQueryCategory.DDL, risk_level=OperationRiskLevel.MEDIUM, needs_migration=False, object_type="TABLE", schema_name="public", ) create_validation = QueryValidationResults( statements=[create_statement], original_query=create_query, highest_risk_level=OperationRiskLevel.MEDIUM, ) # Mock the query result - DDL typically returns empty results expected_result = QueryResult(results=[ StatementResult(rows=[]) ]) with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): # Execute the create table query result = await mock_postgres_client.execute_query(create_validation, readonly=False) # Verify the result assert isinstance(result, QueryResult) assert len(result.results) == 1 # DDL operations typically don't return rows assert result.results[0].rows == [] async def test_execute_metadata_query(self, mock_postgres_client: PostgresClient): """Test executing a metadata query.""" # Create a simple validation result with a SELECT query query = "SELECT schema_name FROM information_schema.schemata LIMIT 5;" statement = ValidatedStatement( query=query, command=SQLQueryCommand.SELECT, category=SQLQueryCategory.DQL, risk_level=OperationRiskLevel.LOW, needs_migration=False, object_type="schemata", schema_name="information_schema", ) validation_result = QueryValidationResults( statements=[statement], original_query=query, highest_risk_level=OperationRiskLevel.LOW, ) # Mock the query result expected_result = QueryResult(results=[ StatementResult(rows=[ {"schema_name": "public"}, {"schema_name": "information_schema"}, {"schema_name": "pg_catalog"}, {"schema_name": "auth"}, {"schema_name": "storage"} ]) ]) with patch.object(mock_postgres_client, 'execute_query', return_value=expected_result): # Execute the query result = await mock_postgres_client.execute_query(validation_result) # Verify the result assert isinstance(result, QueryResult) assert len(result.results) == 1 assert len(result.results[0].rows) == 5 assert "schema_name" in result.results[0].rows[0] async def test_connection_retry_mechanism(self, mock_postgres_client: PostgresClient): """Test that the tenacity retry mechanism works correctly for database connections.""" # Reset the pool mock_postgres_client._pool = None # Mock create_pool to always raise a connection error with patch.object(mock_postgres_client, 'create_pool', side_effect=ConnectionError("Could not connect to database")): # This should trigger the retry mechanism and eventually fail with pytest.raises(ConnectionError) as exc_info: await mock_postgres_client.ensure_pool() # Verify the error message indicates a connection failure after retries assert "Could not connect to database" in str(exc_info.value) ``` -------------------------------------------------------------------------------- /supabase_mcp/services/sdk/auth_admin_sdk_spec.py: -------------------------------------------------------------------------------- ```python def get_auth_admin_methods_spec() -> dict: """Returns a detailed specification of all Auth Admin methods.""" return { "get_user_by_id": { "description": "Retrieve a user by their ID", "parameters": {"uid": {"type": "string", "description": "The user's UUID", "required": True}}, "returns": {"type": "object", "description": "User object containing all user data"}, "example": { "request": {"uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"}, "response": { "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "email": "[email protected]", "phone": "", "created_at": "2023-01-01T00:00:00Z", "confirmed_at": "2023-01-01T00:00:00Z", "last_sign_in_at": "2023-01-01T00:00:00Z", "user_metadata": {"name": "John Doe"}, "app_metadata": {}, }, }, }, "list_users": { "description": "List all users with pagination", "parameters": { "page": { "type": "integer", "description": "Page number (starts at 1)", "required": False, "default": 1, }, "per_page": { "type": "integer", "description": "Number of users per page", "required": False, "default": 50, }, }, "returns": {"type": "object", "description": "Paginated list of users with metadata"}, "example": { "request": {"page": 1, "per_page": 10}, "response": { "users": [ { "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "email": "[email protected]", "user_metadata": {"name": "John Doe"}, } ], "aud": "authenticated", "total_count": 100, "next_page": 2, }, }, }, "create_user": { "description": "Create a new user. Does not send a confirmation email by default.", "parameters": { "email": {"type": "string", "description": "The user's email address"}, "password": {"type": "string", "description": "The user's password"}, "email_confirm": { "type": "boolean", "description": "Confirms the user's email address if set to true", "default": False, }, "phone": {"type": "string", "description": "The user's phone number with country code"}, "phone_confirm": { "type": "boolean", "description": "Confirms the user's phone number if set to true", "default": False, }, "user_metadata": { "type": "object", "description": "A custom data object to store the user's metadata", }, "app_metadata": { "type": "object", "description": "A custom data object to store the user's application specific metadata", }, "role": {"type": "string", "description": "The role claim set in the user's access token JWT"}, "ban_duration": {"type": "string", "description": "Determines how long a user is banned for"}, "nonce": { "type": "string", "description": "The nonce (required for reauthentication if updating password)", }, }, "returns": {"type": "object", "description": "Created user object"}, "example": { "request": { "email": "[email protected]", "password": "secure-password", "email_confirm": True, "user_metadata": {"name": "New User"}, }, "response": { "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "email": "[email protected]", "email_confirmed_at": "2023-01-01T00:00:00Z", "user_metadata": {"name": "New User"}, }, }, "notes": "Either email or phone must be provided. Use invite_user_by_email() if you want to send an email invite.", }, "delete_user": { "description": "Delete a user by their ID. Requires a service_role key.", "parameters": { "id": {"type": "string", "description": "The user's UUID", "required": True}, "should_soft_delete": { "type": "boolean", "description": "If true, the user will be soft-deleted (preserving their data but disabling the account). Defaults to false.", "required": False, "default": False, }, }, "returns": {"type": "object", "description": "Success message"}, "example": { "request": {"id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"}, "response": {"message": "User deleted successfully"}, }, "notes": "This function should only be called on a server. Never expose your service_role key in the browser.", }, "invite_user_by_email": { "description": "Sends an invite link to a user's email address. Typically used by administrators to invite users to join the application.", "parameters": { "email": {"type": "string", "description": "The email address of the user", "required": True}, "options": { "type": "object", "description": "Optional settings for the invite", "required": False, "properties": { "data": { "type": "object", "description": "A custom data object to store additional metadata about the user. Maps to auth.users.user_metadata", "required": False, }, "redirect_to": { "type": "string", "description": "The URL which will be appended to the email link. Once clicked the user will end up on this URL", "required": False, }, }, }, }, "returns": {"type": "object", "description": "User object for the invited user"}, "example": { "request": { "email": "[email protected]", "options": {"data": {"name": "John Doe"}, "redirect_to": "https://example.com/welcome"}, }, "response": { "id": "a1a1a1a1-a1a1-a1a1-a1a1-a1a1a1a1a1a1", "email": "[email protected]", "role": "authenticated", "email_confirmed_at": None, "invited_at": "2023-01-01T00:00:00Z", }, }, "notes": "Note that PKCE is not supported when using invite_user_by_email. This is because the browser initiating the invite is often different from the browser accepting the invite.", }, "generate_link": { "description": "Generate an email link for various authentication purposes. Handles user creation for signup, invite and magiclink types.", "parameters": { "type": { "type": "string", "description": "Link type: 'signup', 'invite', 'magiclink', 'recovery', 'email_change_current', 'email_change_new', 'phone_change'", "required": True, "enum": [ "signup", "invite", "magiclink", "recovery", "email_change_current", "email_change_new", "phone_change", ], }, "email": {"type": "string", "description": "User's email address", "required": True}, "password": { "type": "string", "description": "User's password. Only required if type is signup", "required": False, }, "new_email": { "type": "string", "description": "New email address. Only required if type is email_change_current or email_change_new", "required": False, }, "options": { "type": "object", "description": "Additional options for the link", "required": False, "properties": { "data": { "type": "object", "description": "Custom JSON object containing user metadata. Only accepted if type is signup, invite, or magiclink", "required": False, }, "redirect_to": { "type": "string", "description": "A redirect URL which will be appended to the generated email link", "required": False, }, }, }, }, "returns": {"type": "object", "description": "Generated link details"}, "example": { "request": { "type": "signup", "email": "[email protected]", "password": "secure-password", "options": {"data": {"name": "John Doe"}, "redirect_to": "https://example.com/welcome"}, }, "response": { "action_link": "https://your-project.supabase.co/auth/v1/verify?token=...", "email_otp": "123456", "hashed_token": "...", "redirect_to": "https://example.com/welcome", "verification_type": "signup", }, }, "notes": "generate_link() only generates the email link for email_change_email if the Secure email change is enabled in your project's email auth provider settings.", }, "update_user_by_id": { "description": "Update user attributes by ID. Requires a service_role key.", "parameters": { "uid": {"type": "string", "description": "The user's UUID", "required": True}, "attributes": { "type": "object", "description": "The user attributes to update.", "required": True, "properties": { "email": {"type": "string", "description": "The user's email"}, "phone": {"type": "string", "description": "The user's phone"}, "password": {"type": "string", "description": "The user's password"}, "email_confirm": { "type": "boolean", "description": "Confirms the user's email address if set to true", }, "phone_confirm": { "type": "boolean", "description": "Confirms the user's phone number if set to true", }, "user_metadata": { "type": "object", "description": "A custom data object to store the user's metadata.", }, "app_metadata": { "type": "object", "description": "A custom data object to store the user's application specific metadata.", }, "role": { "type": "string", "description": "The role claim set in the user's access token JWT", }, "ban_duration": { "type": "string", "description": "Determines how long a user is banned for", }, "nonce": { "type": "string", "description": "The nonce sent for reauthentication if the user's password is to be updated", }, }, }, }, "returns": {"type": "object", "description": "Updated user object"}, "example": { "request": { "uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "attributes": {"email": "[email protected]", "user_metadata": {"name": "Updated Name"}}, }, "response": { "id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "email": "[email protected]", "user_metadata": {"name": "Updated Name"}, }, }, "notes": "This function should only be called on a server. Never expose your service_role key in the browser.", }, "delete_factor": { "description": "Deletes a factor on a user. This will log the user out of all active sessions if the deleted factor was verified.", "parameters": { "user_id": { "type": "string", "description": "ID of the user whose factor is being deleted", "required": True, }, "id": {"type": "string", "description": "ID of the MFA factor to delete", "required": True}, }, "returns": {"type": "object", "description": "Success message"}, "example": { "request": {"user_id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "id": "totp-factor-id-123"}, "response": {"message": "Factor deleted successfully"}, }, "notes": "This will log the user out of all active sessions if the deleted factor was verified.", }, } ``` -------------------------------------------------------------------------------- /tests/services/sdk/test_sdk_client.py: -------------------------------------------------------------------------------- ```python import time import uuid from datetime import datetime from unittest.mock import AsyncMock, MagicMock, patch import pytest from supabase_mcp.clients.sdk_client import SupabaseSDKClient from supabase_mcp.exceptions import PythonSDKError from supabase_mcp.settings import Settings # Unique identifier for test users to avoid conflicts TEST_ID = f"test-{int(time.time())}-{uuid.uuid4().hex[:6]}" # Create unique test emails def get_test_email(prefix: str = "user"): """Generate a unique test email""" return f"a.zuev+{prefix}-{TEST_ID}@outlook.com" @pytest.mark.asyncio(loop_scope="module") class TestSDKClientIntegration: """ Unit tests for the SupabaseSDKClient. """ @pytest.fixture def mock_settings(self): """Create mock settings for testing.""" settings = MagicMock(spec=Settings) settings.supabase_project_ref = "test-project-ref" settings.supabase_service_role_key = "test-service-role-key" settings.supabase_region = "us-east-1" settings.supabase_url = "https://test-project-ref.supabase.co" return settings @pytest.fixture async def mock_sdk_client(self, mock_settings): """Create a mock SDK client for testing.""" # Reset singleton SupabaseSDKClient.reset() # Mock the Supabase client mock_supabase = MagicMock() mock_auth_admin = MagicMock() mock_supabase.auth.admin = mock_auth_admin # Mock the create_async_client function to return our mock client with patch('supabase_mcp.clients.sdk_client.create_async_client', return_value=mock_supabase): # Create client - this will now use our mocked create_async_client client = SupabaseSDKClient.get_instance(settings=mock_settings) # Manually set the client to ensure it's available client.client = mock_supabase return client async def test_list_users(self, mock_sdk_client: SupabaseSDKClient): """Test listing users with pagination""" # Mock user data mock_users = [ MagicMock(id="user1", email="[email protected]", user_metadata={}), MagicMock(id="user2", email="[email protected]", user_metadata={}) ] # Mock the list_users method as an async function mock_sdk_client.client.auth.admin.list_users = AsyncMock(return_value=mock_users) # Create test parameters list_params = {"page": 1, "per_page": 10} # List users result = await mock_sdk_client.call_auth_admin_method("list_users", list_params) # Verify response format assert result is not None assert hasattr(result, "__iter__") # Should be iterable (list of users) assert len(result) == 2 # Check that the first user has expected attributes first_user = result[0] assert hasattr(first_user, "id") assert hasattr(first_user, "email") assert hasattr(first_user, "user_metadata") # Test with invalid parameters - mock the validation error mock_sdk_client.client.auth.admin.list_users = AsyncMock(side_effect=Exception("Bad Pagination Parameters")) invalid_params = {"page": -1, "per_page": 10} with pytest.raises(PythonSDKError) as excinfo: await mock_sdk_client.call_auth_admin_method("list_users", invalid_params) assert "Bad Pagination Parameters" in str(excinfo.value) async def test_get_user_by_id(self, mock_sdk_client: SupabaseSDKClient): """Test retrieving a user by ID""" # Mock user data test_email = get_test_email("get") user_id = str(uuid.uuid4()) mock_user = MagicMock( id=user_id, email=test_email, user_metadata={"name": "Test User", "test_id": TEST_ID} ) mock_response = MagicMock(user=mock_user) # Mock the get_user_by_id method as an async function mock_sdk_client.client.auth.admin.get_user_by_id = AsyncMock(return_value=mock_response) # Get the user by ID get_params = {"uid": user_id} get_result = await mock_sdk_client.call_auth_admin_method("get_user_by_id", get_params) # Verify user data assert get_result is not None assert hasattr(get_result, "user") assert get_result.user.id == user_id assert get_result.user.email == test_email # Test with invalid parameters (non-existent user ID) mock_sdk_client.client.auth.admin.get_user_by_id = AsyncMock(side_effect=Exception("user_id must be an UUID")) invalid_params = {"uid": "non-existent-user-id"} with pytest.raises(PythonSDKError) as excinfo: await mock_sdk_client.call_auth_admin_method("get_user_by_id", invalid_params) assert "user_id must be an UUID" in str(excinfo.value) async def test_create_user(self, mock_sdk_client: SupabaseSDKClient): """Test creating a new user""" # Create a new test user test_email = get_test_email("create") user_id = str(uuid.uuid4()) mock_user = MagicMock( id=user_id, email=test_email, user_metadata={"name": "Test User", "test_id": TEST_ID} ) mock_response = MagicMock(user=mock_user) # Mock the create_user method as an async function mock_sdk_client.client.auth.admin.create_user = AsyncMock(return_value=mock_response) create_params = { "email": test_email, "password": f"Password123!{TEST_ID}", "email_confirm": True, "user_metadata": {"name": "Test User", "test_id": TEST_ID}, } # Create the user create_result = await mock_sdk_client.call_auth_admin_method("create_user", create_params) assert create_result is not None assert hasattr(create_result, "user") assert hasattr(create_result.user, "id") assert create_result.user.id == user_id # Test with invalid parameters (missing required fields) mock_sdk_client.client.auth.admin.create_user = AsyncMock(side_effect=Exception("Invalid parameters")) invalid_params = {"user_metadata": {"name": "Invalid User"}} with pytest.raises(PythonSDKError) as excinfo: await mock_sdk_client.call_auth_admin_method("create_user", invalid_params) assert "Invalid parameters" in str(excinfo.value) async def test_update_user_by_id(self, mock_sdk_client: SupabaseSDKClient): """Test updating a user's attributes""" # Mock user data test_email = get_test_email("update") user_id = str(uuid.uuid4()) mock_user = MagicMock( id=user_id, email=test_email, user_metadata={"email": "[email protected]"} ) mock_response = MagicMock(user=mock_user) # Mock the update_user_by_id method as an async function mock_sdk_client.client.auth.admin.update_user_by_id = AsyncMock(return_value=mock_response) # Update the user update_params = { "uid": user_id, "attributes": { "user_metadata": { "email": "[email protected]", } }, } update_result = await mock_sdk_client.call_auth_admin_method("update_user_by_id", update_params) # Verify user was updated assert update_result is not None assert hasattr(update_result, "user") assert update_result.user.id == user_id assert update_result.user.user_metadata["email"] == "[email protected]" # Test with invalid parameters (non-existent user ID) mock_sdk_client.client.auth.admin.update_user_by_id = AsyncMock(side_effect=Exception("user_id must be an uuid")) invalid_params = { "uid": "non-existent-user-id", "attributes": {"user_metadata": {"name": "Invalid Update"}}, } with pytest.raises(PythonSDKError) as excinfo: await mock_sdk_client.call_auth_admin_method("update_user_by_id", invalid_params) assert "user_id must be an uuid" in str(excinfo.value).lower() async def test_delete_user(self, mock_sdk_client: SupabaseSDKClient): """Test deleting a user""" # Mock user data user_id = str(uuid.uuid4()) # Mock the delete_user method as an async function to return None (success) mock_sdk_client.client.auth.admin.delete_user = AsyncMock(return_value=None) # Delete the user delete_params = {"id": user_id} # The delete_user method returns None on success result = await mock_sdk_client.call_auth_admin_method("delete_user", delete_params) assert result is None # Test with invalid parameters (non-UUID format user ID) mock_sdk_client.client.auth.admin.delete_user = AsyncMock(side_effect=Exception("user_id must be an uuid")) invalid_params = {"id": "non-existent-user-id"} with pytest.raises(PythonSDKError) as excinfo: await mock_sdk_client.call_auth_admin_method("delete_user", invalid_params) assert "user_id must be an uuid" in str(excinfo.value).lower() async def test_invite_user_by_email(self, mock_sdk_client: SupabaseSDKClient): """Test inviting a user by email""" # Mock user data test_email = get_test_email("invite") user_id = str(uuid.uuid4()) mock_user = MagicMock( id=user_id, email=test_email, invited_at=datetime.now().isoformat() ) mock_response = MagicMock(user=mock_user) # Mock the invite_user_by_email method as an async function mock_sdk_client.client.auth.admin.invite_user_by_email = AsyncMock(return_value=mock_response) # Create invite parameters invite_params = { "email": test_email, "options": {"data": {"name": "Invited User", "test_id": TEST_ID, "invited_at": datetime.now().isoformat()}}, } # Invite the user result = await mock_sdk_client.call_auth_admin_method("invite_user_by_email", invite_params) # Verify response assert result is not None assert hasattr(result, "user") assert result.user.email == test_email assert hasattr(result.user, "invited_at") # Test with invalid parameters (missing email) mock_sdk_client.client.auth.admin.invite_user_by_email = AsyncMock(side_effect=Exception("Invalid parameters")) invalid_params = {"options": {"data": {"name": "Invalid Invite"}}} with pytest.raises(PythonSDKError) as excinfo: await mock_sdk_client.call_auth_admin_method("invite_user_by_email", invalid_params) assert "Invalid parameters" in str(excinfo.value) async def test_generate_link(self, mock_sdk_client: SupabaseSDKClient): """Test generating authentication links""" # Mock response for generate_link mock_properties = MagicMock(action_link="https://example.com/auth/link") mock_response = MagicMock(properties=mock_properties) # Mock the generate_link method as an async function mock_sdk_client.client.auth.admin.generate_link = AsyncMock(return_value=mock_response) # Test signup link link_params = { "type": "signup", "email": get_test_email("signup"), "password": f"Password123!{TEST_ID}", "options": { "data": {"name": "Signup User", "test_id": TEST_ID}, "redirect_to": "https://example.com/welcome", }, } # Generate link result = await mock_sdk_client.call_auth_admin_method("generate_link", link_params) # Verify response assert result is not None assert hasattr(result, "properties") assert hasattr(result.properties, "action_link") # Test with invalid parameters (invalid link type) mock_sdk_client.client.auth.admin.generate_link = AsyncMock(side_effect=Exception("Invalid parameters")) invalid_params = {"type": "invalid_type", "email": get_test_email("invalid")} with pytest.raises(PythonSDKError) as excinfo: await mock_sdk_client.call_auth_admin_method("generate_link", invalid_params) assert "Invalid parameters" in str(excinfo.value) or "invalid type" in str(excinfo.value).lower() async def test_delete_factor(self, mock_sdk_client: SupabaseSDKClient): """Test deleting an MFA factor""" # Mock the delete_factor method as an async function to raise not implemented mock_sdk_client.client.auth.admin.delete_factor = AsyncMock(side_effect=AttributeError("method not found")) # Attempt to delete a factor delete_factor_params = {"user_id": str(uuid.uuid4()), "id": "non-existent-factor-id"} with pytest.raises(PythonSDKError) as excinfo: await mock_sdk_client.call_auth_admin_method("delete_factor", delete_factor_params) # We expect this to fail with a specific error message assert "not implemented" in str(excinfo.value).lower() or "method not found" in str(excinfo.value).lower() async def test_empty_parameters(self, mock_sdk_client: SupabaseSDKClient): """Test validation errors with empty parameters for various methods""" # Test methods with empty parameters methods = ["get_user_by_id", "create_user", "update_user_by_id", "delete_user", "generate_link"] for method in methods: empty_params = {} # Mock the method to raise validation error setattr(mock_sdk_client.client.auth.admin, method, AsyncMock(side_effect=Exception("Invalid parameters"))) # Should raise PythonSDKError containing validation error details with pytest.raises(PythonSDKError) as excinfo: await mock_sdk_client.call_auth_admin_method(method, empty_params) # Verify error message contains validation details assert "Invalid parameters" in str(excinfo.value) or "validation error" in str(excinfo.value).lower() async def test_client_without_service_role_key(self, mock_settings): """Test that an exception is raised when attempting to use the SDK client without a service role key.""" # Create settings without service role key mock_settings.supabase_service_role_key = None # Reset singleton SupabaseSDKClient.reset() # Create client client = SupabaseSDKClient.get_instance(settings=mock_settings) # Attempt to call a method - should raise an exception with pytest.raises(PythonSDKError) as excinfo: await client.call_auth_admin_method("list_users", {}) assert "service role key is not configured" in str(excinfo.value) ``` -------------------------------------------------------------------------------- /supabase_mcp/services/database/postgres_client.py: -------------------------------------------------------------------------------- ```python from __future__ import annotations import urllib.parse from collections.abc import Awaitable, Callable from typing import Any, TypeVar import asyncpg from pydantic import BaseModel, Field from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential from supabase_mcp.exceptions import ConnectionError, PermissionError, QueryError from supabase_mcp.logger import logger from supabase_mcp.services.database.sql.models import QueryValidationResults from supabase_mcp.services.database.sql.validator import SQLValidator from supabase_mcp.settings import Settings # Define a type variable for generic return types T = TypeVar("T") # TODO: Use a context manager to properly handle the connection pool class StatementResult(BaseModel): """Represents the result of a single SQL statement.""" rows: list[dict[str, Any]] = Field( default_factory=list, description="List of rows returned by the statement. Is empty if the statement is a DDL statement.", ) class QueryResult(BaseModel): """Represents results of query execution, consisting of one or more statements.""" results: list[StatementResult] = Field( description="List of results from the statements in the query.", ) # Helper function for retry decorator to safely log exceptions def log_db_retry_attempt(retry_state: RetryCallState) -> None: """Log database retry attempts. Args: retry_state: Current retry state from tenacity """ if retry_state.outcome is not None and retry_state.outcome.failed: exception = retry_state.outcome.exception() exception_str = str(exception) logger.warning(f"Database error, retrying ({retry_state.attempt_number}/3): {exception_str}") # Add the new AsyncSupabaseClient class class PostgresClient: """Asynchronous client for interacting with Supabase PostgreSQL database.""" _instance: PostgresClient | None = None # Singleton instance def __init__( self, settings: Settings, project_ref: str | None = None, db_password: str | None = None, db_region: str | None = None, ): """Initialize client configuration (but don't connect yet). Args: settings_instance: Settings instance to use for configuration. project_ref: Optional Supabase project reference. If not provided, will be taken from settings. db_password: Optional database password. If not provided, will be taken from settings. db_region: Optional database region. If not provided, will be taken from settings. """ self._pool: asyncpg.Pool[asyncpg.Record] | None = None self._settings = settings self.project_ref = project_ref or self._settings.supabase_project_ref self.db_password = db_password or self._settings.supabase_db_password self.db_region = db_region or self._settings.supabase_region self.db_url = self._build_connection_string() self.sql_validator: SQLValidator = SQLValidator() # Only log once during initialization with clear project info is_local = self.project_ref.startswith("127.0.0.1") logger.info( f"✔️ PostgreSQL client initialized successfully for {'local' if is_local else 'remote'} " f"project: {self.project_ref} (region: {self.db_region})" ) @classmethod def get_instance( cls, settings: Settings, project_ref: str | None = None, db_password: str | None = None, ) -> PostgresClient: """Create and return a configured AsyncSupabaseClient instance. This is the recommended way to create a client instance. Args: settings_instance: Settings instance to use for configuration project_ref: Optional Supabase project reference db_password: Optional database password Returns: Configured AsyncSupabaseClient instance """ if cls._instance is None: cls._instance = cls( settings=settings, project_ref=project_ref, db_password=db_password, ) # Doesn't connect yet - will connect lazily when needed return cls._instance def _build_connection_string(self) -> str: """Build the database connection string for asyncpg. Returns: PostgreSQL connection string compatible with asyncpg """ encoded_password = urllib.parse.quote_plus(self.db_password) if self.project_ref.startswith("127.0.0.1"): # Local development connection_string = f"postgresql://postgres:{encoded_password}@{self.project_ref}/postgres" return connection_string # Production Supabase - via transaction pooler connection_string = ( f"postgresql://postgres.{self.project_ref}:{encoded_password}" f"@aws-0-{self._settings.supabase_region}.pooler.supabase.com:6543/postgres" ) return connection_string @retry( retry=retry_if_exception_type( ( asyncpg.exceptions.ConnectionDoesNotExistError, # Connection lost asyncpg.exceptions.InterfaceError, # Connection disruption asyncpg.exceptions.TooManyConnectionsError, # Temporary connection limit OSError, # Network issues ) ), stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), before_sleep=log_db_retry_attempt, ) async def create_pool(self) -> asyncpg.Pool[asyncpg.Record]: """Create and configure a database connection pool. Returns: Configured asyncpg connection pool Raises: ConnectionError: If unable to establish a connection to the database """ try: logger.debug(f"Creating connection pool for project: {self.project_ref}") # Create the pool with optimal settings pool = await asyncpg.create_pool( self.db_url, min_size=2, # Minimum connections to keep ready max_size=10, # Maximum connections allowed (same as current) statement_cache_size=0, command_timeout=30.0, # Command timeout in seconds max_inactive_connection_lifetime=300.0, # 5 minutes ) # Test the connection with a simple query async with pool.acquire() as conn: await conn.execute("SELECT 1") logger.info("✓ Database connection established successfully") return pool except asyncpg.PostgresError as e: # Extract connection details for better error reporting host_part = self.db_url.split("@")[1].split("/")[0] if "@" in self.db_url else "unknown" # Check specifically for the "Tenant or user not found" error which is often caused by region mismatch if "Tenant or user not found" in str(e): error_message = ( "CONNECTION ERROR: Region mismatch detected!\n\n" f"Could not connect to Supabase project '{self.project_ref}'.\n\n" "This error typically occurs when your SUPABASE_REGION setting doesn't match your project's actual region.\n" f"Your configuration is using region: '{self.db_region}' (default: us-east-1)\n\n" "ACTION REQUIRED: Please set the correct SUPABASE_REGION in your MCP server configuration.\n" "You can find your project's region in the Supabase dashboard under Project Settings." ) else: error_message = ( f"Could not connect to database: {e}\n" f"Connection attempted to: {host_part}\n via Transaction Pooler\n" f"Project ref: {self.project_ref}\n" f"Region: {self.db_region}\n\n" f"Please check:\n" f"1. Your Supabase project reference is correct\n" f"2. Your database password is correct\n" f"3. Your region setting matches your Supabase project region\n" f"4. Your Supabase project is active and the database is online\n" ) logger.error(f"Failed to connect to database: {e}") logger.error(f"Connection details: {host_part}, Project: {self.project_ref}, Region: {self.db_region}") raise ConnectionError(error_message) from e except OSError as e: # For network-related errors, provide a different message that clearly indicates # this is a network/system issue rather than a database configuration problem host_part = self.db_url.split("@")[1].split("/")[0] if "@" in self.db_url else "unknown" error_message = ( f"Network error while connecting to database: {e}\n" f"Connection attempted to: {host_part}\n\n" f"This appears to be a network or system issue rather than a database configuration problem.\n" f"Please check:\n" f"1. Your internet connection is working\n" f"2. Any firewalls or network security settings allow connections to {host_part}\n" f"3. DNS resolution is working correctly\n" f"4. The Supabase service is not experiencing an outage\n" ) logger.error(f"Network error connecting to database: {e}") logger.error(f"Connection details: {host_part}") raise ConnectionError(error_message) from e async def ensure_pool(self) -> None: """Ensure a valid connection pool exists. This method is called before executing queries to make sure we have an active connection pool. """ if self._pool is None: logger.debug("No active connection pool, creating one") self._pool = await self.create_pool() else: logger.debug("Using existing connection pool") async def close(self) -> None: """Close the connection pool and release all resources. This should be called when shutting down the application. """ import asyncio if self._pool: await asyncio.wait_for(self._pool.close(), timeout=5.0) self._pool = None else: logger.debug("No PostgreSQL connection pool to close") @classmethod async def reset(cls) -> None: """Reset the singleton instance cleanly. This closes any open connections and resets the singleton instance. """ if cls._instance is not None: await cls._instance.close() cls._instance = None logger.info("AsyncSupabaseClient instance reset complete") async def with_connection(self, operation_func: Callable[[asyncpg.Connection[Any]], Awaitable[T]]) -> T: """Execute an operation with a database connection. Args: operation_func: Async function that takes a connection and returns a result Returns: The result of the operation function Raises: ConnectionError: If a database connection issue occurs """ # Ensure we have an active connection pool await self.ensure_pool() # Acquire a connection from the pool and execute the operation async with self._pool.acquire() as conn: return await operation_func(conn) async def with_transaction( self, conn: asyncpg.Connection[Any], operation_func: Callable[[], Awaitable[T]], readonly: bool = False ) -> T: """Execute an operation within a transaction. Args: conn: Database connection operation_func: Async function that executes within the transaction readonly: Whether the transaction is read-only Returns: The result of the operation function Raises: QueryError: If the query execution fails """ # Execute the operation within a transaction async with conn.transaction(readonly=readonly): return await operation_func() async def execute_statement(self, conn: asyncpg.Connection[Any], query: str) -> StatementResult: """Execute a single SQL statement. Args: conn: Database connection query: SQL query to execute Returns: StatementResult containing the rows returned by the statement Raises: QueryError: If the statement execution fails """ try: # Execute the query result = await conn.fetch(query) # Convert records to dictionaries rows = [dict(record) for record in result] # Log success logger.debug(f"Statement executed successfully, rows: {len(rows)}") # Return the result return StatementResult(rows=rows) except asyncpg.PostgresError as e: await self._handle_postgres_error(e) @retry( retry=retry_if_exception_type( ( asyncpg.exceptions.ConnectionDoesNotExistError, # Connection lost asyncpg.exceptions.InterfaceError, # Connection disruption asyncpg.exceptions.TooManyConnectionsError, # Temporary connection limit OSError, # Network issues ) ), stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), before_sleep=log_db_retry_attempt, ) async def execute_query( self, validated_query: QueryValidationResults, readonly: bool = True, # Default to read-only for safety ) -> QueryResult: """Execute a SQL query asynchronously with proper transaction management. Args: validated_query: Validated query containing statements to execute readonly: Whether to execute in read-only mode Returns: QueryResult containing the results of all statements Raises: ConnectionError: If a database connection issue occurs QueryError: If the query execution fails PermissionError: When user lacks required privileges """ # Log query execution (truncate long queries for readability) truncated_query = ( validated_query.original_query[:100] + "..." if len(validated_query.original_query) > 100 else validated_query.original_query ) logger.debug(f"Executing query (readonly={readonly}): {truncated_query}") # Define the operation to execute all statements within a transaction async def execute_all_statements(conn): async def transaction_operation(): results = [] for statement in validated_query.statements: if statement.query: # Skip statements with no query result = await self.execute_statement(conn, statement.query) results.append(result) else: logger.warning(f"Statement has no query, statement: {statement}") return results # Execute the operation within a transaction results = await self.with_transaction(conn, transaction_operation, readonly) return QueryResult(results=results) # Execute the operation with a connection return await self.with_connection(execute_all_statements) async def _handle_postgres_error(self, error: asyncpg.PostgresError) -> None: """Handle PostgreSQL errors and convert to appropriate exceptions. Args: error: PostgreSQL error Raises: PermissionError: When user lacks required privileges QueryError: For schema errors or general query errors """ if isinstance(error, asyncpg.exceptions.InsufficientPrivilegeError): logger.error(f"Permission denied: {error}") raise PermissionError( f"Access denied: {str(error)}. Use live_dangerously('database', True) for write operations." ) from error elif isinstance( error, ( asyncpg.exceptions.UndefinedTableError, asyncpg.exceptions.UndefinedColumnError, ), ): logger.error(f"Schema error: {error}") raise QueryError(str(error)) from error else: logger.error(f"Database error: {error}") raise QueryError(f"Query execution failed: {str(error)}") from error ```