This is page 2 of 6. Use http://codebase.md/alexander-zuev/supabase-mcp-server?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .claude │ └── settings.local.json ├── .dockerignore ├── .env.example ├── .env.test.example ├── .github │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.md │ │ ├── feature_request.md │ │ └── roadmap_item.md │ ├── PULL_REQUEST_TEMPLATE.md │ └── workflows │ ├── ci.yaml │ ├── docs │ │ └── release-checklist.md │ └── publish.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── CHANGELOG.MD ├── codecov.yml ├── CONTRIBUTING.MD ├── Dockerfile ├── LICENSE ├── llms-full.txt ├── pyproject.toml ├── README.md ├── smithery.yaml ├── supabase_mcp │ ├── __init__.py │ ├── clients │ │ ├── api_client.py │ │ ├── base_http_client.py │ │ ├── management_client.py │ │ └── sdk_client.py │ ├── core │ │ ├── __init__.py │ │ ├── container.py │ │ └── feature_manager.py │ ├── exceptions.py │ ├── logger.py │ ├── main.py │ ├── services │ │ ├── __init__.py │ │ ├── api │ │ │ ├── __init__.py │ │ │ ├── api_manager.py │ │ │ ├── spec_manager.py │ │ │ └── specs │ │ │ └── api_spec.json │ │ ├── database │ │ │ ├── __init__.py │ │ │ ├── migration_manager.py │ │ │ ├── postgres_client.py │ │ │ ├── query_manager.py │ │ │ └── sql │ │ │ ├── loader.py │ │ │ ├── models.py │ │ │ ├── queries │ │ │ │ ├── create_migration.sql │ │ │ │ ├── get_migrations.sql │ │ │ │ ├── get_schemas.sql │ │ │ │ ├── get_table_schema.sql │ │ │ │ ├── get_tables.sql │ │ │ │ ├── init_migrations.sql │ │ │ │ └── logs │ │ │ │ ├── auth_logs.sql │ │ │ │ ├── cron_logs.sql │ │ │ │ ├── edge_logs.sql │ │ │ │ ├── function_edge_logs.sql │ │ │ │ ├── pgbouncer_logs.sql │ │ │ │ ├── postgres_logs.sql │ │ │ │ ├── postgrest_logs.sql │ │ │ │ ├── realtime_logs.sql │ │ │ │ ├── storage_logs.sql │ │ │ │ └── supavisor_logs.sql │ │ │ └── validator.py │ │ ├── logs │ │ │ ├── __init__.py │ │ │ └── log_manager.py │ │ ├── safety │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ ├── safety_configs.py │ │ │ └── safety_manager.py │ │ └── sdk │ │ ├── __init__.py │ │ ├── auth_admin_models.py │ │ └── auth_admin_sdk_spec.py │ ├── settings.py │ └── tools │ ├── __init__.py │ ├── descriptions │ │ ├── api_tools.yaml │ │ ├── database_tools.yaml │ │ ├── logs_and_analytics_tools.yaml │ │ ├── safety_tools.yaml │ │ └── sdk_tools.yaml │ ├── manager.py │ └── registry.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── services │ │ ├── __init__.py │ │ ├── api │ │ │ ├── __init__.py │ │ │ ├── test_api_client.py │ │ │ ├── test_api_manager.py │ │ │ └── test_spec_manager.py │ │ ├── database │ │ │ ├── sql │ │ │ │ ├── __init__.py │ │ │ │ ├── conftest.py │ │ │ │ ├── test_loader.py │ │ │ │ ├── test_sql_validator_integration.py │ │ │ │ └── test_sql_validator.py │ │ │ ├── test_migration_manager.py │ │ │ ├── test_postgres_client.py │ │ │ └── test_query_manager.py │ │ ├── logs │ │ │ └── test_log_manager.py │ │ ├── safety │ │ │ ├── test_api_safety_config.py │ │ │ ├── test_safety_manager.py │ │ │ └── test_sql_safety_config.py │ │ └── sdk │ │ ├── test_auth_admin_models.py │ │ └── test_sdk_client.py │ ├── test_container.py │ ├── test_main.py │ ├── test_settings.py │ ├── test_tool_manager.py │ ├── test_tools_integration.py.bak │ └── test_tools.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /tests/services/api/test_api_manager.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any 2 | from unittest.mock import MagicMock, patch 3 | 4 | import pytest 5 | 6 | from supabase_mcp.exceptions import SafetyError 7 | from supabase_mcp.services.api.api_manager import SupabaseApiManager 8 | from supabase_mcp.services.safety.models import ClientType 9 | 10 | 11 | class TestApiManager: 12 | """Tests for the API Manager.""" 13 | 14 | @pytest.mark.unit 15 | def test_path_parameter_replacement(self, mock_api_manager: SupabaseApiManager): 16 | """ 17 | Test that path parameters are correctly replaced in API paths. 18 | 19 | This test verifies that the API Manager correctly replaces path placeholders 20 | with actual values, handling both required and optional parameters. 21 | """ 22 | # Use the mock_api_manager fixture instead of creating one manually 23 | api_manager = mock_api_manager 24 | 25 | # Test with a simple path and required parameters (avoiding 'ref' which is auto-injected) 26 | path = "/v1/organizations/{slug}/members" 27 | path_params = {"slug": "example-org"} 28 | 29 | result = api_manager.replace_path_params(path, path_params) 30 | expected = "/v1/organizations/example-org/members" 31 | assert result == expected, f"Expected {expected}, got {result}" 32 | 33 | # Test with missing required parameters 34 | path = "/v1/organizations/{slug}/members/{id}" 35 | path_params = {"slug": "example-org"} 36 | 37 | with pytest.raises(ValueError) as excinfo: 38 | api_manager.replace_path_params(path, path_params) 39 | assert "Missing path parameters" in str(excinfo.value) 40 | 41 | # Test with extra parameters (should be ignored) 42 | path = "/v1/organizations/{slug}" 43 | path_params = {"slug": "example-org", "extra": "should-be-ignored"} 44 | 45 | with pytest.raises(ValueError) as excinfo: 46 | api_manager.replace_path_params(path, path_params) 47 | assert "Unknown path parameter" in str(excinfo.value) 48 | 49 | # Test with no parameters 50 | path = "/v1/organizations" 51 | result = api_manager.replace_path_params(path) 52 | expected = "/v1/organizations" 53 | assert result == expected, f"Expected {expected}, got {result}" 54 | 55 | @pytest.mark.asyncio 56 | @pytest.mark.unit 57 | @patch("supabase_mcp.services.api.api_manager.logger") 58 | async def test_safety_validation(self, mock_logger: MagicMock, mock_api_manager: SupabaseApiManager): 59 | """ 60 | Test that API operations are properly validated through the safety manager. 61 | 62 | This test verifies that the API Manager correctly validates operations 63 | before executing them, and handles safety errors appropriately. 64 | """ 65 | # Use the mock_api_manager fixture instead of creating one manually 66 | api_manager = mock_api_manager 67 | 68 | # Mock the replace_path_params method to return the path unchanged 69 | api_manager.replace_path_params = MagicMock(return_value="/v1/organizations/example-org") 70 | 71 | # Mock the client's execute_request method to return a simple response 72 | mock_response = {"success": True} 73 | api_manager.client.execute_request = MagicMock() 74 | api_manager.client.execute_request.return_value = mock_response 75 | 76 | # Make the mock awaitable 77 | async def mock_execute_request(*args: Any, **kwargs: Any) -> dict[str, Any]: 78 | return mock_response 79 | 80 | api_manager.client.execute_request = mock_execute_request 81 | 82 | # Test a successful operation 83 | method = "GET" 84 | path = "/v1/organizations/{slug}" 85 | path_params = {"slug": "example-org"} 86 | 87 | result = await api_manager.execute_request(method, path, path_params) 88 | 89 | # Verify that the safety manager was called with the correct parameters 90 | api_manager.safety_manager.validate_operation.assert_called_once_with( 91 | ClientType.API, (method, path, path_params, None, None), has_confirmation=False 92 | ) 93 | 94 | # Verify that the result is what we expected 95 | assert result == {"success": True} 96 | 97 | # Test an operation that fails safety validation 98 | api_manager.safety_manager.validate_operation.reset_mock() 99 | 100 | # Make the safety manager raise a SafetyError 101 | def raise_safety_error(*args: Any, **kwargs: Any) -> None: 102 | raise SafetyError("Operation not allowed") 103 | 104 | api_manager.safety_manager.validate_operation.side_effect = raise_safety_error 105 | 106 | # The execute_request method should raise the SafetyError 107 | with pytest.raises(SafetyError) as excinfo: 108 | await api_manager.execute_request("DELETE", "/v1/organizations/{slug}", {"slug": "example-org"}) 109 | 110 | assert "Operation not allowed" in str(excinfo.value) 111 | 112 | @pytest.mark.asyncio 113 | @pytest.mark.unit 114 | async def test_retrieve_logs_basic(self, mock_api_manager: SupabaseApiManager): 115 | """ 116 | Test that the retrieve_logs method correctly builds and executes a logs query. 117 | 118 | This test verifies that the API Manager correctly builds a logs query using 119 | the LogManager and executes it through the Management API. 120 | """ 121 | # Mock the log_manager's build_logs_query method 122 | mock_api_manager.log_manager.build_logs_query = MagicMock(return_value="SELECT * FROM postgres_logs LIMIT 10") 123 | 124 | # Mock the execute_request method to return a simple response 125 | mock_response = {"result": [{"id": "123", "event_message": "test"}]} 126 | 127 | async def mock_execute_request(*args: Any, **kwargs: Any) -> dict[str, Any]: 128 | return mock_response 129 | 130 | mock_api_manager.execute_request = mock_execute_request 131 | 132 | # Execute the method 133 | result = await mock_api_manager.retrieve_logs( 134 | collection="postgres", 135 | limit=10, 136 | hours_ago=24, 137 | ) 138 | 139 | # Verify that the log_manager was called with the correct parameters 140 | mock_api_manager.log_manager.build_logs_query.assert_called_once_with( 141 | collection="postgres", 142 | limit=10, 143 | hours_ago=24, 144 | filters=None, 145 | search=None, 146 | custom_query=None, 147 | ) 148 | 149 | # Verify that the result is what we expected 150 | assert result == {"result": [{"id": "123", "event_message": "test"}]} 151 | 152 | @pytest.mark.asyncio 153 | @pytest.mark.unit 154 | async def test_retrieve_logs_error_handling(self, mock_api_manager: SupabaseApiManager): 155 | """ 156 | Test that the retrieve_logs method correctly handles errors. 157 | 158 | This test verifies that the API Manager correctly handles errors that occur 159 | during log retrieval and propagates them to the caller. 160 | """ 161 | # Mock the log_manager's build_logs_query method 162 | mock_api_manager.log_manager.build_logs_query = MagicMock(return_value="SELECT * FROM postgres_logs LIMIT 10") 163 | 164 | # Mock the execute_request method to raise an exception 165 | async def mock_execute_request_error(*args: Any, **kwargs: Any) -> dict[str, Any]: 166 | raise Exception("API error") 167 | 168 | mock_api_manager.execute_request = mock_execute_request_error 169 | 170 | # The retrieve_logs method should propagate the exception 171 | with pytest.raises(Exception) as excinfo: 172 | await mock_api_manager.retrieve_logs(collection="postgres") 173 | 174 | assert "API error" in str(excinfo.value) 175 | ``` -------------------------------------------------------------------------------- /tests/test_tool_manager.py: -------------------------------------------------------------------------------- ```python 1 | from unittest.mock import MagicMock, mock_open, patch 2 | 3 | from supabase_mcp.tools.manager import ToolManager, ToolName 4 | 5 | 6 | class TestToolManager: 7 | """Tests for the ToolManager class.""" 8 | 9 | def test_singleton_pattern(self): 10 | """Test that ToolManager follows the singleton pattern.""" 11 | # Get two instances 12 | manager1 = ToolManager.get_instance() 13 | manager2 = ToolManager.get_instance() 14 | 15 | # They should be the same object 16 | assert manager1 is manager2 17 | 18 | # Reset the singleton for other tests 19 | # pylint: disable=protected-access 20 | # We need to reset the singleton for test isolation 21 | ToolManager._instance = None # type: ignore 22 | 23 | @patch("supabase_mcp.tools.manager.Path") 24 | @patch("supabase_mcp.tools.manager.yaml.safe_load") 25 | def test_load_descriptions(self, mock_yaml_load: MagicMock, mock_path: MagicMock): 26 | """Test that descriptions are loaded correctly from YAML files.""" 27 | # Setup mock directory structure 28 | mock_file_path = MagicMock() 29 | mock_dir = MagicMock() 30 | 31 | # Mock the Path(__file__) call 32 | mock_path.return_value = mock_file_path 33 | mock_file_path.parent = mock_dir 34 | mock_dir.__truediv__.return_value = mock_dir # For the / operator 35 | 36 | # Mock directory existence check 37 | mock_dir.exists.return_value = True 38 | 39 | # Mock the glob to return some YAML files 40 | mock_file1 = MagicMock() 41 | mock_file1.name = "database_tools.yaml" 42 | mock_file2 = MagicMock() 43 | mock_file2.name = "api_tools.yaml" 44 | mock_dir.glob.return_value = [mock_file1, mock_file2] 45 | 46 | # Mock the file open and YAML load 47 | mock_yaml_data = {"get_schemas": "Description for get_schemas", "get_tables": "Description for get_tables"} 48 | mock_yaml_load.return_value = mock_yaml_data 49 | 50 | # Create a new instance to trigger _load_descriptions 51 | with patch("builtins.open", mock_open(read_data="dummy yaml content")): 52 | # We need to create the manager to trigger _load_descriptions 53 | ToolManager() 54 | 55 | # Verify the descriptions were loaded 56 | assert mock_dir.glob.call_count > 0 57 | assert mock_dir.glob.call_args[0][0] == "*.yaml" 58 | assert mock_yaml_load.call_count >= 1 59 | 60 | # Reset the singleton for other tests 61 | # pylint: disable=protected-access 62 | ToolManager._instance = None # type: ignore 63 | 64 | def test_get_description_valid_tool(self): 65 | """Test getting a description for a valid tool.""" 66 | # Setup 67 | manager = ToolManager.get_instance() 68 | 69 | # Force the descriptions to have a known value for testing 70 | # pylint: disable=protected-access 71 | # We need to set the descriptions directly for testing 72 | manager.descriptions = { 73 | ToolName.GET_SCHEMAS.value: "Description for get_schemas", 74 | ToolName.GET_TABLES.value: "Description for get_tables", 75 | } 76 | 77 | # Test 78 | description = manager.get_description(ToolName.GET_SCHEMAS.value) 79 | 80 | # Verify 81 | assert description == "Description for get_schemas" 82 | 83 | # Reset the singleton for other tests 84 | # pylint: disable=protected-access 85 | ToolManager._instance = None # type: ignore 86 | 87 | def test_get_description_invalid_tool(self): 88 | """Test getting a description for an invalid tool.""" 89 | # Setup 90 | manager = ToolManager.get_instance() 91 | 92 | # Force the descriptions to have a known value for testing 93 | # pylint: disable=protected-access 94 | # We need to set the descriptions directly for testing 95 | manager.descriptions = { 96 | ToolName.GET_SCHEMAS.value: "Description for get_schemas", 97 | ToolName.GET_TABLES.value: "Description for get_tables", 98 | } 99 | 100 | # Test and verify 101 | description = manager.get_description("nonexistent_tool") 102 | assert description == "" # The method returns an empty string for unknown tools 103 | 104 | # Reset the singleton for other tests 105 | # pylint: disable=protected-access 106 | ToolManager._instance = None # type: ignore 107 | 108 | def test_all_tool_names_have_descriptions(self): 109 | """Test that all tools defined in ToolName enum have descriptions.""" 110 | # Setup - get a fresh instance 111 | # Reset the singleton first to ensure we get a clean instance 112 | # pylint: disable=protected-access 113 | ToolManager._instance = None # type: ignore 114 | 115 | # Get a fresh instance that will load the real YAML files 116 | manager = ToolManager.get_instance() 117 | 118 | # Print the loaded descriptions for debugging 119 | print(f"\nLoaded descriptions: {manager.descriptions}") 120 | 121 | # Verify that we have at least some descriptions loaded 122 | assert len(manager.descriptions) > 0, "No descriptions were loaded" 123 | 124 | # Check that descriptions are not empty 125 | empty_descriptions: list[str] = [] 126 | for tool_name, description in manager.descriptions.items(): 127 | if not description or len(description.strip()) == 0: 128 | empty_descriptions.append(tool_name) 129 | 130 | # Fail if we found any empty descriptions 131 | assert len(empty_descriptions) == 0, f"Found empty descriptions for tools: {empty_descriptions}" 132 | 133 | # Check that at least some of the tool names have descriptions 134 | found_descriptions = 0 135 | missing_descriptions: list[str] = [] 136 | 137 | for tool_name in ToolName: 138 | description = manager.get_description(tool_name.value) 139 | if description: 140 | found_descriptions += 1 141 | else: 142 | missing_descriptions.append(tool_name.value) 143 | 144 | # Print missing descriptions for debugging 145 | if missing_descriptions: 146 | print(f"\nMissing descriptions for: {missing_descriptions}") 147 | 148 | # We should have at least some descriptions 149 | assert found_descriptions > 0, "No tool has a description" 150 | 151 | # Reset the singleton for other tests 152 | # pylint: disable=protected-access 153 | ToolManager._instance = None # type: ignore 154 | 155 | @patch.object(ToolManager, "_load_descriptions") 156 | def test_initialization_loads_descriptions(self, mock_load_descriptions: MagicMock): 157 | """Test that descriptions are loaded during initialization.""" 158 | # Create a new instance 159 | # We need to create the manager to trigger __init__ 160 | ToolManager() 161 | 162 | # Verify _load_descriptions was called 163 | assert mock_load_descriptions.call_count > 0 164 | 165 | # Reset the singleton for other tests 166 | # pylint: disable=protected-access 167 | ToolManager._instance = None # type: ignore 168 | 169 | def test_tool_enum_completeness(self): 170 | """Test that the ToolName enum contains all expected tools.""" 171 | # Get all tool values from the enum 172 | tool_values = [tool.value for tool in ToolName] 173 | 174 | # Verify the total number of tools 175 | # Update this number when new tools are added 176 | expected_tool_count = 12 177 | assert len(tool_values) == expected_tool_count, f"Expected {expected_tool_count} tools, got {len(tool_values)}" 178 | 179 | # Verify specific tools are included 180 | assert "retrieve_logs" in tool_values, "retrieve_logs tool is missing from ToolName enum" 181 | 182 | # Reset the singleton for other tests 183 | # pylint: disable=protected-access 184 | ToolManager._instance = None # type: ignore 185 | ``` -------------------------------------------------------------------------------- /tests/services/safety/test_api_safety_config.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Unit tests for the APISafetyConfig class. 3 | 4 | This file contains unit test cases for the APISafetyConfig class, which is responsible for 5 | determining the risk level of API operations and whether they are allowed or require confirmation. 6 | """ 7 | 8 | import pytest 9 | 10 | from supabase_mcp.services.safety.models import OperationRiskLevel, SafetyMode 11 | from supabase_mcp.services.safety.safety_configs import APISafetyConfig, HTTPMethod 12 | 13 | 14 | @pytest.mark.unit 15 | class TestAPISafetyConfig: 16 | """Unit tests for the APISafetyConfig class.""" 17 | 18 | def test_get_risk_level_low_risk(self): 19 | """Test getting risk level for low-risk operations (GET requests).""" 20 | config = APISafetyConfig() 21 | # API operations are tuples of (method, path, path_params, query_params, request_body) 22 | operation = ("GET", "/v1/projects/{ref}/functions", {}, {}, {}) 23 | risk_level = config.get_risk_level(operation) 24 | assert risk_level == OperationRiskLevel.LOW 25 | 26 | def test_get_risk_level_medium_risk(self): 27 | """Test getting risk level for medium-risk operations (POST/PUT/PATCH).""" 28 | config = APISafetyConfig() 29 | 30 | # Test POST request 31 | operation = ("POST", "/v1/projects/{ref}/functions", {}, {}, {}) 32 | risk_level = config.get_risk_level(operation) 33 | assert risk_level == OperationRiskLevel.MEDIUM 34 | 35 | # Test PUT request 36 | operation = ("PUT", "/v1/projects/{ref}/functions", {}, {}, {}) 37 | risk_level = config.get_risk_level(operation) 38 | assert risk_level == OperationRiskLevel.MEDIUM 39 | 40 | # Test PATCH request 41 | operation = ("PATCH", "/v1/projects/{ref}/functions/{function_slug}", {}, {}, {}) 42 | risk_level = config.get_risk_level(operation) 43 | assert risk_level == OperationRiskLevel.MEDIUM 44 | 45 | def test_get_risk_level_high_risk(self): 46 | """Test getting risk level for high-risk operations.""" 47 | config = APISafetyConfig() 48 | 49 | # Test DELETE request for a function 50 | operation = ("DELETE", "/v1/projects/{ref}/functions/{function_slug}", {}, {}, {}) 51 | risk_level = config.get_risk_level(operation) 52 | assert risk_level == OperationRiskLevel.HIGH 53 | 54 | # Test other high-risk operations 55 | high_risk_paths = [ 56 | "/v1/projects/{ref}/branches/{branch_id}", 57 | "/v1/projects/{ref}/custom-hostname", 58 | "/v1/projects/{ref}/network-bans", 59 | ] 60 | 61 | for path in high_risk_paths: 62 | operation = ("DELETE", path, {}, {}, {}) 63 | risk_level = config.get_risk_level(operation) 64 | assert risk_level == OperationRiskLevel.HIGH, f"Path {path} should be HIGH risk" 65 | 66 | def test_get_risk_level_extreme_risk(self): 67 | """Test getting risk level for extreme-risk operations.""" 68 | config = APISafetyConfig() 69 | 70 | # Test DELETE request for a project 71 | operation = ("DELETE", "/v1/projects/{ref}", {}, {}, {}) 72 | risk_level = config.get_risk_level(operation) 73 | assert risk_level == OperationRiskLevel.EXTREME 74 | 75 | def test_is_operation_allowed(self): 76 | """Test if operations are allowed based on risk level and safety mode.""" 77 | config = APISafetyConfig() 78 | 79 | # Low risk operations should be allowed in both safe and unsafe modes 80 | assert config.is_operation_allowed(OperationRiskLevel.LOW, SafetyMode.SAFE) is True 81 | assert config.is_operation_allowed(OperationRiskLevel.LOW, SafetyMode.UNSAFE) is True 82 | 83 | # Medium/high risk operations should only be allowed in unsafe mode 84 | assert config.is_operation_allowed(OperationRiskLevel.MEDIUM, SafetyMode.SAFE) is False 85 | assert config.is_operation_allowed(OperationRiskLevel.MEDIUM, SafetyMode.UNSAFE) is True 86 | assert config.is_operation_allowed(OperationRiskLevel.HIGH, SafetyMode.SAFE) is False 87 | assert config.is_operation_allowed(OperationRiskLevel.HIGH, SafetyMode.UNSAFE) is True 88 | 89 | # Extreme risk operations should not be allowed in safe mode 90 | assert config.is_operation_allowed(OperationRiskLevel.EXTREME, SafetyMode.SAFE) is False 91 | # In the current implementation, extreme risk operations are never allowed 92 | assert config.is_operation_allowed(OperationRiskLevel.EXTREME, SafetyMode.UNSAFE) is False 93 | 94 | def test_needs_confirmation(self): 95 | """Test if operations need confirmation based on risk level.""" 96 | config = APISafetyConfig() 97 | 98 | # Low and medium risk operations should not need confirmation 99 | assert config.needs_confirmation(OperationRiskLevel.LOW) is False 100 | assert config.needs_confirmation(OperationRiskLevel.MEDIUM) is False 101 | 102 | # High and extreme risk operations should need confirmation 103 | assert config.needs_confirmation(OperationRiskLevel.HIGH) is True 104 | assert config.needs_confirmation(OperationRiskLevel.EXTREME) is True 105 | 106 | def test_path_matching(self): 107 | """Test that path patterns are correctly matched.""" 108 | config = APISafetyConfig() 109 | 110 | # Test exact path matching 111 | operation = ("GET", "/v1/projects/{ref}/functions", {}, {}, {}) 112 | assert config.get_risk_level(operation) == OperationRiskLevel.LOW 113 | 114 | # Test path with parameters 115 | operation = ("GET", "/v1/projects/abc123/functions", {}, {}, {}) 116 | assert config.get_risk_level(operation) == OperationRiskLevel.LOW 117 | 118 | # Test path with multiple parameters 119 | operation = ("DELETE", "/v1/projects/abc123/functions/my-function", {}, {}, {}) 120 | assert config.get_risk_level(operation) == OperationRiskLevel.HIGH 121 | 122 | # Test path that doesn't match any pattern (should default to MEDIUM for non-GET) 123 | operation = ("DELETE", "/v1/some/unknown/path", {}, {}, {}) 124 | assert config.get_risk_level(operation) == OperationRiskLevel.LOW 125 | 126 | # Test path that doesn't match any pattern (should default to LOW for GET) 127 | operation = ("GET", "/v1/some/unknown/path", {}, {}, {}) 128 | assert config.get_risk_level(operation) == OperationRiskLevel.LOW 129 | 130 | def test_method_case_insensitivity(self): 131 | """Test that HTTP method matching is case-insensitive.""" 132 | config = APISafetyConfig() 133 | 134 | # Test with lowercase method 135 | operation = ("get", "/v1/projects/{ref}/functions", {}, {}, {}) 136 | assert config.get_risk_level(operation) == OperationRiskLevel.LOW 137 | 138 | # Test with uppercase method 139 | operation = ("GET", "/v1/projects/{ref}/functions", {}, {}, {}) 140 | assert config.get_risk_level(operation) == OperationRiskLevel.LOW 141 | 142 | # Test with mixed case method 143 | operation = ("GeT", "/v1/projects/{ref}/functions", {}, {}, {}) 144 | assert config.get_risk_level(operation) == OperationRiskLevel.LOW 145 | 146 | def test_path_safety_config_structure(self): 147 | """Test that the PATH_SAFETY_CONFIG structure is correctly defined.""" 148 | config = APISafetyConfig() 149 | 150 | # Check that the config has the expected structure 151 | assert hasattr(config, "PATH_SAFETY_CONFIG") 152 | 153 | # Check that risk levels are represented as keys 154 | assert OperationRiskLevel.MEDIUM in config.PATH_SAFETY_CONFIG 155 | assert OperationRiskLevel.HIGH in config.PATH_SAFETY_CONFIG 156 | assert OperationRiskLevel.EXTREME in config.PATH_SAFETY_CONFIG 157 | 158 | # Check that each risk level has a dictionary of methods to paths 159 | for risk_level, methods_dict in config.PATH_SAFETY_CONFIG.items(): 160 | assert isinstance(methods_dict, dict) 161 | for method, paths in methods_dict.items(): 162 | assert isinstance(method, HTTPMethod) 163 | assert isinstance(paths, list) 164 | for path in paths: 165 | assert isinstance(path, str) 166 | ``` -------------------------------------------------------------------------------- /supabase_mcp/settings.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | from pathlib import Path 3 | from typing import Literal 4 | 5 | from pydantic import Field, ValidationInfo, field_validator 6 | from pydantic_settings import BaseSettings, SettingsConfigDict 7 | 8 | from supabase_mcp.logger import logger 9 | 10 | SUPPORTED_REGIONS = Literal[ 11 | "us-west-1", # West US (North California) 12 | "us-east-1", # East US (North Virginia) 13 | "us-east-2", # East US (Ohio) 14 | "ca-central-1", # Canada (Central) 15 | "eu-west-1", # West EU (Ireland) 16 | "eu-west-2", # West Europe (London) 17 | "eu-west-3", # West EU (Paris) 18 | "eu-central-1", # Central EU (Frankfurt) 19 | "eu-central-2", # Central Europe (Zurich) 20 | "eu-north-1", # North EU (Stockholm) 21 | "ap-south-1", # South Asia (Mumbai) 22 | "ap-southeast-1", # Southeast Asia (Singapore) 23 | "ap-northeast-1", # Northeast Asia (Tokyo) 24 | "ap-northeast-2", # Northeast Asia (Seoul) 25 | "ap-southeast-2", # Oceania (Sydney) 26 | "sa-east-1", # South America (São Paulo) 27 | ] 28 | 29 | 30 | def find_config_file(env_file: str = ".env") -> str | None: 31 | """Find the specified env file in order of precedence: 32 | 1. Current working directory (where command is run) 33 | 2. Global config: 34 | - Windows: %APPDATA%/supabase-mcp/{env_file} 35 | - macOS/Linux: ~/.config/supabase-mcp/{env_file} 36 | 37 | Args: 38 | env_file: The name of the environment file to look for (default: ".env") 39 | 40 | Returns: 41 | The path to the found config file, or None if not found 42 | """ 43 | # 1. Check current directory 44 | cwd_config = Path.cwd() / env_file 45 | if cwd_config.exists(): 46 | return str(cwd_config) 47 | 48 | # 2. Check global config 49 | home = Path.home() 50 | if os.name == "nt": # Windows 51 | global_config = Path(os.environ.get("APPDATA", "")) / "supabase-mcp" / ".env" 52 | else: # macOS/Linux 53 | global_config = home / ".config" / "supabase-mcp" / ".env" 54 | 55 | if global_config.exists(): 56 | logger.error( 57 | f"DEPRECATED: {global_config} is deprecated and will be removed in a future release. " 58 | "Use your IDE's native .json config file to configure access to MCP." 59 | ) 60 | return str(global_config) 61 | 62 | return None 63 | 64 | 65 | class Settings(BaseSettings): 66 | """Initializes settings for Supabase MCP server.""" 67 | 68 | supabase_project_ref: str = Field( 69 | default="127.0.0.1:54322", # Local Supabase default 70 | description="Supabase project ref - Must be 20 chars for remote projects, can be local address for development", 71 | alias="SUPABASE_PROJECT_REF", 72 | ) 73 | supabase_db_password: str | None = Field( 74 | default=None, # Will be validated based on project_ref 75 | description="Supabase database password - Required for remote projects, defaults to 'postgres' for local", 76 | alias="SUPABASE_DB_PASSWORD", 77 | ) 78 | supabase_region: str = Field( 79 | default="us-east-1", # East US (North Virginia) - Supabase's default region 80 | description="Supabase region for connection", 81 | alias="SUPABASE_REGION", 82 | ) 83 | supabase_access_token: str | None = Field( 84 | default=None, 85 | description="Optional personal access token for accessing Supabase Management API", 86 | alias="SUPABASE_ACCESS_TOKEN", 87 | ) 88 | supabase_service_role_key: str | None = Field( 89 | default=None, 90 | description="Optional service role key for accessing Python SDK", 91 | alias="SUPABASE_SERVICE_ROLE_KEY", 92 | ) 93 | 94 | supabase_api_url: str = Field( 95 | default="https://api.supabase.com", 96 | description="Supabase API URL", 97 | ) 98 | 99 | query_api_key: str = Field( 100 | default="test-key", 101 | description="TheQuery.dev API key", 102 | alias="QUERY_API_KEY", 103 | ) 104 | 105 | query_api_url: str = Field( 106 | default="https://api.thequery.dev/v1", 107 | description="TheQuery.dev API URL", 108 | alias="QUERY_API_URL", 109 | ) 110 | 111 | @field_validator("supabase_region") 112 | @classmethod 113 | def validate_region(cls, v: str, info: ValidationInfo) -> str: 114 | """Validate that the region is supported by Supabase.""" 115 | # Get the project_ref from the values 116 | values = info.data 117 | project_ref = values.get("supabase_project_ref", "") 118 | 119 | # If this is a remote project and region is the default 120 | if not project_ref.startswith("127.0.0.1") and v == "us-east-1" and "SUPABASE_REGION" not in os.environ: 121 | logger.warning( 122 | "You're connecting to a remote Supabase project but haven't specified a region. " 123 | "Using default 'us-east-1', which may cause 'Tenant or user not found' errors if incorrect. " 124 | "Please set the correct SUPABASE_REGION in your configuration." 125 | ) 126 | 127 | # Validate that the region is supported 128 | if v not in SUPPORTED_REGIONS.__args__: 129 | supported = "\n - ".join([""] + list(SUPPORTED_REGIONS.__args__)) 130 | raise ValueError(f"Region '{v}' is not supported. Supported regions are:{supported}") 131 | return v 132 | 133 | @field_validator("supabase_project_ref") 134 | @classmethod 135 | def validate_project_ref(cls, v: str) -> str: 136 | """Validate the project ref format.""" 137 | if v.startswith("127.0.0.1"): 138 | # Local development - allow default format 139 | return v 140 | 141 | # Remote project - must be 20 chars 142 | if len(v) != 20: 143 | logger.error("Invalid Supabase project ref format") 144 | raise ValueError( 145 | "Invalid Supabase project ref format. " 146 | "Remote project refs must be exactly 20 characters long. " 147 | f"Got {len(v)} characters instead." 148 | ) 149 | return v 150 | 151 | @field_validator("supabase_db_password") 152 | @classmethod 153 | def validate_db_password(cls, v: str | None, info: ValidationInfo) -> str: 154 | """Validate database password based on project type.""" 155 | project_ref = info.data.get("supabase_project_ref", "") 156 | 157 | # For local development, allow default password 158 | if project_ref.startswith("127.0.0.1"): 159 | return v or "postgres" # Default to postgres for local 160 | 161 | # For remote projects, password is required 162 | if not v: 163 | logger.error("SUPABASE_DB_PASSWORD is required when connecting to a remote instance") 164 | raise ValueError( 165 | "Database password is required for remote Supabase projects. " 166 | "Please set SUPABASE_DB_PASSWORD in your environment variables." 167 | ) 168 | return v 169 | 170 | @classmethod 171 | def with_config(cls, config_file: str | None = None) -> "Settings": 172 | """Create Settings with a specific config file. 173 | 174 | Args: 175 | config_file: Path to .env file to use, or None for no config file 176 | """ 177 | 178 | # Create a new Settings class with the specific config 179 | class SettingsWithConfig(cls): 180 | model_config = SettingsConfigDict(env_file=config_file, env_file_encoding="utf-8") 181 | 182 | instance = SettingsWithConfig() 183 | 184 | # Log configuration source and precedence - simplified to a single clear message 185 | env_vars_present = any(var in os.environ for var in ["SUPABASE_PROJECT_REF", "SUPABASE_DB_PASSWORD"]) 186 | 187 | if env_vars_present and config_file: 188 | logger.info(f"Using environment variables (highest precedence) over config file: {config_file}") 189 | elif env_vars_present: 190 | logger.info("Using environment variables for configuration") 191 | elif config_file: 192 | logger.info(f"Using settings from config file: {config_file}") 193 | else: 194 | logger.info("Using default settings (local development)") 195 | 196 | return instance 197 | 198 | 199 | # Module-level singleton - maintains existing interface 200 | settings = Settings.with_config(find_config_file()) 201 | ``` -------------------------------------------------------------------------------- /tests/services/api/test_api_client.py: -------------------------------------------------------------------------------- ```python 1 | import httpx 2 | import pytest 3 | from unittest.mock import AsyncMock, MagicMock, patch 4 | 5 | from supabase_mcp.clients.management_client import ManagementAPIClient 6 | from supabase_mcp.exceptions import APIClientError, APIConnectionError 7 | from supabase_mcp.settings import Settings 8 | 9 | 10 | @pytest.mark.asyncio(loop_scope="module") 11 | class TestAPIClient: 12 | """Unit tests for the API client.""" 13 | 14 | @pytest.fixture 15 | def mock_settings(self): 16 | """Create mock settings for testing.""" 17 | settings = MagicMock(spec=Settings) 18 | settings.supabase_access_token = "test-token" 19 | settings.supabase_project_ref = "test-project-ref" 20 | settings.supabase_region = "us-east-1" 21 | settings.query_api_url = "https://api.test.com" 22 | settings.supabase_api_url = "https://api.supabase.com" 23 | return settings 24 | 25 | async def test_execute_get_request(self, mock_settings): 26 | """Test executing a GET request to the API.""" 27 | # Create client but don't mock the httpx client yet 28 | client = ManagementAPIClient(settings=mock_settings) 29 | 30 | # Setup mock response 31 | mock_response = MagicMock(spec=httpx.Response) 32 | mock_response.status_code = 404 33 | mock_response.is_success = False 34 | mock_response.headers = {"content-type": "application/json"} 35 | mock_response.json.return_value = {"message": "Cannot GET /v1/health"} 36 | mock_response.text = '{"message": "Cannot GET /v1/health"}' 37 | mock_response.content = b'{"message": "Cannot GET /v1/health"}' 38 | 39 | # Mock the send_request method to return our mock response 40 | with patch.object(client, 'send_request', return_value=mock_response): 41 | path = "/v1/health" 42 | 43 | # Execute the request and expect a 404 error 44 | with pytest.raises(APIClientError) as exc_info: 45 | await client.execute_request( 46 | method="GET", 47 | path=path, 48 | ) 49 | 50 | # Verify the error details 51 | assert exc_info.value.status_code == 404 52 | assert "Cannot GET /v1/health" in str(exc_info.value) 53 | 54 | async def test_request_preparation(self, mock_settings): 55 | """Test that requests are properly prepared with headers and parameters.""" 56 | client = ManagementAPIClient(settings=mock_settings) 57 | 58 | # Prepare a request with parameters 59 | method = "GET" 60 | path = "/v1/health" 61 | request_params = {"param1": "value1", "param2": "value2"} 62 | 63 | # Prepare the request 64 | request = client.prepare_request( 65 | method=method, 66 | path=path, 67 | request_params=request_params, 68 | ) 69 | 70 | # Verify the request 71 | assert request.method == method 72 | assert path in str(request.url) 73 | assert "param1=value1" in str(request.url) 74 | assert "param2=value2" in str(request.url) 75 | assert "Content-Type" in request.headers 76 | assert request.headers["Content-Type"] == "application/json" 77 | 78 | async def test_error_handling(self, mock_settings): 79 | """Test handling of API errors.""" 80 | client = ManagementAPIClient(settings=mock_settings) 81 | 82 | # Setup mock response 83 | mock_response = MagicMock(spec=httpx.Response) 84 | mock_response.status_code = 404 85 | mock_response.is_success = False 86 | mock_response.headers = {"content-type": "application/json"} 87 | mock_response.json.return_value = {"message": "Cannot GET /v1/nonexistent-endpoint"} 88 | mock_response.text = '{"message": "Cannot GET /v1/nonexistent-endpoint"}' 89 | mock_response.content = b'{"message": "Cannot GET /v1/nonexistent-endpoint"}' 90 | 91 | with patch.object(client, 'send_request', return_value=mock_response): 92 | path = "/v1/nonexistent-endpoint" 93 | 94 | # Execute the request and expect an APIClientError 95 | with pytest.raises(APIClientError) as exc_info: 96 | await client.execute_request( 97 | method="GET", 98 | path=path, 99 | ) 100 | 101 | # Verify the error details 102 | assert exc_info.value.status_code == 404 103 | assert "Cannot GET /v1/nonexistent-endpoint" in str(exc_info.value) 104 | 105 | async def test_request_with_body(self, mock_settings): 106 | """Test executing a request with a body.""" 107 | client = ManagementAPIClient(settings=mock_settings) 108 | 109 | # Test the request preparation 110 | method = "POST" 111 | path = "/v1/health/check" 112 | request_body = {"test": "data", "nested": {"value": 123}} 113 | 114 | # Prepare the request 115 | request = client.prepare_request( 116 | method=method, 117 | path=path, 118 | request_body=request_body, 119 | ) 120 | 121 | # Verify the request 122 | assert request.method == method 123 | assert path in str(request.url) 124 | assert request.content # Should have content for the body 125 | assert "Content-Type" in request.headers 126 | assert request.headers["Content-Type"] == "application/json" 127 | 128 | async def test_response_parsing(self, mock_settings): 129 | """Test parsing API responses.""" 130 | client = ManagementAPIClient(settings=mock_settings) 131 | 132 | # Setup mock response 133 | mock_response = MagicMock(spec=httpx.Response) 134 | mock_response.status_code = 200 135 | mock_response.is_success = True 136 | mock_response.headers = {"content-type": "application/json"} 137 | mock_response.json.return_value = [{"id": "project1", "name": "Test Project"}] 138 | mock_response.content = b'[{"id": "project1", "name": "Test Project"}]' 139 | 140 | with patch.object(client, 'send_request', return_value=mock_response): 141 | path = "/v1/projects" 142 | 143 | # Execute the request 144 | response = await client.execute_request( 145 | method="GET", 146 | path=path, 147 | ) 148 | 149 | # Verify the response is parsed correctly 150 | assert isinstance(response, list) 151 | assert len(response) > 0 152 | assert "id" in response[0] 153 | 154 | async def test_request_retry_mechanism(self, mock_settings): 155 | """Test that the tenacity retry mechanism works correctly for API requests.""" 156 | client = ManagementAPIClient(settings=mock_settings) 157 | 158 | # Create a mock request object for the NetworkError 159 | mock_request = MagicMock(spec=httpx.Request) 160 | mock_request.method = "GET" 161 | mock_request.url = "https://api.supabase.com/v1/projects" 162 | 163 | # Mock the client's send method to always raise a network error 164 | with patch.object(client.client, 'send', side_effect=httpx.NetworkError("Simulated network failure", request=mock_request)): 165 | # Execute a request - this should trigger retries and eventually fail 166 | with pytest.raises(APIConnectionError) as exc_info: 167 | await client.execute_request( 168 | method="GET", 169 | path="/v1/projects", 170 | ) 171 | 172 | # Verify the error message indicates retries were attempted 173 | assert "Network error after 3 retry attempts" in str(exc_info.value) 174 | 175 | async def test_request_without_access_token(self, mock_settings): 176 | """Test that an exception is raised when attempting to send a request without an access token.""" 177 | # Create client with no access token 178 | mock_settings.supabase_access_token = None 179 | client = ManagementAPIClient(settings=mock_settings) 180 | 181 | # Attempt to execute a request - should raise an exception 182 | with pytest.raises(APIClientError) as exc_info: 183 | await client.execute_request( 184 | method="GET", 185 | path="/v1/projects", 186 | ) 187 | 188 | assert "Supabase access token is not configured" in str(exc_info.value) ``` -------------------------------------------------------------------------------- /supabase_mcp/clients/base_http_client.py: -------------------------------------------------------------------------------- ```python 1 | from abc import ABC, abstractmethod 2 | from json.decoder import JSONDecodeError 3 | from typing import Any, TypeVar 4 | 5 | import httpx 6 | from pydantic import BaseModel 7 | from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential 8 | 9 | from supabase_mcp.exceptions import ( 10 | APIClientError, 11 | APIConnectionError, 12 | APIResponseError, 13 | APIServerError, 14 | UnexpectedError, 15 | ) 16 | from supabase_mcp.logger import logger 17 | 18 | T = TypeVar("T") 19 | 20 | 21 | # Helper function for retry decorator to safely log exceptions 22 | def log_retry_attempt(retry_state: RetryCallState) -> None: 23 | """Log retry attempts with exception details if available.""" 24 | exception = retry_state.outcome.exception() if retry_state.outcome and retry_state.outcome.failed else None 25 | exception_str = str(exception) if exception else "Unknown error" 26 | logger.warning(f"Network error, retrying ({retry_state.attempt_number}/3): {exception_str}") 27 | 28 | 29 | class AsyncHTTPClient(ABC): 30 | """Abstract base class for async HTTP clients.""" 31 | 32 | @abstractmethod 33 | async def _ensure_client(self) -> httpx.AsyncClient: 34 | """Ensure client exists and is ready for use. 35 | 36 | Creates the client if it doesn't exist yet. 37 | Returns the client instance. 38 | """ 39 | pass 40 | 41 | @abstractmethod 42 | async def close(self) -> None: 43 | """Close the client and release resources. 44 | 45 | Should be called when the client is no longer needed. 46 | """ 47 | pass 48 | 49 | def prepare_request( 50 | self, 51 | client: httpx.AsyncClient, 52 | method: str, 53 | path: str, 54 | request_params: dict[str, Any] | None = None, 55 | request_body: dict[str, Any] | None = None, 56 | ) -> httpx.Request: 57 | """ 58 | Prepare an HTTP request. 59 | 60 | Args: 61 | client: The httpx client to use 62 | method: HTTP method (GET, POST, etc.) 63 | path: API path 64 | request_params: Query parameters 65 | request_body: Request body 66 | 67 | Returns: 68 | Prepared httpx.Request object 69 | 70 | Raises: 71 | APIClientError: If request preparation fails 72 | """ 73 | try: 74 | return client.build_request(method=method, url=path, params=request_params, json=request_body) 75 | except Exception as e: 76 | raise APIClientError( 77 | message=f"Failed to build request: {str(e)}", 78 | status_code=None, 79 | ) from e 80 | 81 | @retry( 82 | retry=retry_if_exception_type(httpx.NetworkError), # This includes ConnectError and TimeoutException 83 | stop=stop_after_attempt(3), 84 | wait=wait_exponential(multiplier=1, min=2, max=10), 85 | reraise=True, # Ensure the original exception is raised 86 | before_sleep=log_retry_attempt, 87 | ) 88 | async def send_request(self, client: httpx.AsyncClient, request: httpx.Request) -> httpx.Response: 89 | """ 90 | Send an HTTP request with retry logic for transient errors. 91 | 92 | Args: 93 | client: The httpx client to use 94 | request: Prepared httpx.Request object 95 | 96 | Returns: 97 | httpx.Response object 98 | 99 | Raises: 100 | APIConnectionError: For connection issues 101 | APIClientError: For other request errors 102 | """ 103 | try: 104 | return await client.send(request) 105 | except httpx.NetworkError as e: 106 | # All NetworkErrors will be retried by the decorator 107 | # This will only be reached after all retries are exhausted 108 | logger.error(f"Network error after all retry attempts: {str(e)}") 109 | raise APIConnectionError( 110 | message=f"Network error after 3 retry attempts: {str(e)}", 111 | status_code=None, 112 | ) from e 113 | except Exception as e: 114 | # Other exceptions won't be retried 115 | raise APIClientError( 116 | message=f"Request failed: {str(e)}", 117 | status_code=None, 118 | ) from e 119 | 120 | def parse_response(self, response: httpx.Response) -> dict[str, Any]: 121 | """ 122 | Parse an HTTP response as JSON. 123 | 124 | Args: 125 | response: httpx.Response object 126 | 127 | Returns: 128 | Parsed response body as dictionary 129 | 130 | Raises: 131 | APIResponseError: If response cannot be parsed as JSON 132 | """ 133 | if not response.content: 134 | return {} 135 | 136 | try: 137 | return response.json() 138 | except JSONDecodeError as e: 139 | raise APIResponseError( 140 | message=f"Failed to parse response as JSON: {str(e)}", 141 | status_code=response.status_code, 142 | response_body={"raw_content": response.text}, 143 | ) from e 144 | 145 | def handle_error_response(self, response: httpx.Response, parsed_body: dict[str, Any] | None = None) -> None: 146 | """ 147 | Handle error responses based on status code. 148 | 149 | Args: 150 | response: httpx.Response object 151 | parsed_body: Parsed response body if available 152 | 153 | Raises: 154 | APIClientError: For client errors (4xx) 155 | APIServerError: For server errors (5xx) 156 | UnexpectedError: For unexpected status codes 157 | """ 158 | # Extract error message 159 | error_message = f"API request failed: {response.status_code}" 160 | if parsed_body and "message" in parsed_body: 161 | error_message = parsed_body["message"] 162 | 163 | # Determine error type based on status code 164 | if 400 <= response.status_code < 500: 165 | raise APIClientError( 166 | message=error_message, 167 | status_code=response.status_code, 168 | response_body=parsed_body, 169 | ) 170 | elif response.status_code >= 500: 171 | raise APIServerError( 172 | message=error_message, 173 | status_code=response.status_code, 174 | response_body=parsed_body, 175 | ) 176 | else: 177 | # This should not happen, but just in case 178 | raise UnexpectedError( 179 | message=f"Unexpected status code: {response.status_code}", 180 | status_code=response.status_code, 181 | response_body=parsed_body, 182 | ) 183 | 184 | async def execute_request( 185 | self, 186 | method: str, 187 | path: str, 188 | request_params: dict[str, Any] | None = None, 189 | request_body: dict[str, Any] | None = None, 190 | ) -> dict[str, Any] | BaseModel: 191 | """ 192 | Execute an HTTP request. 193 | 194 | Args: 195 | method: HTTP method (GET, POST, etc.) 196 | path: API path 197 | request_params: Query parameters 198 | request_body: Request body 199 | 200 | Returns: 201 | API response as a dictionary 202 | 203 | Raises: 204 | APIClientError: For client errors (4xx) 205 | APIConnectionError: For connection issues 206 | APIResponseError: For response parsing errors 207 | UnexpectedError: For unexpected errors 208 | """ 209 | # Log detailed request information 210 | logger.info(f"API Client: Executing {method} request to {path}") 211 | if request_params: 212 | logger.debug(f"Request params: {request_params}") 213 | if request_body: 214 | logger.debug(f"Request body: {request_body}") 215 | 216 | # Get client 217 | client = await self._ensure_client() 218 | 219 | # Prepare request 220 | request = self.prepare_request(client, method, path, request_params, request_body) 221 | 222 | # Send request 223 | response = await self.send_request(client, request) 224 | 225 | # Parse response (for both success and error cases) 226 | parsed_body = self.parse_response(response) 227 | 228 | # Check if successful 229 | if not response.is_success: 230 | logger.warning(f"Request failed: {method} {path} - Status {response.status_code}") 231 | self.handle_error_response(response, parsed_body) 232 | 233 | # Log success and return 234 | logger.info(f"Request successful: {method} {path} - Status {response.status_code}") 235 | return parsed_body 236 | ``` -------------------------------------------------------------------------------- /tests/services/logs/test_log_manager.py: -------------------------------------------------------------------------------- ```python 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from supabase_mcp.services.database.sql.loader import SQLLoader 6 | from supabase_mcp.services.logs.log_manager import LogManager 7 | 8 | 9 | class TestLogManager: 10 | """Tests for the LogManager class.""" 11 | 12 | def test_init(self): 13 | """Test initialization of LogManager.""" 14 | log_manager = LogManager() 15 | assert isinstance(log_manager.sql_loader, SQLLoader) 16 | assert log_manager.COLLECTION_TO_TABLE["postgres"] == "postgres_logs" 17 | assert log_manager.COLLECTION_TO_TABLE["api_gateway"] == "edge_logs" 18 | assert log_manager.COLLECTION_TO_TABLE["edge_functions"] == "function_edge_logs" 19 | 20 | @pytest.mark.parametrize( 21 | "collection,hours_ago,filters,search,expected_clause", 22 | [ 23 | # Test with hours_ago only 24 | ( 25 | "postgres", 26 | 24, 27 | None, 28 | None, 29 | "WHERE postgres_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)", 30 | ), 31 | # Test with search only 32 | ( 33 | "auth", 34 | None, 35 | None, 36 | "error", 37 | "WHERE event_message LIKE '%error%'", 38 | ), 39 | # Test with filters only 40 | ( 41 | "api_gateway", 42 | None, 43 | [{"field": "status_code", "operator": "=", "value": 500}], 44 | None, 45 | "WHERE status_code = 500", 46 | ), 47 | # Test with string value in filters 48 | ( 49 | "api_gateway", 50 | None, 51 | [{"field": "method", "operator": "=", "value": "GET"}], 52 | None, 53 | "WHERE method = 'GET'", 54 | ), 55 | # Test with multiple filters 56 | ( 57 | "postgres", 58 | None, 59 | [ 60 | {"field": "parsed.error_severity", "operator": "=", "value": "ERROR"}, 61 | {"field": "parsed.application_name", "operator": "LIKE", "value": "app%"}, 62 | ], 63 | None, 64 | "WHERE parsed.error_severity = 'ERROR' AND parsed.application_name LIKE 'app%'", 65 | ), 66 | # Test with hours_ago and search 67 | ( 68 | "storage", 69 | 12, 70 | None, 71 | "upload", 72 | "WHERE storage_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 12 HOUR) AND event_message LIKE '%upload%'", 73 | ), 74 | # Test with all parameters 75 | ( 76 | "edge_functions", 77 | 6, 78 | [{"field": "response.status_code", "operator": ">", "value": 400}], 79 | "timeout", 80 | "WHERE function_edge_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 6 HOUR) AND event_message LIKE '%timeout%' AND response.status_code > 400", 81 | ), 82 | # Test with cron logs (special case) 83 | ( 84 | "cron", 85 | 24, 86 | None, 87 | None, 88 | "AND postgres_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)", 89 | ), 90 | # Test with cron logs and other parameters 91 | ( 92 | "cron", 93 | 12, 94 | [{"field": "parsed.error_severity", "operator": "=", "value": "ERROR"}], 95 | "failed", 96 | "AND postgres_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 12 HOUR) AND event_message LIKE '%failed%' AND parsed.error_severity = 'ERROR'", 97 | ), 98 | ], 99 | ) 100 | def test_build_where_clause(self, collection, hours_ago, filters, search, expected_clause): 101 | """Test building WHERE clauses for different scenarios.""" 102 | log_manager = LogManager() 103 | where_clause = log_manager._build_where_clause( 104 | collection=collection, hours_ago=hours_ago, filters=filters, search=search 105 | ) 106 | assert where_clause == expected_clause 107 | 108 | def test_build_where_clause_escapes_single_quotes(self): 109 | """Test that single quotes in search strings are properly escaped.""" 110 | log_manager = LogManager() 111 | where_clause = log_manager._build_where_clause(collection="postgres", search="O'Reilly") 112 | assert where_clause == "WHERE event_message LIKE '%O''Reilly%'" 113 | 114 | # Test with filters containing single quotes 115 | where_clause = log_manager._build_where_clause( 116 | collection="postgres", 117 | filters=[{"field": "parsed.query", "operator": "LIKE", "value": "SELECT * FROM O'Reilly"}], 118 | ) 119 | assert where_clause == "WHERE parsed.query LIKE 'SELECT * FROM O''Reilly'" 120 | 121 | @patch.object(SQLLoader, "get_logs_query") 122 | def test_build_logs_query_with_custom_query(self, mock_get_logs_query): 123 | """Test building a logs query with a custom query.""" 124 | log_manager = LogManager() 125 | custom_query = "SELECT * FROM postgres_logs LIMIT 10" 126 | 127 | query = log_manager.build_logs_query(collection="postgres", custom_query=custom_query) 128 | 129 | assert query == custom_query 130 | # Ensure get_logs_query is not called when custom_query is provided 131 | mock_get_logs_query.assert_not_called() 132 | 133 | @patch.object(LogManager, "_build_where_clause") 134 | @patch.object(SQLLoader, "get_logs_query") 135 | def test_build_logs_query_standard(self, mock_get_logs_query, mock_build_where_clause): 136 | """Test building a standard logs query.""" 137 | log_manager = LogManager() 138 | mock_build_where_clause.return_value = "WHERE timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)" 139 | mock_get_logs_query.return_value = "SELECT * FROM postgres_logs WHERE timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR) LIMIT 20" 140 | 141 | query = log_manager.build_logs_query( 142 | collection="postgres", 143 | limit=20, 144 | hours_ago=24, 145 | filters=[{"field": "parsed.error_severity", "operator": "=", "value": "ERROR"}], 146 | search="connection", 147 | ) 148 | 149 | mock_build_where_clause.assert_called_once_with( 150 | collection="postgres", 151 | hours_ago=24, 152 | filters=[{"field": "parsed.error_severity", "operator": "=", "value": "ERROR"}], 153 | search="connection", 154 | ) 155 | 156 | mock_get_logs_query.assert_called_once_with( 157 | collection="postgres", 158 | where_clause="WHERE timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)", 159 | limit=20, 160 | ) 161 | 162 | assert ( 163 | query 164 | == "SELECT * FROM postgres_logs WHERE timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR) LIMIT 20" 165 | ) 166 | 167 | @patch.object(SQLLoader, "get_logs_query") 168 | def test_build_logs_query_integration(self, mock_get_logs_query, sql_loader): 169 | """Test building a logs query with integration between components.""" 170 | # Setup 171 | log_manager = LogManager() 172 | log_manager.sql_loader = sql_loader 173 | 174 | # Mock the SQL loader to return a predictable result 175 | mock_get_logs_query.return_value = ( 176 | "SELECT id, postgres_logs.timestamp, event_message FROM postgres_logs " 177 | "WHERE postgres_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR) " 178 | "ORDER BY timestamp DESC LIMIT 10" 179 | ) 180 | 181 | # Execute 182 | query = log_manager.build_logs_query( 183 | collection="postgres", 184 | limit=10, 185 | hours_ago=24, 186 | ) 187 | 188 | # Verify 189 | assert "SELECT id, postgres_logs.timestamp, event_message FROM postgres_logs" in query 190 | assert "LIMIT 10" in query 191 | mock_get_logs_query.assert_called_once() 192 | 193 | def test_unknown_collection(self): 194 | """Test handling of unknown collections.""" 195 | log_manager = LogManager() 196 | 197 | # Test with a collection that doesn't exist in the mapping 198 | where_clause = log_manager._build_where_clause( 199 | collection="unknown_collection", 200 | hours_ago=24, 201 | ) 202 | 203 | # Should use the collection name as the table name 204 | assert ( 205 | where_clause == "WHERE unknown_collection.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)" 206 | ) 207 | ``` -------------------------------------------------------------------------------- /supabase_mcp/services/database/query_manager.py: -------------------------------------------------------------------------------- ```python 1 | from supabase_mcp.exceptions import OperationNotAllowedError 2 | from supabase_mcp.logger import logger 3 | from supabase_mcp.services.database.migration_manager import MigrationManager 4 | from supabase_mcp.services.database.postgres_client import PostgresClient, QueryResult 5 | from supabase_mcp.services.database.sql.loader import SQLLoader 6 | from supabase_mcp.services.database.sql.models import QueryValidationResults 7 | from supabase_mcp.services.database.sql.validator import SQLValidator 8 | from supabase_mcp.services.safety.models import ClientType, SafetyMode 9 | from supabase_mcp.services.safety.safety_manager import SafetyManager 10 | 11 | 12 | class QueryManager: 13 | """ 14 | Manages SQL query execution with validation and migration handling. 15 | 16 | This class is responsible for: 17 | 1. Validating SQL queries for safety 18 | 2. Executing queries through the database client 19 | 3. Managing migrations for queries that require them 20 | 4. Loading SQL queries from files 21 | 22 | It acts as a central point for all SQL operations, ensuring consistent 23 | validation and execution patterns. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | postgres_client: PostgresClient, 29 | safety_manager: SafetyManager, 30 | sql_validator: SQLValidator | None = None, 31 | migration_manager: MigrationManager | None = None, 32 | sql_loader: SQLLoader | None = None, 33 | ): 34 | """ 35 | Initialize the QueryManager. 36 | 37 | Args: 38 | postgres_client: The database client to use for executing queries 39 | safety_manager: The safety manager to use for validating operations 40 | sql_validator: Optional SQL validator to use 41 | migration_manager: Optional migration manager to use 42 | sql_loader: Optional SQL loader to use 43 | """ 44 | self.db_client = postgres_client 45 | self.safety_manager = safety_manager 46 | self.validator = sql_validator or SQLValidator() 47 | self.sql_loader = sql_loader or SQLLoader() 48 | self.migration_manager = migration_manager or MigrationManager(loader=self.sql_loader) 49 | 50 | def check_readonly(self) -> bool: 51 | """Returns true if current safety mode is SAFE.""" 52 | result = self.safety_manager.get_safety_mode(ClientType.DATABASE) == SafetyMode.SAFE 53 | logger.debug(f"Check readonly result: {result}") 54 | return result 55 | 56 | async def handle_query(self, query: str, has_confirmation: bool = False, migration_name: str = "") -> QueryResult: 57 | """ 58 | Handle a SQL query with validation and potential migration. Uses migration name, if provided. 59 | 60 | This method: 61 | 1. Validates the query for safety 62 | 2. Checks if the query requires migration 63 | 3. Handles migration if needed 64 | 4. Executes the query 65 | 66 | Args: 67 | query: SQL query to execute 68 | params: Query parameters 69 | has_confirmation: Whether the operation has been confirmed by the user 70 | 71 | Returns: 72 | QueryResult: The result of the query execution 73 | 74 | Raises: 75 | OperationNotAllowedError: If the query is not allowed in the current safety mode 76 | ConfirmationRequiredError: If the query requires confirmation and has_confirmation is False 77 | """ 78 | # 1. Run through the validator 79 | validated_query = self.validator.validate_query(query) 80 | 81 | # 2. Ensure execution is allowed 82 | self.safety_manager.validate_operation(ClientType.DATABASE, validated_query, has_confirmation) 83 | logger.debug(f"Operation with risk level {validated_query.highest_risk_level} validated successfully") 84 | 85 | # 3. Handle migration if needed 86 | await self.handle_migration(validated_query, query, migration_name) 87 | 88 | # 4. Execute the query 89 | return await self.handle_query_execution(validated_query) 90 | 91 | async def handle_query_execution(self, validated_query: QueryValidationResults) -> QueryResult: 92 | """ 93 | Handle query execution with validation and potential migration. 94 | 95 | This method: 96 | 1. Checks the readonly mode 97 | 2. Executes the query 98 | 3. Returns the result 99 | 100 | Args: 101 | validated_query: The validation result 102 | query: The original query 103 | 104 | Returns: 105 | QueryResult: The result of the query execution 106 | """ 107 | readonly = self.check_readonly() 108 | result = await self.db_client.execute_query(validated_query, readonly) 109 | logger.debug(f"Query result: {result}") 110 | return result 111 | 112 | async def handle_migration( 113 | self, validation_result: QueryValidationResults, original_query: str, migration_name: str = "" 114 | ) -> None: 115 | """ 116 | Handle migration for a query that requires it. 117 | 118 | Args: 119 | validation_result: The validation result 120 | query: The original query 121 | migration_name: Migration name to use, if provided 122 | """ 123 | # 1. Check if migration is needed 124 | if not validation_result.needs_migration(): 125 | logger.debug("No migration needed for this query") 126 | return 127 | 128 | # 2. Prepare migration query 129 | migration_query, name = self.migration_manager.prepare_migration_query( 130 | validation_result, original_query, migration_name 131 | ) 132 | logger.debug("Migration query prepared") 133 | 134 | # 3. Execute migration query 135 | try: 136 | # First, ensure the migration schema exists 137 | await self.init_migration_schema() 138 | 139 | # Then execute the migration query 140 | migration_validation = self.validator.validate_query(migration_query) 141 | await self.db_client.execute_query(migration_validation, readonly=False) 142 | logger.info(f"Migration '{name}' executed successfully") 143 | except Exception as e: 144 | logger.debug(f"Migration failure details: {str(e)}") 145 | # We don't want to fail the main query if migration fails 146 | # Just log the error and continue 147 | logger.warning(f"Failed to record migration '{name}': {e}") 148 | 149 | async def init_migration_schema(self) -> None: 150 | """Initialize the migrations schema and table if they don't exist.""" 151 | try: 152 | # Get the initialization query 153 | init_query = self.sql_loader.get_init_migrations_query() 154 | 155 | # Validate and execute it 156 | init_validation = self.validator.validate_query(init_query) 157 | await self.db_client.execute_query(init_validation, readonly=False) 158 | logger.debug("Migrations schema initialized successfully") 159 | except Exception as e: 160 | logger.warning(f"Failed to initialize migrations schema: {e}") 161 | 162 | async def handle_confirmation(self, confirmation_id: str) -> QueryResult: 163 | """ 164 | Handle a confirmed operation using its confirmation ID. 165 | 166 | This method retrieves the stored operation and passes it to handle_query. 167 | 168 | Args: 169 | confirmation_id: The unique ID of the confirmation to process 170 | 171 | Returns: 172 | QueryResult: The result of the query execution 173 | """ 174 | # Get the stored operation 175 | operation = self.safety_manager.get_stored_operation(confirmation_id) 176 | if not operation: 177 | raise OperationNotAllowedError(f"Invalid or expired confirmation ID: {confirmation_id}") 178 | 179 | # Get the query from the operation 180 | query = operation.original_query 181 | logger.debug(f"Processing confirmed operation with ID {confirmation_id}") 182 | 183 | # Call handle_query with the query and has_confirmation=True 184 | return await self.handle_query(query, has_confirmation=True) 185 | 186 | def get_schemas_query(self) -> str: 187 | """Get a query to list all schemas.""" 188 | return self.sql_loader.get_schemas_query() 189 | 190 | def get_tables_query(self, schema_name: str) -> str: 191 | """Get a query to list all tables in a schema.""" 192 | return self.sql_loader.get_tables_query(schema_name) 193 | 194 | def get_table_schema_query(self, schema_name: str, table: str) -> str: 195 | """Get a query to get the schema of a table.""" 196 | return self.sql_loader.get_table_schema_query(schema_name, table) 197 | 198 | def get_migrations_query( 199 | self, limit: int = 50, offset: int = 0, name_pattern: str = "", include_full_queries: bool = False 200 | ) -> str: 201 | """Get a query to list migrations.""" 202 | return self.sql_loader.get_migrations_query( 203 | limit=limit, offset=offset, name_pattern=name_pattern, include_full_queries=include_full_queries 204 | ) 205 | ``` -------------------------------------------------------------------------------- /supabase_mcp/clients/management_client.py: -------------------------------------------------------------------------------- ```python 1 | from __future__ import annotations 2 | 3 | from json.decoder import JSONDecodeError 4 | from typing import Any 5 | 6 | import httpx 7 | from httpx import Request, Response 8 | from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential 9 | 10 | from supabase_mcp.exceptions import ( 11 | APIClientError, 12 | APIConnectionError, 13 | APIResponseError, 14 | APIServerError, 15 | UnexpectedError, 16 | ) 17 | from supabase_mcp.logger import logger 18 | from supabase_mcp.settings import Settings 19 | 20 | 21 | # Helper function for retry decorator to safely log exceptions 22 | def log_retry_attempt(retry_state: RetryCallState) -> None: 23 | """Log retry attempts with exception details if available.""" 24 | exception = retry_state.outcome.exception() if retry_state.outcome and retry_state.outcome.failed else None 25 | exception_str = str(exception) if exception else "Unknown error" 26 | logger.warning(f"Network error, retrying ({retry_state.attempt_number}/3): {exception_str}") 27 | 28 | 29 | class ManagementAPIClient: 30 | """ 31 | Client for Supabase Management API. 32 | 33 | Handles low-level HTTP requests to the Supabase Management API. 34 | """ 35 | 36 | def __init__(self, settings: Settings) -> None: 37 | """Initialize the API client with default settings.""" 38 | self.settings = settings 39 | self.client = self.create_httpx_client(settings) 40 | 41 | logger.info("✔️ Management API client initialized successfully") 42 | 43 | def create_httpx_client(self, settings: Settings) -> httpx.AsyncClient: 44 | """Create and configure an httpx client for API requests.""" 45 | headers = { 46 | "Authorization": f"Bearer {settings.supabase_access_token}", 47 | "Content-Type": "application/json", 48 | } 49 | 50 | return httpx.AsyncClient( 51 | base_url=settings.supabase_api_url, 52 | headers=headers, 53 | timeout=30.0, 54 | ) 55 | 56 | def prepare_request( 57 | self, 58 | method: str, 59 | path: str, 60 | request_params: dict[str, Any] | None = None, 61 | request_body: dict[str, Any] | None = None, 62 | ) -> Request: 63 | """ 64 | Prepare an HTTP request to the Supabase Management API. 65 | 66 | Args: 67 | method: HTTP method (GET, POST, etc.) 68 | path: API path 69 | request_params: Query parameters 70 | request_body: Request body 71 | 72 | Returns: 73 | Prepared httpx.Request object 74 | 75 | Raises: 76 | APIClientError: If request preparation fails 77 | """ 78 | try: 79 | return self.client.build_request(method=method, url=path, params=request_params, json=request_body) 80 | except Exception as e: 81 | raise APIClientError( 82 | message=f"Failed to build request: {str(e)}", 83 | status_code=None, 84 | ) from e 85 | 86 | @retry( 87 | retry=retry_if_exception_type(httpx.NetworkError), # This includes ConnectError and TimeoutException 88 | stop=stop_after_attempt(3), 89 | wait=wait_exponential(multiplier=1, min=2, max=10), 90 | reraise=True, # Ensure the original exception is raised 91 | before_sleep=log_retry_attempt, 92 | ) 93 | async def send_request(self, request: Request) -> Response: 94 | """ 95 | Send an HTTP request with retry logic for transient errors. 96 | 97 | Args: 98 | request: Prepared httpx.Request object 99 | 100 | Returns: 101 | httpx.Response object 102 | 103 | Raises: 104 | APIConnectionError: For connection issues 105 | APIClientError: For other request errors 106 | """ 107 | try: 108 | return await self.client.send(request) 109 | except httpx.NetworkError as e: 110 | # All NetworkErrors will be retried by the decorator 111 | # This will only be reached after all retries are exhausted 112 | logger.error(f"Network error after all retry attempts: {str(e)}") 113 | raise APIConnectionError( 114 | message=f"Network error after 3 retry attempts: {str(e)}", 115 | status_code=None, 116 | ) from e 117 | except Exception as e: 118 | # Other exceptions won't be retried 119 | raise APIClientError( 120 | message=f"Request failed: {str(e)}", 121 | status_code=None, 122 | ) from e 123 | 124 | def parse_response(self, response: Response) -> dict[str, Any]: 125 | """ 126 | Parse an HTTP response as JSON. 127 | 128 | Args: 129 | response: httpx.Response object 130 | 131 | Returns: 132 | Parsed response body as dictionary 133 | 134 | Raises: 135 | APIResponseError: If response cannot be parsed as JSON 136 | """ 137 | if not response.content: 138 | return {} 139 | 140 | try: 141 | return response.json() 142 | except JSONDecodeError as e: 143 | raise APIResponseError( 144 | message=f"Failed to parse response as JSON: {str(e)}", 145 | status_code=response.status_code, 146 | response_body={"raw_content": response.text}, 147 | ) from e 148 | 149 | def handle_error_response(self, response: Response, parsed_body: dict[str, Any] | None = None) -> None: 150 | """ 151 | Handle error responses based on status code. 152 | 153 | Args: 154 | response: httpx.Response object 155 | parsed_body: Parsed response body if available 156 | 157 | Raises: 158 | APIClientError: For client errors (4xx) 159 | APIServerError: For server errors (5xx) 160 | UnexpectedError: For unexpected status codes 161 | """ 162 | # Extract error message 163 | error_message = f"API request failed: {response.status_code}" 164 | if parsed_body and "message" in parsed_body: 165 | error_message = parsed_body["message"] 166 | 167 | # Determine error type based on status code 168 | if 400 <= response.status_code < 500: 169 | raise APIClientError( 170 | message=error_message, 171 | status_code=response.status_code, 172 | response_body=parsed_body, 173 | ) 174 | elif response.status_code >= 500: 175 | raise APIServerError( 176 | message=error_message, 177 | status_code=response.status_code, 178 | response_body=parsed_body, 179 | ) 180 | else: 181 | # This should not happen, but just in case 182 | raise UnexpectedError( 183 | message=f"Unexpected status code: {response.status_code}", 184 | status_code=response.status_code, 185 | response_body=parsed_body, 186 | ) 187 | 188 | async def execute_request( 189 | self, 190 | method: str, 191 | path: str, 192 | request_params: dict[str, Any] | None = None, 193 | request_body: dict[str, Any] | None = None, 194 | ) -> dict[str, Any]: 195 | """ 196 | Execute an HTTP request to the Supabase Management API. 197 | 198 | Args: 199 | method: HTTP method (GET, POST, etc.) 200 | path: API path 201 | request_params: Query parameters 202 | request_body: Request body 203 | 204 | Returns: 205 | API response as a dictionary 206 | 207 | Raises: 208 | APIClientError: For client errors (4xx) 209 | APIConnectionError: For connection issues 210 | APIResponseError: For response parsing errors 211 | UnexpectedError: For unexpected errors 212 | """ 213 | # Check if access token is available 214 | if not self.settings.supabase_access_token: 215 | raise APIClientError( 216 | "Supabase access token is not configured. Set SUPABASE_ACCESS_TOKEN environment variable to use Management API tools." 217 | ) 218 | 219 | # Log detailed request information 220 | logger.info(f"API Client: Executing {method} request to {path}") 221 | if request_params: 222 | logger.debug(f"Request params: {request_params}") 223 | if request_body: 224 | logger.debug(f"Request body: {request_body}") 225 | 226 | # Prepare request 227 | request = self.prepare_request(method, path, request_params, request_body) 228 | 229 | # Send request 230 | response = await self.send_request(request) 231 | 232 | # Parse response (for both success and error cases) 233 | parsed_body = self.parse_response(response) 234 | 235 | # Check if successful 236 | if not response.is_success: 237 | logger.warning(f"Request failed: {method} {path} - Status {response.status_code}") 238 | self.handle_error_response(response, parsed_body) 239 | 240 | # Log success and return 241 | logger.info(f"Request successful: {method} {path} - Status {response.status_code}") 242 | return parsed_body 243 | 244 | async def close(self) -> None: 245 | """Close the HTTP client and release resources.""" 246 | if self.client: 247 | await self.client.aclose() 248 | logger.info("HTTP API client closed") 249 | ``` -------------------------------------------------------------------------------- /supabase_mcp/services/api/spec_manager.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | from enum import Enum 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | import httpx 7 | 8 | from supabase_mcp.logger import logger 9 | 10 | # Constants 11 | SPEC_URL = "https://api.supabase.com/api/v1-json" 12 | LOCAL_SPEC_PATH = Path(__file__).parent / "specs" / "api_spec.json" 13 | 14 | 15 | class ApiDomain(str, Enum): 16 | """Enum of all possible domains in the Supabase Management API.""" 17 | 18 | ANALYTICS = "Analytics" 19 | AUTH = "Auth" 20 | DATABASE = "Database" 21 | DOMAINS = "Domains" 22 | EDGE_FUNCTIONS = "Edge Functions" 23 | ENVIRONMENTS = "Environments" 24 | OAUTH = "OAuth" 25 | ORGANIZATIONS = "Organizations" 26 | PROJECTS = "Projects" 27 | REST = "Rest" 28 | SECRETS = "Secrets" 29 | STORAGE = "Storage" 30 | 31 | @classmethod 32 | def list(cls) -> list[str]: 33 | """Return a list of all domain values.""" 34 | return [domain.value for domain in cls] 35 | 36 | 37 | class ApiSpecManager: 38 | """ 39 | Manages the OpenAPI specification for the Supabase Management API. 40 | Handles spec loading, caching, and validation. 41 | """ 42 | 43 | def __init__(self) -> None: 44 | self.spec: dict[str, Any] | None = None 45 | self._paths_cache: dict[str, dict[str, str]] | None = None 46 | self._domains_cache: list[str] | None = None 47 | 48 | async def _fetch_remote_spec(self) -> dict[str, Any] | None: 49 | """ 50 | Fetch latest OpenAPI spec from Supabase API. 51 | Returns None if fetch fails. 52 | """ 53 | try: 54 | async with httpx.AsyncClient() as client: 55 | response = await client.get(SPEC_URL) 56 | if response.status_code == 200: 57 | return response.json() 58 | logger.warning(f"Failed to fetch API spec: {response.status_code}") 59 | return None 60 | except Exception as e: 61 | logger.warning(f"Error fetching API spec: {e}") 62 | return None 63 | 64 | def _load_local_spec(self) -> dict[str, Any]: 65 | """ 66 | Load OpenAPI spec from local file. 67 | This is our fallback spec shipped with the server. 68 | """ 69 | try: 70 | with open(LOCAL_SPEC_PATH) as f: 71 | return json.load(f) 72 | except FileNotFoundError: 73 | logger.error(f"Local spec not found at {LOCAL_SPEC_PATH}") 74 | raise 75 | except json.JSONDecodeError as e: 76 | logger.error(f"Invalid JSON in local spec: {e}") 77 | raise 78 | 79 | async def get_spec(self) -> dict[str, Any]: 80 | """Retrieve the enriched spec.""" 81 | if self.spec is None: 82 | raw_spec = await self._fetch_remote_spec() 83 | if not raw_spec: 84 | # If remote fetch fails, use our fallback spec 85 | logger.info("Using fallback API spec") 86 | raw_spec = self._load_local_spec() 87 | self.spec = raw_spec 88 | 89 | return self.spec 90 | 91 | def get_all_paths_and_methods(self) -> dict[str, dict[str, str]]: 92 | """ 93 | Returns a dictionary of all paths and their methods with operation IDs. 94 | 95 | Returns: 96 | Dict[str, Dict[str, str]]: {path: {method: operationId}} 97 | """ 98 | if self._paths_cache is None: 99 | self._build_caches() 100 | return self._paths_cache or {} 101 | 102 | def get_paths_and_methods_by_domain(self, domain: str) -> dict[str, dict[str, str]]: 103 | """ 104 | Returns paths and methods within a specific domain (tag). 105 | 106 | Args: 107 | domain (str): The domain name (e.g., "Auth", "Projects"). 108 | 109 | Returns: 110 | Dict[str, Dict[str, str]]: {path: {method: operationId}} 111 | """ 112 | 113 | if self._paths_cache is None: 114 | self._build_caches() 115 | 116 | # Validate domain using enum 117 | try: 118 | valid_domain = ApiDomain(domain).value 119 | except ValueError as e: 120 | raise ValueError(f"Invalid domain: {domain}") from e 121 | 122 | domain_paths: dict[str, dict[str, str]] = {} 123 | if self.spec: 124 | for path, methods in self.spec.get("paths", {}).items(): 125 | for method, details in methods.items(): 126 | if valid_domain in details.get("tags", []): 127 | if path not in domain_paths: 128 | domain_paths[path] = {} 129 | domain_paths[path][method] = details.get("operationId", "") 130 | return domain_paths 131 | 132 | def get_all_domains(self) -> list[str]: 133 | """ 134 | Returns a list of all available domains (tags). 135 | 136 | Returns: 137 | List[str]: List of domain names. 138 | """ 139 | if self._domains_cache is None: 140 | self._build_caches() 141 | return self._domains_cache or [] 142 | 143 | def get_spec_for_path_and_method(self, path: str, method: str) -> dict[str, Any] | None: 144 | """ 145 | Returns the full specification for a given path and HTTP method. 146 | 147 | Args: 148 | path (str): The API path (e.g., "/v1/projects"). 149 | method (str): The HTTP method (e.g., "get", "post"). 150 | 151 | Returns: 152 | Optional[Dict[str, Any]]: The full spec for the operation, or None if not found. 153 | """ 154 | if self.spec is None: 155 | return None 156 | 157 | path_spec = self.spec.get("paths", {}).get(path) 158 | if path_spec: 159 | return path_spec.get(method.lower()) # Ensure lowercase method 160 | return None 161 | 162 | def get_spec_part(self, part: str, *args: str | int) -> Any: 163 | """ 164 | Safely retrieves a nested part of the OpenAPI spec. 165 | 166 | Args: 167 | part: The top-level key (e.g., 'paths', 'components'). 168 | *args: Subsequent keys or indices to traverse the spec. 169 | 170 | Returns: 171 | The value at the specified location in the spec, or None if not found. 172 | """ 173 | if self.spec is None: 174 | return None 175 | 176 | current = self.spec.get(part) 177 | for key in args: 178 | if isinstance(current, dict) and key in current: 179 | current = current[key] 180 | elif isinstance(current, list) and isinstance(key, int) and 0 <= key < len(current): 181 | current = current[key] 182 | else: 183 | return None # Key not found or invalid index 184 | return current 185 | 186 | def _build_caches(self) -> None: 187 | """ 188 | Build internal caches for faster lookups. 189 | This populates _paths_cache and _domains_cache. 190 | """ 191 | if self.spec is None: 192 | logger.error("Cannot build caches: OpenAPI spec not loaded") 193 | return 194 | 195 | # Build paths cache 196 | paths_cache: dict[str, dict[str, str]] = {} 197 | domains_set = set() 198 | 199 | for path, methods in self.spec.get("paths", {}).items(): 200 | for method, details in methods.items(): 201 | # Add to paths cache 202 | if path not in paths_cache: 203 | paths_cache[path] = {} 204 | paths_cache[path][method] = details.get("operationId", "") 205 | 206 | # Collect domains (tags) 207 | for tag in details.get("tags", []): 208 | domains_set.add(tag) 209 | 210 | self._paths_cache = paths_cache 211 | self._domains_cache = sorted(list(domains_set)) 212 | 213 | 214 | # Example usage (assuming you have an instance of ApiSpecManager called 'spec_manager'): 215 | async def main() -> None: 216 | """Test function to demonstrate ApiSpecManager functionality.""" 217 | # Create a new instance of ApiSpecManager 218 | spec_manager = ApiSpecManager() 219 | 220 | # Load the spec 221 | await spec_manager.get_spec() 222 | 223 | # Print the path to help debug 224 | print(f"Looking for spec at: {LOCAL_SPEC_PATH}") 225 | 226 | # 1. Get all domains 227 | all_domains = spec_manager.get_all_domains() 228 | print("\nAll Domains:") 229 | print(all_domains) 230 | 231 | # 2. Get all paths and methods 232 | all_paths = spec_manager.get_all_paths_and_methods() 233 | print("\nAll Paths and Methods (sample):") 234 | # Just print a few to avoid overwhelming output 235 | for i, (path, methods) in enumerate(all_paths.items()): 236 | if i >= 5: # Limit to 5 paths 237 | break 238 | print(f" {path}:") 239 | for method, operation_id in methods.items(): 240 | print(f" {method}: {operation_id}") 241 | 242 | # 3. Get paths and methods for the "Edge Functions" domain 243 | edge_paths = spec_manager.get_paths_and_methods_by_domain("Edge Functions") 244 | print("\nEdge Functions Paths and Methods:") 245 | for path, methods in edge_paths.items(): 246 | print(f" {path}:") 247 | for method, operation_id in methods.items(): 248 | print(f" {method}: {operation_id}") 249 | 250 | # 4. Get the full spec for a specific path and method 251 | path = "/v1/projects/{ref}/functions" 252 | method = "GET" 253 | full_spec = spec_manager.get_spec_for_path_and_method(path, method) 254 | print(f"\nFull Spec for {method} {path}:") 255 | if full_spec: 256 | print(json.dumps(full_spec, indent=2)[:500] + "...") # Truncate for readability 257 | else: 258 | print("Spec not found for this path/method") 259 | 260 | 261 | if __name__ == "__main__": 262 | import asyncio 263 | 264 | asyncio.run(main()) 265 | ``` -------------------------------------------------------------------------------- /supabase_mcp/tools/registry.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Literal 2 | 3 | from mcp.server.fastmcp import FastMCP 4 | 5 | from supabase_mcp.core.container import ServicesContainer 6 | from supabase_mcp.services.database.postgres_client import QueryResult 7 | from supabase_mcp.tools.manager import ToolName 8 | 9 | 10 | class ToolRegistry: 11 | """Responsible for registering tools with the MCP server""" 12 | 13 | def __init__(self, mcp: FastMCP, services_container: ServicesContainer): 14 | self.mcp = mcp 15 | self.services_container = services_container 16 | 17 | def register_tools(self) -> FastMCP: 18 | """Register all tools with the MCP server""" 19 | mcp = self.mcp 20 | services_container = self.services_container 21 | 22 | tool_manager = services_container.tool_manager 23 | feature_manager = services_container.feature_manager 24 | 25 | @mcp.tool(description=tool_manager.get_description(ToolName.GET_SCHEMAS)) # type: ignore 26 | async def get_schemas() -> QueryResult: 27 | """List all database schemas with their sizes and table counts.""" 28 | return await feature_manager.execute_tool(ToolName.GET_SCHEMAS, services_container=services_container) 29 | 30 | @mcp.tool(description=tool_manager.get_description(ToolName.GET_TABLES)) # type: ignore 31 | async def get_tables(schema_name: str) -> QueryResult: 32 | """List all tables, foreign tables, and views in a schema with their sizes, row counts, and metadata.""" 33 | return await feature_manager.execute_tool( 34 | ToolName.GET_TABLES, services_container=services_container, schema_name=schema_name 35 | ) 36 | 37 | @mcp.tool(description=tool_manager.get_description(ToolName.GET_TABLE_SCHEMA)) # type: ignore 38 | async def get_table_schema(schema_name: str, table: str) -> QueryResult: 39 | """Get detailed table structure including columns, keys, and relationships.""" 40 | return await feature_manager.execute_tool( 41 | ToolName.GET_TABLE_SCHEMA, 42 | services_container=services_container, 43 | schema_name=schema_name, 44 | table=table, 45 | ) 46 | 47 | @mcp.tool(description=tool_manager.get_description(ToolName.EXECUTE_POSTGRESQL)) # type: ignore 48 | async def execute_postgresql(query: str, migration_name: str = "") -> QueryResult: 49 | """Execute PostgreSQL statements against your Supabase database.""" 50 | return await feature_manager.execute_tool( 51 | ToolName.EXECUTE_POSTGRESQL, 52 | services_container=services_container, 53 | query=query, 54 | migration_name=migration_name, 55 | ) 56 | 57 | @mcp.tool(description=tool_manager.get_description(ToolName.RETRIEVE_MIGRATIONS)) # type: ignore 58 | async def retrieve_migrations( 59 | limit: int = 50, 60 | offset: int = 0, 61 | name_pattern: str = "", 62 | include_full_queries: bool = False, 63 | ) -> QueryResult: 64 | """Retrieve a list of all migrations a user has from Supabase. 65 | 66 | SAFETY: This is a low-risk read operation that can be executed in SAFE mode. 67 | """ 68 | 69 | result = await feature_manager.execute_tool( 70 | ToolName.RETRIEVE_MIGRATIONS, 71 | services_container=services_container, 72 | limit=limit, 73 | offset=offset, 74 | name_pattern=name_pattern, 75 | include_full_queries=include_full_queries, 76 | ) 77 | return QueryResult.model_validate(result) 78 | 79 | @mcp.tool(description=tool_manager.get_description(ToolName.SEND_MANAGEMENT_API_REQUEST)) # type: ignore 80 | async def send_management_api_request( 81 | method: str, 82 | path: str, 83 | path_params: dict[str, str], 84 | request_params: dict[str, Any], 85 | request_body: dict[str, Any], 86 | ) -> dict[str, Any]: 87 | """Execute a Supabase Management API request.""" 88 | return await feature_manager.execute_tool( 89 | ToolName.SEND_MANAGEMENT_API_REQUEST, 90 | services_container=services_container, 91 | method=method, 92 | path=path, 93 | path_params=path_params, 94 | request_params=request_params, 95 | request_body=request_body, 96 | ) 97 | 98 | @mcp.tool(description=tool_manager.get_description(ToolName.GET_MANAGEMENT_API_SPEC)) # type: ignore 99 | async def get_management_api_spec(params: dict[str, Any] = {}) -> dict[str, Any]: 100 | """Get the Supabase Management API specification. 101 | 102 | This tool can be used in four different ways (and then some ;)): 103 | 1. Without parameters: Returns all domains (default) 104 | 2. With path and method: Returns the full specification for a specific API endpoint 105 | 3. With domain only: Returns all paths and methods within that domain 106 | 4. With all_paths=True: Returns all paths and methods 107 | 108 | Args: 109 | params: Dictionary containing optional parameters: 110 | - path: Optional API path (e.g., "/v1/projects/{ref}/functions") 111 | - method: Optional HTTP method (e.g., "GET", "POST") 112 | - domain: Optional domain/tag name (e.g., "Auth", "Storage") 113 | - all_paths: If True, returns all paths and methods 114 | 115 | Returns: 116 | API specification based on the provided parameters 117 | """ 118 | return await feature_manager.execute_tool( 119 | ToolName.GET_MANAGEMENT_API_SPEC, services_container=services_container, params=params 120 | ) 121 | 122 | @mcp.tool(description=tool_manager.get_description(ToolName.GET_AUTH_ADMIN_METHODS_SPEC)) # type: ignore 123 | async def get_auth_admin_methods_spec() -> dict[str, Any]: 124 | """Get Python SDK methods specification for Auth Admin.""" 125 | return await feature_manager.execute_tool( 126 | ToolName.GET_AUTH_ADMIN_METHODS_SPEC, services_container=services_container 127 | ) 128 | 129 | @mcp.tool(description=tool_manager.get_description(ToolName.CALL_AUTH_ADMIN_METHOD)) # type: ignore 130 | async def call_auth_admin_method(method: str, params: dict[str, Any]) -> dict[str, Any]: 131 | """Call an Auth Admin method from Supabase Python SDK.""" 132 | return await feature_manager.execute_tool( 133 | ToolName.CALL_AUTH_ADMIN_METHOD, services_container=services_container, method=method, params=params 134 | ) 135 | 136 | @mcp.tool(description=tool_manager.get_description(ToolName.LIVE_DANGEROUSLY)) # type: ignore 137 | async def live_dangerously( 138 | service: Literal["api", "database"], enable_unsafe_mode: bool = False 139 | ) -> dict[str, Any]: 140 | """ 141 | Toggle between safe and unsafe operation modes for API or Database services. 142 | 143 | This function controls the safety level for operations, allowing you to: 144 | - Enable write operations for the database (INSERT, UPDATE, DELETE, schema changes) 145 | - Enable state-changing operations for the Management API 146 | """ 147 | return await feature_manager.execute_tool( 148 | ToolName.LIVE_DANGEROUSLY, 149 | services_container=services_container, 150 | service=service, 151 | enable_unsafe_mode=enable_unsafe_mode, 152 | ) 153 | 154 | @mcp.tool(description=tool_manager.get_description(ToolName.CONFIRM_DESTRUCTIVE_OPERATION)) # type: ignore 155 | async def confirm_destructive_operation( 156 | operation_type: Literal["api", "database"], confirmation_id: str, user_confirmation: bool = False 157 | ) -> QueryResult | dict[str, Any]: 158 | """Execute a destructive operation after confirmation. Use this only after reviewing the risks with the user.""" 159 | return await feature_manager.execute_tool( 160 | ToolName.CONFIRM_DESTRUCTIVE_OPERATION, 161 | services_container=services_container, 162 | operation_type=operation_type, 163 | confirmation_id=confirmation_id, 164 | user_confirmation=user_confirmation, 165 | ) 166 | 167 | @mcp.tool(description=tool_manager.get_description(ToolName.RETRIEVE_LOGS)) # type: ignore 168 | async def retrieve_logs( 169 | collection: str, 170 | limit: int = 20, 171 | hours_ago: int = 1, 172 | filters: list[dict[str, Any]] = [], 173 | search: str = "", 174 | custom_query: str = "", 175 | ) -> dict[str, Any]: 176 | """Retrieve logs from your Supabase project's services for debugging and monitoring.""" 177 | return await feature_manager.execute_tool( 178 | ToolName.RETRIEVE_LOGS, 179 | services_container=services_container, 180 | collection=collection, 181 | limit=limit, 182 | hours_ago=hours_ago, 183 | filters=filters, 184 | search=search, 185 | custom_query=custom_query, 186 | ) 187 | 188 | return mcp 189 | ``` -------------------------------------------------------------------------------- /tests/services/database/test_query_manager.py: -------------------------------------------------------------------------------- ```python 1 | from unittest.mock import AsyncMock, MagicMock 2 | 3 | import pytest 4 | 5 | from supabase_mcp.exceptions import SafetyError 6 | from supabase_mcp.services.database.query_manager import QueryManager 7 | from supabase_mcp.services.database.sql.loader import SQLLoader 8 | from supabase_mcp.services.database.sql.validator import ( 9 | QueryValidationResults, 10 | SQLQueryCategory, 11 | SQLQueryCommand, 12 | SQLValidator, 13 | ValidatedStatement, 14 | ) 15 | from supabase_mcp.services.safety.models import ClientType, OperationRiskLevel 16 | 17 | 18 | @pytest.mark.asyncio(loop_scope="module") 19 | class TestQueryManager: 20 | """Tests for the Query Manager.""" 21 | 22 | @pytest.mark.unit 23 | async def test_query_execution(self, mock_query_manager: QueryManager): 24 | """Test query execution through the Query Manager.""" 25 | 26 | query_manager = mock_query_manager 27 | 28 | # Ensure validator and safety_manager are proper mocks 29 | query_manager.validator = MagicMock() 30 | query_manager.safety_manager = MagicMock() 31 | 32 | # Create a mock validation result for a SELECT query 33 | validated_statement = ValidatedStatement( 34 | category=SQLQueryCategory.DQL, 35 | command=SQLQueryCommand.SELECT, 36 | risk_level=OperationRiskLevel.LOW, 37 | query="SELECT * FROM users", 38 | needs_migration=False, 39 | object_type="TABLE", 40 | schema_name="public", 41 | ) 42 | 43 | validation_result = QueryValidationResults( 44 | statements=[validated_statement], 45 | highest_risk_level=OperationRiskLevel.LOW, 46 | has_transaction_control=False, 47 | original_query="SELECT * FROM users", 48 | ) 49 | 50 | # Make the validator return our mock validation result 51 | query_manager.validator.validate_query.return_value = validation_result 52 | 53 | # Make the db_client return a mock query result 54 | mock_query_result = MagicMock() 55 | query_manager.db_client.execute_query = AsyncMock(return_value=mock_query_result) 56 | 57 | # Execute a query 58 | query = "SELECT * FROM users" 59 | result = await query_manager.handle_query(query) 60 | 61 | # Verify the validator was called with the query 62 | query_manager.validator.validate_query.assert_called_once_with(query) 63 | 64 | # Verify the db_client was called with the validation result 65 | query_manager.db_client.execute_query.assert_called_once_with(validation_result, False) 66 | 67 | # Verify the result is what we expect 68 | assert result == mock_query_result 69 | 70 | @pytest.mark.asyncio 71 | @pytest.mark.unit 72 | async def test_safety_validation_blocks_dangerous_query(self, mock_query_manager: QueryManager): 73 | """Test that the safety validation blocks dangerous queries.""" 74 | 75 | # Create a query manager with the mock dependencies 76 | query_manager = mock_query_manager 77 | 78 | # Ensure validator and safety_manager are proper mocks 79 | query_manager.validator = MagicMock() 80 | query_manager.safety_manager = MagicMock() 81 | 82 | # Create a mock validation result for a DROP TABLE query 83 | validated_statement = ValidatedStatement( 84 | category=SQLQueryCategory.DDL, 85 | command=SQLQueryCommand.DROP, 86 | risk_level=OperationRiskLevel.EXTREME, 87 | query="DROP TABLE users", 88 | needs_migration=False, 89 | object_type="TABLE", 90 | schema_name="public", 91 | ) 92 | 93 | validation_result = QueryValidationResults( 94 | statements=[validated_statement], 95 | highest_risk_level=OperationRiskLevel.EXTREME, 96 | has_transaction_control=False, 97 | original_query="DROP TABLE users", 98 | ) 99 | 100 | # Make the validator return our mock validation result 101 | query_manager.validator.validate_query.return_value = validation_result 102 | 103 | # Make the safety manager raise a SafetyError 104 | error_message = "Operation not allowed in SAFE mode" 105 | query_manager.safety_manager.validate_operation.side_effect = SafetyError(error_message) 106 | 107 | # Execute a query - should raise a SafetyError 108 | query = "DROP TABLE users" 109 | with pytest.raises(SafetyError) as excinfo: 110 | await query_manager.handle_query(query) 111 | 112 | # Verify the error message 113 | assert error_message in str(excinfo.value) 114 | 115 | # Verify the validator was called with the query 116 | query_manager.validator.validate_query.assert_called_once_with(query) 117 | 118 | # Verify the safety manager was called with the validation result 119 | query_manager.safety_manager.validate_operation.assert_called_once_with( 120 | ClientType.DATABASE, validation_result, False 121 | ) 122 | 123 | # Verify the db_client was not called 124 | query_manager.db_client.execute_query.assert_not_called() 125 | 126 | @pytest.mark.unit 127 | async def test_get_migrations_query(self, query_manager_integration: QueryManager): 128 | """Test that get_migrations_query returns a valid query string.""" 129 | # Test with default parameters 130 | query = query_manager_integration.get_migrations_query() 131 | assert isinstance(query, str) 132 | assert "supabase_migrations.schema_migrations" in query 133 | assert "LIMIT 50" in query 134 | 135 | # Test with custom parameters 136 | custom_query = query_manager_integration.get_migrations_query( 137 | limit=10, offset=5, name_pattern="test", include_full_queries=True 138 | ) 139 | assert isinstance(custom_query, str) 140 | assert "supabase_migrations.schema_migrations" in custom_query 141 | assert "LIMIT 10" in custom_query 142 | assert "OFFSET 5" in custom_query 143 | assert "name ILIKE" in custom_query 144 | assert "statements" in custom_query # Should include statements column when include_full_queries=True 145 | 146 | @pytest.mark.unit 147 | async def test_init_migration_schema(self): 148 | """Test that init_migration_schema initializes the migration schema correctly.""" 149 | # Create minimal mocks 150 | postgres_client = MagicMock() 151 | postgres_client.execute_query = AsyncMock() 152 | 153 | safety_manager = MagicMock() 154 | 155 | # Create a real SQLLoader and SQLValidator 156 | sql_loader = SQLLoader() 157 | sql_validator = SQLValidator() 158 | 159 | # Create the QueryManager with minimal mocking 160 | query_manager = QueryManager( 161 | postgres_client=postgres_client, 162 | safety_manager=safety_manager, 163 | sql_validator=sql_validator, 164 | sql_loader=sql_loader, 165 | ) 166 | 167 | # Call the method 168 | await query_manager.init_migration_schema() 169 | 170 | # Verify that the SQL loader was used to get the init migrations query 171 | # and that the query was executed 172 | assert postgres_client.execute_query.called 173 | 174 | # Get the arguments that execute_query was called with 175 | call_args = postgres_client.execute_query.call_args 176 | assert call_args is not None 177 | 178 | # The first argument should be a QueryValidationResults object 179 | args, _ = call_args # Use _ to ignore unused kwargs 180 | assert len(args) > 0 181 | validation_result = args[0] 182 | assert isinstance(validation_result, QueryValidationResults) 183 | 184 | # Check that the validation result contains the expected SQL 185 | init_query = sql_loader.get_init_migrations_query() 186 | assert any(stmt.query and stmt.query in init_query for stmt in validation_result.statements) 187 | 188 | @pytest.mark.unit 189 | async def test_handle_migration(self): 190 | """Test that handle_migration correctly handles migrations when needed.""" 191 | # Create minimal mocks 192 | postgres_client = MagicMock() 193 | postgres_client.execute_query = AsyncMock() 194 | 195 | safety_manager = MagicMock() 196 | 197 | # Create a real SQLLoader 198 | sql_loader = SQLLoader() 199 | 200 | # Create a mock MigrationManager 201 | migration_manager = MagicMock() 202 | migration_query = "INSERT INTO _migrations.migrations (name) VALUES ('test_migration')" 203 | migration_name = "test_migration" 204 | migration_manager.prepare_migration_query.return_value = (migration_query, migration_name) 205 | 206 | # Create a real SQLValidator 207 | sql_validator = SQLValidator() 208 | 209 | # Create the QueryManager with minimal mocking 210 | query_manager = QueryManager( 211 | postgres_client=postgres_client, 212 | safety_manager=safety_manager, 213 | sql_validator=sql_validator, 214 | sql_loader=sql_loader, 215 | migration_manager=migration_manager, 216 | ) 217 | 218 | # Create a validation result that needs migration 219 | validated_statement = ValidatedStatement( 220 | category=SQLQueryCategory.DDL, 221 | command=SQLQueryCommand.CREATE, 222 | risk_level=OperationRiskLevel.MEDIUM, 223 | query="CREATE TABLE test (id INT)", 224 | needs_migration=True, 225 | object_type="TABLE", 226 | schema_name="public", 227 | ) 228 | 229 | validation_result = QueryValidationResults( 230 | statements=[validated_statement], 231 | highest_risk_level=OperationRiskLevel.MEDIUM, 232 | has_transaction_control=False, 233 | original_query="CREATE TABLE test (id INT)", 234 | ) 235 | 236 | # Call the method 237 | await query_manager.handle_migration(validation_result, "CREATE TABLE test (id INT)", "test_migration") 238 | 239 | # Verify that the migration manager was called to prepare the migration query 240 | migration_manager.prepare_migration_query.assert_called_once_with( 241 | validation_result, "CREATE TABLE test (id INT)", "test_migration" 242 | ) 243 | 244 | # Verify that execute_query was called at least twice 245 | # Once for init_migration_schema and once for the migration query 246 | assert postgres_client.execute_query.call_count >= 2 247 | ``` -------------------------------------------------------------------------------- /supabase_mcp/services/safety/safety_manager.py: -------------------------------------------------------------------------------- ```python 1 | import time 2 | import uuid 3 | from typing import Any, Optional 4 | 5 | from supabase_mcp.exceptions import ConfirmationRequiredError, OperationNotAllowedError 6 | from supabase_mcp.logger import logger 7 | from supabase_mcp.services.safety.models import ClientType, SafetyMode 8 | from supabase_mcp.services.safety.safety_configs import APISafetyConfig, SafetyConfigBase, SQLSafetyConfig 9 | 10 | 11 | class SafetyManager: 12 | """A singleton service that maintains current safety state. 13 | 14 | Provides methods to: 15 | - Get/set safety modes for different clients 16 | - Register safety configurations 17 | - Check if operations are allowed 18 | Serves as the central point for safety decisions""" 19 | 20 | _instance: Optional["SafetyManager"] = None 21 | 22 | def __init__(self) -> None: 23 | """Initialize the safety manager with default safety modes.""" 24 | self._safety_modes: dict[ClientType, SafetyMode] = { 25 | ClientType.DATABASE: SafetyMode.SAFE, 26 | ClientType.API: SafetyMode.SAFE, 27 | } 28 | self._safety_configs: dict[ClientType, SafetyConfigBase[Any]] = {} 29 | self._pending_confirmations: dict[str, dict[str, Any]] = {} 30 | self._confirmation_expiry = 300 # 5 minutes in seconds 31 | 32 | @classmethod 33 | def get_instance(cls) -> "SafetyManager": 34 | """Get the singleton instance of the safety manager.""" 35 | if cls._instance is None: 36 | cls._instance = SafetyManager() 37 | return cls._instance 38 | 39 | def register_safety_configs(self) -> bool: 40 | """Register all safety configurations with the SafetyManager. 41 | 42 | Returns: 43 | bool: True if all configurations were registered successfully 44 | """ 45 | # Register SQL safety config 46 | sql_config = SQLSafetyConfig() 47 | self.register_config(ClientType.DATABASE, sql_config) 48 | 49 | # Register API safety config 50 | api_config = APISafetyConfig() 51 | self.register_config(ClientType.API, api_config) 52 | 53 | logger.info("✓ Safety configurations registered successfully") 54 | return True 55 | 56 | def register_config(self, client_type: ClientType, config: SafetyConfigBase[Any]) -> None: 57 | """Register a safety configuration for a client type. 58 | 59 | Args: 60 | client_type: The client type to register the configuration for 61 | config: The safety configuration for the client 62 | """ 63 | self._safety_configs[client_type] = config 64 | 65 | def get_safety_mode(self, client_type: ClientType) -> SafetyMode: 66 | """Get the current safety mode for a client type. 67 | 68 | Args: 69 | client_type: The client type to get the safety mode for 70 | 71 | Returns: 72 | The current safety mode for the client type 73 | """ 74 | if client_type not in self._safety_modes: 75 | logger.warning(f"No safety mode registered for {client_type}, defaulting to SAFE") 76 | return SafetyMode.SAFE 77 | return self._safety_modes[client_type] 78 | 79 | def set_safety_mode(self, client_type: ClientType, mode: SafetyMode) -> None: 80 | """Set the safety mode for a client type. 81 | 82 | Args: 83 | client_type: The client type to set the safety mode for 84 | mode: The safety mode to set 85 | """ 86 | self._safety_modes[client_type] = mode 87 | logger.debug(f"Set safety mode for {client_type} to {mode}") 88 | 89 | def validate_operation( 90 | self, 91 | client_type: ClientType, 92 | operation: Any, 93 | has_confirmation: bool = False, 94 | ) -> None: 95 | """Validate if an operation is allowed for a client type. 96 | 97 | This method will raise appropriate exceptions if the operation is not allowed 98 | or requires confirmation. 99 | 100 | Args: 101 | client_type: The client type to check the operation for 102 | operation: The operation to check 103 | has_confirmation: Whether the operation has been confirmed by the user 104 | 105 | Raises: 106 | OperationNotAllowedError: If the operation is not allowed in the current safety mode 107 | ConfirmationRequiredError: If the operation requires confirmation and has_confirmation is False 108 | """ 109 | # Get the current safety mode and config 110 | mode = self.get_safety_mode(client_type) 111 | config = self._safety_configs.get(client_type) 112 | 113 | if not config: 114 | message = f"No safety configuration registered for {client_type}" 115 | logger.warning(message) 116 | raise OperationNotAllowedError(message) 117 | 118 | # Get the risk level for the operation 119 | risk_level = config.get_risk_level(operation) 120 | logger.debug(f"Operation risk level: {risk_level}") 121 | 122 | # Check if the operation is allowed in the current mode 123 | is_allowed = config.is_operation_allowed(risk_level, mode) 124 | if not is_allowed: 125 | message = f"Operation with risk level {risk_level} is not allowed in {mode} mode" 126 | logger.debug(f"Operation with risk level {risk_level} not allowed in {mode} mode") 127 | raise OperationNotAllowedError(message) 128 | 129 | # Check if the operation needs confirmation 130 | needs_confirmation = config.needs_confirmation(risk_level) 131 | if needs_confirmation and not has_confirmation: 132 | # Store the operation for later confirmation 133 | confirmation_id = self._store_confirmation(client_type, operation, risk_level) 134 | 135 | message = ( 136 | f"Operation with risk level {risk_level} requires explicit user confirmation.\n\n" 137 | f"WHAT HAPPENED: This high-risk operation was rejected for safety reasons.\n" 138 | f"WHAT TO DO: 1. Review the operation with the user and explain the risks\n" 139 | f" 2. If the user approves, use the confirmation tool with this ID: {confirmation_id}\n\n" 140 | f'CONFIRMATION COMMAND: confirm_destructive_postgresql(confirmation_id="{confirmation_id}", user_confirmation=True)' 141 | ) 142 | logger.debug( 143 | f"Operation with risk level {risk_level} requires confirmation, stored with ID {confirmation_id}" 144 | ) 145 | raise ConfirmationRequiredError(message) 146 | 147 | logger.debug(f"Operation with risk level {risk_level} allowed in {mode} mode") 148 | 149 | def _store_confirmation(self, client_type: ClientType, operation: Any, risk_level: int) -> str: 150 | """Store an operation that needs confirmation. 151 | 152 | Args: 153 | client_type: The client type the operation is for 154 | operation: The operation to store 155 | risk_level: The risk level of the operation 156 | 157 | Returns: 158 | A unique confirmation ID 159 | """ 160 | # Generate a unique ID 161 | confirmation_id = f"conf_{uuid.uuid4().hex[:8]}" 162 | 163 | # Store the operation with metadata 164 | self._pending_confirmations[confirmation_id] = { 165 | "operation": operation, 166 | "client_type": client_type, 167 | "risk_level": risk_level, 168 | "timestamp": time.time(), 169 | } 170 | 171 | # Clean up expired confirmations 172 | self._cleanup_expired_confirmations() 173 | 174 | return confirmation_id 175 | 176 | def _get_confirmation(self, confirmation_id: str) -> dict[str, Any] | None: 177 | """Retrieve a stored confirmation by ID. 178 | 179 | Args: 180 | confirmation_id: The ID of the confirmation to retrieve 181 | 182 | Returns: 183 | The stored confirmation data or None if not found or expired 184 | """ 185 | # Clean up expired confirmations first 186 | self._cleanup_expired_confirmations() 187 | 188 | # Return the stored confirmation if it exists 189 | return self._pending_confirmations.get(confirmation_id) 190 | 191 | def _cleanup_expired_confirmations(self) -> None: 192 | """Remove expired confirmations from storage.""" 193 | current_time = time.time() 194 | expired_ids = [ 195 | conf_id 196 | for conf_id, data in self._pending_confirmations.items() 197 | if current_time - data["timestamp"] > self._confirmation_expiry 198 | ] 199 | 200 | for conf_id in expired_ids: 201 | logger.debug(f"Removing expired confirmation with ID {conf_id}") 202 | del self._pending_confirmations[conf_id] 203 | 204 | def get_stored_operation(self, confirmation_id: str) -> Any | None: 205 | """Get a stored operation by its confirmation ID. 206 | 207 | Args: 208 | confirmation_id: The confirmation ID to get the operation for 209 | 210 | Returns: 211 | The stored operation, or None if not found 212 | """ 213 | confirmation = self._get_confirmation(confirmation_id) 214 | if confirmation is None: 215 | return None 216 | return confirmation.get("operation") 217 | 218 | def get_operations_by_risk_level( 219 | self, risk_level: str, client_type: ClientType = ClientType.DATABASE 220 | ) -> dict[str, list[str]]: 221 | """Get operations for a specific risk level. 222 | 223 | Args: 224 | risk_level: The risk level to get operations for 225 | client_type: The client type to get operations for 226 | 227 | Returns: 228 | A dictionary mapping HTTP methods to lists of paths 229 | """ 230 | # Get the config for the specified client type 231 | config = self._safety_configs.get(client_type) 232 | if not config or not hasattr(config, "PATH_SAFETY_CONFIG"): 233 | return {} 234 | 235 | # Get the operations for this risk level 236 | risk_config = getattr(config, "PATH_SAFETY_CONFIG", {}) 237 | if risk_level in risk_config: 238 | return risk_config[risk_level] 239 | 240 | def get_current_mode(self, client_type: ClientType) -> str: 241 | """Get the current safety mode as a string. 242 | 243 | Args: 244 | client_type: The client type to get the mode for 245 | 246 | Returns: 247 | The current safety mode as a string 248 | """ 249 | mode = self.get_safety_mode(client_type) 250 | return str(mode) 251 | 252 | @classmethod 253 | def reset(cls) -> None: 254 | """Reset the singleton instance cleanly. 255 | 256 | This closes any open connections and resets the singleton instance. 257 | """ 258 | if cls._instance is not None: 259 | cls._instance = None 260 | logger.info("SafetyManager instance reset complete") 261 | ``` -------------------------------------------------------------------------------- /supabase_mcp/clients/sdk_client.py: -------------------------------------------------------------------------------- ```python 1 | from __future__ import annotations 2 | 3 | from typing import Any, TypeVar 4 | 5 | from pydantic import BaseModel, ValidationError 6 | from supabase import AsyncClient, create_async_client 7 | from supabase.lib.client_options import AsyncClientOptions 8 | 9 | from supabase_mcp.exceptions import PythonSDKError 10 | from supabase_mcp.logger import logger 11 | from supabase_mcp.services.sdk.auth_admin_models import ( 12 | PARAM_MODELS, 13 | CreateUserParams, 14 | DeleteFactorParams, 15 | DeleteUserParams, 16 | GenerateLinkParams, 17 | GetUserByIdParams, 18 | InviteUserByEmailParams, 19 | ListUsersParams, 20 | UpdateUserByIdParams, 21 | ) 22 | from supabase_mcp.services.sdk.auth_admin_sdk_spec import get_auth_admin_methods_spec 23 | from supabase_mcp.settings import Settings 24 | 25 | T = TypeVar("T", bound=BaseModel) 26 | 27 | 28 | class IncorrectSDKParamsError(PythonSDKError): 29 | """Error raised when the parameters passed to the SDK are incorrect.""" 30 | 31 | pass 32 | 33 | 34 | class SupabaseSDKClient: 35 | """Supabase Python SDK client, which exposes functionality related to Auth admin of the Python SDK.""" 36 | 37 | _instance: SupabaseSDKClient | None = None 38 | 39 | def __init__( 40 | self, 41 | settings: Settings | None = None, 42 | project_ref: str | None = None, 43 | service_role_key: str | None = None, 44 | ): 45 | self.client: AsyncClient | None = None 46 | self.settings = settings 47 | self.project_ref = settings.supabase_project_ref if settings else project_ref 48 | self.service_role_key = settings.supabase_service_role_key if settings else service_role_key 49 | self.supabase_url = self.get_supabase_url() 50 | logger.info(f"✔️ Supabase SDK client initialized successfully for project {self.project_ref}") 51 | 52 | def get_supabase_url(self) -> str: 53 | """Returns the Supabase URL based on the project reference""" 54 | if not self.project_ref: 55 | raise PythonSDKError("Project reference is not set") 56 | if self.project_ref.startswith("127.0.0.1"): 57 | # Return the default Supabase API URL 58 | return "http://127.0.0.1:54321" 59 | return f"https://{self.project_ref}.supabase.co" 60 | 61 | @classmethod 62 | def create( 63 | cls, 64 | settings: Settings | None = None, 65 | project_ref: str | None = None, 66 | service_role_key: str | None = None, 67 | ) -> SupabaseSDKClient: 68 | if cls._instance is None: 69 | cls._instance = cls(settings, project_ref, service_role_key) 70 | return cls._instance 71 | 72 | @classmethod 73 | def get_instance( 74 | cls, 75 | settings: Settings | None = None, 76 | project_ref: str | None = None, 77 | service_role_key: str | None = None, 78 | ) -> SupabaseSDKClient: 79 | """Returns the singleton instance""" 80 | if cls._instance is None: 81 | cls.create(settings, project_ref, service_role_key) 82 | return cls._instance 83 | 84 | async def create_client(self) -> AsyncClient: 85 | """Creates a new Supabase client""" 86 | try: 87 | client = await create_async_client( 88 | self.supabase_url, 89 | self.service_role_key, 90 | options=AsyncClientOptions( 91 | auto_refresh_token=False, 92 | persist_session=False, 93 | ), 94 | ) 95 | return client 96 | except Exception as e: 97 | logger.error(f"Error creating Supabase client: {e}") 98 | raise PythonSDKError(f"Error creating Supabase client: {e}") from e 99 | 100 | async def get_client(self) -> AsyncClient: 101 | """Returns the Supabase client""" 102 | if not self.client: 103 | self.client = await self.create_client() 104 | logger.info(f"Created Supabase SDK client for project {self.project_ref}") 105 | return self.client 106 | 107 | async def close(self) -> None: 108 | """Reset the client reference to allow garbage collection.""" 109 | self.client = None 110 | logger.info("Supabase SDK client reference cleared") 111 | 112 | def return_python_sdk_spec(self) -> dict: 113 | """Returns the Python SDK spec""" 114 | return get_auth_admin_methods_spec() 115 | 116 | def _validate_params(self, method: str, params: dict, param_model_cls: type[T]) -> T: 117 | """Validate parameters using the appropriate Pydantic model""" 118 | try: 119 | return param_model_cls.model_validate(params) 120 | except ValidationError as e: 121 | raise PythonSDKError(f"Invalid parameters for method {method}: {str(e)}") from e 122 | 123 | async def _get_user_by_id(self, params: GetUserByIdParams) -> dict: 124 | """Get user by ID implementation""" 125 | self.client = await self.get_client() 126 | admin_auth_client = self.client.auth.admin 127 | result = await admin_auth_client.get_user_by_id(params.uid) 128 | return result 129 | 130 | async def _list_users(self, params: ListUsersParams) -> dict: 131 | """List users implementation""" 132 | self.client = await self.get_client() 133 | admin_auth_client = self.client.auth.admin 134 | result = await admin_auth_client.list_users(page=params.page, per_page=params.per_page) 135 | return result 136 | 137 | async def _create_user(self, params: CreateUserParams) -> dict: 138 | """Create user implementation""" 139 | self.client = await self.get_client() 140 | admin_auth_client = self.client.auth.admin 141 | user_data = params.model_dump(exclude_none=True) 142 | result = await admin_auth_client.create_user(user_data) 143 | return result 144 | 145 | async def _delete_user(self, params: DeleteUserParams) -> dict: 146 | """Delete user implementation""" 147 | self.client = await self.get_client() 148 | admin_auth_client = self.client.auth.admin 149 | result = await admin_auth_client.delete_user(params.id, should_soft_delete=params.should_soft_delete) 150 | return result 151 | 152 | async def _invite_user_by_email(self, params: InviteUserByEmailParams) -> dict: 153 | """Invite user by email implementation""" 154 | self.client = await self.get_client() 155 | admin_auth_client = self.client.auth.admin 156 | options = params.options if params.options else {} 157 | result = await admin_auth_client.invite_user_by_email(params.email, options) 158 | return result 159 | 160 | async def _generate_link(self, params: GenerateLinkParams) -> dict: 161 | """Generate link implementation""" 162 | self.client = await self.get_client() 163 | admin_auth_client = self.client.auth.admin 164 | 165 | # Create a params dictionary as expected by the SDK 166 | params_dict = params.model_dump(exclude_none=True) 167 | 168 | try: 169 | # The SDK expects a single 'params' parameter containing all the fields 170 | result = await admin_auth_client.generate_link(params=params_dict) 171 | return result 172 | except TypeError as e: 173 | # Catch parameter errors and provide a more helpful message 174 | error_msg = str(e) 175 | if "unexpected keyword argument" in error_msg: 176 | raise IncorrectSDKParamsError( 177 | f"Incorrect parameters for generate_link: {error_msg}. " 178 | f"Please check the SDK specification for the correct parameter structure." 179 | ) from e 180 | raise 181 | 182 | async def _update_user_by_id(self, params: UpdateUserByIdParams) -> dict: 183 | """Update user by ID implementation""" 184 | self.client = await self.get_client() 185 | admin_auth_client = self.client.auth.admin 186 | uid = params.uid 187 | attributes = params.attributes.model_dump(exclude={"uid"}, exclude_none=True) 188 | result = await admin_auth_client.update_user_by_id(uid, attributes) 189 | return result 190 | 191 | async def _delete_factor(self, params: DeleteFactorParams) -> dict: 192 | """Delete factor implementation""" 193 | # This method is not implemented in the Supabase SDK yet 194 | raise NotImplementedError("The delete_factor method is not implemented in the Supabase SDK yet") 195 | 196 | async def call_auth_admin_method(self, method: str, params: dict[str, Any]) -> Any: 197 | """Calls a method of the Python SDK client""" 198 | # Check if service role key is available 199 | if not self.service_role_key: 200 | raise PythonSDKError( 201 | "Supabase service role key is not configured. Set SUPABASE_SERVICE_ROLE_KEY environment variable to use Auth Admin tools." 202 | ) 203 | 204 | if not self.client: 205 | self.client = await self.get_client() 206 | if not self.client: 207 | raise PythonSDKError("Python SDK client not initialized") 208 | 209 | # Validate method exists 210 | if method not in PARAM_MODELS: 211 | available_methods = ", ".join(PARAM_MODELS.keys()) 212 | raise PythonSDKError(f"Unknown method: {method}. Available methods: {available_methods}") 213 | 214 | # Get the appropriate model class and validate parameters 215 | param_model_cls = PARAM_MODELS[method] 216 | validated_params = self._validate_params(method, params, param_model_cls) 217 | 218 | # Method dispatch using a dictionary of method implementations 219 | method_handlers = { 220 | "get_user_by_id": self._get_user_by_id, 221 | "list_users": self._list_users, 222 | "create_user": self._create_user, 223 | "delete_user": self._delete_user, 224 | "invite_user_by_email": self._invite_user_by_email, 225 | "generate_link": self._generate_link, 226 | "update_user_by_id": self._update_user_by_id, 227 | "delete_factor": self._delete_factor, 228 | } 229 | 230 | # Call the appropriate method handler 231 | try: 232 | handler = method_handlers.get(method) 233 | if not handler: 234 | raise PythonSDKError(f"Method {method} is not implemented") 235 | 236 | logger.debug(f"Python SDK request params: {validated_params}") 237 | return await handler(validated_params) 238 | except Exception as e: 239 | if isinstance(e, IncorrectSDKParamsError): 240 | # Re-raise our custom error without wrapping it 241 | raise e 242 | logger.error(f"Error calling {method}: {e}") 243 | raise PythonSDKError(f"Error calling {method}: {str(e)}") from e 244 | 245 | @classmethod 246 | def reset(cls) -> None: 247 | """Reset the singleton instance cleanly. 248 | 249 | This closes any open connections and resets the singleton instance. 250 | """ 251 | if cls._instance is not None: 252 | cls._instance = None 253 | logger.info("SupabaseSDKClient instance reset complete") 254 | ``` -------------------------------------------------------------------------------- /tests/services/sdk/test_auth_admin_models.py: -------------------------------------------------------------------------------- ```python 1 | import pytest 2 | from pydantic import ValidationError 3 | 4 | from supabase_mcp.services.sdk.auth_admin_models import ( 5 | PARAM_MODELS, 6 | AdminUserAttributes, 7 | CreateUserParams, 8 | DeleteFactorParams, 9 | DeleteUserParams, 10 | GenerateLinkParams, 11 | GetUserByIdParams, 12 | InviteUserByEmailParams, 13 | ListUsersParams, 14 | UpdateUserByIdParams, 15 | UserMetadata, 16 | ) 17 | 18 | 19 | class TestModelConversion: 20 | """Test conversion from JSON data to models and validation""" 21 | 22 | def test_get_user_by_id_conversion(self): 23 | """Test conversion of get_user_by_id JSON data""" 24 | # Valid payload 25 | valid_payload = {"uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"} 26 | params = GetUserByIdParams.model_validate(valid_payload) 27 | assert params.uid == valid_payload["uid"] 28 | 29 | # Invalid payload (missing required uid) 30 | invalid_payload = {} 31 | with pytest.raises(ValidationError) as excinfo: 32 | GetUserByIdParams.model_validate(invalid_payload) 33 | assert "uid" in str(excinfo.value) 34 | 35 | def test_list_users_conversion(self): 36 | """Test conversion of list_users JSON data""" 37 | # Valid payload with custom values 38 | valid_payload = {"page": 2, "per_page": 20} 39 | params = ListUsersParams.model_validate(valid_payload) 40 | assert params.page == valid_payload["page"] 41 | assert params.per_page == valid_payload["per_page"] 42 | 43 | # Valid payload with defaults 44 | empty_payload = {} 45 | params = ListUsersParams.model_validate(empty_payload) 46 | assert params.page == 1 47 | assert params.per_page == 50 48 | 49 | # Invalid payload (non-integer values) 50 | invalid_payload = {"page": "not-a-number", "per_page": "also-not-a-number"} 51 | with pytest.raises(ValidationError) as excinfo: 52 | ListUsersParams.model_validate(invalid_payload) 53 | assert "page" in str(excinfo.value) 54 | 55 | def test_create_user_conversion(self): 56 | """Test conversion of create_user JSON data""" 57 | # Valid payload with email 58 | valid_payload = { 59 | "email": "[email protected]", 60 | "password": "secure-password", 61 | "email_confirm": True, 62 | "user_metadata": UserMetadata(email="[email protected]"), 63 | } 64 | params = CreateUserParams.model_validate(valid_payload) 65 | assert params.email == valid_payload["email"] 66 | assert params.password == valid_payload["password"] 67 | assert params.email_confirm is True 68 | assert params.user_metadata == valid_payload["user_metadata"] 69 | 70 | # Valid payload with phone 71 | valid_phone_payload = { 72 | "phone": "+1234567890", 73 | "password": "secure-password", 74 | "phone_confirm": True, 75 | } 76 | params = CreateUserParams.model_validate(valid_phone_payload) 77 | assert params.phone == valid_phone_payload["phone"] 78 | assert params.password == valid_phone_payload["password"] 79 | assert params.phone_confirm is True 80 | 81 | # Invalid payload (missing both email and phone) 82 | invalid_payload = {"password": "secure-password"} 83 | with pytest.raises(ValidationError) as excinfo: 84 | CreateUserParams.model_validate(invalid_payload) 85 | assert "Either email or phone must be provided" in str(excinfo.value) 86 | 87 | def test_delete_user_conversion(self): 88 | """Test conversion of delete_user JSON data""" 89 | # Valid payload with custom values 90 | valid_payload = {"id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "should_soft_delete": True} 91 | params = DeleteUserParams.model_validate(valid_payload) 92 | assert params.id == valid_payload["id"] 93 | assert params.should_soft_delete is True 94 | 95 | # Valid payload with defaults 96 | valid_payload = {"id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"} 97 | params = DeleteUserParams.model_validate(valid_payload) 98 | assert params.id == valid_payload["id"] 99 | assert params.should_soft_delete is False 100 | 101 | # Invalid payload (missing id) 102 | invalid_payload = {"should_soft_delete": True} 103 | with pytest.raises(ValidationError) as excinfo: 104 | DeleteUserParams.model_validate(invalid_payload) 105 | assert "id" in str(excinfo.value) 106 | 107 | def test_invite_user_by_email_conversion(self): 108 | """Test conversion of invite_user_by_email JSON data""" 109 | # Valid payload with options 110 | valid_payload = { 111 | "email": "[email protected]", 112 | "options": {"data": {"name": "Invited User"}, "redirect_to": "https://example.com/welcome"}, 113 | } 114 | params = InviteUserByEmailParams.model_validate(valid_payload) 115 | assert params.email == valid_payload["email"] 116 | assert params.options == valid_payload["options"] 117 | 118 | # Valid payload without options 119 | valid_payload = {"email": "[email protected]"} 120 | params = InviteUserByEmailParams.model_validate(valid_payload) 121 | assert params.email == valid_payload["email"] 122 | assert params.options is None 123 | 124 | # Invalid payload (missing email) 125 | invalid_payload = {"options": {"data": {"name": "Invited User"}}} 126 | with pytest.raises(ValidationError) as excinfo: 127 | InviteUserByEmailParams.model_validate(invalid_payload) 128 | assert "email" in str(excinfo.value) 129 | 130 | def test_generate_link_conversion(self): 131 | """Test conversion of generate_link JSON data""" 132 | # Valid signup link payload 133 | valid_signup_payload = { 134 | "type": "signup", 135 | "email": "[email protected]", 136 | "password": "secure-password", 137 | "options": {"data": {"name": "New User"}, "redirect_to": "https://example.com/welcome"}, 138 | } 139 | params = GenerateLinkParams.model_validate(valid_signup_payload) 140 | assert params.type == valid_signup_payload["type"] 141 | assert params.email == valid_signup_payload["email"] 142 | assert params.password == valid_signup_payload["password"] 143 | assert params.options == valid_signup_payload["options"] 144 | 145 | # Valid email_change link payload 146 | valid_email_change_payload = { 147 | "type": "email_change_current", 148 | "email": "[email protected]", 149 | "new_email": "[email protected]", 150 | } 151 | params = GenerateLinkParams.model_validate(valid_email_change_payload) 152 | assert params.type == valid_email_change_payload["type"] 153 | assert params.email == valid_email_change_payload["email"] 154 | assert params.new_email == valid_email_change_payload["new_email"] 155 | 156 | # Invalid payload (missing password for signup) 157 | invalid_signup_payload = { 158 | "type": "signup", 159 | "email": "[email protected]", 160 | } 161 | with pytest.raises(ValidationError) as excinfo: 162 | GenerateLinkParams.model_validate(invalid_signup_payload) 163 | assert "Password is required for signup links" in str(excinfo.value) 164 | 165 | # Invalid payload (missing new_email for email_change) 166 | invalid_email_change_payload = { 167 | "type": "email_change_current", 168 | "email": "[email protected]", 169 | } 170 | with pytest.raises(ValidationError) as excinfo: 171 | GenerateLinkParams.model_validate(invalid_email_change_payload) 172 | assert "new_email is required for email change links" in str(excinfo.value) 173 | 174 | # Invalid payload (invalid type) 175 | invalid_type_payload = { 176 | "type": "invalid-type", 177 | "email": "[email protected]", 178 | } 179 | with pytest.raises(ValidationError) as excinfo: 180 | GenerateLinkParams.model_validate(invalid_type_payload) 181 | assert "type" in str(excinfo.value) 182 | 183 | def test_update_user_by_id_conversion(self): 184 | """Test conversion of update_user_by_id JSON data""" 185 | # Valid payload 186 | valid_payload = { 187 | "uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", 188 | "attributes": AdminUserAttributes(email="[email protected]", email_verified=True), 189 | } 190 | params = UpdateUserByIdParams.model_validate(valid_payload) 191 | assert params.uid == valid_payload["uid"] 192 | assert params.attributes == valid_payload["attributes"] 193 | 194 | # Invalid payload (incorrect metadata and missing uids) 195 | invalid_payload = { 196 | "email": "[email protected]", 197 | "user_metadata": {"name": "Updated User"}, 198 | } 199 | with pytest.raises(ValidationError) as excinfo: 200 | UpdateUserByIdParams.model_validate(invalid_payload) 201 | assert "uid" in str(excinfo.value) 202 | 203 | def test_delete_factor_conversion(self): 204 | """Test conversion of delete_factor JSON data""" 205 | # Valid payload 206 | valid_payload = { 207 | "user_id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", 208 | "id": "totp-factor-id-123", 209 | } 210 | params = DeleteFactorParams.model_validate(valid_payload) 211 | assert params.user_id == valid_payload["user_id"] 212 | assert params.id == valid_payload["id"] 213 | 214 | # Invalid payload (missing user_id) 215 | invalid_payload = { 216 | "id": "totp-factor-id-123", 217 | } 218 | with pytest.raises(ValidationError) as excinfo: 219 | DeleteFactorParams.model_validate(invalid_payload) 220 | assert "user_id" in str(excinfo.value) 221 | 222 | # Invalid payload (missing id) 223 | invalid_payload = { 224 | "user_id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", 225 | } 226 | with pytest.raises(ValidationError) as excinfo: 227 | DeleteFactorParams.model_validate(invalid_payload) 228 | assert "id" in str(excinfo.value) 229 | 230 | def test_param_models_mapping(self): 231 | """Test PARAM_MODELS mapping functionality""" 232 | # Test that all methods have the correct corresponding model 233 | method_model_pairs = [ 234 | ("get_user_by_id", GetUserByIdParams), 235 | ("list_users", ListUsersParams), 236 | ("create_user", CreateUserParams), 237 | ("delete_user", DeleteUserParams), 238 | ("invite_user_by_email", InviteUserByEmailParams), 239 | ("generate_link", GenerateLinkParams), 240 | ("update_user_by_id", UpdateUserByIdParams), 241 | ("delete_factor", DeleteFactorParams), 242 | ] 243 | 244 | for method, expected_model in method_model_pairs: 245 | assert method in PARAM_MODELS 246 | assert PARAM_MODELS[method] == expected_model 247 | 248 | # Test actual validation of data through PARAM_MODELS mapping 249 | method = "create_user" 250 | model_class = PARAM_MODELS[method] 251 | 252 | valid_payload = {"email": "[email protected]", "password": "secure-password"} 253 | 254 | params = model_class.model_validate(valid_payload) 255 | assert params.email == valid_payload["email"] 256 | assert params.password == valid_payload["password"] 257 | ``` -------------------------------------------------------------------------------- /tests/test_tools.py: -------------------------------------------------------------------------------- ```python 1 | """Unit tests for tools - no external dependencies.""" 2 | import uuid 3 | from unittest.mock import AsyncMock, MagicMock, patch 4 | 5 | import pytest 6 | from mcp.server.fastmcp import FastMCP 7 | 8 | from supabase_mcp.core.container import ServicesContainer 9 | from supabase_mcp.exceptions import ConfirmationRequiredError, OperationNotAllowedError, PythonSDKError 10 | from supabase_mcp.services.database.postgres_client import QueryResult, StatementResult 11 | from supabase_mcp.services.safety.models import ClientType, OperationRiskLevel, SafetyMode 12 | 13 | 14 | @pytest.mark.asyncio 15 | class TestDatabaseToolsUnit: 16 | """Unit tests for database tools.""" 17 | 18 | @pytest.fixture 19 | def mock_container(self): 20 | """Create a mock container with all necessary services.""" 21 | container = MagicMock(spec=ServicesContainer) 22 | 23 | # Mock query manager 24 | container.query_manager = MagicMock() 25 | container.query_manager.handle_query = AsyncMock() 26 | container.query_manager.get_schemas_query = MagicMock(return_value="SELECT * FROM schemas") 27 | container.query_manager.get_tables_query = MagicMock(return_value="SELECT * FROM tables") 28 | container.query_manager.get_table_schema_query = MagicMock(return_value="SELECT * FROM columns") 29 | 30 | # Mock safety manager 31 | container.safety_manager = MagicMock() 32 | container.safety_manager.check_permission = MagicMock() 33 | container.safety_manager.is_unsafe_mode = MagicMock(return_value=False) 34 | 35 | return container 36 | 37 | async def test_get_schemas_returns_query_result(self, mock_container): 38 | """Test that get_schemas returns proper QueryResult.""" 39 | # Setup mock response 40 | mock_result = QueryResult(results=[ 41 | StatementResult(rows=[ 42 | {"schema_name": "public", "total_size": "100MB", "table_count": 10}, 43 | {"schema_name": "auth", "total_size": "50MB", "table_count": 5} 44 | ]) 45 | ]) 46 | mock_container.query_manager.handle_query.return_value = mock_result 47 | 48 | # Execute 49 | query = mock_container.query_manager.get_schemas_query() 50 | result = await mock_container.query_manager.handle_query(query) 51 | 52 | # Verify 53 | assert isinstance(result, QueryResult) 54 | assert len(result.results[0].rows) == 2 55 | assert result.results[0].rows[0]["schema_name"] == "public" 56 | 57 | async def test_get_tables_with_schema_filter(self, mock_container): 58 | """Test that get_tables properly filters by schema.""" 59 | # Setup 60 | mock_result = QueryResult(results=[ 61 | StatementResult(rows=[ 62 | {"table_name": "users", "table_type": "BASE TABLE", "row_count": 100, "size_bytes": 1024} 63 | ]) 64 | ]) 65 | mock_container.query_manager.handle_query.return_value = mock_result 66 | 67 | # Execute 68 | query = mock_container.query_manager.get_tables_query("public") 69 | result = await mock_container.query_manager.handle_query(query) 70 | 71 | # Verify 72 | mock_container.query_manager.get_tables_query.assert_called_with("public") 73 | assert result.results[0].rows[0]["table_name"] == "users" 74 | 75 | async def test_unsafe_query_blocked_in_safe_mode(self, mock_container): 76 | """Test that unsafe queries are blocked in safe mode.""" 77 | # Setup 78 | mock_container.safety_manager.is_unsafe_mode.return_value = False 79 | mock_container.safety_manager.check_permission.side_effect = OperationNotAllowedError( 80 | "DROP operations are not allowed in safe mode" 81 | ) 82 | 83 | # Execute & Verify 84 | with pytest.raises(OperationNotAllowedError): 85 | mock_container.safety_manager.check_permission( 86 | ClientType.DATABASE, 87 | OperationRiskLevel.HIGH 88 | ) 89 | 90 | 91 | @pytest.mark.asyncio 92 | class TestAPIToolsUnit: 93 | """Unit tests for API tools.""" 94 | 95 | @pytest.fixture 96 | def mock_container(self): 97 | """Create a mock container with API services.""" 98 | container = MagicMock(spec=ServicesContainer) 99 | 100 | # Mock API manager 101 | container.api_manager = MagicMock() 102 | container.api_manager.send_request = AsyncMock() 103 | container.api_manager.spec_manager = MagicMock() 104 | container.api_manager.spec_manager.get_full_spec = MagicMock(return_value={"paths": {}}) 105 | 106 | # Mock safety manager 107 | container.safety_manager = MagicMock() 108 | container.safety_manager.check_permission = MagicMock() 109 | container.safety_manager.is_unsafe_mode = MagicMock(return_value=False) 110 | 111 | return container 112 | 113 | async def test_api_request_success(self, mock_container): 114 | """Test successful API request.""" 115 | # Setup 116 | mock_response = {"id": "123", "name": "Test Project"} 117 | mock_container.api_manager.send_request.return_value = mock_response 118 | 119 | # Execute 120 | result = await mock_container.api_manager.send_request( 121 | "GET", "/v1/projects", {} 122 | ) 123 | 124 | # Verify 125 | assert result["id"] == "123" 126 | assert result["name"] == "Test Project" 127 | 128 | async def test_api_spec_retrieval(self, mock_container): 129 | """Test API spec retrieval.""" 130 | # Setup 131 | expected_spec = { 132 | "paths": { 133 | "/v1/projects": { 134 | "get": {"summary": "List projects"} 135 | } 136 | } 137 | } 138 | mock_container.api_manager.spec_manager.get_full_spec.return_value = expected_spec 139 | 140 | # Execute 141 | spec = mock_container.api_manager.spec_manager.get_full_spec() 142 | 143 | # Verify 144 | assert "paths" in spec 145 | assert "/v1/projects" in spec["paths"] 146 | 147 | async def test_medium_risk_api_blocked_in_safe_mode(self, mock_container): 148 | """Test that medium risk API operations are blocked in safe mode.""" 149 | # Setup 150 | mock_container.safety_manager.check_permission.side_effect = ConfirmationRequiredError( 151 | "This operation requires confirmation", 152 | {"method": "POST", "path": "/v1/projects"} 153 | ) 154 | 155 | # Execute & Verify 156 | with pytest.raises(ConfirmationRequiredError) as exc_info: 157 | mock_container.safety_manager.check_permission( 158 | ClientType.API, 159 | OperationRiskLevel.MEDIUM 160 | ) 161 | 162 | assert "requires confirmation" in str(exc_info.value) 163 | 164 | 165 | @pytest.mark.asyncio 166 | class TestAuthToolsUnit: 167 | """Unit tests for auth tools.""" 168 | 169 | @pytest.fixture 170 | def mock_container(self): 171 | """Create a mock container with SDK client.""" 172 | container = MagicMock(spec=ServicesContainer) 173 | 174 | # Mock SDK client 175 | container.sdk_client = MagicMock() 176 | container.sdk_client.call_auth_admin_method = AsyncMock() 177 | container.sdk_client.return_python_sdk_spec = MagicMock(return_value={ 178 | "methods": ["list_users", "create_user", "delete_user"] 179 | }) 180 | 181 | return container 182 | 183 | async def test_list_users_success(self, mock_container): 184 | """Test listing users successfully.""" 185 | # Setup 186 | mock_users = [ 187 | {"id": "user1", "email": "[email protected]"}, 188 | {"id": "user2", "email": "[email protected]"} 189 | ] 190 | mock_container.sdk_client.call_auth_admin_method.return_value = mock_users 191 | 192 | # Execute 193 | result = await mock_container.sdk_client.call_auth_admin_method( 194 | "list_users", {"page": 1, "per_page": 10} 195 | ) 196 | 197 | # Verify 198 | assert len(result) == 2 199 | assert result[0]["email"] == "[email protected]" 200 | 201 | async def test_invalid_method_raises_error(self, mock_container): 202 | """Test that invalid method names raise errors.""" 203 | # Setup 204 | mock_container.sdk_client.call_auth_admin_method.side_effect = PythonSDKError( 205 | "Unknown method: invalid_method" 206 | ) 207 | 208 | # Execute & Verify 209 | with pytest.raises(PythonSDKError) as exc_info: 210 | await mock_container.sdk_client.call_auth_admin_method( 211 | "invalid_method", {} 212 | ) 213 | 214 | assert "Unknown method" in str(exc_info.value) 215 | 216 | async def test_create_user_validation(self, mock_container): 217 | """Test user creation with validation.""" 218 | # Setup 219 | new_user = {"id": str(uuid.uuid4()), "email": "[email protected]"} 220 | mock_container.sdk_client.call_auth_admin_method.return_value = {"user": new_user} 221 | 222 | # Execute 223 | result = await mock_container.sdk_client.call_auth_admin_method( 224 | "create_user", 225 | {"email": "[email protected]", "password": "TestPass123!"} 226 | ) 227 | 228 | # Verify 229 | assert result["user"]["email"] == "[email protected]" 230 | mock_container.sdk_client.call_auth_admin_method.assert_called_once() 231 | 232 | 233 | @pytest.mark.asyncio 234 | class TestSafetyToolsUnit: 235 | """Unit tests for safety tools - these already work well.""" 236 | 237 | @pytest.fixture 238 | def mock_container(self): 239 | """Create a mock container with safety manager.""" 240 | container = MagicMock(spec=ServicesContainer) 241 | 242 | # Mock safety manager with proper methods 243 | container.safety_manager = MagicMock() 244 | container.safety_manager.set_unsafe_mode = MagicMock() 245 | container.safety_manager.get_mode = MagicMock(return_value=SafetyMode.SAFE) 246 | container.safety_manager.confirm_operation = MagicMock() 247 | container.safety_manager.is_unsafe_mode = MagicMock(return_value=False) 248 | 249 | return container 250 | 251 | async def test_live_dangerously_enables_unsafe_mode(self, mock_container): 252 | """Test that live_dangerously enables unsafe mode.""" 253 | # Execute 254 | mock_container.safety_manager.set_unsafe_mode(ClientType.DATABASE, True) 255 | 256 | # Verify 257 | mock_container.safety_manager.set_unsafe_mode.assert_called_with(ClientType.DATABASE, True) 258 | 259 | async def test_confirm_operation_stores_confirmation(self, mock_container): 260 | """Test that confirm operation stores the confirmation.""" 261 | # Setup 262 | confirmation_id = str(uuid.uuid4()) 263 | 264 | # Execute 265 | mock_container.safety_manager.confirm_operation(confirmation_id) 266 | 267 | # Verify 268 | mock_container.safety_manager.confirm_operation.assert_called_with(confirmation_id) 269 | 270 | async def test_safety_mode_switching(self, mock_container): 271 | """Test switching between safe and unsafe modes.""" 272 | # Test enabling unsafe mode 273 | mock_container.safety_manager.set_unsafe_mode(ClientType.API, True) 274 | mock_container.safety_manager.set_unsafe_mode.assert_called_with(ClientType.API, True) 275 | 276 | # Test disabling unsafe mode 277 | mock_container.safety_manager.set_unsafe_mode(ClientType.API, False) 278 | mock_container.safety_manager.set_unsafe_mode.assert_called_with(ClientType.API, False) ``` -------------------------------------------------------------------------------- /supabase_mcp/core/feature_manager.py: -------------------------------------------------------------------------------- ```python 1 | from typing import TYPE_CHECKING, Any, Literal 2 | 3 | from supabase_mcp.clients.api_client import ApiClient 4 | from supabase_mcp.exceptions import APIError, ConfirmationRequiredError, FeatureAccessError, FeatureTemporaryError 5 | from supabase_mcp.logger import logger 6 | from supabase_mcp.services.database.postgres_client import QueryResult 7 | from supabase_mcp.services.safety.models import ClientType, SafetyMode 8 | from supabase_mcp.tools.manager import ToolName 9 | 10 | if TYPE_CHECKING: 11 | from supabase_mcp.core.container import ServicesContainer 12 | 13 | 14 | class FeatureManager: 15 | """Service for managing features, access to them and their configuration.""" 16 | 17 | def __init__(self, api_client: ApiClient): 18 | """Initialize the feature service. 19 | 20 | Args: 21 | api_client: Client for communicating with the API 22 | """ 23 | self.api_client = api_client 24 | 25 | async def check_feature_access(self, feature_name: str) -> None: 26 | """Check if the user has access to a feature. 27 | 28 | Args: 29 | feature_name: Name of the feature to check 30 | 31 | Raises: 32 | FeatureAccessError: If the user doesn't have access to the feature 33 | """ 34 | try: 35 | # Use the API client to check feature access 36 | response = await self.api_client.check_feature_access(feature_name) 37 | 38 | # If access is not granted, raise an exception 39 | if not response.access_granted: 40 | logger.info(f"Feature access denied: {feature_name}") 41 | raise FeatureAccessError(feature_name) 42 | 43 | logger.debug(f"Feature access granted: {feature_name}") 44 | 45 | except APIError as e: 46 | logger.error(f"API error checking feature access: {feature_name} - {e}") 47 | raise FeatureTemporaryError(feature_name, e.status_code, e.response_body) from e 48 | except Exception as e: 49 | if not isinstance(e, FeatureAccessError): 50 | logger.error(f"Unexpected error checking feature access: {feature_name} - {e}") 51 | raise FeatureTemporaryError(feature_name) from e 52 | raise 53 | 54 | async def execute_tool(self, tool_name: ToolName, services_container: "ServicesContainer", **kwargs: Any) -> Any: 55 | """Execute a tool with feature access check. 56 | 57 | Args: 58 | tool_name: Name of the tool to execute 59 | services_container: Container with all services 60 | **kwargs: Arguments to pass to the tool 61 | 62 | Returns: 63 | Result of the tool execution 64 | """ 65 | # Check feature access 66 | await self.check_feature_access(tool_name.value) 67 | 68 | # Execute the appropriate tool based on name 69 | if tool_name == ToolName.GET_SCHEMAS: 70 | return await self.get_schemas(services_container) 71 | elif tool_name == ToolName.GET_TABLES: 72 | return await self.get_tables(services_container, **kwargs) 73 | elif tool_name == ToolName.GET_TABLE_SCHEMA: 74 | return await self.get_table_schema(services_container, **kwargs) 75 | elif tool_name == ToolName.EXECUTE_POSTGRESQL: 76 | return await self.execute_postgresql(services_container, **kwargs) 77 | elif tool_name == ToolName.RETRIEVE_MIGRATIONS: 78 | return await self.retrieve_migrations(services_container, **kwargs) 79 | elif tool_name == ToolName.SEND_MANAGEMENT_API_REQUEST: 80 | return await self.send_management_api_request(services_container, **kwargs) 81 | elif tool_name == ToolName.GET_MANAGEMENT_API_SPEC: 82 | return await self.get_management_api_spec(services_container, **kwargs) 83 | elif tool_name == ToolName.GET_AUTH_ADMIN_METHODS_SPEC: 84 | return await self.get_auth_admin_methods_spec(services_container) 85 | elif tool_name == ToolName.CALL_AUTH_ADMIN_METHOD: 86 | return await self.call_auth_admin_method(services_container, **kwargs) 87 | elif tool_name == ToolName.LIVE_DANGEROUSLY: 88 | return await self.live_dangerously(services_container, **kwargs) 89 | elif tool_name == ToolName.CONFIRM_DESTRUCTIVE_OPERATION: 90 | return await self.confirm_destructive_operation(services_container, **kwargs) 91 | elif tool_name == ToolName.RETRIEVE_LOGS: 92 | return await self.retrieve_logs(services_container, **kwargs) 93 | else: 94 | raise ValueError(f"Unknown tool: {tool_name}") 95 | 96 | async def get_schemas(self, container: "ServicesContainer") -> QueryResult: 97 | """List all database schemas with their sizes and table counts.""" 98 | query_manager = container.query_manager 99 | query = query_manager.get_schemas_query() 100 | return await query_manager.handle_query(query) 101 | 102 | async def get_tables(self, container: "ServicesContainer", schema_name: str) -> QueryResult: 103 | """List all tables, foreign tables, and views in a schema with their sizes, row counts, and metadata.""" 104 | query_manager = container.query_manager 105 | query = query_manager.get_tables_query(schema_name) 106 | return await query_manager.handle_query(query) 107 | 108 | async def get_table_schema(self, container: "ServicesContainer", schema_name: str, table: str) -> QueryResult: 109 | """Get detailed table structure including columns, keys, and relationships.""" 110 | query_manager = container.query_manager 111 | query = query_manager.get_table_schema_query(schema_name, table) 112 | return await query_manager.handle_query(query) 113 | 114 | async def execute_postgresql( 115 | self, container: "ServicesContainer", query: str, migration_name: str = "" 116 | ) -> QueryResult: 117 | """Execute PostgreSQL statements against your Supabase database.""" 118 | query_manager = container.query_manager 119 | return await query_manager.handle_query(query, has_confirmation=False, migration_name=migration_name) 120 | 121 | async def retrieve_migrations( 122 | self, 123 | container: "ServicesContainer", 124 | limit: int = 50, 125 | offset: int = 0, 126 | name_pattern: str = "", 127 | include_full_queries: bool = False, 128 | ) -> QueryResult: 129 | """Retrieve a list of all migrations a user has from Supabase.""" 130 | query_manager = container.query_manager 131 | query = query_manager.get_migrations_query( 132 | limit=limit, offset=offset, name_pattern=name_pattern, include_full_queries=include_full_queries 133 | ) 134 | return await query_manager.handle_query(query) 135 | 136 | async def send_management_api_request( 137 | self, 138 | container: "ServicesContainer", 139 | method: str, 140 | path: str, 141 | path_params: dict[str, str], 142 | request_params: dict[str, Any], 143 | request_body: dict[str, Any], 144 | ) -> dict[str, Any]: 145 | """Execute a Supabase Management API request.""" 146 | api_manager = container.api_manager 147 | return await api_manager.execute_request(method, path, path_params, request_params, request_body) 148 | 149 | async def get_management_api_spec( 150 | self, container: "ServicesContainer", params: dict[str, Any] = {} 151 | ) -> dict[str, Any]: 152 | """Get the Supabase Management API specification.""" 153 | path = params.get("path") 154 | method = params.get("method") 155 | domain = params.get("domain") 156 | all_paths = params.get("all_paths", False) 157 | 158 | logger.debug( 159 | f"Getting management API spec with path: {path}, method: {method}, domain: {domain}, all_paths: {all_paths}" 160 | ) 161 | api_manager = container.api_manager 162 | return await api_manager.handle_spec_request(path, method, domain, all_paths) 163 | 164 | async def get_auth_admin_methods_spec(self, container: "ServicesContainer") -> dict[str, Any]: 165 | """Get Python SDK methods specification for Auth Admin.""" 166 | sdk_client = container.sdk_client 167 | return sdk_client.return_python_sdk_spec() 168 | 169 | async def call_auth_admin_method( 170 | self, container: "ServicesContainer", method: str, params: dict[str, Any] 171 | ) -> dict[str, Any]: 172 | """Call an Auth Admin method from Supabase Python SDK.""" 173 | sdk_client = container.sdk_client 174 | return await sdk_client.call_auth_admin_method(method, params) 175 | 176 | async def live_dangerously( 177 | self, container: "ServicesContainer", service: Literal["api", "database"], enable_unsafe_mode: bool = False 178 | ) -> dict[str, Any]: 179 | """ 180 | Toggle between safe and unsafe operation modes for API or Database services. 181 | 182 | This function controls the safety level for operations, allowing you to: 183 | - Enable write operations for the database (INSERT, UPDATE, DELETE, schema changes) 184 | - Enable state-changing operations for the Management API 185 | """ 186 | safety_manager = container.safety_manager 187 | if service == "api": 188 | # Set the safety mode in the safety manager 189 | new_mode = SafetyMode.UNSAFE if enable_unsafe_mode else SafetyMode.SAFE 190 | safety_manager.set_safety_mode(ClientType.API, new_mode) 191 | 192 | # Return the actual mode that was set 193 | return {"service": "api", "mode": safety_manager.get_safety_mode(ClientType.API)} 194 | elif service == "database": 195 | # Set the safety mode in the safety manager 196 | new_mode = SafetyMode.UNSAFE if enable_unsafe_mode else SafetyMode.SAFE 197 | safety_manager.set_safety_mode(ClientType.DATABASE, new_mode) 198 | 199 | # Return the actual mode that was set 200 | return {"service": "database", "mode": safety_manager.get_safety_mode(ClientType.DATABASE)} 201 | 202 | async def confirm_destructive_operation( 203 | self, 204 | container: "ServicesContainer", 205 | operation_type: Literal["api", "database"], 206 | confirmation_id: str, 207 | user_confirmation: bool = False, 208 | ) -> QueryResult | dict[str, Any]: 209 | """Execute a destructive operation after confirmation. Use this only after reviewing the risks with the user.""" 210 | api_manager = container.api_manager 211 | query_manager = container.query_manager 212 | if not user_confirmation: 213 | raise ConfirmationRequiredError("Destructive operation requires explicit user confirmation.") 214 | 215 | if operation_type == "api": 216 | return await api_manager.handle_confirmation(confirmation_id) 217 | elif operation_type == "database": 218 | return await query_manager.handle_confirmation(confirmation_id) 219 | 220 | async def retrieve_logs( 221 | self, 222 | container: "ServicesContainer", 223 | collection: str, 224 | limit: int = 20, 225 | hours_ago: int = 1, 226 | filters: list[dict[str, Any]] = [], 227 | search: str = "", 228 | custom_query: str = "", 229 | ) -> dict[str, Any]: 230 | """Retrieve logs from your Supabase project's services for debugging and monitoring.""" 231 | logger.info( 232 | f"Tool called: retrieve_logs(collection={collection}, limit={limit}, hours_ago={hours_ago}, filters={filters}, search={search}, custom_query={'<custom>' if custom_query else None})" 233 | ) 234 | 235 | api_manager = container.api_manager 236 | result = await api_manager.retrieve_logs( 237 | collection=collection, 238 | limit=limit, 239 | hours_ago=hours_ago, 240 | filters=filters, 241 | search=search, 242 | custom_query=custom_query, 243 | ) 244 | 245 | logger.info(f"Tool completed: retrieve_logs - Retrieved log entries for collection={collection}") 246 | 247 | return result 248 | ``` -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | from collections.abc import AsyncGenerator, Generator 3 | from pathlib import Path 4 | from typing import Any 5 | from unittest.mock import AsyncMock, MagicMock 6 | 7 | import pytest 8 | import pytest_asyncio 9 | from dotenv import load_dotenv 10 | from mcp.server.fastmcp import FastMCP 11 | 12 | from supabase_mcp.clients.management_client import ManagementAPIClient 13 | from supabase_mcp.clients.sdk_client import SupabaseSDKClient 14 | from supabase_mcp.core.container import ServicesContainer 15 | from supabase_mcp.logger import logger 16 | from supabase_mcp.services.api.api_manager import SupabaseApiManager 17 | from supabase_mcp.services.api.spec_manager import ApiSpecManager 18 | from supabase_mcp.services.database.migration_manager import MigrationManager 19 | from supabase_mcp.services.database.postgres_client import PostgresClient 20 | from supabase_mcp.services.database.query_manager import QueryManager 21 | from supabase_mcp.services.database.sql.loader import SQLLoader 22 | from supabase_mcp.services.database.sql.validator import SQLValidator 23 | from supabase_mcp.services.safety.safety_manager import SafetyManager 24 | from supabase_mcp.settings import Settings, find_config_file 25 | from supabase_mcp.tools import ToolManager 26 | from supabase_mcp.tools.registry import ToolRegistry 27 | 28 | # ====================== 29 | # Environment Fixtures 30 | # ====================== 31 | 32 | 33 | @pytest.fixture 34 | def clean_environment() -> Generator[None, None, None]: 35 | """Fixture to provide a clean environment without any Supabase-related env vars.""" 36 | # Save original environment 37 | original_env = dict(os.environ) 38 | 39 | # Remove all Supabase-related environment variables 40 | for key in list(os.environ.keys()): 41 | if key.startswith("SUPABASE_"): 42 | del os.environ[key] 43 | 44 | yield 45 | 46 | # Restore original environment 47 | os.environ.clear() 48 | os.environ.update(original_env) 49 | 50 | 51 | def load_test_env() -> dict[str, str | None]: 52 | """Load test environment variables from .env.test file""" 53 | env_test_path = Path(__file__).parent.parent / ".env.test" 54 | if not env_test_path.exists(): 55 | raise FileNotFoundError(f"Test environment file not found at {env_test_path}") 56 | 57 | load_dotenv(env_test_path) 58 | return { 59 | "SUPABASE_PROJECT_REF": os.getenv("SUPABASE_PROJECT_REF"), 60 | "SUPABASE_DB_PASSWORD": os.getenv("SUPABASE_DB_PASSWORD"), 61 | "SUPABASE_SERVICE_ROLE_KEY": os.getenv("SUPABASE_SERVICE_ROLE_KEY"), 62 | "SUPABASE_ACCESS_TOKEN": os.getenv("SUPABASE_ACCESS_TOKEN"), 63 | } 64 | 65 | 66 | @pytest.fixture(scope="session") 67 | def settings_integration() -> Settings: 68 | """Fixture providing settings for integration tests. 69 | 70 | This fixture loads settings from environment variables or .env.test file. 71 | Uses session scope since settings don't change during tests. 72 | """ 73 | return Settings.with_config(find_config_file(".env.test")) 74 | 75 | 76 | @pytest.fixture 77 | def mock_validator() -> SQLValidator: 78 | """Fixture providing a mock SQLValidator for integration tests.""" 79 | return SQLValidator() 80 | 81 | 82 | @pytest.fixture 83 | def settings_integration_custom_env() -> Generator[Settings, None, None]: 84 | """Fixture that provides Settings instance for integration tests using .env.test""" 85 | 86 | # Load custom environment variables 87 | test_env = load_test_env() 88 | original_env = dict(os.environ) 89 | 90 | # Set up test environment 91 | for key, value in test_env.items(): 92 | if value is not None: 93 | os.environ[key] = value 94 | 95 | # Create fresh settings instance 96 | settings = Settings() 97 | logger.info(f"Custom connection settings initialized: {settings}") 98 | 99 | yield settings 100 | 101 | # Restore original environment 102 | os.environ.clear() 103 | os.environ.update(original_env) 104 | 105 | 106 | # ====================== 107 | # Service Fixtures 108 | # ====================== 109 | 110 | 111 | @pytest_asyncio.fixture(scope="module") 112 | async def postgres_client_integration(settings_integration: Settings) -> AsyncGenerator[PostgresClient, None]: 113 | # Reset before creation 114 | await PostgresClient.reset() 115 | 116 | # Create a client 117 | client = PostgresClient(settings=settings_integration) 118 | 119 | try: 120 | yield client 121 | finally: 122 | await client.close() 123 | 124 | 125 | @pytest_asyncio.fixture(scope="module") 126 | async def spec_manager_integration() -> AsyncGenerator[ApiSpecManager, None]: 127 | """Fixture providing an ApiSpecManager instance for tests.""" 128 | manager = ApiSpecManager() 129 | yield manager 130 | 131 | 132 | @pytest_asyncio.fixture(scope="module") 133 | async def api_client_integration(settings_integration: Settings) -> AsyncGenerator[ManagementAPIClient, None]: 134 | # We don't need to reset since it's not a singleton 135 | client = ManagementAPIClient(settings=settings_integration) 136 | 137 | try: 138 | yield client 139 | finally: 140 | await client.close() 141 | 142 | 143 | @pytest_asyncio.fixture(scope="module") 144 | async def sdk_client_integration(settings_integration: Settings) -> AsyncGenerator[SupabaseSDKClient, None]: 145 | """Fixture providing a SupabaseSDKClient instance for tests. 146 | 147 | Uses function scope to ensure a fresh client for each test. 148 | """ 149 | client = SupabaseSDKClient.get_instance(settings=settings_integration) 150 | try: 151 | yield client 152 | finally: 153 | # Reset the singleton to ensure a fresh client for the next test 154 | SupabaseSDKClient.reset() 155 | 156 | 157 | @pytest.fixture(scope="module") 158 | def safety_manager_integration() -> SafetyManager: 159 | """Fixture providing a safety manager for integration tests.""" 160 | # Reset the safety manager singleton 161 | SafetyManager.reset() 162 | 163 | # Create a new safety manager 164 | safety_manager = SafetyManager.get_instance() 165 | safety_manager.register_safety_configs() 166 | 167 | return safety_manager 168 | 169 | 170 | @pytest.fixture(scope="module") 171 | def tool_manager_integration() -> ToolManager: 172 | """Fixture providing a tool manager for integration tests.""" 173 | # Reset the tool manager singleton 174 | ToolManager.reset() 175 | return ToolManager.get_instance() 176 | 177 | 178 | @pytest.fixture(scope="module") 179 | def query_manager_integration( 180 | postgres_client_integration: PostgresClient, 181 | safety_manager_integration: SafetyManager, 182 | ) -> QueryManager: 183 | """Fixture providing a query manager for integration tests.""" 184 | query_manager = QueryManager( 185 | postgres_client=postgres_client_integration, 186 | safety_manager=safety_manager_integration, 187 | ) 188 | return query_manager 189 | 190 | 191 | @pytest.fixture(scope="module") 192 | def mock_api_manager() -> SupabaseApiManager: 193 | """Fixture providing a properly mocked API manager for unit tests.""" 194 | # Create mock dependencies 195 | mock_client = MagicMock() 196 | mock_safety_manager = MagicMock() 197 | mock_spec_manager = MagicMock() 198 | 199 | # Create the API manager with proper constructor arguments 200 | api_manager = SupabaseApiManager(api_client=mock_client, safety_manager=mock_safety_manager) 201 | 202 | # Add the spec_manager attribute 203 | api_manager.spec_manager = mock_spec_manager 204 | 205 | return api_manager 206 | 207 | 208 | @pytest.fixture 209 | def mock_query_manager() -> QueryManager: 210 | """Fixture providing a properly mocked Query manager for unit tests.""" 211 | # Create mock dependencies 212 | mock_safety_manager = MagicMock() 213 | mock_postgres_client = MagicMock() 214 | mock_validator = MagicMock() 215 | 216 | # Create the Query manager with proper constructor arguments 217 | query_manager = QueryManager( 218 | postgres_client=mock_postgres_client, 219 | safety_manager=mock_safety_manager, 220 | ) 221 | 222 | # Replace the validator with a mock 223 | query_manager.validator = mock_validator 224 | 225 | # Store the postgres client as an attribute for tests to access 226 | query_manager.db_client = mock_postgres_client 227 | 228 | # Make execute_query_async an AsyncMock 229 | query_manager.db_client.execute_query_async = AsyncMock() 230 | 231 | return query_manager 232 | 233 | 234 | @pytest_asyncio.fixture(scope="module") 235 | async def api_manager_integration( 236 | api_client_integration: ManagementAPIClient, 237 | safety_manager_integration: SafetyManager, 238 | ) -> AsyncGenerator[SupabaseApiManager, None]: 239 | """Fixture providing an API manager for integration tests.""" 240 | 241 | # Create a new API manager 242 | api_manager = SupabaseApiManager.get_instance( 243 | api_client=api_client_integration, 244 | safety_manager=safety_manager_integration, 245 | ) 246 | 247 | try: 248 | yield api_manager 249 | finally: 250 | # Reset the API manager singleton 251 | SupabaseApiManager.reset() 252 | 253 | 254 | # ====================== 255 | # Mock MCP Server 256 | # ====================== 257 | 258 | 259 | @pytest.fixture 260 | def mock_mcp_server() -> Any: 261 | """Fixture providing a mock MCP server for integration tests.""" 262 | 263 | # Create a simple mock MCP server that mimics the FastMCP interface 264 | class MockMCP: 265 | def __init__(self) -> None: 266 | self.tools: dict[str, Any] = {} 267 | self.name = "mock_mcp" 268 | 269 | def register_tool(self, name: str, func: Any, **kwargs: Any) -> None: 270 | """Register a tool with the MCP server.""" 271 | self.tools[name] = func 272 | 273 | def run(self) -> None: 274 | """Mock run method.""" 275 | pass 276 | 277 | return MockMCP() 278 | 279 | 280 | @pytest.fixture(scope="module") 281 | def mock_mcp_server_integration() -> Any: 282 | """Fixture providing a mock MCP server for integration tests.""" 283 | return FastMCP(name="supabase") 284 | 285 | 286 | # ====================== 287 | # Container Fixture 288 | # ====================== 289 | 290 | 291 | @pytest.fixture(scope="module") 292 | def container_integration( 293 | postgres_client_integration: PostgresClient, 294 | api_client_integration: ManagementAPIClient, 295 | sdk_client_integration: SupabaseSDKClient, 296 | api_manager_integration: SupabaseApiManager, 297 | safety_manager_integration: SafetyManager, 298 | query_manager_integration: QueryManager, 299 | tool_manager_integration: ToolManager, 300 | mock_mcp_server_integration: FastMCP, 301 | ) -> ServicesContainer: 302 | """Fixture providing a basic Container for integration tests. 303 | 304 | This container includes all services needed for integration testing, 305 | but is not initialized. 306 | """ 307 | # Create a new container with all the services 308 | container = ServicesContainer( 309 | mcp_server=mock_mcp_server_integration, 310 | postgres_client=postgres_client_integration, 311 | api_client=api_client_integration, 312 | sdk_client=sdk_client_integration, 313 | api_manager=api_manager_integration, 314 | safety_manager=safety_manager_integration, 315 | query_manager=query_manager_integration, 316 | tool_manager=tool_manager_integration, 317 | ) 318 | 319 | logger.info("✓ Integration container created successfully.") 320 | 321 | return container 322 | 323 | 324 | @pytest.fixture(scope="module") 325 | def initialized_container_integration( 326 | container_integration: ServicesContainer, 327 | settings_integration: Settings, 328 | ) -> ServicesContainer: 329 | """Fixture providing a fully initialized Container for integration tests. 330 | 331 | This container is initialized with all services and ready to use. 332 | """ 333 | container_integration.initialize_services(settings_integration) 334 | logger.info("✓ Integration container initialized successfully.") 335 | 336 | return container_integration 337 | 338 | 339 | @pytest.fixture(scope="module") 340 | def tools_registry_integration( 341 | initialized_container_integration: ServicesContainer, 342 | ) -> ServicesContainer: 343 | """Fixture providing a Container with tools registered for integration tests. 344 | 345 | This container has all tools registered with the MCP server. 346 | """ 347 | container = initialized_container_integration 348 | mcp_server = container.mcp_server 349 | 350 | registry = ToolRegistry(mcp_server, container) 351 | registry.register_tools() 352 | 353 | logger.info("✓ Tools registered with MCP server successfully.") 354 | 355 | return container 356 | 357 | 358 | @pytest.fixture 359 | def sql_loader() -> SQLLoader: 360 | """Fixture providing a SQLLoader instance for tests.""" 361 | return SQLLoader() 362 | 363 | 364 | @pytest.fixture 365 | def migration_manager(sql_loader: SQLLoader) -> MigrationManager: 366 | """Fixture providing a MigrationManager instance for tests.""" 367 | return MigrationManager(loader=sql_loader) 368 | ``` -------------------------------------------------------------------------------- /tests/services/api/test_spec_manager.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | from unittest.mock import AsyncMock, MagicMock, mock_open, patch 3 | 4 | import httpx 5 | import pytest 6 | 7 | from supabase_mcp.services.api.spec_manager import ApiSpecManager 8 | 9 | # Test data 10 | SAMPLE_SPEC = {"openapi": "3.0.0", "paths": {"/v1/test": {"get": {"operationId": "test"}}}} 11 | 12 | 13 | class TestApiSpecManager: 14 | """Integration tests for api spec manager tools.""" 15 | 16 | # Local Spec Tests 17 | def test_load_local_spec_success(self, spec_manager_integration: ApiSpecManager): 18 | """Test successful loading of local spec file""" 19 | mock_file = mock_open(read_data=json.dumps(SAMPLE_SPEC)) 20 | 21 | with patch("builtins.open", mock_file): 22 | result = spec_manager_integration._load_local_spec() 23 | 24 | assert result == SAMPLE_SPEC 25 | mock_file.assert_called_once() 26 | 27 | def test_load_local_spec_file_not_found(self, spec_manager_integration: ApiSpecManager): 28 | """Test handling of missing local spec file""" 29 | with patch("builtins.open", side_effect=FileNotFoundError), pytest.raises(FileNotFoundError): 30 | spec_manager_integration._load_local_spec() 31 | 32 | def test_load_local_spec_invalid_json(self, spec_manager_integration: ApiSpecManager): 33 | """Test handling of invalid JSON in local spec""" 34 | mock_file = mock_open(read_data="invalid json") 35 | 36 | with patch("builtins.open", mock_file), pytest.raises(json.JSONDecodeError): 37 | spec_manager_integration._load_local_spec() 38 | 39 | # Remote Spec Tests 40 | @pytest.mark.asyncio 41 | async def test_fetch_remote_spec_success(self, spec_manager_integration: ApiSpecManager): 42 | """Test successful remote spec fetch""" 43 | mock_response = MagicMock() 44 | mock_response.status_code = 200 45 | mock_response.json.return_value = SAMPLE_SPEC 46 | 47 | mock_client = AsyncMock() 48 | mock_client.get.return_value = mock_response 49 | mock_client.__aenter__.return_value = mock_client # Mock async context manager 50 | 51 | with patch("httpx.AsyncClient", return_value=mock_client): 52 | result = await spec_manager_integration._fetch_remote_spec() 53 | 54 | assert result == SAMPLE_SPEC 55 | mock_client.get.assert_called_once() 56 | 57 | @pytest.mark.asyncio 58 | async def test_fetch_remote_spec_api_error(self, spec_manager_integration: ApiSpecManager): 59 | """Test handling of API error during remote fetch""" 60 | mock_response = MagicMock() 61 | mock_response.status_code = 500 62 | 63 | mock_client = AsyncMock() 64 | mock_client.get.return_value = mock_response 65 | mock_client.__aenter__.return_value = mock_client # Mock async context manager 66 | 67 | with patch("httpx.AsyncClient", return_value=mock_client): 68 | result = await spec_manager_integration._fetch_remote_spec() 69 | 70 | assert result is None 71 | 72 | @pytest.mark.asyncio 73 | async def test_fetch_remote_spec_network_error(self, spec_manager_integration: ApiSpecManager): 74 | """Test handling of network error during remote fetch""" 75 | mock_client = AsyncMock() 76 | mock_client.get.side_effect = httpx.NetworkError("Network error") 77 | 78 | with patch("httpx.AsyncClient", return_value=mock_client): 79 | result = await spec_manager_integration._fetch_remote_spec() 80 | 81 | assert result is None 82 | 83 | # Startup Flow Tests 84 | @pytest.mark.asyncio 85 | async def test_startup_remote_success(self, spec_manager_integration: ApiSpecManager): 86 | """Test successful startup with remote fetch""" 87 | # Reset spec to None to ensure we're testing the fetch 88 | spec_manager_integration.spec = None 89 | 90 | # Mock the fetch method to return sample spec 91 | mock_fetch = AsyncMock(return_value=SAMPLE_SPEC) 92 | 93 | with patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch): 94 | result = await spec_manager_integration.get_spec() 95 | 96 | assert result == SAMPLE_SPEC 97 | mock_fetch.assert_called_once() 98 | 99 | @pytest.mark.asyncio 100 | async def test_get_spec_remote_fail_local_fallback(self, spec_manager_integration: ApiSpecManager): 101 | """Test get_spec with remote failure and local fallback""" 102 | # Reset spec to None to ensure we're testing the fetch and fallback 103 | spec_manager_integration.spec = None 104 | 105 | # Mock fetch to fail and local to succeed 106 | mock_fetch = AsyncMock(return_value=None) 107 | mock_local = MagicMock(return_value=SAMPLE_SPEC) 108 | 109 | with ( 110 | patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch), 111 | patch.object(spec_manager_integration, "_load_local_spec", mock_local), 112 | ): 113 | result = await spec_manager_integration.get_spec() 114 | 115 | assert result == SAMPLE_SPEC 116 | mock_fetch.assert_called_once() 117 | mock_local.assert_called_once() 118 | 119 | @pytest.mark.asyncio 120 | async def test_get_spec_both_fail(self, spec_manager_integration: ApiSpecManager): 121 | """Test get_spec with both remote and local failure""" 122 | # Reset spec to None to ensure we're testing the fetch and fallback 123 | spec_manager_integration.spec = None 124 | 125 | # Mock both fetch and local to fail 126 | mock_fetch = AsyncMock(return_value=None) 127 | mock_local = MagicMock(side_effect=FileNotFoundError("Test file not found")) 128 | 129 | with ( 130 | patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch), 131 | patch.object(spec_manager_integration, "_load_local_spec", mock_local), 132 | pytest.raises(FileNotFoundError), 133 | ): 134 | await spec_manager_integration.get_spec() 135 | 136 | mock_fetch.assert_called_once() 137 | mock_local.assert_called_once() 138 | 139 | @pytest.mark.asyncio 140 | async def test_get_spec_cached(self, spec_manager_integration: ApiSpecManager): 141 | """Test that get_spec returns cached spec if available""" 142 | # Set the spec directly 143 | spec_manager_integration.spec = SAMPLE_SPEC 144 | 145 | # Mock the fetch method to verify it's not called 146 | mock_fetch = AsyncMock() 147 | 148 | with patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch): 149 | result = await spec_manager_integration.get_spec() 150 | 151 | assert result == SAMPLE_SPEC 152 | mock_fetch.assert_not_called() 153 | 154 | @pytest.mark.asyncio 155 | async def test_get_spec_not_loaded(self, spec_manager_integration: ApiSpecManager): 156 | """Test behavior when spec is not loaded but can be loaded""" 157 | # Reset spec to None 158 | spec_manager_integration.spec = None 159 | 160 | # Mock fetch to succeed 161 | mock_fetch = AsyncMock(return_value=SAMPLE_SPEC) 162 | 163 | with patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch): 164 | result = await spec_manager_integration.get_spec() 165 | 166 | assert result == SAMPLE_SPEC 167 | mock_fetch.assert_called_once() 168 | 169 | @pytest.mark.asyncio 170 | async def test_comprehensive_spec_retrieval(self, spec_manager_integration: ApiSpecManager): 171 | """ 172 | Comprehensive test of API spec retrieval and functionality. 173 | This test exactly mirrors the main() function to ensure all aspects work correctly. 174 | """ 175 | # Create a fresh instance to avoid any cached data from other tests 176 | from supabase_mcp.services.api.spec_manager import LOCAL_SPEC_PATH, ApiSpecManager 177 | 178 | spec_manager = ApiSpecManager() 179 | 180 | # Print the path being used (for debugging) 181 | print(f"\nTest is looking for spec at: {LOCAL_SPEC_PATH}") 182 | 183 | # Load the spec 184 | spec = await spec_manager.get_spec() 185 | assert spec is not None, "Spec should be loaded successfully" 186 | 187 | # 1. Test get_all_domains 188 | all_domains = spec_manager.get_all_domains() 189 | print(f"\nAll domains: {all_domains}") 190 | assert len(all_domains) > 0, "Should have at least one domain" 191 | 192 | # Verify all expected domains are present 193 | expected_domains = [ 194 | "Analytics", 195 | "Auth", 196 | "Database", 197 | "Domains", 198 | "Edge Functions", 199 | "Environments", 200 | "OAuth", 201 | "Organizations", 202 | "Projects", 203 | "Rest", 204 | "Secrets", 205 | "Storage", 206 | ] 207 | for domain in expected_domains: 208 | assert domain in all_domains, f"Domain '{domain}' should be in the list of domains" 209 | 210 | # 2. Test get_all_paths_and_methods 211 | all_paths = spec_manager.get_all_paths_and_methods() 212 | assert len(all_paths) > 0, "Should have at least one path" 213 | 214 | # Sample a few paths to verify 215 | sample_paths = list(all_paths.keys())[:5] 216 | print("\nSample paths:") 217 | for path in sample_paths: 218 | print(f" {path}:") 219 | assert path.startswith("/v1/"), f"Path {path} should start with /v1/" 220 | assert len(all_paths[path]) > 0, f"Path {path} should have at least one method" 221 | for method, operation_id in all_paths[path].items(): 222 | print(f" {method}: {operation_id}") 223 | assert method.lower() in ["get", "post", "put", "patch", "delete"], f"Method {method} should be valid" 224 | assert operation_id.startswith("v1-"), f"Operation ID {operation_id} should start with v1-" 225 | 226 | # 3. Test get_paths_and_methods_by_domain for each domain 227 | for domain in expected_domains: 228 | domain_paths = spec_manager.get_paths_and_methods_by_domain(domain) 229 | assert len(domain_paths) > 0, f"Domain {domain} should have at least one path" 230 | print(f"\n{domain} domain has {len(domain_paths)} paths") 231 | 232 | # 4. Test Edge Functions domain specifically 233 | edge_paths = spec_manager.get_paths_and_methods_by_domain("Edge Functions") 234 | print("\nEdge Functions Paths and Methods:") 235 | for path in edge_paths: 236 | print(f" {path}") 237 | for method, operation_id in edge_paths[path].items(): 238 | print(f" {method}: {operation_id}") 239 | 240 | # Verify specific Edge Functions paths exist 241 | expected_edge_paths = [ 242 | "/v1/projects/{ref}/functions", 243 | "/v1/projects/{ref}/functions/{function_slug}", 244 | "/v1/projects/{ref}/functions/deploy", 245 | ] 246 | for path in expected_edge_paths: 247 | assert path in edge_paths, f"Expected path {path} should be in Edge Functions domain" 248 | 249 | # 5. Test get_spec_for_path_and_method 250 | # Test for Edge Functions 251 | path = "/v1/projects/{ref}/functions" 252 | method = "GET" 253 | full_spec = spec_manager.get_spec_for_path_and_method(path, method) 254 | assert full_spec is not None, f"Should find spec for {method} {path}" 255 | assert "operationId" in full_spec, "Spec should include operationId" 256 | assert full_spec["operationId"] == "v1-list-all-functions", "Should have correct operationId" 257 | 258 | # Test for another domain (Auth) 259 | auth_path = "/v1/projects/{ref}/config/auth" 260 | auth_method = "GET" 261 | auth_spec = spec_manager.get_spec_for_path_and_method(auth_path, auth_method) 262 | assert auth_spec is not None, f"Should find spec for {auth_method} {auth_path}" 263 | assert "operationId" in auth_spec, "Auth spec should include operationId" 264 | 265 | # 6. Test get_spec_part 266 | # Get a specific schema 267 | schema = spec_manager.get_spec_part("components", "schemas", "FunctionResponse") 268 | assert schema is not None, "Should find FunctionResponse schema" 269 | assert "properties" in schema, "Schema should have properties" 270 | 271 | # 7. Test caching behavior 272 | # Call get_spec again - should use cached version 273 | import time 274 | 275 | start_time = time.time() 276 | await spec_manager.get_spec() 277 | end_time = time.time() 278 | assert (end_time - start_time) < 0.1, "Cached spec retrieval should be fast" 279 | ```