#
tokens: 49143/50000 19/106 files (page 2/6)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 2 of 6. Use http://codebase.md/alexander-zuev/supabase-mcp-server?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .claude
│   └── settings.local.json
├── .dockerignore
├── .env.example
├── .env.test.example
├── .github
│   ├── FUNDING.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.md
│   │   ├── feature_request.md
│   │   └── roadmap_item.md
│   ├── PULL_REQUEST_TEMPLATE.md
│   └── workflows
│       ├── ci.yaml
│       ├── docs
│       │   └── release-checklist.md
│       └── publish.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── CHANGELOG.MD
├── codecov.yml
├── CONTRIBUTING.MD
├── Dockerfile
├── LICENSE
├── llms-full.txt
├── pyproject.toml
├── README.md
├── smithery.yaml
├── supabase_mcp
│   ├── __init__.py
│   ├── clients
│   │   ├── api_client.py
│   │   ├── base_http_client.py
│   │   ├── management_client.py
│   │   └── sdk_client.py
│   ├── core
│   │   ├── __init__.py
│   │   ├── container.py
│   │   └── feature_manager.py
│   ├── exceptions.py
│   ├── logger.py
│   ├── main.py
│   ├── services
│   │   ├── __init__.py
│   │   ├── api
│   │   │   ├── __init__.py
│   │   │   ├── api_manager.py
│   │   │   ├── spec_manager.py
│   │   │   └── specs
│   │   │       └── api_spec.json
│   │   ├── database
│   │   │   ├── __init__.py
│   │   │   ├── migration_manager.py
│   │   │   ├── postgres_client.py
│   │   │   ├── query_manager.py
│   │   │   └── sql
│   │   │       ├── loader.py
│   │   │       ├── models.py
│   │   │       ├── queries
│   │   │       │   ├── create_migration.sql
│   │   │       │   ├── get_migrations.sql
│   │   │       │   ├── get_schemas.sql
│   │   │       │   ├── get_table_schema.sql
│   │   │       │   ├── get_tables.sql
│   │   │       │   ├── init_migrations.sql
│   │   │       │   └── logs
│   │   │       │       ├── auth_logs.sql
│   │   │       │       ├── cron_logs.sql
│   │   │       │       ├── edge_logs.sql
│   │   │       │       ├── function_edge_logs.sql
│   │   │       │       ├── pgbouncer_logs.sql
│   │   │       │       ├── postgres_logs.sql
│   │   │       │       ├── postgrest_logs.sql
│   │   │       │       ├── realtime_logs.sql
│   │   │       │       ├── storage_logs.sql
│   │   │       │       └── supavisor_logs.sql
│   │   │       └── validator.py
│   │   ├── logs
│   │   │   ├── __init__.py
│   │   │   └── log_manager.py
│   │   ├── safety
│   │   │   ├── __init__.py
│   │   │   ├── models.py
│   │   │   ├── safety_configs.py
│   │   │   └── safety_manager.py
│   │   └── sdk
│   │       ├── __init__.py
│   │       ├── auth_admin_models.py
│   │       └── auth_admin_sdk_spec.py
│   ├── settings.py
│   └── tools
│       ├── __init__.py
│       ├── descriptions
│       │   ├── api_tools.yaml
│       │   ├── database_tools.yaml
│       │   ├── logs_and_analytics_tools.yaml
│       │   ├── safety_tools.yaml
│       │   └── sdk_tools.yaml
│       ├── manager.py
│       └── registry.py
├── tests
│   ├── __init__.py
│   ├── conftest.py
│   ├── services
│   │   ├── __init__.py
│   │   ├── api
│   │   │   ├── __init__.py
│   │   │   ├── test_api_client.py
│   │   │   ├── test_api_manager.py
│   │   │   └── test_spec_manager.py
│   │   ├── database
│   │   │   ├── sql
│   │   │   │   ├── __init__.py
│   │   │   │   ├── conftest.py
│   │   │   │   ├── test_loader.py
│   │   │   │   ├── test_sql_validator_integration.py
│   │   │   │   └── test_sql_validator.py
│   │   │   ├── test_migration_manager.py
│   │   │   ├── test_postgres_client.py
│   │   │   └── test_query_manager.py
│   │   ├── logs
│   │   │   └── test_log_manager.py
│   │   ├── safety
│   │   │   ├── test_api_safety_config.py
│   │   │   ├── test_safety_manager.py
│   │   │   └── test_sql_safety_config.py
│   │   └── sdk
│   │       ├── test_auth_admin_models.py
│   │       └── test_sdk_client.py
│   ├── test_container.py
│   ├── test_main.py
│   ├── test_settings.py
│   ├── test_tool_manager.py
│   ├── test_tools_integration.py.bak
│   └── test_tools.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/tests/services/api/test_api_manager.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Any
  2 | from unittest.mock import MagicMock, patch
  3 | 
  4 | import pytest
  5 | 
  6 | from supabase_mcp.exceptions import SafetyError
  7 | from supabase_mcp.services.api.api_manager import SupabaseApiManager
  8 | from supabase_mcp.services.safety.models import ClientType
  9 | 
 10 | 
 11 | class TestApiManager:
 12 |     """Tests for the API Manager."""
 13 | 
 14 |     @pytest.mark.unit
 15 |     def test_path_parameter_replacement(self, mock_api_manager: SupabaseApiManager):
 16 |         """
 17 |         Test that path parameters are correctly replaced in API paths.
 18 | 
 19 |         This test verifies that the API Manager correctly replaces path placeholders
 20 |         with actual values, handling both required and optional parameters.
 21 |         """
 22 |         # Use the mock_api_manager fixture instead of creating one manually
 23 |         api_manager = mock_api_manager
 24 | 
 25 |         # Test with a simple path and required parameters (avoiding 'ref' which is auto-injected)
 26 |         path = "/v1/organizations/{slug}/members"
 27 |         path_params = {"slug": "example-org"}
 28 | 
 29 |         result = api_manager.replace_path_params(path, path_params)
 30 |         expected = "/v1/organizations/example-org/members"
 31 |         assert result == expected, f"Expected {expected}, got {result}"
 32 | 
 33 |         # Test with missing required parameters
 34 |         path = "/v1/organizations/{slug}/members/{id}"
 35 |         path_params = {"slug": "example-org"}
 36 | 
 37 |         with pytest.raises(ValueError) as excinfo:
 38 |             api_manager.replace_path_params(path, path_params)
 39 |         assert "Missing path parameters" in str(excinfo.value)
 40 | 
 41 |         # Test with extra parameters (should be ignored)
 42 |         path = "/v1/organizations/{slug}"
 43 |         path_params = {"slug": "example-org", "extra": "should-be-ignored"}
 44 | 
 45 |         with pytest.raises(ValueError) as excinfo:
 46 |             api_manager.replace_path_params(path, path_params)
 47 |         assert "Unknown path parameter" in str(excinfo.value)
 48 | 
 49 |         # Test with no parameters
 50 |         path = "/v1/organizations"
 51 |         result = api_manager.replace_path_params(path)
 52 |         expected = "/v1/organizations"
 53 |         assert result == expected, f"Expected {expected}, got {result}"
 54 | 
 55 |     @pytest.mark.asyncio
 56 |     @pytest.mark.unit
 57 |     @patch("supabase_mcp.services.api.api_manager.logger")
 58 |     async def test_safety_validation(self, mock_logger: MagicMock, mock_api_manager: SupabaseApiManager):
 59 |         """
 60 |         Test that API operations are properly validated through the safety manager.
 61 | 
 62 |         This test verifies that the API Manager correctly validates operations
 63 |         before executing them, and handles safety errors appropriately.
 64 |         """
 65 |         # Use the mock_api_manager fixture instead of creating one manually
 66 |         api_manager = mock_api_manager
 67 | 
 68 |         # Mock the replace_path_params method to return the path unchanged
 69 |         api_manager.replace_path_params = MagicMock(return_value="/v1/organizations/example-org")
 70 | 
 71 |         # Mock the client's execute_request method to return a simple response
 72 |         mock_response = {"success": True}
 73 |         api_manager.client.execute_request = MagicMock()
 74 |         api_manager.client.execute_request.return_value = mock_response
 75 | 
 76 |         # Make the mock awaitable
 77 |         async def mock_execute_request(*args: Any, **kwargs: Any) -> dict[str, Any]:
 78 |             return mock_response
 79 | 
 80 |         api_manager.client.execute_request = mock_execute_request
 81 | 
 82 |         # Test a successful operation
 83 |         method = "GET"
 84 |         path = "/v1/organizations/{slug}"
 85 |         path_params = {"slug": "example-org"}
 86 | 
 87 |         result = await api_manager.execute_request(method, path, path_params)
 88 | 
 89 |         # Verify that the safety manager was called with the correct parameters
 90 |         api_manager.safety_manager.validate_operation.assert_called_once_with(
 91 |             ClientType.API, (method, path, path_params, None, None), has_confirmation=False
 92 |         )
 93 | 
 94 |         # Verify that the result is what we expected
 95 |         assert result == {"success": True}
 96 | 
 97 |         # Test an operation that fails safety validation
 98 |         api_manager.safety_manager.validate_operation.reset_mock()
 99 | 
100 |         # Make the safety manager raise a SafetyError
101 |         def raise_safety_error(*args: Any, **kwargs: Any) -> None:
102 |             raise SafetyError("Operation not allowed")
103 | 
104 |         api_manager.safety_manager.validate_operation.side_effect = raise_safety_error
105 | 
106 |         # The execute_request method should raise the SafetyError
107 |         with pytest.raises(SafetyError) as excinfo:
108 |             await api_manager.execute_request("DELETE", "/v1/organizations/{slug}", {"slug": "example-org"})
109 | 
110 |         assert "Operation not allowed" in str(excinfo.value)
111 | 
112 |     @pytest.mark.asyncio
113 |     @pytest.mark.unit
114 |     async def test_retrieve_logs_basic(self, mock_api_manager: SupabaseApiManager):
115 |         """
116 |         Test that the retrieve_logs method correctly builds and executes a logs query.
117 | 
118 |         This test verifies that the API Manager correctly builds a logs query using
119 |         the LogManager and executes it through the Management API.
120 |         """
121 |         # Mock the log_manager's build_logs_query method
122 |         mock_api_manager.log_manager.build_logs_query = MagicMock(return_value="SELECT * FROM postgres_logs LIMIT 10")
123 | 
124 |         # Mock the execute_request method to return a simple response
125 |         mock_response = {"result": [{"id": "123", "event_message": "test"}]}
126 | 
127 |         async def mock_execute_request(*args: Any, **kwargs: Any) -> dict[str, Any]:
128 |             return mock_response
129 | 
130 |         mock_api_manager.execute_request = mock_execute_request
131 | 
132 |         # Execute the method
133 |         result = await mock_api_manager.retrieve_logs(
134 |             collection="postgres",
135 |             limit=10,
136 |             hours_ago=24,
137 |         )
138 | 
139 |         # Verify that the log_manager was called with the correct parameters
140 |         mock_api_manager.log_manager.build_logs_query.assert_called_once_with(
141 |             collection="postgres",
142 |             limit=10,
143 |             hours_ago=24,
144 |             filters=None,
145 |             search=None,
146 |             custom_query=None,
147 |         )
148 | 
149 |         # Verify that the result is what we expected
150 |         assert result == {"result": [{"id": "123", "event_message": "test"}]}
151 | 
152 |     @pytest.mark.asyncio
153 |     @pytest.mark.unit
154 |     async def test_retrieve_logs_error_handling(self, mock_api_manager: SupabaseApiManager):
155 |         """
156 |         Test that the retrieve_logs method correctly handles errors.
157 | 
158 |         This test verifies that the API Manager correctly handles errors that occur
159 |         during log retrieval and propagates them to the caller.
160 |         """
161 |         # Mock the log_manager's build_logs_query method
162 |         mock_api_manager.log_manager.build_logs_query = MagicMock(return_value="SELECT * FROM postgres_logs LIMIT 10")
163 | 
164 |         # Mock the execute_request method to raise an exception
165 |         async def mock_execute_request_error(*args: Any, **kwargs: Any) -> dict[str, Any]:
166 |             raise Exception("API error")
167 | 
168 |         mock_api_manager.execute_request = mock_execute_request_error
169 | 
170 |         # The retrieve_logs method should propagate the exception
171 |         with pytest.raises(Exception) as excinfo:
172 |             await mock_api_manager.retrieve_logs(collection="postgres")
173 | 
174 |         assert "API error" in str(excinfo.value)
175 | 
```

--------------------------------------------------------------------------------
/tests/test_tool_manager.py:
--------------------------------------------------------------------------------

```python
  1 | from unittest.mock import MagicMock, mock_open, patch
  2 | 
  3 | from supabase_mcp.tools.manager import ToolManager, ToolName
  4 | 
  5 | 
  6 | class TestToolManager:
  7 |     """Tests for the ToolManager class."""
  8 | 
  9 |     def test_singleton_pattern(self):
 10 |         """Test that ToolManager follows the singleton pattern."""
 11 |         # Get two instances
 12 |         manager1 = ToolManager.get_instance()
 13 |         manager2 = ToolManager.get_instance()
 14 | 
 15 |         # They should be the same object
 16 |         assert manager1 is manager2
 17 | 
 18 |         # Reset the singleton for other tests
 19 |         # pylint: disable=protected-access
 20 |         # We need to reset the singleton for test isolation
 21 |         ToolManager._instance = None  # type: ignore
 22 | 
 23 |     @patch("supabase_mcp.tools.manager.Path")
 24 |     @patch("supabase_mcp.tools.manager.yaml.safe_load")
 25 |     def test_load_descriptions(self, mock_yaml_load: MagicMock, mock_path: MagicMock):
 26 |         """Test that descriptions are loaded correctly from YAML files."""
 27 |         # Setup mock directory structure
 28 |         mock_file_path = MagicMock()
 29 |         mock_dir = MagicMock()
 30 | 
 31 |         # Mock the Path(__file__) call
 32 |         mock_path.return_value = mock_file_path
 33 |         mock_file_path.parent = mock_dir
 34 |         mock_dir.__truediv__.return_value = mock_dir  # For the / operator
 35 | 
 36 |         # Mock directory existence check
 37 |         mock_dir.exists.return_value = True
 38 | 
 39 |         # Mock the glob to return some YAML files
 40 |         mock_file1 = MagicMock()
 41 |         mock_file1.name = "database_tools.yaml"
 42 |         mock_file2 = MagicMock()
 43 |         mock_file2.name = "api_tools.yaml"
 44 |         mock_dir.glob.return_value = [mock_file1, mock_file2]
 45 | 
 46 |         # Mock the file open and YAML load
 47 |         mock_yaml_data = {"get_schemas": "Description for get_schemas", "get_tables": "Description for get_tables"}
 48 |         mock_yaml_load.return_value = mock_yaml_data
 49 | 
 50 |         # Create a new instance to trigger _load_descriptions
 51 |         with patch("builtins.open", mock_open(read_data="dummy yaml content")):
 52 |             # We need to create the manager to trigger _load_descriptions
 53 |             ToolManager()
 54 | 
 55 |         # Verify the descriptions were loaded
 56 |         assert mock_dir.glob.call_count > 0
 57 |         assert mock_dir.glob.call_args[0][0] == "*.yaml"
 58 |         assert mock_yaml_load.call_count >= 1
 59 | 
 60 |         # Reset the singleton for other tests
 61 |         # pylint: disable=protected-access
 62 |         ToolManager._instance = None  # type: ignore
 63 | 
 64 |     def test_get_description_valid_tool(self):
 65 |         """Test getting a description for a valid tool."""
 66 |         # Setup
 67 |         manager = ToolManager.get_instance()
 68 | 
 69 |         # Force the descriptions to have a known value for testing
 70 |         # pylint: disable=protected-access
 71 |         # We need to set the descriptions directly for testing
 72 |         manager.descriptions = {
 73 |             ToolName.GET_SCHEMAS.value: "Description for get_schemas",
 74 |             ToolName.GET_TABLES.value: "Description for get_tables",
 75 |         }
 76 | 
 77 |         # Test
 78 |         description = manager.get_description(ToolName.GET_SCHEMAS.value)
 79 | 
 80 |         # Verify
 81 |         assert description == "Description for get_schemas"
 82 | 
 83 |         # Reset the singleton for other tests
 84 |         # pylint: disable=protected-access
 85 |         ToolManager._instance = None  # type: ignore
 86 | 
 87 |     def test_get_description_invalid_tool(self):
 88 |         """Test getting a description for an invalid tool."""
 89 |         # Setup
 90 |         manager = ToolManager.get_instance()
 91 | 
 92 |         # Force the descriptions to have a known value for testing
 93 |         # pylint: disable=protected-access
 94 |         # We need to set the descriptions directly for testing
 95 |         manager.descriptions = {
 96 |             ToolName.GET_SCHEMAS.value: "Description for get_schemas",
 97 |             ToolName.GET_TABLES.value: "Description for get_tables",
 98 |         }
 99 | 
100 |         # Test and verify
101 |         description = manager.get_description("nonexistent_tool")
102 |         assert description == ""  # The method returns an empty string for unknown tools
103 | 
104 |         # Reset the singleton for other tests
105 |         # pylint: disable=protected-access
106 |         ToolManager._instance = None  # type: ignore
107 | 
108 |     def test_all_tool_names_have_descriptions(self):
109 |         """Test that all tools defined in ToolName enum have descriptions."""
110 |         # Setup - get a fresh instance
111 |         # Reset the singleton first to ensure we get a clean instance
112 |         # pylint: disable=protected-access
113 |         ToolManager._instance = None  # type: ignore
114 | 
115 |         # Get a fresh instance that will load the real YAML files
116 |         manager = ToolManager.get_instance()
117 | 
118 |         # Print the loaded descriptions for debugging
119 |         print(f"\nLoaded descriptions: {manager.descriptions}")
120 | 
121 |         # Verify that we have at least some descriptions loaded
122 |         assert len(manager.descriptions) > 0, "No descriptions were loaded"
123 | 
124 |         # Check that descriptions are not empty
125 |         empty_descriptions: list[str] = []
126 |         for tool_name, description in manager.descriptions.items():
127 |             if not description or len(description.strip()) == 0:
128 |                 empty_descriptions.append(tool_name)
129 | 
130 |         # Fail if we found any empty descriptions
131 |         assert len(empty_descriptions) == 0, f"Found empty descriptions for tools: {empty_descriptions}"
132 | 
133 |         # Check that at least some of the tool names have descriptions
134 |         found_descriptions = 0
135 |         missing_descriptions: list[str] = []
136 | 
137 |         for tool_name in ToolName:
138 |             description = manager.get_description(tool_name.value)
139 |             if description:
140 |                 found_descriptions += 1
141 |             else:
142 |                 missing_descriptions.append(tool_name.value)
143 | 
144 |         # Print missing descriptions for debugging
145 |         if missing_descriptions:
146 |             print(f"\nMissing descriptions for: {missing_descriptions}")
147 | 
148 |         # We should have at least some descriptions
149 |         assert found_descriptions > 0, "No tool has a description"
150 | 
151 |         # Reset the singleton for other tests
152 |         # pylint: disable=protected-access
153 |         ToolManager._instance = None  # type: ignore
154 | 
155 |     @patch.object(ToolManager, "_load_descriptions")
156 |     def test_initialization_loads_descriptions(self, mock_load_descriptions: MagicMock):
157 |         """Test that descriptions are loaded during initialization."""
158 |         # Create a new instance
159 |         # We need to create the manager to trigger __init__
160 |         ToolManager()
161 | 
162 |         # Verify _load_descriptions was called
163 |         assert mock_load_descriptions.call_count > 0
164 | 
165 |         # Reset the singleton for other tests
166 |         # pylint: disable=protected-access
167 |         ToolManager._instance = None  # type: ignore
168 | 
169 |     def test_tool_enum_completeness(self):
170 |         """Test that the ToolName enum contains all expected tools."""
171 |         # Get all tool values from the enum
172 |         tool_values = [tool.value for tool in ToolName]
173 | 
174 |         # Verify the total number of tools
175 |         # Update this number when new tools are added
176 |         expected_tool_count = 12
177 |         assert len(tool_values) == expected_tool_count, f"Expected {expected_tool_count} tools, got {len(tool_values)}"
178 | 
179 |         # Verify specific tools are included
180 |         assert "retrieve_logs" in tool_values, "retrieve_logs tool is missing from ToolName enum"
181 | 
182 |         # Reset the singleton for other tests
183 |         # pylint: disable=protected-access
184 |         ToolManager._instance = None  # type: ignore
185 | 
```

--------------------------------------------------------------------------------
/tests/services/safety/test_api_safety_config.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Unit tests for the APISafetyConfig class.
  3 | 
  4 | This file contains unit test cases for the APISafetyConfig class, which is responsible for
  5 | determining the risk level of API operations and whether they are allowed or require confirmation.
  6 | """
  7 | 
  8 | import pytest
  9 | 
 10 | from supabase_mcp.services.safety.models import OperationRiskLevel, SafetyMode
 11 | from supabase_mcp.services.safety.safety_configs import APISafetyConfig, HTTPMethod
 12 | 
 13 | 
 14 | @pytest.mark.unit
 15 | class TestAPISafetyConfig:
 16 |     """Unit tests for the APISafetyConfig class."""
 17 | 
 18 |     def test_get_risk_level_low_risk(self):
 19 |         """Test getting risk level for low-risk operations (GET requests)."""
 20 |         config = APISafetyConfig()
 21 |         # API operations are tuples of (method, path, path_params, query_params, request_body)
 22 |         operation = ("GET", "/v1/projects/{ref}/functions", {}, {}, {})
 23 |         risk_level = config.get_risk_level(operation)
 24 |         assert risk_level == OperationRiskLevel.LOW
 25 | 
 26 |     def test_get_risk_level_medium_risk(self):
 27 |         """Test getting risk level for medium-risk operations (POST/PUT/PATCH)."""
 28 |         config = APISafetyConfig()
 29 | 
 30 |         # Test POST request
 31 |         operation = ("POST", "/v1/projects/{ref}/functions", {}, {}, {})
 32 |         risk_level = config.get_risk_level(operation)
 33 |         assert risk_level == OperationRiskLevel.MEDIUM
 34 | 
 35 |         # Test PUT request
 36 |         operation = ("PUT", "/v1/projects/{ref}/functions", {}, {}, {})
 37 |         risk_level = config.get_risk_level(operation)
 38 |         assert risk_level == OperationRiskLevel.MEDIUM
 39 | 
 40 |         # Test PATCH request
 41 |         operation = ("PATCH", "/v1/projects/{ref}/functions/{function_slug}", {}, {}, {})
 42 |         risk_level = config.get_risk_level(operation)
 43 |         assert risk_level == OperationRiskLevel.MEDIUM
 44 | 
 45 |     def test_get_risk_level_high_risk(self):
 46 |         """Test getting risk level for high-risk operations."""
 47 |         config = APISafetyConfig()
 48 | 
 49 |         # Test DELETE request for a function
 50 |         operation = ("DELETE", "/v1/projects/{ref}/functions/{function_slug}", {}, {}, {})
 51 |         risk_level = config.get_risk_level(operation)
 52 |         assert risk_level == OperationRiskLevel.HIGH
 53 | 
 54 |         # Test other high-risk operations
 55 |         high_risk_paths = [
 56 |             "/v1/projects/{ref}/branches/{branch_id}",
 57 |             "/v1/projects/{ref}/custom-hostname",
 58 |             "/v1/projects/{ref}/network-bans",
 59 |         ]
 60 | 
 61 |         for path in high_risk_paths:
 62 |             operation = ("DELETE", path, {}, {}, {})
 63 |             risk_level = config.get_risk_level(operation)
 64 |             assert risk_level == OperationRiskLevel.HIGH, f"Path {path} should be HIGH risk"
 65 | 
 66 |     def test_get_risk_level_extreme_risk(self):
 67 |         """Test getting risk level for extreme-risk operations."""
 68 |         config = APISafetyConfig()
 69 | 
 70 |         # Test DELETE request for a project
 71 |         operation = ("DELETE", "/v1/projects/{ref}", {}, {}, {})
 72 |         risk_level = config.get_risk_level(operation)
 73 |         assert risk_level == OperationRiskLevel.EXTREME
 74 | 
 75 |     def test_is_operation_allowed(self):
 76 |         """Test if operations are allowed based on risk level and safety mode."""
 77 |         config = APISafetyConfig()
 78 | 
 79 |         # Low risk operations should be allowed in both safe and unsafe modes
 80 |         assert config.is_operation_allowed(OperationRiskLevel.LOW, SafetyMode.SAFE) is True
 81 |         assert config.is_operation_allowed(OperationRiskLevel.LOW, SafetyMode.UNSAFE) is True
 82 | 
 83 |         # Medium/high risk operations should only be allowed in unsafe mode
 84 |         assert config.is_operation_allowed(OperationRiskLevel.MEDIUM, SafetyMode.SAFE) is False
 85 |         assert config.is_operation_allowed(OperationRiskLevel.MEDIUM, SafetyMode.UNSAFE) is True
 86 |         assert config.is_operation_allowed(OperationRiskLevel.HIGH, SafetyMode.SAFE) is False
 87 |         assert config.is_operation_allowed(OperationRiskLevel.HIGH, SafetyMode.UNSAFE) is True
 88 | 
 89 |         # Extreme risk operations should not be allowed in safe mode
 90 |         assert config.is_operation_allowed(OperationRiskLevel.EXTREME, SafetyMode.SAFE) is False
 91 |         # In the current implementation, extreme risk operations are never allowed
 92 |         assert config.is_operation_allowed(OperationRiskLevel.EXTREME, SafetyMode.UNSAFE) is False
 93 | 
 94 |     def test_needs_confirmation(self):
 95 |         """Test if operations need confirmation based on risk level."""
 96 |         config = APISafetyConfig()
 97 | 
 98 |         # Low and medium risk operations should not need confirmation
 99 |         assert config.needs_confirmation(OperationRiskLevel.LOW) is False
100 |         assert config.needs_confirmation(OperationRiskLevel.MEDIUM) is False
101 | 
102 |         # High and extreme risk operations should need confirmation
103 |         assert config.needs_confirmation(OperationRiskLevel.HIGH) is True
104 |         assert config.needs_confirmation(OperationRiskLevel.EXTREME) is True
105 | 
106 |     def test_path_matching(self):
107 |         """Test that path patterns are correctly matched."""
108 |         config = APISafetyConfig()
109 | 
110 |         # Test exact path matching
111 |         operation = ("GET", "/v1/projects/{ref}/functions", {}, {}, {})
112 |         assert config.get_risk_level(operation) == OperationRiskLevel.LOW
113 | 
114 |         # Test path with parameters
115 |         operation = ("GET", "/v1/projects/abc123/functions", {}, {}, {})
116 |         assert config.get_risk_level(operation) == OperationRiskLevel.LOW
117 | 
118 |         # Test path with multiple parameters
119 |         operation = ("DELETE", "/v1/projects/abc123/functions/my-function", {}, {}, {})
120 |         assert config.get_risk_level(operation) == OperationRiskLevel.HIGH
121 | 
122 |         # Test path that doesn't match any pattern (should default to MEDIUM for non-GET)
123 |         operation = ("DELETE", "/v1/some/unknown/path", {}, {}, {})
124 |         assert config.get_risk_level(operation) == OperationRiskLevel.LOW
125 | 
126 |         # Test path that doesn't match any pattern (should default to LOW for GET)
127 |         operation = ("GET", "/v1/some/unknown/path", {}, {}, {})
128 |         assert config.get_risk_level(operation) == OperationRiskLevel.LOW
129 | 
130 |     def test_method_case_insensitivity(self):
131 |         """Test that HTTP method matching is case-insensitive."""
132 |         config = APISafetyConfig()
133 | 
134 |         # Test with lowercase method
135 |         operation = ("get", "/v1/projects/{ref}/functions", {}, {}, {})
136 |         assert config.get_risk_level(operation) == OperationRiskLevel.LOW
137 | 
138 |         # Test with uppercase method
139 |         operation = ("GET", "/v1/projects/{ref}/functions", {}, {}, {})
140 |         assert config.get_risk_level(operation) == OperationRiskLevel.LOW
141 | 
142 |         # Test with mixed case method
143 |         operation = ("GeT", "/v1/projects/{ref}/functions", {}, {}, {})
144 |         assert config.get_risk_level(operation) == OperationRiskLevel.LOW
145 | 
146 |     def test_path_safety_config_structure(self):
147 |         """Test that the PATH_SAFETY_CONFIG structure is correctly defined."""
148 |         config = APISafetyConfig()
149 | 
150 |         # Check that the config has the expected structure
151 |         assert hasattr(config, "PATH_SAFETY_CONFIG")
152 | 
153 |         # Check that risk levels are represented as keys
154 |         assert OperationRiskLevel.MEDIUM in config.PATH_SAFETY_CONFIG
155 |         assert OperationRiskLevel.HIGH in config.PATH_SAFETY_CONFIG
156 |         assert OperationRiskLevel.EXTREME in config.PATH_SAFETY_CONFIG
157 | 
158 |         # Check that each risk level has a dictionary of methods to paths
159 |         for risk_level, methods_dict in config.PATH_SAFETY_CONFIG.items():
160 |             assert isinstance(methods_dict, dict)
161 |             for method, paths in methods_dict.items():
162 |                 assert isinstance(method, HTTPMethod)
163 |                 assert isinstance(paths, list)
164 |                 for path in paths:
165 |                     assert isinstance(path, str)
166 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/settings.py:
--------------------------------------------------------------------------------

```python
  1 | import os
  2 | from pathlib import Path
  3 | from typing import Literal
  4 | 
  5 | from pydantic import Field, ValidationInfo, field_validator
  6 | from pydantic_settings import BaseSettings, SettingsConfigDict
  7 | 
  8 | from supabase_mcp.logger import logger
  9 | 
 10 | SUPPORTED_REGIONS = Literal[
 11 |     "us-west-1",  # West US (North California)
 12 |     "us-east-1",  # East US (North Virginia)
 13 |     "us-east-2",  # East US (Ohio)
 14 |     "ca-central-1",  # Canada (Central)
 15 |     "eu-west-1",  # West EU (Ireland)
 16 |     "eu-west-2",  # West Europe (London)
 17 |     "eu-west-3",  # West EU (Paris)
 18 |     "eu-central-1",  # Central EU (Frankfurt)
 19 |     "eu-central-2",  # Central Europe (Zurich)
 20 |     "eu-north-1",  # North EU (Stockholm)
 21 |     "ap-south-1",  # South Asia (Mumbai)
 22 |     "ap-southeast-1",  # Southeast Asia (Singapore)
 23 |     "ap-northeast-1",  # Northeast Asia (Tokyo)
 24 |     "ap-northeast-2",  # Northeast Asia (Seoul)
 25 |     "ap-southeast-2",  # Oceania (Sydney)
 26 |     "sa-east-1",  # South America (São Paulo)
 27 | ]
 28 | 
 29 | 
 30 | def find_config_file(env_file: str = ".env") -> str | None:
 31 |     """Find the specified env file in order of precedence:
 32 |     1. Current working directory (where command is run)
 33 |     2. Global config:
 34 |        - Windows: %APPDATA%/supabase-mcp/{env_file}
 35 |        - macOS/Linux: ~/.config/supabase-mcp/{env_file}
 36 | 
 37 |     Args:
 38 |         env_file: The name of the environment file to look for (default: ".env")
 39 | 
 40 |     Returns:
 41 |         The path to the found config file, or None if not found
 42 |     """
 43 |     # 1. Check current directory
 44 |     cwd_config = Path.cwd() / env_file
 45 |     if cwd_config.exists():
 46 |         return str(cwd_config)
 47 | 
 48 |     # 2. Check global config
 49 |     home = Path.home()
 50 |     if os.name == "nt":  # Windows
 51 |         global_config = Path(os.environ.get("APPDATA", "")) / "supabase-mcp" / ".env"
 52 |     else:  # macOS/Linux
 53 |         global_config = home / ".config" / "supabase-mcp" / ".env"
 54 | 
 55 |     if global_config.exists():
 56 |         logger.error(
 57 |             f"DEPRECATED: {global_config} is deprecated and will be removed in a future release. "
 58 |             "Use your IDE's native .json config file to configure access to MCP."
 59 |         )
 60 |         return str(global_config)
 61 | 
 62 |     return None
 63 | 
 64 | 
 65 | class Settings(BaseSettings):
 66 |     """Initializes settings for Supabase MCP server."""
 67 | 
 68 |     supabase_project_ref: str = Field(
 69 |         default="127.0.0.1:54322",  # Local Supabase default
 70 |         description="Supabase project ref - Must be 20 chars for remote projects, can be local address for development",
 71 |         alias="SUPABASE_PROJECT_REF",
 72 |     )
 73 |     supabase_db_password: str | None = Field(
 74 |         default=None,  # Will be validated based on project_ref
 75 |         description="Supabase database password - Required for remote projects, defaults to 'postgres' for local",
 76 |         alias="SUPABASE_DB_PASSWORD",
 77 |     )
 78 |     supabase_region: str = Field(
 79 |         default="us-east-1",  # East US (North Virginia) - Supabase's default region
 80 |         description="Supabase region for connection",
 81 |         alias="SUPABASE_REGION",
 82 |     )
 83 |     supabase_access_token: str | None = Field(
 84 |         default=None,
 85 |         description="Optional personal access token for accessing Supabase Management API",
 86 |         alias="SUPABASE_ACCESS_TOKEN",
 87 |     )
 88 |     supabase_service_role_key: str | None = Field(
 89 |         default=None,
 90 |         description="Optional service role key for accessing Python SDK",
 91 |         alias="SUPABASE_SERVICE_ROLE_KEY",
 92 |     )
 93 | 
 94 |     supabase_api_url: str = Field(
 95 |         default="https://api.supabase.com",
 96 |         description="Supabase API URL",
 97 |     )
 98 | 
 99 |     query_api_key: str = Field(
100 |         default="test-key",
101 |         description="TheQuery.dev API key",
102 |         alias="QUERY_API_KEY",
103 |     )
104 | 
105 |     query_api_url: str = Field(
106 |         default="https://api.thequery.dev/v1",
107 |         description="TheQuery.dev API URL",
108 |         alias="QUERY_API_URL",
109 |     )
110 | 
111 |     @field_validator("supabase_region")
112 |     @classmethod
113 |     def validate_region(cls, v: str, info: ValidationInfo) -> str:
114 |         """Validate that the region is supported by Supabase."""
115 |         # Get the project_ref from the values
116 |         values = info.data
117 |         project_ref = values.get("supabase_project_ref", "")
118 | 
119 |         # If this is a remote project and region is the default
120 |         if not project_ref.startswith("127.0.0.1") and v == "us-east-1" and "SUPABASE_REGION" not in os.environ:
121 |             logger.warning(
122 |                 "You're connecting to a remote Supabase project but haven't specified a region. "
123 |                 "Using default 'us-east-1', which may cause 'Tenant or user not found' errors if incorrect. "
124 |                 "Please set the correct SUPABASE_REGION in your configuration."
125 |             )
126 | 
127 |         # Validate that the region is supported
128 |         if v not in SUPPORTED_REGIONS.__args__:
129 |             supported = "\n  - ".join([""] + list(SUPPORTED_REGIONS.__args__))
130 |             raise ValueError(f"Region '{v}' is not supported. Supported regions are:{supported}")
131 |         return v
132 | 
133 |     @field_validator("supabase_project_ref")
134 |     @classmethod
135 |     def validate_project_ref(cls, v: str) -> str:
136 |         """Validate the project ref format."""
137 |         if v.startswith("127.0.0.1"):
138 |             # Local development - allow default format
139 |             return v
140 | 
141 |         # Remote project - must be 20 chars
142 |         if len(v) != 20:
143 |             logger.error("Invalid Supabase project ref format")
144 |             raise ValueError(
145 |                 "Invalid Supabase project ref format. "
146 |                 "Remote project refs must be exactly 20 characters long. "
147 |                 f"Got {len(v)} characters instead."
148 |             )
149 |         return v
150 | 
151 |     @field_validator("supabase_db_password")
152 |     @classmethod
153 |     def validate_db_password(cls, v: str | None, info: ValidationInfo) -> str:
154 |         """Validate database password based on project type."""
155 |         project_ref = info.data.get("supabase_project_ref", "")
156 | 
157 |         # For local development, allow default password
158 |         if project_ref.startswith("127.0.0.1"):
159 |             return v or "postgres"  # Default to postgres for local
160 | 
161 |         # For remote projects, password is required
162 |         if not v:
163 |             logger.error("SUPABASE_DB_PASSWORD is required when connecting to a remote instance")
164 |             raise ValueError(
165 |                 "Database password is required for remote Supabase projects. "
166 |                 "Please set SUPABASE_DB_PASSWORD in your environment variables."
167 |             )
168 |         return v
169 | 
170 |     @classmethod
171 |     def with_config(cls, config_file: str | None = None) -> "Settings":
172 |         """Create Settings with a specific config file.
173 | 
174 |         Args:
175 |             config_file: Path to .env file to use, or None for no config file
176 |         """
177 | 
178 |         # Create a new Settings class with the specific config
179 |         class SettingsWithConfig(cls):
180 |             model_config = SettingsConfigDict(env_file=config_file, env_file_encoding="utf-8")
181 | 
182 |         instance = SettingsWithConfig()
183 | 
184 |         # Log configuration source and precedence - simplified to a single clear message
185 |         env_vars_present = any(var in os.environ for var in ["SUPABASE_PROJECT_REF", "SUPABASE_DB_PASSWORD"])
186 | 
187 |         if env_vars_present and config_file:
188 |             logger.info(f"Using environment variables (highest precedence) over config file: {config_file}")
189 |         elif env_vars_present:
190 |             logger.info("Using environment variables for configuration")
191 |         elif config_file:
192 |             logger.info(f"Using settings from config file: {config_file}")
193 |         else:
194 |             logger.info("Using default settings (local development)")
195 | 
196 |         return instance
197 | 
198 | 
199 | # Module-level singleton - maintains existing interface
200 | settings = Settings.with_config(find_config_file())
201 | 
```

--------------------------------------------------------------------------------
/tests/services/api/test_api_client.py:
--------------------------------------------------------------------------------

```python
  1 | import httpx
  2 | import pytest
  3 | from unittest.mock import AsyncMock, MagicMock, patch
  4 | 
  5 | from supabase_mcp.clients.management_client import ManagementAPIClient
  6 | from supabase_mcp.exceptions import APIClientError, APIConnectionError
  7 | from supabase_mcp.settings import Settings
  8 | 
  9 | 
 10 | @pytest.mark.asyncio(loop_scope="module")
 11 | class TestAPIClient:
 12 |     """Unit tests for the API client."""
 13 | 
 14 |     @pytest.fixture
 15 |     def mock_settings(self):
 16 |         """Create mock settings for testing."""
 17 |         settings = MagicMock(spec=Settings)
 18 |         settings.supabase_access_token = "test-token"
 19 |         settings.supabase_project_ref = "test-project-ref"
 20 |         settings.supabase_region = "us-east-1"
 21 |         settings.query_api_url = "https://api.test.com"
 22 |         settings.supabase_api_url = "https://api.supabase.com"
 23 |         return settings
 24 | 
 25 |     async def test_execute_get_request(self, mock_settings):
 26 |         """Test executing a GET request to the API."""
 27 |         # Create client but don't mock the httpx client yet
 28 |         client = ManagementAPIClient(settings=mock_settings)
 29 |         
 30 |         # Setup mock response
 31 |         mock_response = MagicMock(spec=httpx.Response)
 32 |         mock_response.status_code = 404
 33 |         mock_response.is_success = False
 34 |         mock_response.headers = {"content-type": "application/json"}
 35 |         mock_response.json.return_value = {"message": "Cannot GET /v1/health"}
 36 |         mock_response.text = '{"message": "Cannot GET /v1/health"}'
 37 |         mock_response.content = b'{"message": "Cannot GET /v1/health"}'
 38 |         
 39 |         # Mock the send_request method to return our mock response
 40 |         with patch.object(client, 'send_request', return_value=mock_response):
 41 |             path = "/v1/health"
 42 |             
 43 |             # Execute the request and expect a 404 error
 44 |             with pytest.raises(APIClientError) as exc_info:
 45 |                 await client.execute_request(
 46 |                     method="GET",
 47 |                     path=path,
 48 |                 )
 49 |             
 50 |             # Verify the error details
 51 |             assert exc_info.value.status_code == 404
 52 |             assert "Cannot GET /v1/health" in str(exc_info.value)
 53 | 
 54 |     async def test_request_preparation(self, mock_settings):
 55 |         """Test that requests are properly prepared with headers and parameters."""
 56 |         client = ManagementAPIClient(settings=mock_settings)
 57 |         
 58 |         # Prepare a request with parameters
 59 |         method = "GET"
 60 |         path = "/v1/health"
 61 |         request_params = {"param1": "value1", "param2": "value2"}
 62 | 
 63 |         # Prepare the request
 64 |         request = client.prepare_request(
 65 |             method=method,
 66 |             path=path,
 67 |             request_params=request_params,
 68 |         )
 69 | 
 70 |         # Verify the request
 71 |         assert request.method == method
 72 |         assert path in str(request.url)
 73 |         assert "param1=value1" in str(request.url)
 74 |         assert "param2=value2" in str(request.url)
 75 |         assert "Content-Type" in request.headers
 76 |         assert request.headers["Content-Type"] == "application/json"
 77 | 
 78 |     async def test_error_handling(self, mock_settings):
 79 |         """Test handling of API errors."""
 80 |         client = ManagementAPIClient(settings=mock_settings)
 81 |         
 82 |         # Setup mock response
 83 |         mock_response = MagicMock(spec=httpx.Response)
 84 |         mock_response.status_code = 404
 85 |         mock_response.is_success = False
 86 |         mock_response.headers = {"content-type": "application/json"}
 87 |         mock_response.json.return_value = {"message": "Cannot GET /v1/nonexistent-endpoint"}
 88 |         mock_response.text = '{"message": "Cannot GET /v1/nonexistent-endpoint"}'
 89 |         mock_response.content = b'{"message": "Cannot GET /v1/nonexistent-endpoint"}'
 90 |         
 91 |         with patch.object(client, 'send_request', return_value=mock_response):
 92 |             path = "/v1/nonexistent-endpoint"
 93 |             
 94 |             # Execute the request and expect an APIClientError
 95 |             with pytest.raises(APIClientError) as exc_info:
 96 |                 await client.execute_request(
 97 |                     method="GET",
 98 |                     path=path,
 99 |                 )
100 |             
101 |             # Verify the error details
102 |             assert exc_info.value.status_code == 404
103 |             assert "Cannot GET /v1/nonexistent-endpoint" in str(exc_info.value)
104 | 
105 |     async def test_request_with_body(self, mock_settings):
106 |         """Test executing a request with a body."""
107 |         client = ManagementAPIClient(settings=mock_settings)
108 |         
109 |         # Test the request preparation
110 |         method = "POST"
111 |         path = "/v1/health/check"
112 |         request_body = {"test": "data", "nested": {"value": 123}}
113 | 
114 |         # Prepare the request
115 |         request = client.prepare_request(
116 |             method=method,
117 |             path=path,
118 |             request_body=request_body,
119 |         )
120 | 
121 |         # Verify the request
122 |         assert request.method == method
123 |         assert path in str(request.url)
124 |         assert request.content  # Should have content for the body
125 |         assert "Content-Type" in request.headers
126 |         assert request.headers["Content-Type"] == "application/json"
127 | 
128 |     async def test_response_parsing(self, mock_settings):
129 |         """Test parsing API responses."""
130 |         client = ManagementAPIClient(settings=mock_settings)
131 |         
132 |         # Setup mock response
133 |         mock_response = MagicMock(spec=httpx.Response)
134 |         mock_response.status_code = 200
135 |         mock_response.is_success = True
136 |         mock_response.headers = {"content-type": "application/json"}
137 |         mock_response.json.return_value = [{"id": "project1", "name": "Test Project"}]
138 |         mock_response.content = b'[{"id": "project1", "name": "Test Project"}]'
139 |         
140 |         with patch.object(client, 'send_request', return_value=mock_response):
141 |             path = "/v1/projects"
142 |             
143 |             # Execute the request
144 |             response = await client.execute_request(
145 |                 method="GET",
146 |                 path=path,
147 |             )
148 |             
149 |             # Verify the response is parsed correctly
150 |             assert isinstance(response, list)
151 |             assert len(response) > 0
152 |             assert "id" in response[0]
153 | 
154 |     async def test_request_retry_mechanism(self, mock_settings):
155 |         """Test that the tenacity retry mechanism works correctly for API requests."""
156 |         client = ManagementAPIClient(settings=mock_settings)
157 |         
158 |         # Create a mock request object for the NetworkError
159 |         mock_request = MagicMock(spec=httpx.Request)
160 |         mock_request.method = "GET"
161 |         mock_request.url = "https://api.supabase.com/v1/projects"
162 |         
163 |         # Mock the client's send method to always raise a network error
164 |         with patch.object(client.client, 'send', side_effect=httpx.NetworkError("Simulated network failure", request=mock_request)):
165 |             # Execute a request - this should trigger retries and eventually fail
166 |             with pytest.raises(APIConnectionError) as exc_info:
167 |                 await client.execute_request(
168 |                     method="GET",
169 |                     path="/v1/projects",
170 |                 )
171 |             
172 |             # Verify the error message indicates retries were attempted
173 |             assert "Network error after 3 retry attempts" in str(exc_info.value)
174 | 
175 |     async def test_request_without_access_token(self, mock_settings):
176 |         """Test that an exception is raised when attempting to send a request without an access token."""
177 |         # Create client with no access token
178 |         mock_settings.supabase_access_token = None
179 |         client = ManagementAPIClient(settings=mock_settings)
180 | 
181 |         # Attempt to execute a request - should raise an exception
182 |         with pytest.raises(APIClientError) as exc_info:
183 |             await client.execute_request(
184 |                 method="GET",
185 |                 path="/v1/projects",
186 |             )
187 |         
188 |         assert "Supabase access token is not configured" in str(exc_info.value)
```

--------------------------------------------------------------------------------
/supabase_mcp/clients/base_http_client.py:
--------------------------------------------------------------------------------

```python
  1 | from abc import ABC, abstractmethod
  2 | from json.decoder import JSONDecodeError
  3 | from typing import Any, TypeVar
  4 | 
  5 | import httpx
  6 | from pydantic import BaseModel
  7 | from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential
  8 | 
  9 | from supabase_mcp.exceptions import (
 10 |     APIClientError,
 11 |     APIConnectionError,
 12 |     APIResponseError,
 13 |     APIServerError,
 14 |     UnexpectedError,
 15 | )
 16 | from supabase_mcp.logger import logger
 17 | 
 18 | T = TypeVar("T")
 19 | 
 20 | 
 21 | # Helper function for retry decorator to safely log exceptions
 22 | def log_retry_attempt(retry_state: RetryCallState) -> None:
 23 |     """Log retry attempts with exception details if available."""
 24 |     exception = retry_state.outcome.exception() if retry_state.outcome and retry_state.outcome.failed else None
 25 |     exception_str = str(exception) if exception else "Unknown error"
 26 |     logger.warning(f"Network error, retrying ({retry_state.attempt_number}/3): {exception_str}")
 27 | 
 28 | 
 29 | class AsyncHTTPClient(ABC):
 30 |     """Abstract base class for async HTTP clients."""
 31 | 
 32 |     @abstractmethod
 33 |     async def _ensure_client(self) -> httpx.AsyncClient:
 34 |         """Ensure client exists and is ready for use.
 35 | 
 36 |         Creates the client if it doesn't exist yet.
 37 |         Returns the client instance.
 38 |         """
 39 |         pass
 40 | 
 41 |     @abstractmethod
 42 |     async def close(self) -> None:
 43 |         """Close the client and release resources.
 44 | 
 45 |         Should be called when the client is no longer needed.
 46 |         """
 47 |         pass
 48 | 
 49 |     def prepare_request(
 50 |         self,
 51 |         client: httpx.AsyncClient,
 52 |         method: str,
 53 |         path: str,
 54 |         request_params: dict[str, Any] | None = None,
 55 |         request_body: dict[str, Any] | None = None,
 56 |     ) -> httpx.Request:
 57 |         """
 58 |         Prepare an HTTP request.
 59 | 
 60 |         Args:
 61 |             client: The httpx client to use
 62 |             method: HTTP method (GET, POST, etc.)
 63 |             path: API path
 64 |             request_params: Query parameters
 65 |             request_body: Request body
 66 | 
 67 |         Returns:
 68 |             Prepared httpx.Request object
 69 | 
 70 |         Raises:
 71 |             APIClientError: If request preparation fails
 72 |         """
 73 |         try:
 74 |             return client.build_request(method=method, url=path, params=request_params, json=request_body)
 75 |         except Exception as e:
 76 |             raise APIClientError(
 77 |                 message=f"Failed to build request: {str(e)}",
 78 |                 status_code=None,
 79 |             ) from e
 80 | 
 81 |     @retry(
 82 |         retry=retry_if_exception_type(httpx.NetworkError),  # This includes ConnectError and TimeoutException
 83 |         stop=stop_after_attempt(3),
 84 |         wait=wait_exponential(multiplier=1, min=2, max=10),
 85 |         reraise=True,  # Ensure the original exception is raised
 86 |         before_sleep=log_retry_attempt,
 87 |     )
 88 |     async def send_request(self, client: httpx.AsyncClient, request: httpx.Request) -> httpx.Response:
 89 |         """
 90 |         Send an HTTP request with retry logic for transient errors.
 91 | 
 92 |         Args:
 93 |             client: The httpx client to use
 94 |             request: Prepared httpx.Request object
 95 | 
 96 |         Returns:
 97 |             httpx.Response object
 98 | 
 99 |         Raises:
100 |             APIConnectionError: For connection issues
101 |             APIClientError: For other request errors
102 |         """
103 |         try:
104 |             return await client.send(request)
105 |         except httpx.NetworkError as e:
106 |             # All NetworkErrors will be retried by the decorator
107 |             # This will only be reached after all retries are exhausted
108 |             logger.error(f"Network error after all retry attempts: {str(e)}")
109 |             raise APIConnectionError(
110 |                 message=f"Network error after 3 retry attempts: {str(e)}",
111 |                 status_code=None,
112 |             ) from e
113 |         except Exception as e:
114 |             # Other exceptions won't be retried
115 |             raise APIClientError(
116 |                 message=f"Request failed: {str(e)}",
117 |                 status_code=None,
118 |             ) from e
119 | 
120 |     def parse_response(self, response: httpx.Response) -> dict[str, Any]:
121 |         """
122 |         Parse an HTTP response as JSON.
123 | 
124 |         Args:
125 |             response: httpx.Response object
126 | 
127 |         Returns:
128 |             Parsed response body as dictionary
129 | 
130 |         Raises:
131 |             APIResponseError: If response cannot be parsed as JSON
132 |         """
133 |         if not response.content:
134 |             return {}
135 | 
136 |         try:
137 |             return response.json()
138 |         except JSONDecodeError as e:
139 |             raise APIResponseError(
140 |                 message=f"Failed to parse response as JSON: {str(e)}",
141 |                 status_code=response.status_code,
142 |                 response_body={"raw_content": response.text},
143 |             ) from e
144 | 
145 |     def handle_error_response(self, response: httpx.Response, parsed_body: dict[str, Any] | None = None) -> None:
146 |         """
147 |         Handle error responses based on status code.
148 | 
149 |         Args:
150 |             response: httpx.Response object
151 |             parsed_body: Parsed response body if available
152 | 
153 |         Raises:
154 |             APIClientError: For client errors (4xx)
155 |             APIServerError: For server errors (5xx)
156 |             UnexpectedError: For unexpected status codes
157 |         """
158 |         # Extract error message
159 |         error_message = f"API request failed: {response.status_code}"
160 |         if parsed_body and "message" in parsed_body:
161 |             error_message = parsed_body["message"]
162 | 
163 |         # Determine error type based on status code
164 |         if 400 <= response.status_code < 500:
165 |             raise APIClientError(
166 |                 message=error_message,
167 |                 status_code=response.status_code,
168 |                 response_body=parsed_body,
169 |             )
170 |         elif response.status_code >= 500:
171 |             raise APIServerError(
172 |                 message=error_message,
173 |                 status_code=response.status_code,
174 |                 response_body=parsed_body,
175 |             )
176 |         else:
177 |             # This should not happen, but just in case
178 |             raise UnexpectedError(
179 |                 message=f"Unexpected status code: {response.status_code}",
180 |                 status_code=response.status_code,
181 |                 response_body=parsed_body,
182 |             )
183 | 
184 |     async def execute_request(
185 |         self,
186 |         method: str,
187 |         path: str,
188 |         request_params: dict[str, Any] | None = None,
189 |         request_body: dict[str, Any] | None = None,
190 |     ) -> dict[str, Any] | BaseModel:
191 |         """
192 |         Execute an HTTP request.
193 | 
194 |         Args:
195 |             method: HTTP method (GET, POST, etc.)
196 |             path: API path
197 |             request_params: Query parameters
198 |             request_body: Request body
199 | 
200 |         Returns:
201 |             API response as a dictionary
202 | 
203 |         Raises:
204 |             APIClientError: For client errors (4xx)
205 |             APIConnectionError: For connection issues
206 |             APIResponseError: For response parsing errors
207 |             UnexpectedError: For unexpected errors
208 |         """
209 |         # Log detailed request information
210 |         logger.info(f"API Client: Executing {method} request to {path}")
211 |         if request_params:
212 |             logger.debug(f"Request params: {request_params}")
213 |         if request_body:
214 |             logger.debug(f"Request body: {request_body}")
215 | 
216 |         # Get client
217 |         client = await self._ensure_client()
218 | 
219 |         # Prepare request
220 |         request = self.prepare_request(client, method, path, request_params, request_body)
221 | 
222 |         # Send request
223 |         response = await self.send_request(client, request)
224 | 
225 |         # Parse response (for both success and error cases)
226 |         parsed_body = self.parse_response(response)
227 | 
228 |         # Check if successful
229 |         if not response.is_success:
230 |             logger.warning(f"Request failed: {method} {path} - Status {response.status_code}")
231 |             self.handle_error_response(response, parsed_body)
232 | 
233 |         # Log success and return
234 |         logger.info(f"Request successful: {method} {path} - Status {response.status_code}")
235 |         return parsed_body
236 | 
```

--------------------------------------------------------------------------------
/tests/services/logs/test_log_manager.py:
--------------------------------------------------------------------------------

```python
  1 | from unittest.mock import patch
  2 | 
  3 | import pytest
  4 | 
  5 | from supabase_mcp.services.database.sql.loader import SQLLoader
  6 | from supabase_mcp.services.logs.log_manager import LogManager
  7 | 
  8 | 
  9 | class TestLogManager:
 10 |     """Tests for the LogManager class."""
 11 | 
 12 |     def test_init(self):
 13 |         """Test initialization of LogManager."""
 14 |         log_manager = LogManager()
 15 |         assert isinstance(log_manager.sql_loader, SQLLoader)
 16 |         assert log_manager.COLLECTION_TO_TABLE["postgres"] == "postgres_logs"
 17 |         assert log_manager.COLLECTION_TO_TABLE["api_gateway"] == "edge_logs"
 18 |         assert log_manager.COLLECTION_TO_TABLE["edge_functions"] == "function_edge_logs"
 19 | 
 20 |     @pytest.mark.parametrize(
 21 |         "collection,hours_ago,filters,search,expected_clause",
 22 |         [
 23 |             # Test with hours_ago only
 24 |             (
 25 |                 "postgres",
 26 |                 24,
 27 |                 None,
 28 |                 None,
 29 |                 "WHERE postgres_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)",
 30 |             ),
 31 |             # Test with search only
 32 |             (
 33 |                 "auth",
 34 |                 None,
 35 |                 None,
 36 |                 "error",
 37 |                 "WHERE event_message LIKE '%error%'",
 38 |             ),
 39 |             # Test with filters only
 40 |             (
 41 |                 "api_gateway",
 42 |                 None,
 43 |                 [{"field": "status_code", "operator": "=", "value": 500}],
 44 |                 None,
 45 |                 "WHERE status_code = 500",
 46 |             ),
 47 |             # Test with string value in filters
 48 |             (
 49 |                 "api_gateway",
 50 |                 None,
 51 |                 [{"field": "method", "operator": "=", "value": "GET"}],
 52 |                 None,
 53 |                 "WHERE method = 'GET'",
 54 |             ),
 55 |             # Test with multiple filters
 56 |             (
 57 |                 "postgres",
 58 |                 None,
 59 |                 [
 60 |                     {"field": "parsed.error_severity", "operator": "=", "value": "ERROR"},
 61 |                     {"field": "parsed.application_name", "operator": "LIKE", "value": "app%"},
 62 |                 ],
 63 |                 None,
 64 |                 "WHERE parsed.error_severity = 'ERROR' AND parsed.application_name LIKE 'app%'",
 65 |             ),
 66 |             # Test with hours_ago and search
 67 |             (
 68 |                 "storage",
 69 |                 12,
 70 |                 None,
 71 |                 "upload",
 72 |                 "WHERE storage_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 12 HOUR) AND event_message LIKE '%upload%'",
 73 |             ),
 74 |             # Test with all parameters
 75 |             (
 76 |                 "edge_functions",
 77 |                 6,
 78 |                 [{"field": "response.status_code", "operator": ">", "value": 400}],
 79 |                 "timeout",
 80 |                 "WHERE function_edge_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 6 HOUR) AND event_message LIKE '%timeout%' AND response.status_code > 400",
 81 |             ),
 82 |             # Test with cron logs (special case)
 83 |             (
 84 |                 "cron",
 85 |                 24,
 86 |                 None,
 87 |                 None,
 88 |                 "AND postgres_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)",
 89 |             ),
 90 |             # Test with cron logs and other parameters
 91 |             (
 92 |                 "cron",
 93 |                 12,
 94 |                 [{"field": "parsed.error_severity", "operator": "=", "value": "ERROR"}],
 95 |                 "failed",
 96 |                 "AND postgres_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 12 HOUR) AND event_message LIKE '%failed%' AND parsed.error_severity = 'ERROR'",
 97 |             ),
 98 |         ],
 99 |     )
100 |     def test_build_where_clause(self, collection, hours_ago, filters, search, expected_clause):
101 |         """Test building WHERE clauses for different scenarios."""
102 |         log_manager = LogManager()
103 |         where_clause = log_manager._build_where_clause(
104 |             collection=collection, hours_ago=hours_ago, filters=filters, search=search
105 |         )
106 |         assert where_clause == expected_clause
107 | 
108 |     def test_build_where_clause_escapes_single_quotes(self):
109 |         """Test that single quotes in search strings are properly escaped."""
110 |         log_manager = LogManager()
111 |         where_clause = log_manager._build_where_clause(collection="postgres", search="O'Reilly")
112 |         assert where_clause == "WHERE event_message LIKE '%O''Reilly%'"
113 | 
114 |         # Test with filters containing single quotes
115 |         where_clause = log_manager._build_where_clause(
116 |             collection="postgres",
117 |             filters=[{"field": "parsed.query", "operator": "LIKE", "value": "SELECT * FROM O'Reilly"}],
118 |         )
119 |         assert where_clause == "WHERE parsed.query LIKE 'SELECT * FROM O''Reilly'"
120 | 
121 |     @patch.object(SQLLoader, "get_logs_query")
122 |     def test_build_logs_query_with_custom_query(self, mock_get_logs_query):
123 |         """Test building a logs query with a custom query."""
124 |         log_manager = LogManager()
125 |         custom_query = "SELECT * FROM postgres_logs LIMIT 10"
126 | 
127 |         query = log_manager.build_logs_query(collection="postgres", custom_query=custom_query)
128 | 
129 |         assert query == custom_query
130 |         # Ensure get_logs_query is not called when custom_query is provided
131 |         mock_get_logs_query.assert_not_called()
132 | 
133 |     @patch.object(LogManager, "_build_where_clause")
134 |     @patch.object(SQLLoader, "get_logs_query")
135 |     def test_build_logs_query_standard(self, mock_get_logs_query, mock_build_where_clause):
136 |         """Test building a standard logs query."""
137 |         log_manager = LogManager()
138 |         mock_build_where_clause.return_value = "WHERE timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)"
139 |         mock_get_logs_query.return_value = "SELECT * FROM postgres_logs WHERE timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR) LIMIT 20"
140 | 
141 |         query = log_manager.build_logs_query(
142 |             collection="postgres",
143 |             limit=20,
144 |             hours_ago=24,
145 |             filters=[{"field": "parsed.error_severity", "operator": "=", "value": "ERROR"}],
146 |             search="connection",
147 |         )
148 | 
149 |         mock_build_where_clause.assert_called_once_with(
150 |             collection="postgres",
151 |             hours_ago=24,
152 |             filters=[{"field": "parsed.error_severity", "operator": "=", "value": "ERROR"}],
153 |             search="connection",
154 |         )
155 | 
156 |         mock_get_logs_query.assert_called_once_with(
157 |             collection="postgres",
158 |             where_clause="WHERE timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)",
159 |             limit=20,
160 |         )
161 | 
162 |         assert (
163 |             query
164 |             == "SELECT * FROM postgres_logs WHERE timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR) LIMIT 20"
165 |         )
166 | 
167 |     @patch.object(SQLLoader, "get_logs_query")
168 |     def test_build_logs_query_integration(self, mock_get_logs_query, sql_loader):
169 |         """Test building a logs query with integration between components."""
170 |         # Setup
171 |         log_manager = LogManager()
172 |         log_manager.sql_loader = sql_loader
173 | 
174 |         # Mock the SQL loader to return a predictable result
175 |         mock_get_logs_query.return_value = (
176 |             "SELECT id, postgres_logs.timestamp, event_message FROM postgres_logs "
177 |             "WHERE postgres_logs.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR) "
178 |             "ORDER BY timestamp DESC LIMIT 10"
179 |         )
180 | 
181 |         # Execute
182 |         query = log_manager.build_logs_query(
183 |             collection="postgres",
184 |             limit=10,
185 |             hours_ago=24,
186 |         )
187 | 
188 |         # Verify
189 |         assert "SELECT id, postgres_logs.timestamp, event_message FROM postgres_logs" in query
190 |         assert "LIMIT 10" in query
191 |         mock_get_logs_query.assert_called_once()
192 | 
193 |     def test_unknown_collection(self):
194 |         """Test handling of unknown collections."""
195 |         log_manager = LogManager()
196 | 
197 |         # Test with a collection that doesn't exist in the mapping
198 |         where_clause = log_manager._build_where_clause(
199 |             collection="unknown_collection",
200 |             hours_ago=24,
201 |         )
202 | 
203 |         # Should use the collection name as the table name
204 |         assert (
205 |             where_clause == "WHERE unknown_collection.timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 24 HOUR)"
206 |         )
207 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/services/database/query_manager.py:
--------------------------------------------------------------------------------

```python
  1 | from supabase_mcp.exceptions import OperationNotAllowedError
  2 | from supabase_mcp.logger import logger
  3 | from supabase_mcp.services.database.migration_manager import MigrationManager
  4 | from supabase_mcp.services.database.postgres_client import PostgresClient, QueryResult
  5 | from supabase_mcp.services.database.sql.loader import SQLLoader
  6 | from supabase_mcp.services.database.sql.models import QueryValidationResults
  7 | from supabase_mcp.services.database.sql.validator import SQLValidator
  8 | from supabase_mcp.services.safety.models import ClientType, SafetyMode
  9 | from supabase_mcp.services.safety.safety_manager import SafetyManager
 10 | 
 11 | 
 12 | class QueryManager:
 13 |     """
 14 |     Manages SQL query execution with validation and migration handling.
 15 | 
 16 |     This class is responsible for:
 17 |     1. Validating SQL queries for safety
 18 |     2. Executing queries through the database client
 19 |     3. Managing migrations for queries that require them
 20 |     4. Loading SQL queries from files
 21 | 
 22 |     It acts as a central point for all SQL operations, ensuring consistent
 23 |     validation and execution patterns.
 24 |     """
 25 | 
 26 |     def __init__(
 27 |         self,
 28 |         postgres_client: PostgresClient,
 29 |         safety_manager: SafetyManager,
 30 |         sql_validator: SQLValidator | None = None,
 31 |         migration_manager: MigrationManager | None = None,
 32 |         sql_loader: SQLLoader | None = None,
 33 |     ):
 34 |         """
 35 |         Initialize the QueryManager.
 36 | 
 37 |         Args:
 38 |             postgres_client: The database client to use for executing queries
 39 |             safety_manager: The safety manager to use for validating operations
 40 |             sql_validator: Optional SQL validator to use
 41 |             migration_manager: Optional migration manager to use
 42 |             sql_loader: Optional SQL loader to use
 43 |         """
 44 |         self.db_client = postgres_client
 45 |         self.safety_manager = safety_manager
 46 |         self.validator = sql_validator or SQLValidator()
 47 |         self.sql_loader = sql_loader or SQLLoader()
 48 |         self.migration_manager = migration_manager or MigrationManager(loader=self.sql_loader)
 49 | 
 50 |     def check_readonly(self) -> bool:
 51 |         """Returns true if current safety mode is SAFE."""
 52 |         result = self.safety_manager.get_safety_mode(ClientType.DATABASE) == SafetyMode.SAFE
 53 |         logger.debug(f"Check readonly result: {result}")
 54 |         return result
 55 | 
 56 |     async def handle_query(self, query: str, has_confirmation: bool = False, migration_name: str = "") -> QueryResult:
 57 |         """
 58 |         Handle a SQL query with validation and potential migration. Uses migration name, if provided.
 59 | 
 60 |         This method:
 61 |         1. Validates the query for safety
 62 |         2. Checks if the query requires migration
 63 |         3. Handles migration if needed
 64 |         4. Executes the query
 65 | 
 66 |         Args:
 67 |             query: SQL query to execute
 68 |             params: Query parameters
 69 |             has_confirmation: Whether the operation has been confirmed by the user
 70 | 
 71 |         Returns:
 72 |             QueryResult: The result of the query execution
 73 | 
 74 |         Raises:
 75 |             OperationNotAllowedError: If the query is not allowed in the current safety mode
 76 |             ConfirmationRequiredError: If the query requires confirmation and has_confirmation is False
 77 |         """
 78 |         # 1. Run through the validator
 79 |         validated_query = self.validator.validate_query(query)
 80 | 
 81 |         # 2. Ensure execution is allowed
 82 |         self.safety_manager.validate_operation(ClientType.DATABASE, validated_query, has_confirmation)
 83 |         logger.debug(f"Operation with risk level {validated_query.highest_risk_level} validated successfully")
 84 | 
 85 |         # 3. Handle migration if needed
 86 |         await self.handle_migration(validated_query, query, migration_name)
 87 | 
 88 |         # 4. Execute the query
 89 |         return await self.handle_query_execution(validated_query)
 90 | 
 91 |     async def handle_query_execution(self, validated_query: QueryValidationResults) -> QueryResult:
 92 |         """
 93 |         Handle query execution with validation and potential migration.
 94 | 
 95 |         This method:
 96 |         1. Checks the readonly mode
 97 |         2. Executes the query
 98 |         3. Returns the result
 99 | 
100 |         Args:
101 |             validated_query: The validation result
102 |             query: The original query
103 | 
104 |         Returns:
105 |             QueryResult: The result of the query execution
106 |         """
107 |         readonly = self.check_readonly()
108 |         result = await self.db_client.execute_query(validated_query, readonly)
109 |         logger.debug(f"Query result: {result}")
110 |         return result
111 | 
112 |     async def handle_migration(
113 |         self, validation_result: QueryValidationResults, original_query: str, migration_name: str = ""
114 |     ) -> None:
115 |         """
116 |         Handle migration for a query that requires it.
117 | 
118 |         Args:
119 |             validation_result: The validation result
120 |             query: The original query
121 |             migration_name: Migration name to use, if provided
122 |         """
123 |         # 1. Check if migration is needed
124 |         if not validation_result.needs_migration():
125 |             logger.debug("No migration needed for this query")
126 |             return
127 | 
128 |         # 2. Prepare migration query
129 |         migration_query, name = self.migration_manager.prepare_migration_query(
130 |             validation_result, original_query, migration_name
131 |         )
132 |         logger.debug("Migration query prepared")
133 | 
134 |         # 3. Execute migration query
135 |         try:
136 |             # First, ensure the migration schema exists
137 |             await self.init_migration_schema()
138 | 
139 |             # Then execute the migration query
140 |             migration_validation = self.validator.validate_query(migration_query)
141 |             await self.db_client.execute_query(migration_validation, readonly=False)
142 |             logger.info(f"Migration '{name}' executed successfully")
143 |         except Exception as e:
144 |             logger.debug(f"Migration failure details: {str(e)}")
145 |             # We don't want to fail the main query if migration fails
146 |             # Just log the error and continue
147 |             logger.warning(f"Failed to record migration '{name}': {e}")
148 | 
149 |     async def init_migration_schema(self) -> None:
150 |         """Initialize the migrations schema and table if they don't exist."""
151 |         try:
152 |             # Get the initialization query
153 |             init_query = self.sql_loader.get_init_migrations_query()
154 | 
155 |             # Validate and execute it
156 |             init_validation = self.validator.validate_query(init_query)
157 |             await self.db_client.execute_query(init_validation, readonly=False)
158 |             logger.debug("Migrations schema initialized successfully")
159 |         except Exception as e:
160 |             logger.warning(f"Failed to initialize migrations schema: {e}")
161 | 
162 |     async def handle_confirmation(self, confirmation_id: str) -> QueryResult:
163 |         """
164 |         Handle a confirmed operation using its confirmation ID.
165 | 
166 |         This method retrieves the stored operation and passes it to handle_query.
167 | 
168 |         Args:
169 |             confirmation_id: The unique ID of the confirmation to process
170 | 
171 |         Returns:
172 |             QueryResult: The result of the query execution
173 |         """
174 |         # Get the stored operation
175 |         operation = self.safety_manager.get_stored_operation(confirmation_id)
176 |         if not operation:
177 |             raise OperationNotAllowedError(f"Invalid or expired confirmation ID: {confirmation_id}")
178 | 
179 |         # Get the query from the operation
180 |         query = operation.original_query
181 |         logger.debug(f"Processing confirmed operation with ID {confirmation_id}")
182 | 
183 |         # Call handle_query with the query and has_confirmation=True
184 |         return await self.handle_query(query, has_confirmation=True)
185 | 
186 |     def get_schemas_query(self) -> str:
187 |         """Get a query to list all schemas."""
188 |         return self.sql_loader.get_schemas_query()
189 | 
190 |     def get_tables_query(self, schema_name: str) -> str:
191 |         """Get a query to list all tables in a schema."""
192 |         return self.sql_loader.get_tables_query(schema_name)
193 | 
194 |     def get_table_schema_query(self, schema_name: str, table: str) -> str:
195 |         """Get a query to get the schema of a table."""
196 |         return self.sql_loader.get_table_schema_query(schema_name, table)
197 | 
198 |     def get_migrations_query(
199 |         self, limit: int = 50, offset: int = 0, name_pattern: str = "", include_full_queries: bool = False
200 |     ) -> str:
201 |         """Get a query to list migrations."""
202 |         return self.sql_loader.get_migrations_query(
203 |             limit=limit, offset=offset, name_pattern=name_pattern, include_full_queries=include_full_queries
204 |         )
205 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/clients/management_client.py:
--------------------------------------------------------------------------------

```python
  1 | from __future__ import annotations
  2 | 
  3 | from json.decoder import JSONDecodeError
  4 | from typing import Any
  5 | 
  6 | import httpx
  7 | from httpx import Request, Response
  8 | from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential
  9 | 
 10 | from supabase_mcp.exceptions import (
 11 |     APIClientError,
 12 |     APIConnectionError,
 13 |     APIResponseError,
 14 |     APIServerError,
 15 |     UnexpectedError,
 16 | )
 17 | from supabase_mcp.logger import logger
 18 | from supabase_mcp.settings import Settings
 19 | 
 20 | 
 21 | # Helper function for retry decorator to safely log exceptions
 22 | def log_retry_attempt(retry_state: RetryCallState) -> None:
 23 |     """Log retry attempts with exception details if available."""
 24 |     exception = retry_state.outcome.exception() if retry_state.outcome and retry_state.outcome.failed else None
 25 |     exception_str = str(exception) if exception else "Unknown error"
 26 |     logger.warning(f"Network error, retrying ({retry_state.attempt_number}/3): {exception_str}")
 27 | 
 28 | 
 29 | class ManagementAPIClient:
 30 |     """
 31 |     Client for Supabase Management API.
 32 | 
 33 |     Handles low-level HTTP requests to the Supabase Management API.
 34 |     """
 35 | 
 36 |     def __init__(self, settings: Settings) -> None:
 37 |         """Initialize the API client with default settings."""
 38 |         self.settings = settings
 39 |         self.client = self.create_httpx_client(settings)
 40 | 
 41 |         logger.info("✔️ Management API client initialized successfully")
 42 | 
 43 |     def create_httpx_client(self, settings: Settings) -> httpx.AsyncClient:
 44 |         """Create and configure an httpx client for API requests."""
 45 |         headers = {
 46 |             "Authorization": f"Bearer {settings.supabase_access_token}",
 47 |             "Content-Type": "application/json",
 48 |         }
 49 | 
 50 |         return httpx.AsyncClient(
 51 |             base_url=settings.supabase_api_url,
 52 |             headers=headers,
 53 |             timeout=30.0,
 54 |         )
 55 | 
 56 |     def prepare_request(
 57 |         self,
 58 |         method: str,
 59 |         path: str,
 60 |         request_params: dict[str, Any] | None = None,
 61 |         request_body: dict[str, Any] | None = None,
 62 |     ) -> Request:
 63 |         """
 64 |         Prepare an HTTP request to the Supabase Management API.
 65 | 
 66 |         Args:
 67 |             method: HTTP method (GET, POST, etc.)
 68 |             path: API path
 69 |             request_params: Query parameters
 70 |             request_body: Request body
 71 | 
 72 |         Returns:
 73 |             Prepared httpx.Request object
 74 | 
 75 |         Raises:
 76 |             APIClientError: If request preparation fails
 77 |         """
 78 |         try:
 79 |             return self.client.build_request(method=method, url=path, params=request_params, json=request_body)
 80 |         except Exception as e:
 81 |             raise APIClientError(
 82 |                 message=f"Failed to build request: {str(e)}",
 83 |                 status_code=None,
 84 |             ) from e
 85 | 
 86 |     @retry(
 87 |         retry=retry_if_exception_type(httpx.NetworkError),  # This includes ConnectError and TimeoutException
 88 |         stop=stop_after_attempt(3),
 89 |         wait=wait_exponential(multiplier=1, min=2, max=10),
 90 |         reraise=True,  # Ensure the original exception is raised
 91 |         before_sleep=log_retry_attempt,
 92 |     )
 93 |     async def send_request(self, request: Request) -> Response:
 94 |         """
 95 |         Send an HTTP request with retry logic for transient errors.
 96 | 
 97 |         Args:
 98 |             request: Prepared httpx.Request object
 99 | 
100 |         Returns:
101 |             httpx.Response object
102 | 
103 |         Raises:
104 |             APIConnectionError: For connection issues
105 |             APIClientError: For other request errors
106 |         """
107 |         try:
108 |             return await self.client.send(request)
109 |         except httpx.NetworkError as e:
110 |             # All NetworkErrors will be retried by the decorator
111 |             # This will only be reached after all retries are exhausted
112 |             logger.error(f"Network error after all retry attempts: {str(e)}")
113 |             raise APIConnectionError(
114 |                 message=f"Network error after 3 retry attempts: {str(e)}",
115 |                 status_code=None,
116 |             ) from e
117 |         except Exception as e:
118 |             # Other exceptions won't be retried
119 |             raise APIClientError(
120 |                 message=f"Request failed: {str(e)}",
121 |                 status_code=None,
122 |             ) from e
123 | 
124 |     def parse_response(self, response: Response) -> dict[str, Any]:
125 |         """
126 |         Parse an HTTP response as JSON.
127 | 
128 |         Args:
129 |             response: httpx.Response object
130 | 
131 |         Returns:
132 |             Parsed response body as dictionary
133 | 
134 |         Raises:
135 |             APIResponseError: If response cannot be parsed as JSON
136 |         """
137 |         if not response.content:
138 |             return {}
139 | 
140 |         try:
141 |             return response.json()
142 |         except JSONDecodeError as e:
143 |             raise APIResponseError(
144 |                 message=f"Failed to parse response as JSON: {str(e)}",
145 |                 status_code=response.status_code,
146 |                 response_body={"raw_content": response.text},
147 |             ) from e
148 | 
149 |     def handle_error_response(self, response: Response, parsed_body: dict[str, Any] | None = None) -> None:
150 |         """
151 |         Handle error responses based on status code.
152 | 
153 |         Args:
154 |             response: httpx.Response object
155 |             parsed_body: Parsed response body if available
156 | 
157 |         Raises:
158 |             APIClientError: For client errors (4xx)
159 |             APIServerError: For server errors (5xx)
160 |             UnexpectedError: For unexpected status codes
161 |         """
162 |         # Extract error message
163 |         error_message = f"API request failed: {response.status_code}"
164 |         if parsed_body and "message" in parsed_body:
165 |             error_message = parsed_body["message"]
166 | 
167 |         # Determine error type based on status code
168 |         if 400 <= response.status_code < 500:
169 |             raise APIClientError(
170 |                 message=error_message,
171 |                 status_code=response.status_code,
172 |                 response_body=parsed_body,
173 |             )
174 |         elif response.status_code >= 500:
175 |             raise APIServerError(
176 |                 message=error_message,
177 |                 status_code=response.status_code,
178 |                 response_body=parsed_body,
179 |             )
180 |         else:
181 |             # This should not happen, but just in case
182 |             raise UnexpectedError(
183 |                 message=f"Unexpected status code: {response.status_code}",
184 |                 status_code=response.status_code,
185 |                 response_body=parsed_body,
186 |             )
187 | 
188 |     async def execute_request(
189 |         self,
190 |         method: str,
191 |         path: str,
192 |         request_params: dict[str, Any] | None = None,
193 |         request_body: dict[str, Any] | None = None,
194 |     ) -> dict[str, Any]:
195 |         """
196 |         Execute an HTTP request to the Supabase Management API.
197 | 
198 |         Args:
199 |             method: HTTP method (GET, POST, etc.)
200 |             path: API path
201 |             request_params: Query parameters
202 |             request_body: Request body
203 | 
204 |         Returns:
205 |             API response as a dictionary
206 | 
207 |         Raises:
208 |             APIClientError: For client errors (4xx)
209 |             APIConnectionError: For connection issues
210 |             APIResponseError: For response parsing errors
211 |             UnexpectedError: For unexpected errors
212 |         """
213 |         # Check if access token is available
214 |         if not self.settings.supabase_access_token:
215 |             raise APIClientError(
216 |                 "Supabase access token is not configured. Set SUPABASE_ACCESS_TOKEN environment variable to use Management API tools."
217 |             )
218 | 
219 |         # Log detailed request information
220 |         logger.info(f"API Client: Executing {method} request to {path}")
221 |         if request_params:
222 |             logger.debug(f"Request params: {request_params}")
223 |         if request_body:
224 |             logger.debug(f"Request body: {request_body}")
225 | 
226 |         # Prepare request
227 |         request = self.prepare_request(method, path, request_params, request_body)
228 | 
229 |         # Send request
230 |         response = await self.send_request(request)
231 | 
232 |         # Parse response (for both success and error cases)
233 |         parsed_body = self.parse_response(response)
234 | 
235 |         # Check if successful
236 |         if not response.is_success:
237 |             logger.warning(f"Request failed: {method} {path} - Status {response.status_code}")
238 |             self.handle_error_response(response, parsed_body)
239 | 
240 |         # Log success and return
241 |         logger.info(f"Request successful: {method} {path} - Status {response.status_code}")
242 |         return parsed_body
243 | 
244 |     async def close(self) -> None:
245 |         """Close the HTTP client and release resources."""
246 |         if self.client:
247 |             await self.client.aclose()
248 |             logger.info("HTTP API client closed")
249 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/services/api/spec_manager.py:
--------------------------------------------------------------------------------

```python
  1 | import json
  2 | from enum import Enum
  3 | from pathlib import Path
  4 | from typing import Any
  5 | 
  6 | import httpx
  7 | 
  8 | from supabase_mcp.logger import logger
  9 | 
 10 | # Constants
 11 | SPEC_URL = "https://api.supabase.com/api/v1-json"
 12 | LOCAL_SPEC_PATH = Path(__file__).parent / "specs" / "api_spec.json"
 13 | 
 14 | 
 15 | class ApiDomain(str, Enum):
 16 |     """Enum of all possible domains in the Supabase Management API."""
 17 | 
 18 |     ANALYTICS = "Analytics"
 19 |     AUTH = "Auth"
 20 |     DATABASE = "Database"
 21 |     DOMAINS = "Domains"
 22 |     EDGE_FUNCTIONS = "Edge Functions"
 23 |     ENVIRONMENTS = "Environments"
 24 |     OAUTH = "OAuth"
 25 |     ORGANIZATIONS = "Organizations"
 26 |     PROJECTS = "Projects"
 27 |     REST = "Rest"
 28 |     SECRETS = "Secrets"
 29 |     STORAGE = "Storage"
 30 | 
 31 |     @classmethod
 32 |     def list(cls) -> list[str]:
 33 |         """Return a list of all domain values."""
 34 |         return [domain.value for domain in cls]
 35 | 
 36 | 
 37 | class ApiSpecManager:
 38 |     """
 39 |     Manages the OpenAPI specification for the Supabase Management API.
 40 |     Handles spec loading, caching, and validation.
 41 |     """
 42 | 
 43 |     def __init__(self) -> None:
 44 |         self.spec: dict[str, Any] | None = None
 45 |         self._paths_cache: dict[str, dict[str, str]] | None = None
 46 |         self._domains_cache: list[str] | None = None
 47 | 
 48 |     async def _fetch_remote_spec(self) -> dict[str, Any] | None:
 49 |         """
 50 |         Fetch latest OpenAPI spec from Supabase API.
 51 |         Returns None if fetch fails.
 52 |         """
 53 |         try:
 54 |             async with httpx.AsyncClient() as client:
 55 |                 response = await client.get(SPEC_URL)
 56 |                 if response.status_code == 200:
 57 |                     return response.json()
 58 |                 logger.warning(f"Failed to fetch API spec: {response.status_code}")
 59 |                 return None
 60 |         except Exception as e:
 61 |             logger.warning(f"Error fetching API spec: {e}")
 62 |             return None
 63 | 
 64 |     def _load_local_spec(self) -> dict[str, Any]:
 65 |         """
 66 |         Load OpenAPI spec from local file.
 67 |         This is our fallback spec shipped with the server.
 68 |         """
 69 |         try:
 70 |             with open(LOCAL_SPEC_PATH) as f:
 71 |                 return json.load(f)
 72 |         except FileNotFoundError:
 73 |             logger.error(f"Local spec not found at {LOCAL_SPEC_PATH}")
 74 |             raise
 75 |         except json.JSONDecodeError as e:
 76 |             logger.error(f"Invalid JSON in local spec: {e}")
 77 |             raise
 78 | 
 79 |     async def get_spec(self) -> dict[str, Any]:
 80 |         """Retrieve the enriched spec."""
 81 |         if self.spec is None:
 82 |             raw_spec = await self._fetch_remote_spec()
 83 |             if not raw_spec:
 84 |                 # If remote fetch fails, use our fallback spec
 85 |                 logger.info("Using fallback API spec")
 86 |                 raw_spec = self._load_local_spec()
 87 |             self.spec = raw_spec
 88 | 
 89 |         return self.spec
 90 | 
 91 |     def get_all_paths_and_methods(self) -> dict[str, dict[str, str]]:
 92 |         """
 93 |         Returns a dictionary of all paths and their methods with operation IDs.
 94 | 
 95 |         Returns:
 96 |             Dict[str, Dict[str, str]]:  {path: {method: operationId}}
 97 |         """
 98 |         if self._paths_cache is None:
 99 |             self._build_caches()
100 |         return self._paths_cache or {}
101 | 
102 |     def get_paths_and_methods_by_domain(self, domain: str) -> dict[str, dict[str, str]]:
103 |         """
104 |         Returns paths and methods within a specific domain (tag).
105 | 
106 |         Args:
107 |             domain (str): The domain name (e.g., "Auth", "Projects").
108 | 
109 |         Returns:
110 |             Dict[str, Dict[str, str]]: {path: {method: operationId}}
111 |         """
112 | 
113 |         if self._paths_cache is None:
114 |             self._build_caches()
115 | 
116 |         # Validate domain using enum
117 |         try:
118 |             valid_domain = ApiDomain(domain).value
119 |         except ValueError as e:
120 |             raise ValueError(f"Invalid domain: {domain}") from e
121 | 
122 |         domain_paths: dict[str, dict[str, str]] = {}
123 |         if self.spec:
124 |             for path, methods in self.spec.get("paths", {}).items():
125 |                 for method, details in methods.items():
126 |                     if valid_domain in details.get("tags", []):
127 |                         if path not in domain_paths:
128 |                             domain_paths[path] = {}
129 |                         domain_paths[path][method] = details.get("operationId", "")
130 |         return domain_paths
131 | 
132 |     def get_all_domains(self) -> list[str]:
133 |         """
134 |         Returns a list of all available domains (tags).
135 | 
136 |         Returns:
137 |             List[str]:  List of domain names.
138 |         """
139 |         if self._domains_cache is None:
140 |             self._build_caches()
141 |         return self._domains_cache or []
142 | 
143 |     def get_spec_for_path_and_method(self, path: str, method: str) -> dict[str, Any] | None:
144 |         """
145 |         Returns the full specification for a given path and HTTP method.
146 | 
147 |         Args:
148 |             path (str): The API path (e.g., "/v1/projects").
149 |             method (str): The HTTP method (e.g., "get", "post").
150 | 
151 |         Returns:
152 |             Optional[Dict[str, Any]]: The full spec for the operation, or None if not found.
153 |         """
154 |         if self.spec is None:
155 |             return None
156 | 
157 |         path_spec = self.spec.get("paths", {}).get(path)
158 |         if path_spec:
159 |             return path_spec.get(method.lower())  # Ensure lowercase method
160 |         return None
161 | 
162 |     def get_spec_part(self, part: str, *args: str | int) -> Any:
163 |         """
164 |         Safely retrieves a nested part of the OpenAPI spec.
165 | 
166 |         Args:
167 |             part: The top-level key (e.g., 'paths', 'components').
168 |             *args:  Subsequent keys or indices to traverse the spec.
169 | 
170 |         Returns:
171 |             The value at the specified location in the spec, or None if not found.
172 |         """
173 |         if self.spec is None:
174 |             return None
175 | 
176 |         current = self.spec.get(part)
177 |         for key in args:
178 |             if isinstance(current, dict) and key in current:
179 |                 current = current[key]
180 |             elif isinstance(current, list) and isinstance(key, int) and 0 <= key < len(current):
181 |                 current = current[key]
182 |             else:
183 |                 return None  # Key not found or invalid index
184 |         return current
185 | 
186 |     def _build_caches(self) -> None:
187 |         """
188 |         Build internal caches for faster lookups.
189 |         This populates _paths_cache and _domains_cache.
190 |         """
191 |         if self.spec is None:
192 |             logger.error("Cannot build caches: OpenAPI spec not loaded")
193 |             return
194 | 
195 |         # Build paths cache
196 |         paths_cache: dict[str, dict[str, str]] = {}
197 |         domains_set = set()
198 | 
199 |         for path, methods in self.spec.get("paths", {}).items():
200 |             for method, details in methods.items():
201 |                 # Add to paths cache
202 |                 if path not in paths_cache:
203 |                     paths_cache[path] = {}
204 |                 paths_cache[path][method] = details.get("operationId", "")
205 | 
206 |                 # Collect domains (tags)
207 |                 for tag in details.get("tags", []):
208 |                     domains_set.add(tag)
209 | 
210 |         self._paths_cache = paths_cache
211 |         self._domains_cache = sorted(list(domains_set))
212 | 
213 | 
214 | # Example usage (assuming you have an instance of ApiSpecManager called 'spec_manager'):
215 | async def main() -> None:
216 |     """Test function to demonstrate ApiSpecManager functionality."""
217 |     # Create a new instance of ApiSpecManager
218 |     spec_manager = ApiSpecManager()
219 | 
220 |     # Load the spec
221 |     await spec_manager.get_spec()
222 | 
223 |     # Print the path to help debug
224 |     print(f"Looking for spec at: {LOCAL_SPEC_PATH}")
225 | 
226 |     # 1. Get all domains
227 |     all_domains = spec_manager.get_all_domains()
228 |     print("\nAll Domains:")
229 |     print(all_domains)
230 | 
231 |     # 2. Get all paths and methods
232 |     all_paths = spec_manager.get_all_paths_and_methods()
233 |     print("\nAll Paths and Methods (sample):")
234 |     # Just print a few to avoid overwhelming output
235 |     for i, (path, methods) in enumerate(all_paths.items()):
236 |         if i >= 5:  # Limit to 5 paths
237 |             break
238 |         print(f"  {path}:")
239 |         for method, operation_id in methods.items():
240 |             print(f"    {method}: {operation_id}")
241 | 
242 |     # 3. Get paths and methods for the "Edge Functions" domain
243 |     edge_paths = spec_manager.get_paths_and_methods_by_domain("Edge Functions")
244 |     print("\nEdge Functions Paths and Methods:")
245 |     for path, methods in edge_paths.items():
246 |         print(f"  {path}:")
247 |         for method, operation_id in methods.items():
248 |             print(f"    {method}: {operation_id}")
249 | 
250 |     # 4. Get the full spec for a specific path and method
251 |     path = "/v1/projects/{ref}/functions"
252 |     method = "GET"
253 |     full_spec = spec_manager.get_spec_for_path_and_method(path, method)
254 |     print(f"\nFull Spec for {method} {path}:")
255 |     if full_spec:
256 |         print(json.dumps(full_spec, indent=2)[:500] + "...")  # Truncate for readability
257 |     else:
258 |         print("Spec not found for this path/method")
259 | 
260 | 
261 | if __name__ == "__main__":
262 |     import asyncio
263 | 
264 |     asyncio.run(main())
265 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/tools/registry.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Any, Literal
  2 | 
  3 | from mcp.server.fastmcp import FastMCP
  4 | 
  5 | from supabase_mcp.core.container import ServicesContainer
  6 | from supabase_mcp.services.database.postgres_client import QueryResult
  7 | from supabase_mcp.tools.manager import ToolName
  8 | 
  9 | 
 10 | class ToolRegistry:
 11 |     """Responsible for registering tools with the MCP server"""
 12 | 
 13 |     def __init__(self, mcp: FastMCP, services_container: ServicesContainer):
 14 |         self.mcp = mcp
 15 |         self.services_container = services_container
 16 | 
 17 |     def register_tools(self) -> FastMCP:
 18 |         """Register all tools with the MCP server"""
 19 |         mcp = self.mcp
 20 |         services_container = self.services_container
 21 | 
 22 |         tool_manager = services_container.tool_manager
 23 |         feature_manager = services_container.feature_manager
 24 | 
 25 |         @mcp.tool(description=tool_manager.get_description(ToolName.GET_SCHEMAS))  # type: ignore
 26 |         async def get_schemas() -> QueryResult:
 27 |             """List all database schemas with their sizes and table counts."""
 28 |             return await feature_manager.execute_tool(ToolName.GET_SCHEMAS, services_container=services_container)
 29 | 
 30 |         @mcp.tool(description=tool_manager.get_description(ToolName.GET_TABLES))  # type: ignore
 31 |         async def get_tables(schema_name: str) -> QueryResult:
 32 |             """List all tables, foreign tables, and views in a schema with their sizes, row counts, and metadata."""
 33 |             return await feature_manager.execute_tool(
 34 |                 ToolName.GET_TABLES, services_container=services_container, schema_name=schema_name
 35 |             )
 36 | 
 37 |         @mcp.tool(description=tool_manager.get_description(ToolName.GET_TABLE_SCHEMA))  # type: ignore
 38 |         async def get_table_schema(schema_name: str, table: str) -> QueryResult:
 39 |             """Get detailed table structure including columns, keys, and relationships."""
 40 |             return await feature_manager.execute_tool(
 41 |                 ToolName.GET_TABLE_SCHEMA,
 42 |                 services_container=services_container,
 43 |                 schema_name=schema_name,
 44 |                 table=table,
 45 |             )
 46 | 
 47 |         @mcp.tool(description=tool_manager.get_description(ToolName.EXECUTE_POSTGRESQL))  # type: ignore
 48 |         async def execute_postgresql(query: str, migration_name: str = "") -> QueryResult:
 49 |             """Execute PostgreSQL statements against your Supabase database."""
 50 |             return await feature_manager.execute_tool(
 51 |                 ToolName.EXECUTE_POSTGRESQL,
 52 |                 services_container=services_container,
 53 |                 query=query,
 54 |                 migration_name=migration_name,
 55 |             )
 56 | 
 57 |         @mcp.tool(description=tool_manager.get_description(ToolName.RETRIEVE_MIGRATIONS))  # type: ignore
 58 |         async def retrieve_migrations(
 59 |             limit: int = 50,
 60 |             offset: int = 0,
 61 |             name_pattern: str = "",
 62 |             include_full_queries: bool = False,
 63 |         ) -> QueryResult:
 64 |             """Retrieve a list of all migrations a user has from Supabase.
 65 | 
 66 |             SAFETY: This is a low-risk read operation that can be executed in SAFE mode.
 67 |             """
 68 | 
 69 |             result = await feature_manager.execute_tool(
 70 |                 ToolName.RETRIEVE_MIGRATIONS,
 71 |                 services_container=services_container,
 72 |                 limit=limit,
 73 |                 offset=offset,
 74 |                 name_pattern=name_pattern,
 75 |                 include_full_queries=include_full_queries,
 76 |             )
 77 |             return QueryResult.model_validate(result)
 78 | 
 79 |         @mcp.tool(description=tool_manager.get_description(ToolName.SEND_MANAGEMENT_API_REQUEST))  # type: ignore
 80 |         async def send_management_api_request(
 81 |             method: str,
 82 |             path: str,
 83 |             path_params: dict[str, str],
 84 |             request_params: dict[str, Any],
 85 |             request_body: dict[str, Any],
 86 |         ) -> dict[str, Any]:
 87 |             """Execute a Supabase Management API request."""
 88 |             return await feature_manager.execute_tool(
 89 |                 ToolName.SEND_MANAGEMENT_API_REQUEST,
 90 |                 services_container=services_container,
 91 |                 method=method,
 92 |                 path=path,
 93 |                 path_params=path_params,
 94 |                 request_params=request_params,
 95 |                 request_body=request_body,
 96 |             )
 97 | 
 98 |         @mcp.tool(description=tool_manager.get_description(ToolName.GET_MANAGEMENT_API_SPEC))  # type: ignore
 99 |         async def get_management_api_spec(params: dict[str, Any] = {}) -> dict[str, Any]:
100 |             """Get the Supabase Management API specification.
101 | 
102 |             This tool can be used in four different ways (and then some ;)):
103 |             1. Without parameters: Returns all domains (default)
104 |             2. With path and method: Returns the full specification for a specific API endpoint
105 |             3. With domain only: Returns all paths and methods within that domain
106 |             4. With all_paths=True: Returns all paths and methods
107 | 
108 |             Args:
109 |                 params: Dictionary containing optional parameters:
110 |                     - path: Optional API path (e.g., "/v1/projects/{ref}/functions")
111 |                     - method: Optional HTTP method (e.g., "GET", "POST")
112 |                     - domain: Optional domain/tag name (e.g., "Auth", "Storage")
113 |                     - all_paths: If True, returns all paths and methods
114 | 
115 |             Returns:
116 |                 API specification based on the provided parameters
117 |             """
118 |             return await feature_manager.execute_tool(
119 |                 ToolName.GET_MANAGEMENT_API_SPEC, services_container=services_container, params=params
120 |             )
121 | 
122 |         @mcp.tool(description=tool_manager.get_description(ToolName.GET_AUTH_ADMIN_METHODS_SPEC))  # type: ignore
123 |         async def get_auth_admin_methods_spec() -> dict[str, Any]:
124 |             """Get Python SDK methods specification for Auth Admin."""
125 |             return await feature_manager.execute_tool(
126 |                 ToolName.GET_AUTH_ADMIN_METHODS_SPEC, services_container=services_container
127 |             )
128 | 
129 |         @mcp.tool(description=tool_manager.get_description(ToolName.CALL_AUTH_ADMIN_METHOD))  # type: ignore
130 |         async def call_auth_admin_method(method: str, params: dict[str, Any]) -> dict[str, Any]:
131 |             """Call an Auth Admin method from Supabase Python SDK."""
132 |             return await feature_manager.execute_tool(
133 |                 ToolName.CALL_AUTH_ADMIN_METHOD, services_container=services_container, method=method, params=params
134 |             )
135 | 
136 |         @mcp.tool(description=tool_manager.get_description(ToolName.LIVE_DANGEROUSLY))  # type: ignore
137 |         async def live_dangerously(
138 |             service: Literal["api", "database"], enable_unsafe_mode: bool = False
139 |         ) -> dict[str, Any]:
140 |             """
141 |             Toggle between safe and unsafe operation modes for API or Database services.
142 | 
143 |             This function controls the safety level for operations, allowing you to:
144 |             - Enable write operations for the database (INSERT, UPDATE, DELETE, schema changes)
145 |             - Enable state-changing operations for the Management API
146 |             """
147 |             return await feature_manager.execute_tool(
148 |                 ToolName.LIVE_DANGEROUSLY,
149 |                 services_container=services_container,
150 |                 service=service,
151 |                 enable_unsafe_mode=enable_unsafe_mode,
152 |             )
153 | 
154 |         @mcp.tool(description=tool_manager.get_description(ToolName.CONFIRM_DESTRUCTIVE_OPERATION))  # type: ignore
155 |         async def confirm_destructive_operation(
156 |             operation_type: Literal["api", "database"], confirmation_id: str, user_confirmation: bool = False
157 |         ) -> QueryResult | dict[str, Any]:
158 |             """Execute a destructive operation after confirmation. Use this only after reviewing the risks with the user."""
159 |             return await feature_manager.execute_tool(
160 |                 ToolName.CONFIRM_DESTRUCTIVE_OPERATION,
161 |                 services_container=services_container,
162 |                 operation_type=operation_type,
163 |                 confirmation_id=confirmation_id,
164 |                 user_confirmation=user_confirmation,
165 |             )
166 | 
167 |         @mcp.tool(description=tool_manager.get_description(ToolName.RETRIEVE_LOGS))  # type: ignore
168 |         async def retrieve_logs(
169 |             collection: str,
170 |             limit: int = 20,
171 |             hours_ago: int = 1,
172 |             filters: list[dict[str, Any]] = [],
173 |             search: str = "",
174 |             custom_query: str = "",
175 |         ) -> dict[str, Any]:
176 |             """Retrieve logs from your Supabase project's services for debugging and monitoring."""
177 |             return await feature_manager.execute_tool(
178 |                 ToolName.RETRIEVE_LOGS,
179 |                 services_container=services_container,
180 |                 collection=collection,
181 |                 limit=limit,
182 |                 hours_ago=hours_ago,
183 |                 filters=filters,
184 |                 search=search,
185 |                 custom_query=custom_query,
186 |             )
187 | 
188 |         return mcp
189 | 
```

--------------------------------------------------------------------------------
/tests/services/database/test_query_manager.py:
--------------------------------------------------------------------------------

```python
  1 | from unittest.mock import AsyncMock, MagicMock
  2 | 
  3 | import pytest
  4 | 
  5 | from supabase_mcp.exceptions import SafetyError
  6 | from supabase_mcp.services.database.query_manager import QueryManager
  7 | from supabase_mcp.services.database.sql.loader import SQLLoader
  8 | from supabase_mcp.services.database.sql.validator import (
  9 |     QueryValidationResults,
 10 |     SQLQueryCategory,
 11 |     SQLQueryCommand,
 12 |     SQLValidator,
 13 |     ValidatedStatement,
 14 | )
 15 | from supabase_mcp.services.safety.models import ClientType, OperationRiskLevel
 16 | 
 17 | 
 18 | @pytest.mark.asyncio(loop_scope="module")
 19 | class TestQueryManager:
 20 |     """Tests for the Query Manager."""
 21 | 
 22 |     @pytest.mark.unit
 23 |     async def test_query_execution(self, mock_query_manager: QueryManager):
 24 |         """Test query execution through the Query Manager."""
 25 | 
 26 |         query_manager = mock_query_manager
 27 | 
 28 |         # Ensure validator and safety_manager are proper mocks
 29 |         query_manager.validator = MagicMock()
 30 |         query_manager.safety_manager = MagicMock()
 31 | 
 32 |         # Create a mock validation result for a SELECT query
 33 |         validated_statement = ValidatedStatement(
 34 |             category=SQLQueryCategory.DQL,
 35 |             command=SQLQueryCommand.SELECT,
 36 |             risk_level=OperationRiskLevel.LOW,
 37 |             query="SELECT * FROM users",
 38 |             needs_migration=False,
 39 |             object_type="TABLE",
 40 |             schema_name="public",
 41 |         )
 42 | 
 43 |         validation_result = QueryValidationResults(
 44 |             statements=[validated_statement],
 45 |             highest_risk_level=OperationRiskLevel.LOW,
 46 |             has_transaction_control=False,
 47 |             original_query="SELECT * FROM users",
 48 |         )
 49 | 
 50 |         # Make the validator return our mock validation result
 51 |         query_manager.validator.validate_query.return_value = validation_result
 52 | 
 53 |         # Make the db_client return a mock query result
 54 |         mock_query_result = MagicMock()
 55 |         query_manager.db_client.execute_query = AsyncMock(return_value=mock_query_result)
 56 | 
 57 |         # Execute a query
 58 |         query = "SELECT * FROM users"
 59 |         result = await query_manager.handle_query(query)
 60 | 
 61 |         # Verify the validator was called with the query
 62 |         query_manager.validator.validate_query.assert_called_once_with(query)
 63 | 
 64 |         # Verify the db_client was called with the validation result
 65 |         query_manager.db_client.execute_query.assert_called_once_with(validation_result, False)
 66 | 
 67 |         # Verify the result is what we expect
 68 |         assert result == mock_query_result
 69 | 
 70 |     @pytest.mark.asyncio
 71 |     @pytest.mark.unit
 72 |     async def test_safety_validation_blocks_dangerous_query(self, mock_query_manager: QueryManager):
 73 |         """Test that the safety validation blocks dangerous queries."""
 74 | 
 75 |         # Create a query manager with the mock dependencies
 76 |         query_manager = mock_query_manager
 77 | 
 78 |         # Ensure validator and safety_manager are proper mocks
 79 |         query_manager.validator = MagicMock()
 80 |         query_manager.safety_manager = MagicMock()
 81 | 
 82 |         # Create a mock validation result for a DROP TABLE query
 83 |         validated_statement = ValidatedStatement(
 84 |             category=SQLQueryCategory.DDL,
 85 |             command=SQLQueryCommand.DROP,
 86 |             risk_level=OperationRiskLevel.EXTREME,
 87 |             query="DROP TABLE users",
 88 |             needs_migration=False,
 89 |             object_type="TABLE",
 90 |             schema_name="public",
 91 |         )
 92 | 
 93 |         validation_result = QueryValidationResults(
 94 |             statements=[validated_statement],
 95 |             highest_risk_level=OperationRiskLevel.EXTREME,
 96 |             has_transaction_control=False,
 97 |             original_query="DROP TABLE users",
 98 |         )
 99 | 
100 |         # Make the validator return our mock validation result
101 |         query_manager.validator.validate_query.return_value = validation_result
102 | 
103 |         # Make the safety manager raise a SafetyError
104 |         error_message = "Operation not allowed in SAFE mode"
105 |         query_manager.safety_manager.validate_operation.side_effect = SafetyError(error_message)
106 | 
107 |         # Execute a query - should raise a SafetyError
108 |         query = "DROP TABLE users"
109 |         with pytest.raises(SafetyError) as excinfo:
110 |             await query_manager.handle_query(query)
111 | 
112 |         # Verify the error message
113 |         assert error_message in str(excinfo.value)
114 | 
115 |         # Verify the validator was called with the query
116 |         query_manager.validator.validate_query.assert_called_once_with(query)
117 | 
118 |         # Verify the safety manager was called with the validation result
119 |         query_manager.safety_manager.validate_operation.assert_called_once_with(
120 |             ClientType.DATABASE, validation_result, False
121 |         )
122 | 
123 |         # Verify the db_client was not called
124 |         query_manager.db_client.execute_query.assert_not_called()
125 | 
126 |     @pytest.mark.unit
127 |     async def test_get_migrations_query(self, query_manager_integration: QueryManager):
128 |         """Test that get_migrations_query returns a valid query string."""
129 |         # Test with default parameters
130 |         query = query_manager_integration.get_migrations_query()
131 |         assert isinstance(query, str)
132 |         assert "supabase_migrations.schema_migrations" in query
133 |         assert "LIMIT 50" in query
134 | 
135 |         # Test with custom parameters
136 |         custom_query = query_manager_integration.get_migrations_query(
137 |             limit=10, offset=5, name_pattern="test", include_full_queries=True
138 |         )
139 |         assert isinstance(custom_query, str)
140 |         assert "supabase_migrations.schema_migrations" in custom_query
141 |         assert "LIMIT 10" in custom_query
142 |         assert "OFFSET 5" in custom_query
143 |         assert "name ILIKE" in custom_query
144 |         assert "statements" in custom_query  # Should include statements column when include_full_queries=True
145 | 
146 |     @pytest.mark.unit
147 |     async def test_init_migration_schema(self):
148 |         """Test that init_migration_schema initializes the migration schema correctly."""
149 |         # Create minimal mocks
150 |         postgres_client = MagicMock()
151 |         postgres_client.execute_query = AsyncMock()
152 | 
153 |         safety_manager = MagicMock()
154 | 
155 |         # Create a real SQLLoader and SQLValidator
156 |         sql_loader = SQLLoader()
157 |         sql_validator = SQLValidator()
158 | 
159 |         # Create the QueryManager with minimal mocking
160 |         query_manager = QueryManager(
161 |             postgres_client=postgres_client,
162 |             safety_manager=safety_manager,
163 |             sql_validator=sql_validator,
164 |             sql_loader=sql_loader,
165 |         )
166 | 
167 |         # Call the method
168 |         await query_manager.init_migration_schema()
169 | 
170 |         # Verify that the SQL loader was used to get the init migrations query
171 |         # and that the query was executed
172 |         assert postgres_client.execute_query.called
173 | 
174 |         # Get the arguments that execute_query was called with
175 |         call_args = postgres_client.execute_query.call_args
176 |         assert call_args is not None
177 | 
178 |         # The first argument should be a QueryValidationResults object
179 |         args, _ = call_args  # Use _ to ignore unused kwargs
180 |         assert len(args) > 0
181 |         validation_result = args[0]
182 |         assert isinstance(validation_result, QueryValidationResults)
183 | 
184 |         # Check that the validation result contains the expected SQL
185 |         init_query = sql_loader.get_init_migrations_query()
186 |         assert any(stmt.query and stmt.query in init_query for stmt in validation_result.statements)
187 | 
188 |     @pytest.mark.unit
189 |     async def test_handle_migration(self):
190 |         """Test that handle_migration correctly handles migrations when needed."""
191 |         # Create minimal mocks
192 |         postgres_client = MagicMock()
193 |         postgres_client.execute_query = AsyncMock()
194 | 
195 |         safety_manager = MagicMock()
196 | 
197 |         # Create a real SQLLoader
198 |         sql_loader = SQLLoader()
199 | 
200 |         # Create a mock MigrationManager
201 |         migration_manager = MagicMock()
202 |         migration_query = "INSERT INTO _migrations.migrations (name) VALUES ('test_migration')"
203 |         migration_name = "test_migration"
204 |         migration_manager.prepare_migration_query.return_value = (migration_query, migration_name)
205 | 
206 |         # Create a real SQLValidator
207 |         sql_validator = SQLValidator()
208 | 
209 |         # Create the QueryManager with minimal mocking
210 |         query_manager = QueryManager(
211 |             postgres_client=postgres_client,
212 |             safety_manager=safety_manager,
213 |             sql_validator=sql_validator,
214 |             sql_loader=sql_loader,
215 |             migration_manager=migration_manager,
216 |         )
217 | 
218 |         # Create a validation result that needs migration
219 |         validated_statement = ValidatedStatement(
220 |             category=SQLQueryCategory.DDL,
221 |             command=SQLQueryCommand.CREATE,
222 |             risk_level=OperationRiskLevel.MEDIUM,
223 |             query="CREATE TABLE test (id INT)",
224 |             needs_migration=True,
225 |             object_type="TABLE",
226 |             schema_name="public",
227 |         )
228 | 
229 |         validation_result = QueryValidationResults(
230 |             statements=[validated_statement],
231 |             highest_risk_level=OperationRiskLevel.MEDIUM,
232 |             has_transaction_control=False,
233 |             original_query="CREATE TABLE test (id INT)",
234 |         )
235 | 
236 |         # Call the method
237 |         await query_manager.handle_migration(validation_result, "CREATE TABLE test (id INT)", "test_migration")
238 | 
239 |         # Verify that the migration manager was called to prepare the migration query
240 |         migration_manager.prepare_migration_query.assert_called_once_with(
241 |             validation_result, "CREATE TABLE test (id INT)", "test_migration"
242 |         )
243 | 
244 |         # Verify that execute_query was called at least twice
245 |         # Once for init_migration_schema and once for the migration query
246 |         assert postgres_client.execute_query.call_count >= 2
247 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/services/safety/safety_manager.py:
--------------------------------------------------------------------------------

```python
  1 | import time
  2 | import uuid
  3 | from typing import Any, Optional
  4 | 
  5 | from supabase_mcp.exceptions import ConfirmationRequiredError, OperationNotAllowedError
  6 | from supabase_mcp.logger import logger
  7 | from supabase_mcp.services.safety.models import ClientType, SafetyMode
  8 | from supabase_mcp.services.safety.safety_configs import APISafetyConfig, SafetyConfigBase, SQLSafetyConfig
  9 | 
 10 | 
 11 | class SafetyManager:
 12 |     """A singleton service that maintains current safety state.
 13 | 
 14 |      Provides methods to:
 15 |       - Get/set safety modes for different clients
 16 |       - Register safety configurations
 17 |       - Check if operations are allowed
 18 |     Serves as the central point for safety decisions"""
 19 | 
 20 |     _instance: Optional["SafetyManager"] = None
 21 | 
 22 |     def __init__(self) -> None:
 23 |         """Initialize the safety manager with default safety modes."""
 24 |         self._safety_modes: dict[ClientType, SafetyMode] = {
 25 |             ClientType.DATABASE: SafetyMode.SAFE,
 26 |             ClientType.API: SafetyMode.SAFE,
 27 |         }
 28 |         self._safety_configs: dict[ClientType, SafetyConfigBase[Any]] = {}
 29 |         self._pending_confirmations: dict[str, dict[str, Any]] = {}
 30 |         self._confirmation_expiry = 300  # 5 minutes in seconds
 31 | 
 32 |     @classmethod
 33 |     def get_instance(cls) -> "SafetyManager":
 34 |         """Get the singleton instance of the safety manager."""
 35 |         if cls._instance is None:
 36 |             cls._instance = SafetyManager()
 37 |         return cls._instance
 38 | 
 39 |     def register_safety_configs(self) -> bool:
 40 |         """Register all safety configurations with the SafetyManager.
 41 | 
 42 |         Returns:
 43 |             bool: True if all configurations were registered successfully
 44 |         """
 45 |         # Register SQL safety config
 46 |         sql_config = SQLSafetyConfig()
 47 |         self.register_config(ClientType.DATABASE, sql_config)
 48 | 
 49 |         # Register API safety config
 50 |         api_config = APISafetyConfig()
 51 |         self.register_config(ClientType.API, api_config)
 52 | 
 53 |         logger.info("✓ Safety configurations registered successfully")
 54 |         return True
 55 | 
 56 |     def register_config(self, client_type: ClientType, config: SafetyConfigBase[Any]) -> None:
 57 |         """Register a safety configuration for a client type.
 58 | 
 59 |         Args:
 60 |             client_type: The client type to register the configuration for
 61 |             config: The safety configuration for the client
 62 |         """
 63 |         self._safety_configs[client_type] = config
 64 | 
 65 |     def get_safety_mode(self, client_type: ClientType) -> SafetyMode:
 66 |         """Get the current safety mode for a client type.
 67 | 
 68 |         Args:
 69 |             client_type: The client type to get the safety mode for
 70 | 
 71 |         Returns:
 72 |             The current safety mode for the client type
 73 |         """
 74 |         if client_type not in self._safety_modes:
 75 |             logger.warning(f"No safety mode registered for {client_type}, defaulting to SAFE")
 76 |             return SafetyMode.SAFE
 77 |         return self._safety_modes[client_type]
 78 | 
 79 |     def set_safety_mode(self, client_type: ClientType, mode: SafetyMode) -> None:
 80 |         """Set the safety mode for a client type.
 81 | 
 82 |         Args:
 83 |             client_type: The client type to set the safety mode for
 84 |             mode: The safety mode to set
 85 |         """
 86 |         self._safety_modes[client_type] = mode
 87 |         logger.debug(f"Set safety mode for {client_type} to {mode}")
 88 | 
 89 |     def validate_operation(
 90 |         self,
 91 |         client_type: ClientType,
 92 |         operation: Any,
 93 |         has_confirmation: bool = False,
 94 |     ) -> None:
 95 |         """Validate if an operation is allowed for a client type.
 96 | 
 97 |         This method will raise appropriate exceptions if the operation is not allowed
 98 |         or requires confirmation.
 99 | 
100 |         Args:
101 |             client_type: The client type to check the operation for
102 |             operation: The operation to check
103 |             has_confirmation: Whether the operation has been confirmed by the user
104 | 
105 |         Raises:
106 |             OperationNotAllowedError: If the operation is not allowed in the current safety mode
107 |             ConfirmationRequiredError: If the operation requires confirmation and has_confirmation is False
108 |         """
109 |         # Get the current safety mode and config
110 |         mode = self.get_safety_mode(client_type)
111 |         config = self._safety_configs.get(client_type)
112 | 
113 |         if not config:
114 |             message = f"No safety configuration registered for {client_type}"
115 |             logger.warning(message)
116 |             raise OperationNotAllowedError(message)
117 | 
118 |         # Get the risk level for the operation
119 |         risk_level = config.get_risk_level(operation)
120 |         logger.debug(f"Operation risk level: {risk_level}")
121 | 
122 |         # Check if the operation is allowed in the current mode
123 |         is_allowed = config.is_operation_allowed(risk_level, mode)
124 |         if not is_allowed:
125 |             message = f"Operation with risk level {risk_level} is not allowed in {mode} mode"
126 |             logger.debug(f"Operation with risk level {risk_level} not allowed in {mode} mode")
127 |             raise OperationNotAllowedError(message)
128 | 
129 |         # Check if the operation needs confirmation
130 |         needs_confirmation = config.needs_confirmation(risk_level)
131 |         if needs_confirmation and not has_confirmation:
132 |             # Store the operation for later confirmation
133 |             confirmation_id = self._store_confirmation(client_type, operation, risk_level)
134 | 
135 |             message = (
136 |                 f"Operation with risk level {risk_level} requires explicit user confirmation.\n\n"
137 |                 f"WHAT HAPPENED: This high-risk operation was rejected for safety reasons.\n"
138 |                 f"WHAT TO DO: 1. Review the operation with the user and explain the risks\n"
139 |                 f"            2. If the user approves, use the confirmation tool with this ID: {confirmation_id}\n\n"
140 |                 f'CONFIRMATION COMMAND: confirm_destructive_postgresql(confirmation_id="{confirmation_id}", user_confirmation=True)'
141 |             )
142 |             logger.debug(
143 |                 f"Operation with risk level {risk_level} requires confirmation, stored with ID {confirmation_id}"
144 |             )
145 |             raise ConfirmationRequiredError(message)
146 | 
147 |         logger.debug(f"Operation with risk level {risk_level} allowed in {mode} mode")
148 | 
149 |     def _store_confirmation(self, client_type: ClientType, operation: Any, risk_level: int) -> str:
150 |         """Store an operation that needs confirmation.
151 | 
152 |         Args:
153 |             client_type: The client type the operation is for
154 |             operation: The operation to store
155 |             risk_level: The risk level of the operation
156 | 
157 |         Returns:
158 |             A unique confirmation ID
159 |         """
160 |         # Generate a unique ID
161 |         confirmation_id = f"conf_{uuid.uuid4().hex[:8]}"
162 | 
163 |         # Store the operation with metadata
164 |         self._pending_confirmations[confirmation_id] = {
165 |             "operation": operation,
166 |             "client_type": client_type,
167 |             "risk_level": risk_level,
168 |             "timestamp": time.time(),
169 |         }
170 | 
171 |         # Clean up expired confirmations
172 |         self._cleanup_expired_confirmations()
173 | 
174 |         return confirmation_id
175 | 
176 |     def _get_confirmation(self, confirmation_id: str) -> dict[str, Any] | None:
177 |         """Retrieve a stored confirmation by ID.
178 | 
179 |         Args:
180 |             confirmation_id: The ID of the confirmation to retrieve
181 | 
182 |         Returns:
183 |             The stored confirmation data or None if not found or expired
184 |         """
185 |         # Clean up expired confirmations first
186 |         self._cleanup_expired_confirmations()
187 | 
188 |         # Return the stored confirmation if it exists
189 |         return self._pending_confirmations.get(confirmation_id)
190 | 
191 |     def _cleanup_expired_confirmations(self) -> None:
192 |         """Remove expired confirmations from storage."""
193 |         current_time = time.time()
194 |         expired_ids = [
195 |             conf_id
196 |             for conf_id, data in self._pending_confirmations.items()
197 |             if current_time - data["timestamp"] > self._confirmation_expiry
198 |         ]
199 | 
200 |         for conf_id in expired_ids:
201 |             logger.debug(f"Removing expired confirmation with ID {conf_id}")
202 |             del self._pending_confirmations[conf_id]
203 | 
204 |     def get_stored_operation(self, confirmation_id: str) -> Any | None:
205 |         """Get a stored operation by its confirmation ID.
206 | 
207 |         Args:
208 |             confirmation_id: The confirmation ID to get the operation for
209 | 
210 |         Returns:
211 |             The stored operation, or None if not found
212 |         """
213 |         confirmation = self._get_confirmation(confirmation_id)
214 |         if confirmation is None:
215 |             return None
216 |         return confirmation.get("operation")
217 | 
218 |     def get_operations_by_risk_level(
219 |         self, risk_level: str, client_type: ClientType = ClientType.DATABASE
220 |     ) -> dict[str, list[str]]:
221 |         """Get operations for a specific risk level.
222 | 
223 |         Args:
224 |             risk_level: The risk level to get operations for
225 |             client_type: The client type to get operations for
226 | 
227 |         Returns:
228 |             A dictionary mapping HTTP methods to lists of paths
229 |         """
230 |         # Get the config for the specified client type
231 |         config = self._safety_configs.get(client_type)
232 |         if not config or not hasattr(config, "PATH_SAFETY_CONFIG"):
233 |             return {}
234 | 
235 |         # Get the operations for this risk level
236 |         risk_config = getattr(config, "PATH_SAFETY_CONFIG", {})
237 |         if risk_level in risk_config:
238 |             return risk_config[risk_level]
239 | 
240 |     def get_current_mode(self, client_type: ClientType) -> str:
241 |         """Get the current safety mode as a string.
242 | 
243 |         Args:
244 |             client_type: The client type to get the mode for
245 | 
246 |         Returns:
247 |             The current safety mode as a string
248 |         """
249 |         mode = self.get_safety_mode(client_type)
250 |         return str(mode)
251 | 
252 |     @classmethod
253 |     def reset(cls) -> None:
254 |         """Reset the singleton instance cleanly.
255 | 
256 |         This closes any open connections and resets the singleton instance.
257 |         """
258 |         if cls._instance is not None:
259 |             cls._instance = None
260 |             logger.info("SafetyManager instance reset complete")
261 | 
```

--------------------------------------------------------------------------------
/supabase_mcp/clients/sdk_client.py:
--------------------------------------------------------------------------------

```python
  1 | from __future__ import annotations
  2 | 
  3 | from typing import Any, TypeVar
  4 | 
  5 | from pydantic import BaseModel, ValidationError
  6 | from supabase import AsyncClient, create_async_client
  7 | from supabase.lib.client_options import AsyncClientOptions
  8 | 
  9 | from supabase_mcp.exceptions import PythonSDKError
 10 | from supabase_mcp.logger import logger
 11 | from supabase_mcp.services.sdk.auth_admin_models import (
 12 |     PARAM_MODELS,
 13 |     CreateUserParams,
 14 |     DeleteFactorParams,
 15 |     DeleteUserParams,
 16 |     GenerateLinkParams,
 17 |     GetUserByIdParams,
 18 |     InviteUserByEmailParams,
 19 |     ListUsersParams,
 20 |     UpdateUserByIdParams,
 21 | )
 22 | from supabase_mcp.services.sdk.auth_admin_sdk_spec import get_auth_admin_methods_spec
 23 | from supabase_mcp.settings import Settings
 24 | 
 25 | T = TypeVar("T", bound=BaseModel)
 26 | 
 27 | 
 28 | class IncorrectSDKParamsError(PythonSDKError):
 29 |     """Error raised when the parameters passed to the SDK are incorrect."""
 30 | 
 31 |     pass
 32 | 
 33 | 
 34 | class SupabaseSDKClient:
 35 |     """Supabase Python SDK client, which exposes functionality related to Auth admin of the Python SDK."""
 36 | 
 37 |     _instance: SupabaseSDKClient | None = None
 38 | 
 39 |     def __init__(
 40 |         self,
 41 |         settings: Settings | None = None,
 42 |         project_ref: str | None = None,
 43 |         service_role_key: str | None = None,
 44 |     ):
 45 |         self.client: AsyncClient | None = None
 46 |         self.settings = settings
 47 |         self.project_ref = settings.supabase_project_ref if settings else project_ref
 48 |         self.service_role_key = settings.supabase_service_role_key if settings else service_role_key
 49 |         self.supabase_url = self.get_supabase_url()
 50 |         logger.info(f"✔️ Supabase SDK client initialized successfully for project {self.project_ref}")
 51 | 
 52 |     def get_supabase_url(self) -> str:
 53 |         """Returns the Supabase URL based on the project reference"""
 54 |         if not self.project_ref:
 55 |             raise PythonSDKError("Project reference is not set")
 56 |         if self.project_ref.startswith("127.0.0.1"):
 57 |             # Return the default Supabase API URL
 58 |             return "http://127.0.0.1:54321"
 59 |         return f"https://{self.project_ref}.supabase.co"
 60 | 
 61 |     @classmethod
 62 |     def create(
 63 |         cls,
 64 |         settings: Settings | None = None,
 65 |         project_ref: str | None = None,
 66 |         service_role_key: str | None = None,
 67 |     ) -> SupabaseSDKClient:
 68 |         if cls._instance is None:
 69 |             cls._instance = cls(settings, project_ref, service_role_key)
 70 |         return cls._instance
 71 | 
 72 |     @classmethod
 73 |     def get_instance(
 74 |         cls,
 75 |         settings: Settings | None = None,
 76 |         project_ref: str | None = None,
 77 |         service_role_key: str | None = None,
 78 |     ) -> SupabaseSDKClient:
 79 |         """Returns the singleton instance"""
 80 |         if cls._instance is None:
 81 |             cls.create(settings, project_ref, service_role_key)
 82 |         return cls._instance
 83 | 
 84 |     async def create_client(self) -> AsyncClient:
 85 |         """Creates a new Supabase client"""
 86 |         try:
 87 |             client = await create_async_client(
 88 |                 self.supabase_url,
 89 |                 self.service_role_key,
 90 |                 options=AsyncClientOptions(
 91 |                     auto_refresh_token=False,
 92 |                     persist_session=False,
 93 |                 ),
 94 |             )
 95 |             return client
 96 |         except Exception as e:
 97 |             logger.error(f"Error creating Supabase client: {e}")
 98 |             raise PythonSDKError(f"Error creating Supabase client: {e}") from e
 99 | 
100 |     async def get_client(self) -> AsyncClient:
101 |         """Returns the Supabase client"""
102 |         if not self.client:
103 |             self.client = await self.create_client()
104 |             logger.info(f"Created Supabase SDK client for project {self.project_ref}")
105 |         return self.client
106 | 
107 |     async def close(self) -> None:
108 |         """Reset the client reference to allow garbage collection."""
109 |         self.client = None
110 |         logger.info("Supabase SDK client reference cleared")
111 | 
112 |     def return_python_sdk_spec(self) -> dict:
113 |         """Returns the Python SDK spec"""
114 |         return get_auth_admin_methods_spec()
115 | 
116 |     def _validate_params(self, method: str, params: dict, param_model_cls: type[T]) -> T:
117 |         """Validate parameters using the appropriate Pydantic model"""
118 |         try:
119 |             return param_model_cls.model_validate(params)
120 |         except ValidationError as e:
121 |             raise PythonSDKError(f"Invalid parameters for method {method}: {str(e)}") from e
122 | 
123 |     async def _get_user_by_id(self, params: GetUserByIdParams) -> dict:
124 |         """Get user by ID implementation"""
125 |         self.client = await self.get_client()
126 |         admin_auth_client = self.client.auth.admin
127 |         result = await admin_auth_client.get_user_by_id(params.uid)
128 |         return result
129 | 
130 |     async def _list_users(self, params: ListUsersParams) -> dict:
131 |         """List users implementation"""
132 |         self.client = await self.get_client()
133 |         admin_auth_client = self.client.auth.admin
134 |         result = await admin_auth_client.list_users(page=params.page, per_page=params.per_page)
135 |         return result
136 | 
137 |     async def _create_user(self, params: CreateUserParams) -> dict:
138 |         """Create user implementation"""
139 |         self.client = await self.get_client()
140 |         admin_auth_client = self.client.auth.admin
141 |         user_data = params.model_dump(exclude_none=True)
142 |         result = await admin_auth_client.create_user(user_data)
143 |         return result
144 | 
145 |     async def _delete_user(self, params: DeleteUserParams) -> dict:
146 |         """Delete user implementation"""
147 |         self.client = await self.get_client()
148 |         admin_auth_client = self.client.auth.admin
149 |         result = await admin_auth_client.delete_user(params.id, should_soft_delete=params.should_soft_delete)
150 |         return result
151 | 
152 |     async def _invite_user_by_email(self, params: InviteUserByEmailParams) -> dict:
153 |         """Invite user by email implementation"""
154 |         self.client = await self.get_client()
155 |         admin_auth_client = self.client.auth.admin
156 |         options = params.options if params.options else {}
157 |         result = await admin_auth_client.invite_user_by_email(params.email, options)
158 |         return result
159 | 
160 |     async def _generate_link(self, params: GenerateLinkParams) -> dict:
161 |         """Generate link implementation"""
162 |         self.client = await self.get_client()
163 |         admin_auth_client = self.client.auth.admin
164 | 
165 |         # Create a params dictionary as expected by the SDK
166 |         params_dict = params.model_dump(exclude_none=True)
167 | 
168 |         try:
169 |             # The SDK expects a single 'params' parameter containing all the fields
170 |             result = await admin_auth_client.generate_link(params=params_dict)
171 |             return result
172 |         except TypeError as e:
173 |             # Catch parameter errors and provide a more helpful message
174 |             error_msg = str(e)
175 |             if "unexpected keyword argument" in error_msg:
176 |                 raise IncorrectSDKParamsError(
177 |                     f"Incorrect parameters for generate_link: {error_msg}. "
178 |                     f"Please check the SDK specification for the correct parameter structure."
179 |                 ) from e
180 |             raise
181 | 
182 |     async def _update_user_by_id(self, params: UpdateUserByIdParams) -> dict:
183 |         """Update user by ID implementation"""
184 |         self.client = await self.get_client()
185 |         admin_auth_client = self.client.auth.admin
186 |         uid = params.uid
187 |         attributes = params.attributes.model_dump(exclude={"uid"}, exclude_none=True)
188 |         result = await admin_auth_client.update_user_by_id(uid, attributes)
189 |         return result
190 | 
191 |     async def _delete_factor(self, params: DeleteFactorParams) -> dict:
192 |         """Delete factor implementation"""
193 |         # This method is not implemented in the Supabase SDK yet
194 |         raise NotImplementedError("The delete_factor method is not implemented in the Supabase SDK yet")
195 | 
196 |     async def call_auth_admin_method(self, method: str, params: dict[str, Any]) -> Any:
197 |         """Calls a method of the Python SDK client"""
198 |         # Check if service role key is available
199 |         if not self.service_role_key:
200 |             raise PythonSDKError(
201 |                 "Supabase service role key is not configured. Set SUPABASE_SERVICE_ROLE_KEY environment variable to use Auth Admin tools."
202 |             )
203 | 
204 |         if not self.client:
205 |             self.client = await self.get_client()
206 |             if not self.client:
207 |                 raise PythonSDKError("Python SDK client not initialized")
208 | 
209 |         # Validate method exists
210 |         if method not in PARAM_MODELS:
211 |             available_methods = ", ".join(PARAM_MODELS.keys())
212 |             raise PythonSDKError(f"Unknown method: {method}. Available methods: {available_methods}")
213 | 
214 |         # Get the appropriate model class and validate parameters
215 |         param_model_cls = PARAM_MODELS[method]
216 |         validated_params = self._validate_params(method, params, param_model_cls)
217 | 
218 |         # Method dispatch using a dictionary of method implementations
219 |         method_handlers = {
220 |             "get_user_by_id": self._get_user_by_id,
221 |             "list_users": self._list_users,
222 |             "create_user": self._create_user,
223 |             "delete_user": self._delete_user,
224 |             "invite_user_by_email": self._invite_user_by_email,
225 |             "generate_link": self._generate_link,
226 |             "update_user_by_id": self._update_user_by_id,
227 |             "delete_factor": self._delete_factor,
228 |         }
229 | 
230 |         # Call the appropriate method handler
231 |         try:
232 |             handler = method_handlers.get(method)
233 |             if not handler:
234 |                 raise PythonSDKError(f"Method {method} is not implemented")
235 | 
236 |             logger.debug(f"Python SDK request params: {validated_params}")
237 |             return await handler(validated_params)
238 |         except Exception as e:
239 |             if isinstance(e, IncorrectSDKParamsError):
240 |                 # Re-raise our custom error without wrapping it
241 |                 raise e
242 |             logger.error(f"Error calling {method}: {e}")
243 |             raise PythonSDKError(f"Error calling {method}: {str(e)}") from e
244 | 
245 |     @classmethod
246 |     def reset(cls) -> None:
247 |         """Reset the singleton instance cleanly.
248 | 
249 |         This closes any open connections and resets the singleton instance.
250 |         """
251 |         if cls._instance is not None:
252 |             cls._instance = None
253 |             logger.info("SupabaseSDKClient instance reset complete")
254 | 
```

--------------------------------------------------------------------------------
/tests/services/sdk/test_auth_admin_models.py:
--------------------------------------------------------------------------------

```python
  1 | import pytest
  2 | from pydantic import ValidationError
  3 | 
  4 | from supabase_mcp.services.sdk.auth_admin_models import (
  5 |     PARAM_MODELS,
  6 |     AdminUserAttributes,
  7 |     CreateUserParams,
  8 |     DeleteFactorParams,
  9 |     DeleteUserParams,
 10 |     GenerateLinkParams,
 11 |     GetUserByIdParams,
 12 |     InviteUserByEmailParams,
 13 |     ListUsersParams,
 14 |     UpdateUserByIdParams,
 15 |     UserMetadata,
 16 | )
 17 | 
 18 | 
 19 | class TestModelConversion:
 20 |     """Test conversion from JSON data to models and validation"""
 21 | 
 22 |     def test_get_user_by_id_conversion(self):
 23 |         """Test conversion of get_user_by_id JSON data"""
 24 |         # Valid payload
 25 |         valid_payload = {"uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"}
 26 |         params = GetUserByIdParams.model_validate(valid_payload)
 27 |         assert params.uid == valid_payload["uid"]
 28 | 
 29 |         # Invalid payload (missing required uid)
 30 |         invalid_payload = {}
 31 |         with pytest.raises(ValidationError) as excinfo:
 32 |             GetUserByIdParams.model_validate(invalid_payload)
 33 |         assert "uid" in str(excinfo.value)
 34 | 
 35 |     def test_list_users_conversion(self):
 36 |         """Test conversion of list_users JSON data"""
 37 |         # Valid payload with custom values
 38 |         valid_payload = {"page": 2, "per_page": 20}
 39 |         params = ListUsersParams.model_validate(valid_payload)
 40 |         assert params.page == valid_payload["page"]
 41 |         assert params.per_page == valid_payload["per_page"]
 42 | 
 43 |         # Valid payload with defaults
 44 |         empty_payload = {}
 45 |         params = ListUsersParams.model_validate(empty_payload)
 46 |         assert params.page == 1
 47 |         assert params.per_page == 50
 48 | 
 49 |         # Invalid payload (non-integer values)
 50 |         invalid_payload = {"page": "not-a-number", "per_page": "also-not-a-number"}
 51 |         with pytest.raises(ValidationError) as excinfo:
 52 |             ListUsersParams.model_validate(invalid_payload)
 53 |         assert "page" in str(excinfo.value)
 54 | 
 55 |     def test_create_user_conversion(self):
 56 |         """Test conversion of create_user JSON data"""
 57 |         # Valid payload with email
 58 |         valid_payload = {
 59 |             "email": "[email protected]",
 60 |             "password": "secure-password",
 61 |             "email_confirm": True,
 62 |             "user_metadata": UserMetadata(email="[email protected]"),
 63 |         }
 64 |         params = CreateUserParams.model_validate(valid_payload)
 65 |         assert params.email == valid_payload["email"]
 66 |         assert params.password == valid_payload["password"]
 67 |         assert params.email_confirm is True
 68 |         assert params.user_metadata == valid_payload["user_metadata"]
 69 | 
 70 |         # Valid payload with phone
 71 |         valid_phone_payload = {
 72 |             "phone": "+1234567890",
 73 |             "password": "secure-password",
 74 |             "phone_confirm": True,
 75 |         }
 76 |         params = CreateUserParams.model_validate(valid_phone_payload)
 77 |         assert params.phone == valid_phone_payload["phone"]
 78 |         assert params.password == valid_phone_payload["password"]
 79 |         assert params.phone_confirm is True
 80 | 
 81 |         # Invalid payload (missing both email and phone)
 82 |         invalid_payload = {"password": "secure-password"}
 83 |         with pytest.raises(ValidationError) as excinfo:
 84 |             CreateUserParams.model_validate(invalid_payload)
 85 |         assert "Either email or phone must be provided" in str(excinfo.value)
 86 | 
 87 |     def test_delete_user_conversion(self):
 88 |         """Test conversion of delete_user JSON data"""
 89 |         # Valid payload with custom values
 90 |         valid_payload = {"id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b", "should_soft_delete": True}
 91 |         params = DeleteUserParams.model_validate(valid_payload)
 92 |         assert params.id == valid_payload["id"]
 93 |         assert params.should_soft_delete is True
 94 | 
 95 |         # Valid payload with defaults
 96 |         valid_payload = {"id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b"}
 97 |         params = DeleteUserParams.model_validate(valid_payload)
 98 |         assert params.id == valid_payload["id"]
 99 |         assert params.should_soft_delete is False
100 | 
101 |         # Invalid payload (missing id)
102 |         invalid_payload = {"should_soft_delete": True}
103 |         with pytest.raises(ValidationError) as excinfo:
104 |             DeleteUserParams.model_validate(invalid_payload)
105 |         assert "id" in str(excinfo.value)
106 | 
107 |     def test_invite_user_by_email_conversion(self):
108 |         """Test conversion of invite_user_by_email JSON data"""
109 |         # Valid payload with options
110 |         valid_payload = {
111 |             "email": "[email protected]",
112 |             "options": {"data": {"name": "Invited User"}, "redirect_to": "https://example.com/welcome"},
113 |         }
114 |         params = InviteUserByEmailParams.model_validate(valid_payload)
115 |         assert params.email == valid_payload["email"]
116 |         assert params.options == valid_payload["options"]
117 | 
118 |         # Valid payload without options
119 |         valid_payload = {"email": "[email protected]"}
120 |         params = InviteUserByEmailParams.model_validate(valid_payload)
121 |         assert params.email == valid_payload["email"]
122 |         assert params.options is None
123 | 
124 |         # Invalid payload (missing email)
125 |         invalid_payload = {"options": {"data": {"name": "Invited User"}}}
126 |         with pytest.raises(ValidationError) as excinfo:
127 |             InviteUserByEmailParams.model_validate(invalid_payload)
128 |         assert "email" in str(excinfo.value)
129 | 
130 |     def test_generate_link_conversion(self):
131 |         """Test conversion of generate_link JSON data"""
132 |         # Valid signup link payload
133 |         valid_signup_payload = {
134 |             "type": "signup",
135 |             "email": "[email protected]",
136 |             "password": "secure-password",
137 |             "options": {"data": {"name": "New User"}, "redirect_to": "https://example.com/welcome"},
138 |         }
139 |         params = GenerateLinkParams.model_validate(valid_signup_payload)
140 |         assert params.type == valid_signup_payload["type"]
141 |         assert params.email == valid_signup_payload["email"]
142 |         assert params.password == valid_signup_payload["password"]
143 |         assert params.options == valid_signup_payload["options"]
144 | 
145 |         # Valid email_change link payload
146 |         valid_email_change_payload = {
147 |             "type": "email_change_current",
148 |             "email": "[email protected]",
149 |             "new_email": "[email protected]",
150 |         }
151 |         params = GenerateLinkParams.model_validate(valid_email_change_payload)
152 |         assert params.type == valid_email_change_payload["type"]
153 |         assert params.email == valid_email_change_payload["email"]
154 |         assert params.new_email == valid_email_change_payload["new_email"]
155 | 
156 |         # Invalid payload (missing password for signup)
157 |         invalid_signup_payload = {
158 |             "type": "signup",
159 |             "email": "[email protected]",
160 |         }
161 |         with pytest.raises(ValidationError) as excinfo:
162 |             GenerateLinkParams.model_validate(invalid_signup_payload)
163 |         assert "Password is required for signup links" in str(excinfo.value)
164 | 
165 |         # Invalid payload (missing new_email for email_change)
166 |         invalid_email_change_payload = {
167 |             "type": "email_change_current",
168 |             "email": "[email protected]",
169 |         }
170 |         with pytest.raises(ValidationError) as excinfo:
171 |             GenerateLinkParams.model_validate(invalid_email_change_payload)
172 |         assert "new_email is required for email change links" in str(excinfo.value)
173 | 
174 |         # Invalid payload (invalid type)
175 |         invalid_type_payload = {
176 |             "type": "invalid-type",
177 |             "email": "[email protected]",
178 |         }
179 |         with pytest.raises(ValidationError) as excinfo:
180 |             GenerateLinkParams.model_validate(invalid_type_payload)
181 |         assert "type" in str(excinfo.value)
182 | 
183 |     def test_update_user_by_id_conversion(self):
184 |         """Test conversion of update_user_by_id JSON data"""
185 |         # Valid payload
186 |         valid_payload = {
187 |             "uid": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b",
188 |             "attributes": AdminUserAttributes(email="[email protected]", email_verified=True),
189 |         }
190 |         params = UpdateUserByIdParams.model_validate(valid_payload)
191 |         assert params.uid == valid_payload["uid"]
192 |         assert params.attributes == valid_payload["attributes"]
193 | 
194 |         # Invalid payload (incorrect metadata and missing uids)
195 |         invalid_payload = {
196 |             "email": "[email protected]",
197 |             "user_metadata": {"name": "Updated User"},
198 |         }
199 |         with pytest.raises(ValidationError) as excinfo:
200 |             UpdateUserByIdParams.model_validate(invalid_payload)
201 |         assert "uid" in str(excinfo.value)
202 | 
203 |     def test_delete_factor_conversion(self):
204 |         """Test conversion of delete_factor JSON data"""
205 |         # Valid payload
206 |         valid_payload = {
207 |             "user_id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b",
208 |             "id": "totp-factor-id-123",
209 |         }
210 |         params = DeleteFactorParams.model_validate(valid_payload)
211 |         assert params.user_id == valid_payload["user_id"]
212 |         assert params.id == valid_payload["id"]
213 | 
214 |         # Invalid payload (missing user_id)
215 |         invalid_payload = {
216 |             "id": "totp-factor-id-123",
217 |         }
218 |         with pytest.raises(ValidationError) as excinfo:
219 |             DeleteFactorParams.model_validate(invalid_payload)
220 |         assert "user_id" in str(excinfo.value)
221 | 
222 |         # Invalid payload (missing id)
223 |         invalid_payload = {
224 |             "user_id": "d0e8c69f-e0c3-4a1c-b6d6-9a6c756a6a4b",
225 |         }
226 |         with pytest.raises(ValidationError) as excinfo:
227 |             DeleteFactorParams.model_validate(invalid_payload)
228 |         assert "id" in str(excinfo.value)
229 | 
230 |     def test_param_models_mapping(self):
231 |         """Test PARAM_MODELS mapping functionality"""
232 |         # Test that all methods have the correct corresponding model
233 |         method_model_pairs = [
234 |             ("get_user_by_id", GetUserByIdParams),
235 |             ("list_users", ListUsersParams),
236 |             ("create_user", CreateUserParams),
237 |             ("delete_user", DeleteUserParams),
238 |             ("invite_user_by_email", InviteUserByEmailParams),
239 |             ("generate_link", GenerateLinkParams),
240 |             ("update_user_by_id", UpdateUserByIdParams),
241 |             ("delete_factor", DeleteFactorParams),
242 |         ]
243 | 
244 |         for method, expected_model in method_model_pairs:
245 |             assert method in PARAM_MODELS
246 |             assert PARAM_MODELS[method] == expected_model
247 | 
248 |         # Test actual validation of data through PARAM_MODELS mapping
249 |         method = "create_user"
250 |         model_class = PARAM_MODELS[method]
251 | 
252 |         valid_payload = {"email": "[email protected]", "password": "secure-password"}
253 | 
254 |         params = model_class.model_validate(valid_payload)
255 |         assert params.email == valid_payload["email"]
256 |         assert params.password == valid_payload["password"]
257 | 
```

--------------------------------------------------------------------------------
/tests/test_tools.py:
--------------------------------------------------------------------------------

```python
  1 | """Unit tests for tools - no external dependencies."""
  2 | import uuid
  3 | from unittest.mock import AsyncMock, MagicMock, patch
  4 | 
  5 | import pytest
  6 | from mcp.server.fastmcp import FastMCP
  7 | 
  8 | from supabase_mcp.core.container import ServicesContainer
  9 | from supabase_mcp.exceptions import ConfirmationRequiredError, OperationNotAllowedError, PythonSDKError
 10 | from supabase_mcp.services.database.postgres_client import QueryResult, StatementResult
 11 | from supabase_mcp.services.safety.models import ClientType, OperationRiskLevel, SafetyMode
 12 | 
 13 | 
 14 | @pytest.mark.asyncio
 15 | class TestDatabaseToolsUnit:
 16 |     """Unit tests for database tools."""
 17 | 
 18 |     @pytest.fixture
 19 |     def mock_container(self):
 20 |         """Create a mock container with all necessary services."""
 21 |         container = MagicMock(spec=ServicesContainer)
 22 |         
 23 |         # Mock query manager
 24 |         container.query_manager = MagicMock()
 25 |         container.query_manager.handle_query = AsyncMock()
 26 |         container.query_manager.get_schemas_query = MagicMock(return_value="SELECT * FROM schemas")
 27 |         container.query_manager.get_tables_query = MagicMock(return_value="SELECT * FROM tables")
 28 |         container.query_manager.get_table_schema_query = MagicMock(return_value="SELECT * FROM columns")
 29 |         
 30 |         # Mock safety manager
 31 |         container.safety_manager = MagicMock()
 32 |         container.safety_manager.check_permission = MagicMock()
 33 |         container.safety_manager.is_unsafe_mode = MagicMock(return_value=False)
 34 |         
 35 |         return container
 36 | 
 37 |     async def test_get_schemas_returns_query_result(self, mock_container):
 38 |         """Test that get_schemas returns proper QueryResult."""
 39 |         # Setup mock response
 40 |         mock_result = QueryResult(results=[
 41 |             StatementResult(rows=[
 42 |                 {"schema_name": "public", "total_size": "100MB", "table_count": 10},
 43 |                 {"schema_name": "auth", "total_size": "50MB", "table_count": 5}
 44 |             ])
 45 |         ])
 46 |         mock_container.query_manager.handle_query.return_value = mock_result
 47 |         
 48 |         # Execute
 49 |         query = mock_container.query_manager.get_schemas_query()
 50 |         result = await mock_container.query_manager.handle_query(query)
 51 |         
 52 |         # Verify
 53 |         assert isinstance(result, QueryResult)
 54 |         assert len(result.results[0].rows) == 2
 55 |         assert result.results[0].rows[0]["schema_name"] == "public"
 56 | 
 57 |     async def test_get_tables_with_schema_filter(self, mock_container):
 58 |         """Test that get_tables properly filters by schema."""
 59 |         # Setup
 60 |         mock_result = QueryResult(results=[
 61 |             StatementResult(rows=[
 62 |                 {"table_name": "users", "table_type": "BASE TABLE", "row_count": 100, "size_bytes": 1024}
 63 |             ])
 64 |         ])
 65 |         mock_container.query_manager.handle_query.return_value = mock_result
 66 |         
 67 |         # Execute
 68 |         query = mock_container.query_manager.get_tables_query("public")
 69 |         result = await mock_container.query_manager.handle_query(query)
 70 |         
 71 |         # Verify
 72 |         mock_container.query_manager.get_tables_query.assert_called_with("public")
 73 |         assert result.results[0].rows[0]["table_name"] == "users"
 74 | 
 75 |     async def test_unsafe_query_blocked_in_safe_mode(self, mock_container):
 76 |         """Test that unsafe queries are blocked in safe mode."""
 77 |         # Setup
 78 |         mock_container.safety_manager.is_unsafe_mode.return_value = False
 79 |         mock_container.safety_manager.check_permission.side_effect = OperationNotAllowedError(
 80 |             "DROP operations are not allowed in safe mode"
 81 |         )
 82 |         
 83 |         # Execute & Verify
 84 |         with pytest.raises(OperationNotAllowedError):
 85 |             mock_container.safety_manager.check_permission(
 86 |                 ClientType.DATABASE,
 87 |                 OperationRiskLevel.HIGH
 88 |             )
 89 | 
 90 | 
 91 | @pytest.mark.asyncio
 92 | class TestAPIToolsUnit:
 93 |     """Unit tests for API tools."""
 94 | 
 95 |     @pytest.fixture
 96 |     def mock_container(self):
 97 |         """Create a mock container with API services."""
 98 |         container = MagicMock(spec=ServicesContainer)
 99 |         
100 |         # Mock API manager
101 |         container.api_manager = MagicMock()
102 |         container.api_manager.send_request = AsyncMock()
103 |         container.api_manager.spec_manager = MagicMock()
104 |         container.api_manager.spec_manager.get_full_spec = MagicMock(return_value={"paths": {}})
105 |         
106 |         # Mock safety manager
107 |         container.safety_manager = MagicMock()
108 |         container.safety_manager.check_permission = MagicMock()
109 |         container.safety_manager.is_unsafe_mode = MagicMock(return_value=False)
110 |         
111 |         return container
112 | 
113 |     async def test_api_request_success(self, mock_container):
114 |         """Test successful API request."""
115 |         # Setup
116 |         mock_response = {"id": "123", "name": "Test Project"}
117 |         mock_container.api_manager.send_request.return_value = mock_response
118 |         
119 |         # Execute
120 |         result = await mock_container.api_manager.send_request(
121 |             "GET", "/v1/projects", {}
122 |         )
123 |         
124 |         # Verify
125 |         assert result["id"] == "123"
126 |         assert result["name"] == "Test Project"
127 | 
128 |     async def test_api_spec_retrieval(self, mock_container):
129 |         """Test API spec retrieval."""
130 |         # Setup
131 |         expected_spec = {
132 |             "paths": {
133 |                 "/v1/projects": {
134 |                     "get": {"summary": "List projects"}
135 |                 }
136 |             }
137 |         }
138 |         mock_container.api_manager.spec_manager.get_full_spec.return_value = expected_spec
139 |         
140 |         # Execute
141 |         spec = mock_container.api_manager.spec_manager.get_full_spec()
142 |         
143 |         # Verify
144 |         assert "paths" in spec
145 |         assert "/v1/projects" in spec["paths"]
146 | 
147 |     async def test_medium_risk_api_blocked_in_safe_mode(self, mock_container):
148 |         """Test that medium risk API operations are blocked in safe mode."""
149 |         # Setup
150 |         mock_container.safety_manager.check_permission.side_effect = ConfirmationRequiredError(
151 |             "This operation requires confirmation",
152 |             {"method": "POST", "path": "/v1/projects"}
153 |         )
154 |         
155 |         # Execute & Verify
156 |         with pytest.raises(ConfirmationRequiredError) as exc_info:
157 |             mock_container.safety_manager.check_permission(
158 |                 ClientType.API,
159 |                 OperationRiskLevel.MEDIUM
160 |             )
161 |         
162 |         assert "requires confirmation" in str(exc_info.value)
163 | 
164 | 
165 | @pytest.mark.asyncio
166 | class TestAuthToolsUnit:
167 |     """Unit tests for auth tools."""
168 | 
169 |     @pytest.fixture
170 |     def mock_container(self):
171 |         """Create a mock container with SDK client."""
172 |         container = MagicMock(spec=ServicesContainer)
173 |         
174 |         # Mock SDK client
175 |         container.sdk_client = MagicMock()
176 |         container.sdk_client.call_auth_admin_method = AsyncMock()
177 |         container.sdk_client.return_python_sdk_spec = MagicMock(return_value={
178 |             "methods": ["list_users", "create_user", "delete_user"]
179 |         })
180 |         
181 |         return container
182 | 
183 |     async def test_list_users_success(self, mock_container):
184 |         """Test listing users successfully."""
185 |         # Setup
186 |         mock_users = [
187 |             {"id": "user1", "email": "[email protected]"},
188 |             {"id": "user2", "email": "[email protected]"}
189 |         ]
190 |         mock_container.sdk_client.call_auth_admin_method.return_value = mock_users
191 |         
192 |         # Execute
193 |         result = await mock_container.sdk_client.call_auth_admin_method(
194 |             "list_users", {"page": 1, "per_page": 10}
195 |         )
196 |         
197 |         # Verify
198 |         assert len(result) == 2
199 |         assert result[0]["email"] == "[email protected]"
200 | 
201 |     async def test_invalid_method_raises_error(self, mock_container):
202 |         """Test that invalid method names raise errors."""
203 |         # Setup
204 |         mock_container.sdk_client.call_auth_admin_method.side_effect = PythonSDKError(
205 |             "Unknown method: invalid_method"
206 |         )
207 |         
208 |         # Execute & Verify
209 |         with pytest.raises(PythonSDKError) as exc_info:
210 |             await mock_container.sdk_client.call_auth_admin_method(
211 |                 "invalid_method", {}
212 |             )
213 |         
214 |         assert "Unknown method" in str(exc_info.value)
215 | 
216 |     async def test_create_user_validation(self, mock_container):
217 |         """Test user creation with validation."""
218 |         # Setup
219 |         new_user = {"id": str(uuid.uuid4()), "email": "[email protected]"}
220 |         mock_container.sdk_client.call_auth_admin_method.return_value = {"user": new_user}
221 |         
222 |         # Execute
223 |         result = await mock_container.sdk_client.call_auth_admin_method(
224 |             "create_user", 
225 |             {"email": "[email protected]", "password": "TestPass123!"}
226 |         )
227 |         
228 |         # Verify
229 |         assert result["user"]["email"] == "[email protected]"
230 |         mock_container.sdk_client.call_auth_admin_method.assert_called_once()
231 | 
232 | 
233 | @pytest.mark.asyncio
234 | class TestSafetyToolsUnit:
235 |     """Unit tests for safety tools - these already work well."""
236 | 
237 |     @pytest.fixture
238 |     def mock_container(self):
239 |         """Create a mock container with safety manager."""
240 |         container = MagicMock(spec=ServicesContainer)
241 |         
242 |         # Mock safety manager with proper methods
243 |         container.safety_manager = MagicMock()
244 |         container.safety_manager.set_unsafe_mode = MagicMock()
245 |         container.safety_manager.get_mode = MagicMock(return_value=SafetyMode.SAFE)
246 |         container.safety_manager.confirm_operation = MagicMock()
247 |         container.safety_manager.is_unsafe_mode = MagicMock(return_value=False)
248 |         
249 |         return container
250 | 
251 |     async def test_live_dangerously_enables_unsafe_mode(self, mock_container):
252 |         """Test that live_dangerously enables unsafe mode."""
253 |         # Execute
254 |         mock_container.safety_manager.set_unsafe_mode(ClientType.DATABASE, True)
255 |         
256 |         # Verify
257 |         mock_container.safety_manager.set_unsafe_mode.assert_called_with(ClientType.DATABASE, True)
258 | 
259 |     async def test_confirm_operation_stores_confirmation(self, mock_container):
260 |         """Test that confirm operation stores the confirmation."""
261 |         # Setup
262 |         confirmation_id = str(uuid.uuid4())
263 |         
264 |         # Execute
265 |         mock_container.safety_manager.confirm_operation(confirmation_id)
266 |         
267 |         # Verify
268 |         mock_container.safety_manager.confirm_operation.assert_called_with(confirmation_id)
269 | 
270 |     async def test_safety_mode_switching(self, mock_container):
271 |         """Test switching between safe and unsafe modes."""
272 |         # Test enabling unsafe mode
273 |         mock_container.safety_manager.set_unsafe_mode(ClientType.API, True)
274 |         mock_container.safety_manager.set_unsafe_mode.assert_called_with(ClientType.API, True)
275 |         
276 |         # Test disabling unsafe mode
277 |         mock_container.safety_manager.set_unsafe_mode(ClientType.API, False)
278 |         mock_container.safety_manager.set_unsafe_mode.assert_called_with(ClientType.API, False)
```

--------------------------------------------------------------------------------
/supabase_mcp/core/feature_manager.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import TYPE_CHECKING, Any, Literal
  2 | 
  3 | from supabase_mcp.clients.api_client import ApiClient
  4 | from supabase_mcp.exceptions import APIError, ConfirmationRequiredError, FeatureAccessError, FeatureTemporaryError
  5 | from supabase_mcp.logger import logger
  6 | from supabase_mcp.services.database.postgres_client import QueryResult
  7 | from supabase_mcp.services.safety.models import ClientType, SafetyMode
  8 | from supabase_mcp.tools.manager import ToolName
  9 | 
 10 | if TYPE_CHECKING:
 11 |     from supabase_mcp.core.container import ServicesContainer
 12 | 
 13 | 
 14 | class FeatureManager:
 15 |     """Service for managing features, access to them and their configuration."""
 16 | 
 17 |     def __init__(self, api_client: ApiClient):
 18 |         """Initialize the feature service.
 19 | 
 20 |         Args:
 21 |             api_client: Client for communicating with the API
 22 |         """
 23 |         self.api_client = api_client
 24 | 
 25 |     async def check_feature_access(self, feature_name: str) -> None:
 26 |         """Check if the user has access to a feature.
 27 | 
 28 |         Args:
 29 |             feature_name: Name of the feature to check
 30 | 
 31 |         Raises:
 32 |             FeatureAccessError: If the user doesn't have access to the feature
 33 |         """
 34 |         try:
 35 |             # Use the API client to check feature access
 36 |             response = await self.api_client.check_feature_access(feature_name)
 37 | 
 38 |             # If access is not granted, raise an exception
 39 |             if not response.access_granted:
 40 |                 logger.info(f"Feature access denied: {feature_name}")
 41 |                 raise FeatureAccessError(feature_name)
 42 | 
 43 |             logger.debug(f"Feature access granted: {feature_name}")
 44 | 
 45 |         except APIError as e:
 46 |             logger.error(f"API error checking feature access: {feature_name} - {e}")
 47 |             raise FeatureTemporaryError(feature_name, e.status_code, e.response_body) from e
 48 |         except Exception as e:
 49 |             if not isinstance(e, FeatureAccessError):
 50 |                 logger.error(f"Unexpected error checking feature access: {feature_name} - {e}")
 51 |                 raise FeatureTemporaryError(feature_name) from e
 52 |             raise
 53 | 
 54 |     async def execute_tool(self, tool_name: ToolName, services_container: "ServicesContainer", **kwargs: Any) -> Any:
 55 |         """Execute a tool with feature access check.
 56 | 
 57 |         Args:
 58 |             tool_name: Name of the tool to execute
 59 |             services_container: Container with all services
 60 |             **kwargs: Arguments to pass to the tool
 61 | 
 62 |         Returns:
 63 |             Result of the tool execution
 64 |         """
 65 |         # Check feature access
 66 |         await self.check_feature_access(tool_name.value)
 67 | 
 68 |         # Execute the appropriate tool based on name
 69 |         if tool_name == ToolName.GET_SCHEMAS:
 70 |             return await self.get_schemas(services_container)
 71 |         elif tool_name == ToolName.GET_TABLES:
 72 |             return await self.get_tables(services_container, **kwargs)
 73 |         elif tool_name == ToolName.GET_TABLE_SCHEMA:
 74 |             return await self.get_table_schema(services_container, **kwargs)
 75 |         elif tool_name == ToolName.EXECUTE_POSTGRESQL:
 76 |             return await self.execute_postgresql(services_container, **kwargs)
 77 |         elif tool_name == ToolName.RETRIEVE_MIGRATIONS:
 78 |             return await self.retrieve_migrations(services_container, **kwargs)
 79 |         elif tool_name == ToolName.SEND_MANAGEMENT_API_REQUEST:
 80 |             return await self.send_management_api_request(services_container, **kwargs)
 81 |         elif tool_name == ToolName.GET_MANAGEMENT_API_SPEC:
 82 |             return await self.get_management_api_spec(services_container, **kwargs)
 83 |         elif tool_name == ToolName.GET_AUTH_ADMIN_METHODS_SPEC:
 84 |             return await self.get_auth_admin_methods_spec(services_container)
 85 |         elif tool_name == ToolName.CALL_AUTH_ADMIN_METHOD:
 86 |             return await self.call_auth_admin_method(services_container, **kwargs)
 87 |         elif tool_name == ToolName.LIVE_DANGEROUSLY:
 88 |             return await self.live_dangerously(services_container, **kwargs)
 89 |         elif tool_name == ToolName.CONFIRM_DESTRUCTIVE_OPERATION:
 90 |             return await self.confirm_destructive_operation(services_container, **kwargs)
 91 |         elif tool_name == ToolName.RETRIEVE_LOGS:
 92 |             return await self.retrieve_logs(services_container, **kwargs)
 93 |         else:
 94 |             raise ValueError(f"Unknown tool: {tool_name}")
 95 | 
 96 |     async def get_schemas(self, container: "ServicesContainer") -> QueryResult:
 97 |         """List all database schemas with their sizes and table counts."""
 98 |         query_manager = container.query_manager
 99 |         query = query_manager.get_schemas_query()
100 |         return await query_manager.handle_query(query)
101 | 
102 |     async def get_tables(self, container: "ServicesContainer", schema_name: str) -> QueryResult:
103 |         """List all tables, foreign tables, and views in a schema with their sizes, row counts, and metadata."""
104 |         query_manager = container.query_manager
105 |         query = query_manager.get_tables_query(schema_name)
106 |         return await query_manager.handle_query(query)
107 | 
108 |     async def get_table_schema(self, container: "ServicesContainer", schema_name: str, table: str) -> QueryResult:
109 |         """Get detailed table structure including columns, keys, and relationships."""
110 |         query_manager = container.query_manager
111 |         query = query_manager.get_table_schema_query(schema_name, table)
112 |         return await query_manager.handle_query(query)
113 | 
114 |     async def execute_postgresql(
115 |         self, container: "ServicesContainer", query: str, migration_name: str = ""
116 |     ) -> QueryResult:
117 |         """Execute PostgreSQL statements against your Supabase database."""
118 |         query_manager = container.query_manager
119 |         return await query_manager.handle_query(query, has_confirmation=False, migration_name=migration_name)
120 | 
121 |     async def retrieve_migrations(
122 |         self,
123 |         container: "ServicesContainer",
124 |         limit: int = 50,
125 |         offset: int = 0,
126 |         name_pattern: str = "",
127 |         include_full_queries: bool = False,
128 |     ) -> QueryResult:
129 |         """Retrieve a list of all migrations a user has from Supabase."""
130 |         query_manager = container.query_manager
131 |         query = query_manager.get_migrations_query(
132 |             limit=limit, offset=offset, name_pattern=name_pattern, include_full_queries=include_full_queries
133 |         )
134 |         return await query_manager.handle_query(query)
135 | 
136 |     async def send_management_api_request(
137 |         self,
138 |         container: "ServicesContainer",
139 |         method: str,
140 |         path: str,
141 |         path_params: dict[str, str],
142 |         request_params: dict[str, Any],
143 |         request_body: dict[str, Any],
144 |     ) -> dict[str, Any]:
145 |         """Execute a Supabase Management API request."""
146 |         api_manager = container.api_manager
147 |         return await api_manager.execute_request(method, path, path_params, request_params, request_body)
148 | 
149 |     async def get_management_api_spec(
150 |         self, container: "ServicesContainer", params: dict[str, Any] = {}
151 |     ) -> dict[str, Any]:
152 |         """Get the Supabase Management API specification."""
153 |         path = params.get("path")
154 |         method = params.get("method")
155 |         domain = params.get("domain")
156 |         all_paths = params.get("all_paths", False)
157 | 
158 |         logger.debug(
159 |             f"Getting management API spec with path: {path}, method: {method}, domain: {domain}, all_paths: {all_paths}"
160 |         )
161 |         api_manager = container.api_manager
162 |         return await api_manager.handle_spec_request(path, method, domain, all_paths)
163 | 
164 |     async def get_auth_admin_methods_spec(self, container: "ServicesContainer") -> dict[str, Any]:
165 |         """Get Python SDK methods specification for Auth Admin."""
166 |         sdk_client = container.sdk_client
167 |         return sdk_client.return_python_sdk_spec()
168 | 
169 |     async def call_auth_admin_method(
170 |         self, container: "ServicesContainer", method: str, params: dict[str, Any]
171 |     ) -> dict[str, Any]:
172 |         """Call an Auth Admin method from Supabase Python SDK."""
173 |         sdk_client = container.sdk_client
174 |         return await sdk_client.call_auth_admin_method(method, params)
175 | 
176 |     async def live_dangerously(
177 |         self, container: "ServicesContainer", service: Literal["api", "database"], enable_unsafe_mode: bool = False
178 |     ) -> dict[str, Any]:
179 |         """
180 |         Toggle between safe and unsafe operation modes for API or Database services.
181 | 
182 |         This function controls the safety level for operations, allowing you to:
183 |         - Enable write operations for the database (INSERT, UPDATE, DELETE, schema changes)
184 |         - Enable state-changing operations for the Management API
185 |         """
186 |         safety_manager = container.safety_manager
187 |         if service == "api":
188 |             # Set the safety mode in the safety manager
189 |             new_mode = SafetyMode.UNSAFE if enable_unsafe_mode else SafetyMode.SAFE
190 |             safety_manager.set_safety_mode(ClientType.API, new_mode)
191 | 
192 |             # Return the actual mode that was set
193 |             return {"service": "api", "mode": safety_manager.get_safety_mode(ClientType.API)}
194 |         elif service == "database":
195 |             # Set the safety mode in the safety manager
196 |             new_mode = SafetyMode.UNSAFE if enable_unsafe_mode else SafetyMode.SAFE
197 |             safety_manager.set_safety_mode(ClientType.DATABASE, new_mode)
198 | 
199 |             # Return the actual mode that was set
200 |             return {"service": "database", "mode": safety_manager.get_safety_mode(ClientType.DATABASE)}
201 | 
202 |     async def confirm_destructive_operation(
203 |         self,
204 |         container: "ServicesContainer",
205 |         operation_type: Literal["api", "database"],
206 |         confirmation_id: str,
207 |         user_confirmation: bool = False,
208 |     ) -> QueryResult | dict[str, Any]:
209 |         """Execute a destructive operation after confirmation. Use this only after reviewing the risks with the user."""
210 |         api_manager = container.api_manager
211 |         query_manager = container.query_manager
212 |         if not user_confirmation:
213 |             raise ConfirmationRequiredError("Destructive operation requires explicit user confirmation.")
214 | 
215 |         if operation_type == "api":
216 |             return await api_manager.handle_confirmation(confirmation_id)
217 |         elif operation_type == "database":
218 |             return await query_manager.handle_confirmation(confirmation_id)
219 | 
220 |     async def retrieve_logs(
221 |         self,
222 |         container: "ServicesContainer",
223 |         collection: str,
224 |         limit: int = 20,
225 |         hours_ago: int = 1,
226 |         filters: list[dict[str, Any]] = [],
227 |         search: str = "",
228 |         custom_query: str = "",
229 |     ) -> dict[str, Any]:
230 |         """Retrieve logs from your Supabase project's services for debugging and monitoring."""
231 |         logger.info(
232 |             f"Tool called: retrieve_logs(collection={collection}, limit={limit}, hours_ago={hours_ago}, filters={filters}, search={search}, custom_query={'<custom>' if custom_query else None})"
233 |         )
234 | 
235 |         api_manager = container.api_manager
236 |         result = await api_manager.retrieve_logs(
237 |             collection=collection,
238 |             limit=limit,
239 |             hours_ago=hours_ago,
240 |             filters=filters,
241 |             search=search,
242 |             custom_query=custom_query,
243 |         )
244 | 
245 |         logger.info(f"Tool completed: retrieve_logs - Retrieved log entries for collection={collection}")
246 | 
247 |         return result
248 | 
```

--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------

```python
  1 | import os
  2 | from collections.abc import AsyncGenerator, Generator
  3 | from pathlib import Path
  4 | from typing import Any
  5 | from unittest.mock import AsyncMock, MagicMock
  6 | 
  7 | import pytest
  8 | import pytest_asyncio
  9 | from dotenv import load_dotenv
 10 | from mcp.server.fastmcp import FastMCP
 11 | 
 12 | from supabase_mcp.clients.management_client import ManagementAPIClient
 13 | from supabase_mcp.clients.sdk_client import SupabaseSDKClient
 14 | from supabase_mcp.core.container import ServicesContainer
 15 | from supabase_mcp.logger import logger
 16 | from supabase_mcp.services.api.api_manager import SupabaseApiManager
 17 | from supabase_mcp.services.api.spec_manager import ApiSpecManager
 18 | from supabase_mcp.services.database.migration_manager import MigrationManager
 19 | from supabase_mcp.services.database.postgres_client import PostgresClient
 20 | from supabase_mcp.services.database.query_manager import QueryManager
 21 | from supabase_mcp.services.database.sql.loader import SQLLoader
 22 | from supabase_mcp.services.database.sql.validator import SQLValidator
 23 | from supabase_mcp.services.safety.safety_manager import SafetyManager
 24 | from supabase_mcp.settings import Settings, find_config_file
 25 | from supabase_mcp.tools import ToolManager
 26 | from supabase_mcp.tools.registry import ToolRegistry
 27 | 
 28 | # ======================
 29 | # Environment Fixtures
 30 | # ======================
 31 | 
 32 | 
 33 | @pytest.fixture
 34 | def clean_environment() -> Generator[None, None, None]:
 35 |     """Fixture to provide a clean environment without any Supabase-related env vars."""
 36 |     # Save original environment
 37 |     original_env = dict(os.environ)
 38 | 
 39 |     # Remove all Supabase-related environment variables
 40 |     for key in list(os.environ.keys()):
 41 |         if key.startswith("SUPABASE_"):
 42 |             del os.environ[key]
 43 | 
 44 |     yield
 45 | 
 46 |     # Restore original environment
 47 |     os.environ.clear()
 48 |     os.environ.update(original_env)
 49 | 
 50 | 
 51 | def load_test_env() -> dict[str, str | None]:
 52 |     """Load test environment variables from .env.test file"""
 53 |     env_test_path = Path(__file__).parent.parent / ".env.test"
 54 |     if not env_test_path.exists():
 55 |         raise FileNotFoundError(f"Test environment file not found at {env_test_path}")
 56 | 
 57 |     load_dotenv(env_test_path)
 58 |     return {
 59 |         "SUPABASE_PROJECT_REF": os.getenv("SUPABASE_PROJECT_REF"),
 60 |         "SUPABASE_DB_PASSWORD": os.getenv("SUPABASE_DB_PASSWORD"),
 61 |         "SUPABASE_SERVICE_ROLE_KEY": os.getenv("SUPABASE_SERVICE_ROLE_KEY"),
 62 |         "SUPABASE_ACCESS_TOKEN": os.getenv("SUPABASE_ACCESS_TOKEN"),
 63 |     }
 64 | 
 65 | 
 66 | @pytest.fixture(scope="session")
 67 | def settings_integration() -> Settings:
 68 |     """Fixture providing settings for integration tests.
 69 | 
 70 |     This fixture loads settings from environment variables or .env.test file.
 71 |     Uses session scope since settings don't change during tests.
 72 |     """
 73 |     return Settings.with_config(find_config_file(".env.test"))
 74 | 
 75 | 
 76 | @pytest.fixture
 77 | def mock_validator() -> SQLValidator:
 78 |     """Fixture providing a mock SQLValidator for integration tests."""
 79 |     return SQLValidator()
 80 | 
 81 | 
 82 | @pytest.fixture
 83 | def settings_integration_custom_env() -> Generator[Settings, None, None]:
 84 |     """Fixture that provides Settings instance for integration tests using .env.test"""
 85 | 
 86 |     # Load custom environment variables
 87 |     test_env = load_test_env()
 88 |     original_env = dict(os.environ)
 89 | 
 90 |     # Set up test environment
 91 |     for key, value in test_env.items():
 92 |         if value is not None:
 93 |             os.environ[key] = value
 94 | 
 95 |     # Create fresh settings instance
 96 |     settings = Settings()
 97 |     logger.info(f"Custom connection settings initialized: {settings}")
 98 | 
 99 |     yield settings
100 | 
101 |     # Restore original environment
102 |     os.environ.clear()
103 |     os.environ.update(original_env)
104 | 
105 | 
106 | # ======================
107 | # Service Fixtures
108 | # ======================
109 | 
110 | 
111 | @pytest_asyncio.fixture(scope="module")
112 | async def postgres_client_integration(settings_integration: Settings) -> AsyncGenerator[PostgresClient, None]:
113 |     # Reset before creation
114 |     await PostgresClient.reset()
115 | 
116 |     # Create a client
117 |     client = PostgresClient(settings=settings_integration)
118 | 
119 |     try:
120 |         yield client
121 |     finally:
122 |         await client.close()
123 | 
124 | 
125 | @pytest_asyncio.fixture(scope="module")
126 | async def spec_manager_integration() -> AsyncGenerator[ApiSpecManager, None]:
127 |     """Fixture providing an ApiSpecManager instance for tests."""
128 |     manager = ApiSpecManager()
129 |     yield manager
130 | 
131 | 
132 | @pytest_asyncio.fixture(scope="module")
133 | async def api_client_integration(settings_integration: Settings) -> AsyncGenerator[ManagementAPIClient, None]:
134 |     # We don't need to reset since it's not a singleton
135 |     client = ManagementAPIClient(settings=settings_integration)
136 | 
137 |     try:
138 |         yield client
139 |     finally:
140 |         await client.close()
141 | 
142 | 
143 | @pytest_asyncio.fixture(scope="module")
144 | async def sdk_client_integration(settings_integration: Settings) -> AsyncGenerator[SupabaseSDKClient, None]:
145 |     """Fixture providing a SupabaseSDKClient instance for tests.
146 | 
147 |     Uses function scope to ensure a fresh client for each test.
148 |     """
149 |     client = SupabaseSDKClient.get_instance(settings=settings_integration)
150 |     try:
151 |         yield client
152 |     finally:
153 |         # Reset the singleton to ensure a fresh client for the next test
154 |         SupabaseSDKClient.reset()
155 | 
156 | 
157 | @pytest.fixture(scope="module")
158 | def safety_manager_integration() -> SafetyManager:
159 |     """Fixture providing a safety manager for integration tests."""
160 |     # Reset the safety manager singleton
161 |     SafetyManager.reset()
162 | 
163 |     # Create a new safety manager
164 |     safety_manager = SafetyManager.get_instance()
165 |     safety_manager.register_safety_configs()
166 | 
167 |     return safety_manager
168 | 
169 | 
170 | @pytest.fixture(scope="module")
171 | def tool_manager_integration() -> ToolManager:
172 |     """Fixture providing a tool manager for integration tests."""
173 |     # Reset the tool manager singleton
174 |     ToolManager.reset()
175 |     return ToolManager.get_instance()
176 | 
177 | 
178 | @pytest.fixture(scope="module")
179 | def query_manager_integration(
180 |     postgres_client_integration: PostgresClient,
181 |     safety_manager_integration: SafetyManager,
182 | ) -> QueryManager:
183 |     """Fixture providing a query manager for integration tests."""
184 |     query_manager = QueryManager(
185 |         postgres_client=postgres_client_integration,
186 |         safety_manager=safety_manager_integration,
187 |     )
188 |     return query_manager
189 | 
190 | 
191 | @pytest.fixture(scope="module")
192 | def mock_api_manager() -> SupabaseApiManager:
193 |     """Fixture providing a properly mocked API manager for unit tests."""
194 |     # Create mock dependencies
195 |     mock_client = MagicMock()
196 |     mock_safety_manager = MagicMock()
197 |     mock_spec_manager = MagicMock()
198 | 
199 |     # Create the API manager with proper constructor arguments
200 |     api_manager = SupabaseApiManager(api_client=mock_client, safety_manager=mock_safety_manager)
201 | 
202 |     # Add the spec_manager attribute
203 |     api_manager.spec_manager = mock_spec_manager
204 | 
205 |     return api_manager
206 | 
207 | 
208 | @pytest.fixture
209 | def mock_query_manager() -> QueryManager:
210 |     """Fixture providing a properly mocked Query manager for unit tests."""
211 |     # Create mock dependencies
212 |     mock_safety_manager = MagicMock()
213 |     mock_postgres_client = MagicMock()
214 |     mock_validator = MagicMock()
215 | 
216 |     # Create the Query manager with proper constructor arguments
217 |     query_manager = QueryManager(
218 |         postgres_client=mock_postgres_client,
219 |         safety_manager=mock_safety_manager,
220 |     )
221 | 
222 |     # Replace the validator with a mock
223 |     query_manager.validator = mock_validator
224 | 
225 |     # Store the postgres client as an attribute for tests to access
226 |     query_manager.db_client = mock_postgres_client
227 | 
228 |     # Make execute_query_async an AsyncMock
229 |     query_manager.db_client.execute_query_async = AsyncMock()
230 | 
231 |     return query_manager
232 | 
233 | 
234 | @pytest_asyncio.fixture(scope="module")
235 | async def api_manager_integration(
236 |     api_client_integration: ManagementAPIClient,
237 |     safety_manager_integration: SafetyManager,
238 | ) -> AsyncGenerator[SupabaseApiManager, None]:
239 |     """Fixture providing an API manager for integration tests."""
240 | 
241 |     # Create a new API manager
242 |     api_manager = SupabaseApiManager.get_instance(
243 |         api_client=api_client_integration,
244 |         safety_manager=safety_manager_integration,
245 |     )
246 | 
247 |     try:
248 |         yield api_manager
249 |     finally:
250 |         # Reset the API manager singleton
251 |         SupabaseApiManager.reset()
252 | 
253 | 
254 | # ======================
255 | # Mock MCP Server
256 | # ======================
257 | 
258 | 
259 | @pytest.fixture
260 | def mock_mcp_server() -> Any:
261 |     """Fixture providing a mock MCP server for integration tests."""
262 | 
263 |     # Create a simple mock MCP server that mimics the FastMCP interface
264 |     class MockMCP:
265 |         def __init__(self) -> None:
266 |             self.tools: dict[str, Any] = {}
267 |             self.name = "mock_mcp"
268 | 
269 |         def register_tool(self, name: str, func: Any, **kwargs: Any) -> None:
270 |             """Register a tool with the MCP server."""
271 |             self.tools[name] = func
272 | 
273 |         def run(self) -> None:
274 |             """Mock run method."""
275 |             pass
276 | 
277 |     return MockMCP()
278 | 
279 | 
280 | @pytest.fixture(scope="module")
281 | def mock_mcp_server_integration() -> Any:
282 |     """Fixture providing a mock MCP server for integration tests."""
283 |     return FastMCP(name="supabase")
284 | 
285 | 
286 | # ======================
287 | # Container Fixture
288 | # ======================
289 | 
290 | 
291 | @pytest.fixture(scope="module")
292 | def container_integration(
293 |     postgres_client_integration: PostgresClient,
294 |     api_client_integration: ManagementAPIClient,
295 |     sdk_client_integration: SupabaseSDKClient,
296 |     api_manager_integration: SupabaseApiManager,
297 |     safety_manager_integration: SafetyManager,
298 |     query_manager_integration: QueryManager,
299 |     tool_manager_integration: ToolManager,
300 |     mock_mcp_server_integration: FastMCP,
301 | ) -> ServicesContainer:
302 |     """Fixture providing a basic Container for integration tests.
303 | 
304 |     This container includes all services needed for integration testing,
305 |     but is not initialized.
306 |     """
307 |     # Create a new container with all the services
308 |     container = ServicesContainer(
309 |         mcp_server=mock_mcp_server_integration,
310 |         postgres_client=postgres_client_integration,
311 |         api_client=api_client_integration,
312 |         sdk_client=sdk_client_integration,
313 |         api_manager=api_manager_integration,
314 |         safety_manager=safety_manager_integration,
315 |         query_manager=query_manager_integration,
316 |         tool_manager=tool_manager_integration,
317 |     )
318 | 
319 |     logger.info("✓ Integration container created successfully.")
320 | 
321 |     return container
322 | 
323 | 
324 | @pytest.fixture(scope="module")
325 | def initialized_container_integration(
326 |     container_integration: ServicesContainer,
327 |     settings_integration: Settings,
328 | ) -> ServicesContainer:
329 |     """Fixture providing a fully initialized Container for integration tests.
330 | 
331 |     This container is initialized with all services and ready to use.
332 |     """
333 |     container_integration.initialize_services(settings_integration)
334 |     logger.info("✓ Integration container initialized successfully.")
335 | 
336 |     return container_integration
337 | 
338 | 
339 | @pytest.fixture(scope="module")
340 | def tools_registry_integration(
341 |     initialized_container_integration: ServicesContainer,
342 | ) -> ServicesContainer:
343 |     """Fixture providing a Container with tools registered for integration tests.
344 | 
345 |     This container has all tools registered with the MCP server.
346 |     """
347 |     container = initialized_container_integration
348 |     mcp_server = container.mcp_server
349 | 
350 |     registry = ToolRegistry(mcp_server, container)
351 |     registry.register_tools()
352 | 
353 |     logger.info("✓ Tools registered with MCP server successfully.")
354 | 
355 |     return container
356 | 
357 | 
358 | @pytest.fixture
359 | def sql_loader() -> SQLLoader:
360 |     """Fixture providing a SQLLoader instance for tests."""
361 |     return SQLLoader()
362 | 
363 | 
364 | @pytest.fixture
365 | def migration_manager(sql_loader: SQLLoader) -> MigrationManager:
366 |     """Fixture providing a MigrationManager instance for tests."""
367 |     return MigrationManager(loader=sql_loader)
368 | 
```

--------------------------------------------------------------------------------
/tests/services/api/test_spec_manager.py:
--------------------------------------------------------------------------------

```python
  1 | import json
  2 | from unittest.mock import AsyncMock, MagicMock, mock_open, patch
  3 | 
  4 | import httpx
  5 | import pytest
  6 | 
  7 | from supabase_mcp.services.api.spec_manager import ApiSpecManager
  8 | 
  9 | # Test data
 10 | SAMPLE_SPEC = {"openapi": "3.0.0", "paths": {"/v1/test": {"get": {"operationId": "test"}}}}
 11 | 
 12 | 
 13 | class TestApiSpecManager:
 14 |     """Integration tests for api spec manager tools."""
 15 | 
 16 |     # Local Spec Tests
 17 |     def test_load_local_spec_success(self, spec_manager_integration: ApiSpecManager):
 18 |         """Test successful loading of local spec file"""
 19 |         mock_file = mock_open(read_data=json.dumps(SAMPLE_SPEC))
 20 | 
 21 |         with patch("builtins.open", mock_file):
 22 |             result = spec_manager_integration._load_local_spec()
 23 | 
 24 |         assert result == SAMPLE_SPEC
 25 |         mock_file.assert_called_once()
 26 | 
 27 |     def test_load_local_spec_file_not_found(self, spec_manager_integration: ApiSpecManager):
 28 |         """Test handling of missing local spec file"""
 29 |         with patch("builtins.open", side_effect=FileNotFoundError), pytest.raises(FileNotFoundError):
 30 |             spec_manager_integration._load_local_spec()
 31 | 
 32 |     def test_load_local_spec_invalid_json(self, spec_manager_integration: ApiSpecManager):
 33 |         """Test handling of invalid JSON in local spec"""
 34 |         mock_file = mock_open(read_data="invalid json")
 35 | 
 36 |         with patch("builtins.open", mock_file), pytest.raises(json.JSONDecodeError):
 37 |             spec_manager_integration._load_local_spec()
 38 | 
 39 |     # Remote Spec Tests
 40 |     @pytest.mark.asyncio
 41 |     async def test_fetch_remote_spec_success(self, spec_manager_integration: ApiSpecManager):
 42 |         """Test successful remote spec fetch"""
 43 |         mock_response = MagicMock()
 44 |         mock_response.status_code = 200
 45 |         mock_response.json.return_value = SAMPLE_SPEC
 46 | 
 47 |         mock_client = AsyncMock()
 48 |         mock_client.get.return_value = mock_response
 49 |         mock_client.__aenter__.return_value = mock_client  # Mock async context manager
 50 | 
 51 |         with patch("httpx.AsyncClient", return_value=mock_client):
 52 |             result = await spec_manager_integration._fetch_remote_spec()
 53 | 
 54 |         assert result == SAMPLE_SPEC
 55 |         mock_client.get.assert_called_once()
 56 | 
 57 |     @pytest.mark.asyncio
 58 |     async def test_fetch_remote_spec_api_error(self, spec_manager_integration: ApiSpecManager):
 59 |         """Test handling of API error during remote fetch"""
 60 |         mock_response = MagicMock()
 61 |         mock_response.status_code = 500
 62 | 
 63 |         mock_client = AsyncMock()
 64 |         mock_client.get.return_value = mock_response
 65 |         mock_client.__aenter__.return_value = mock_client  # Mock async context manager
 66 | 
 67 |         with patch("httpx.AsyncClient", return_value=mock_client):
 68 |             result = await spec_manager_integration._fetch_remote_spec()
 69 | 
 70 |         assert result is None
 71 | 
 72 |     @pytest.mark.asyncio
 73 |     async def test_fetch_remote_spec_network_error(self, spec_manager_integration: ApiSpecManager):
 74 |         """Test handling of network error during remote fetch"""
 75 |         mock_client = AsyncMock()
 76 |         mock_client.get.side_effect = httpx.NetworkError("Network error")
 77 | 
 78 |         with patch("httpx.AsyncClient", return_value=mock_client):
 79 |             result = await spec_manager_integration._fetch_remote_spec()
 80 | 
 81 |         assert result is None
 82 | 
 83 |     # Startup Flow Tests
 84 |     @pytest.mark.asyncio
 85 |     async def test_startup_remote_success(self, spec_manager_integration: ApiSpecManager):
 86 |         """Test successful startup with remote fetch"""
 87 |         # Reset spec to None to ensure we're testing the fetch
 88 |         spec_manager_integration.spec = None
 89 | 
 90 |         # Mock the fetch method to return sample spec
 91 |         mock_fetch = AsyncMock(return_value=SAMPLE_SPEC)
 92 | 
 93 |         with patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch):
 94 |             result = await spec_manager_integration.get_spec()
 95 | 
 96 |         assert result == SAMPLE_SPEC
 97 |         mock_fetch.assert_called_once()
 98 | 
 99 |     @pytest.mark.asyncio
100 |     async def test_get_spec_remote_fail_local_fallback(self, spec_manager_integration: ApiSpecManager):
101 |         """Test get_spec with remote failure and local fallback"""
102 |         # Reset spec to None to ensure we're testing the fetch and fallback
103 |         spec_manager_integration.spec = None
104 | 
105 |         # Mock fetch to fail and local to succeed
106 |         mock_fetch = AsyncMock(return_value=None)
107 |         mock_local = MagicMock(return_value=SAMPLE_SPEC)
108 | 
109 |         with (
110 |             patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch),
111 |             patch.object(spec_manager_integration, "_load_local_spec", mock_local),
112 |         ):
113 |             result = await spec_manager_integration.get_spec()
114 | 
115 |         assert result == SAMPLE_SPEC
116 |         mock_fetch.assert_called_once()
117 |         mock_local.assert_called_once()
118 | 
119 |     @pytest.mark.asyncio
120 |     async def test_get_spec_both_fail(self, spec_manager_integration: ApiSpecManager):
121 |         """Test get_spec with both remote and local failure"""
122 |         # Reset spec to None to ensure we're testing the fetch and fallback
123 |         spec_manager_integration.spec = None
124 | 
125 |         # Mock both fetch and local to fail
126 |         mock_fetch = AsyncMock(return_value=None)
127 |         mock_local = MagicMock(side_effect=FileNotFoundError("Test file not found"))
128 | 
129 |         with (
130 |             patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch),
131 |             patch.object(spec_manager_integration, "_load_local_spec", mock_local),
132 |             pytest.raises(FileNotFoundError),
133 |         ):
134 |             await spec_manager_integration.get_spec()
135 | 
136 |         mock_fetch.assert_called_once()
137 |         mock_local.assert_called_once()
138 | 
139 |     @pytest.mark.asyncio
140 |     async def test_get_spec_cached(self, spec_manager_integration: ApiSpecManager):
141 |         """Test that get_spec returns cached spec if available"""
142 |         # Set the spec directly
143 |         spec_manager_integration.spec = SAMPLE_SPEC
144 | 
145 |         # Mock the fetch method to verify it's not called
146 |         mock_fetch = AsyncMock()
147 | 
148 |         with patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch):
149 |             result = await spec_manager_integration.get_spec()
150 | 
151 |         assert result == SAMPLE_SPEC
152 |         mock_fetch.assert_not_called()
153 | 
154 |     @pytest.mark.asyncio
155 |     async def test_get_spec_not_loaded(self, spec_manager_integration: ApiSpecManager):
156 |         """Test behavior when spec is not loaded but can be loaded"""
157 |         # Reset spec to None
158 |         spec_manager_integration.spec = None
159 | 
160 |         # Mock fetch to succeed
161 |         mock_fetch = AsyncMock(return_value=SAMPLE_SPEC)
162 | 
163 |         with patch.object(spec_manager_integration, "_fetch_remote_spec", mock_fetch):
164 |             result = await spec_manager_integration.get_spec()
165 | 
166 |         assert result == SAMPLE_SPEC
167 |         mock_fetch.assert_called_once()
168 | 
169 |     @pytest.mark.asyncio
170 |     async def test_comprehensive_spec_retrieval(self, spec_manager_integration: ApiSpecManager):
171 |         """
172 |         Comprehensive test of API spec retrieval and functionality.
173 |         This test exactly mirrors the main() function to ensure all aspects work correctly.
174 |         """
175 |         # Create a fresh instance to avoid any cached data from other tests
176 |         from supabase_mcp.services.api.spec_manager import LOCAL_SPEC_PATH, ApiSpecManager
177 | 
178 |         spec_manager = ApiSpecManager()
179 | 
180 |         # Print the path being used (for debugging)
181 |         print(f"\nTest is looking for spec at: {LOCAL_SPEC_PATH}")
182 | 
183 |         # Load the spec
184 |         spec = await spec_manager.get_spec()
185 |         assert spec is not None, "Spec should be loaded successfully"
186 | 
187 |         # 1. Test get_all_domains
188 |         all_domains = spec_manager.get_all_domains()
189 |         print(f"\nAll domains: {all_domains}")
190 |         assert len(all_domains) > 0, "Should have at least one domain"
191 | 
192 |         # Verify all expected domains are present
193 |         expected_domains = [
194 |             "Analytics",
195 |             "Auth",
196 |             "Database",
197 |             "Domains",
198 |             "Edge Functions",
199 |             "Environments",
200 |             "OAuth",
201 |             "Organizations",
202 |             "Projects",
203 |             "Rest",
204 |             "Secrets",
205 |             "Storage",
206 |         ]
207 |         for domain in expected_domains:
208 |             assert domain in all_domains, f"Domain '{domain}' should be in the list of domains"
209 | 
210 |         # 2. Test get_all_paths_and_methods
211 |         all_paths = spec_manager.get_all_paths_and_methods()
212 |         assert len(all_paths) > 0, "Should have at least one path"
213 | 
214 |         # Sample a few paths to verify
215 |         sample_paths = list(all_paths.keys())[:5]
216 |         print("\nSample paths:")
217 |         for path in sample_paths:
218 |             print(f"  {path}:")
219 |             assert path.startswith("/v1/"), f"Path {path} should start with /v1/"
220 |             assert len(all_paths[path]) > 0, f"Path {path} should have at least one method"
221 |             for method, operation_id in all_paths[path].items():
222 |                 print(f"    {method}: {operation_id}")
223 |                 assert method.lower() in ["get", "post", "put", "patch", "delete"], f"Method {method} should be valid"
224 |                 assert operation_id.startswith("v1-"), f"Operation ID {operation_id} should start with v1-"
225 | 
226 |         # 3. Test get_paths_and_methods_by_domain for each domain
227 |         for domain in expected_domains:
228 |             domain_paths = spec_manager.get_paths_and_methods_by_domain(domain)
229 |             assert len(domain_paths) > 0, f"Domain {domain} should have at least one path"
230 |             print(f"\n{domain} domain has {len(domain_paths)} paths")
231 | 
232 |         # 4. Test Edge Functions domain specifically
233 |         edge_paths = spec_manager.get_paths_and_methods_by_domain("Edge Functions")
234 |         print("\nEdge Functions Paths and Methods:")
235 |         for path in edge_paths:
236 |             print(f"  {path}")
237 |             for method, operation_id in edge_paths[path].items():
238 |                 print(f"    {method}: {operation_id}")
239 | 
240 |         # Verify specific Edge Functions paths exist
241 |         expected_edge_paths = [
242 |             "/v1/projects/{ref}/functions",
243 |             "/v1/projects/{ref}/functions/{function_slug}",
244 |             "/v1/projects/{ref}/functions/deploy",
245 |         ]
246 |         for path in expected_edge_paths:
247 |             assert path in edge_paths, f"Expected path {path} should be in Edge Functions domain"
248 | 
249 |         # 5. Test get_spec_for_path_and_method
250 |         # Test for Edge Functions
251 |         path = "/v1/projects/{ref}/functions"
252 |         method = "GET"
253 |         full_spec = spec_manager.get_spec_for_path_and_method(path, method)
254 |         assert full_spec is not None, f"Should find spec for {method} {path}"
255 |         assert "operationId" in full_spec, "Spec should include operationId"
256 |         assert full_spec["operationId"] == "v1-list-all-functions", "Should have correct operationId"
257 | 
258 |         # Test for another domain (Auth)
259 |         auth_path = "/v1/projects/{ref}/config/auth"
260 |         auth_method = "GET"
261 |         auth_spec = spec_manager.get_spec_for_path_and_method(auth_path, auth_method)
262 |         assert auth_spec is not None, f"Should find spec for {auth_method} {auth_path}"
263 |         assert "operationId" in auth_spec, "Auth spec should include operationId"
264 | 
265 |         # 6. Test get_spec_part
266 |         # Get a specific schema
267 |         schema = spec_manager.get_spec_part("components", "schemas", "FunctionResponse")
268 |         assert schema is not None, "Should find FunctionResponse schema"
269 |         assert "properties" in schema, "Schema should have properties"
270 | 
271 |         # 7. Test caching behavior
272 |         # Call get_spec again - should use cached version
273 |         import time
274 | 
275 |         start_time = time.time()
276 |         await spec_manager.get_spec()
277 |         end_time = time.time()
278 |         assert (end_time - start_time) < 0.1, "Cached spec retrieval should be fast"
279 | 
```
Page 2/6FirstPrevNextLast