This is page 4 of 29. Use http://codebase.md/wshobson/maverick-mcp?page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/providers/mocks/mock_cache.py:
--------------------------------------------------------------------------------
```python
"""
Mock cache manager implementation for testing.
"""
import time
from typing import Any
class MockCacheManager:
"""
Mock implementation of ICacheManager for testing.
This implementation uses in-memory storage and provides predictable
behavior for testing cache-dependent functionality.
"""
def __init__(self):
"""Initialize the mock cache manager."""
self._data: dict[str, dict[str, Any]] = {}
self._call_log: list[dict[str, Any]] = []
async def get(self, key: str) -> Any:
"""Get data from mock cache."""
self._log_call("get", {"key": key})
if key not in self._data:
return None
entry = self._data[key]
# Check if expired
if "expires_at" in entry and entry["expires_at"] < time.time():
del self._data[key]
return None
return entry["value"]
async def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
"""Store data in mock cache."""
self._log_call("set", {"key": key, "value": value, "ttl": ttl})
entry = {"value": value}
if ttl is not None:
entry["expires_at"] = time.time() + ttl
self._data[key] = entry
return True
async def delete(self, key: str) -> bool:
"""Delete a key from mock cache."""
self._log_call("delete", {"key": key})
if key in self._data:
del self._data[key]
return True
return False
async def exists(self, key: str) -> bool:
"""Check if a key exists in mock cache."""
self._log_call("exists", {"key": key})
if key not in self._data:
return False
entry = self._data[key]
# Check if expired
if "expires_at" in entry and entry["expires_at"] < time.time():
del self._data[key]
return False
return True
async def clear(self, pattern: str | None = None) -> int:
"""Clear cache entries."""
self._log_call("clear", {"pattern": pattern})
if pattern is None:
count = len(self._data)
self._data.clear()
return count
# Simple pattern matching (only supports prefix*)
if pattern.endswith("*"):
prefix = pattern[:-1]
keys_to_delete = [k for k in self._data.keys() if k.startswith(prefix)]
else:
keys_to_delete = [k for k in self._data.keys() if k == pattern]
for key in keys_to_delete:
del self._data[key]
return len(keys_to_delete)
async def get_many(self, keys: list[str]) -> dict[str, Any]:
"""Get multiple values at once."""
self._log_call("get_many", {"keys": keys})
results = {}
for key in keys:
value = await self.get(key)
if value is not None:
results[key] = value
return results
async def set_many(self, items: list[tuple[str, Any, int | None]]) -> int:
"""Set multiple values at once."""
self._log_call("set_many", {"items_count": len(items)})
success_count = 0
for key, value, ttl in items:
if await self.set(key, value, ttl):
success_count += 1
return success_count
async def delete_many(self, keys: list[str]) -> int:
"""Delete multiple keys."""
self._log_call("delete_many", {"keys": keys})
deleted_count = 0
for key in keys:
if await self.delete(key):
deleted_count += 1
return deleted_count
async def exists_many(self, keys: list[str]) -> dict[str, bool]:
"""Check existence of multiple keys."""
self._log_call("exists_many", {"keys": keys})
results = {}
for key in keys:
results[key] = await self.exists(key)
return results
async def count_keys(self, pattern: str) -> int:
"""Count keys matching a pattern."""
self._log_call("count_keys", {"pattern": pattern})
if pattern.endswith("*"):
prefix = pattern[:-1]
return len([k for k in self._data.keys() if k.startswith(prefix)])
else:
return 1 if pattern in self._data else 0
async def get_or_set(
self, key: str, default_value: Any, ttl: int | None = None
) -> Any:
"""Get value from cache, setting it if it doesn't exist."""
self._log_call(
"get_or_set", {"key": key, "default_value": default_value, "ttl": ttl}
)
value = await self.get(key)
if value is not None:
return value
await self.set(key, default_value, ttl)
return default_value
async def increment(self, key: str, amount: int = 1) -> int:
"""Increment a numeric value in cache."""
self._log_call("increment", {"key": key, "amount": amount})
current = await self.get(key)
if current is None:
new_value = amount
else:
try:
current_int = int(current)
new_value = current_int + amount
except (ValueError, TypeError):
raise ValueError(f"Key {key} contains non-numeric value: {current}")
await self.set(key, new_value)
return new_value
async def set_if_not_exists(
self, key: str, value: Any, ttl: int | None = None
) -> bool:
"""Set a value only if the key doesn't already exist."""
self._log_call("set_if_not_exists", {"key": key, "value": value, "ttl": ttl})
if await self.exists(key):
return False
return await self.set(key, value, ttl)
async def get_ttl(self, key: str) -> int | None:
"""Get the remaining time-to-live for a key."""
self._log_call("get_ttl", {"key": key})
if key not in self._data:
return None
entry = self._data[key]
if "expires_at" not in entry:
return None
remaining = int(entry["expires_at"] - time.time())
return max(0, remaining)
async def expire(self, key: str, ttl: int) -> bool:
"""Set expiration time for an existing key."""
self._log_call("expire", {"key": key, "ttl": ttl})
if key not in self._data:
return False
self._data[key]["expires_at"] = time.time() + ttl
return True
# Testing utilities
def _log_call(self, method: str, args: dict[str, Any]) -> None:
"""Log method calls for testing verification."""
self._call_log.append(
{
"method": method,
"args": args,
"timestamp": time.time(),
}
)
def get_call_log(self) -> list[dict[str, Any]]:
"""Get the log of method calls for testing verification."""
return self._call_log.copy()
def clear_call_log(self) -> None:
"""Clear the method call log."""
self._call_log.clear()
def get_cache_contents(self) -> dict[str, Any]:
"""Get all cache contents for testing verification."""
return {k: v["value"] for k, v in self._data.items()}
def set_cache_contents(self, contents: dict[str, Any]) -> None:
"""Set cache contents directly for testing setup."""
self._data.clear()
for key, value in contents.items():
self._data[key] = {"value": value}
def simulate_cache_expiry(self, key: str) -> None:
"""Simulate cache expiry for testing."""
if key in self._data:
self._data[key]["expires_at"] = time.time() - 1
```
--------------------------------------------------------------------------------
/maverick_mcp/infrastructure/sse_optimizer.py:
--------------------------------------------------------------------------------
```python
"""
SSE Transport Optimizer for FastMCP server stability.
Provides SSE-specific optimizations to prevent connection drops
and ensure persistent tool availability in Claude Desktop.
"""
import asyncio
import logging
from typing import Any
from fastmcp import FastMCP
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
logger = logging.getLogger(__name__)
class SSEStabilityMiddleware(BaseHTTPMiddleware):
"""
Middleware to enhance SSE connection stability.
Features:
- Connection keepalive headers
- Proper CORS for SSE
- Connection state tracking
- Automatic reconnection support
"""
async def dispatch(self, request: Request, call_next) -> Response:
# Add SSE-specific headers for stability
response = await call_next(request)
# SSE connection optimizations
if request.url.path.endswith("/sse"):
# Keepalive and caching headers
response.headers["Cache-Control"] = "no-cache"
response.headers["Connection"] = "keep-alive"
response.headers["Content-Type"] = "text/event-stream"
# CORS headers for cross-origin SSE
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Credentials"] = "false"
# Prevent proxy buffering
response.headers["X-Accel-Buffering"] = "no"
return response
class SSEHeartbeat:
"""
Heartbeat mechanism for SSE connections.
Sends periodic keepalive messages to maintain connection
and detect client disconnections early.
"""
def __init__(self, interval: float = 30.0):
self.interval = interval
self.active_connections: dict[str, asyncio.Task] = {}
async def start_heartbeat(self, connection_id: str, send_function):
"""Start heartbeat for a specific connection."""
try:
while True:
await asyncio.sleep(self.interval)
# Send heartbeat event
heartbeat_event = {
"event": "heartbeat",
"data": {
"timestamp": asyncio.get_event_loop().time(),
"connection_id": connection_id[:8],
},
}
await send_function(heartbeat_event)
except asyncio.CancelledError:
logger.info(f"Heartbeat stopped for connection: {connection_id[:8]}")
except Exception as e:
logger.error(f"Heartbeat error for {connection_id[:8]}: {e}")
def register_connection(self, connection_id: str, send_function) -> None:
"""Register a new connection for heartbeat."""
if connection_id in self.active_connections:
# Cancel existing heartbeat
self.active_connections[connection_id].cancel()
# Start new heartbeat task
task = asyncio.create_task(self.start_heartbeat(connection_id, send_function))
self.active_connections[connection_id] = task
logger.info(f"Heartbeat registered for connection: {connection_id[:8]}")
def unregister_connection(self, connection_id: str) -> None:
"""Unregister connection and stop heartbeat."""
if connection_id in self.active_connections:
self.active_connections[connection_id].cancel()
del self.active_connections[connection_id]
logger.info(f"Heartbeat unregistered for connection: {connection_id[:8]}")
async def shutdown(self):
"""Shutdown all heartbeats."""
for task in self.active_connections.values():
task.cancel()
if self.active_connections:
await asyncio.gather(
*self.active_connections.values(), return_exceptions=True
)
self.active_connections.clear()
logger.info("All heartbeats shutdown")
class SSEOptimizer:
"""
SSE Transport Optimizer for enhanced stability.
Provides comprehensive optimizations for SSE connections:
- Stability middleware
- Heartbeat mechanism
- Connection monitoring
- Automatic recovery
"""
def __init__(self, mcp_server: FastMCP):
self.mcp_server = mcp_server
self.heartbeat = SSEHeartbeat(interval=25.0) # 25-second heartbeat
self.connection_count = 0
def optimize_server(self) -> None:
"""Apply SSE optimizations to the FastMCP server."""
# Add stability middleware
if hasattr(self.mcp_server, "fastapi_app") and self.mcp_server.fastapi_app:
self.mcp_server.fastapi_app.add_middleware(SSEStabilityMiddleware)
logger.info("SSE stability middleware added")
# Register SSE event handlers
self._register_sse_handlers()
logger.info("SSE transport optimizations applied")
def _register_sse_handlers(self) -> None:
"""Register SSE-specific event handlers."""
@self.mcp_server.event("sse_connection_opened")
async def on_sse_connection_open(connection_id: str, send_function):
"""Handle SSE connection open with optimization."""
self.connection_count += 1
logger.info(
f"SSE connection opened: {connection_id[:8]} (total: {self.connection_count})"
)
# Register heartbeat
self.heartbeat.register_connection(connection_id, send_function)
# Send connection confirmation
await send_function(
{
"event": "connection_ready",
"data": {
"connection_id": connection_id[:8],
"server": "maverick-mcp",
"transport": "sse",
"optimization": "enabled",
},
}
)
@self.mcp_server.event("sse_connection_closed")
async def on_sse_connection_close(connection_id: str):
"""Handle SSE connection close with cleanup."""
self.connection_count = max(0, self.connection_count - 1)
logger.info(
f"SSE connection closed: {connection_id[:8]} (remaining: {self.connection_count})"
)
# Unregister heartbeat
self.heartbeat.unregister_connection(connection_id)
async def shutdown(self):
"""Shutdown SSE optimizer."""
await self.heartbeat.shutdown()
logger.info("SSE optimizer shutdown complete")
def get_sse_status(self) -> dict[str, Any]:
"""Get SSE connection status."""
return {
"active_connections": self.connection_count,
"heartbeat_connections": len(self.heartbeat.active_connections),
"heartbeat_interval": self.heartbeat.interval,
"optimization_status": "enabled",
}
# Global SSE optimizer instance
_sse_optimizer: SSEOptimizer | None = None
def get_sse_optimizer(mcp_server: FastMCP) -> SSEOptimizer:
"""Get or create the global SSE optimizer."""
global _sse_optimizer
if _sse_optimizer is None:
_sse_optimizer = SSEOptimizer(mcp_server)
return _sse_optimizer
def apply_sse_optimizations(mcp_server: FastMCP) -> SSEOptimizer:
"""Apply SSE transport optimizations to FastMCP server."""
optimizer = get_sse_optimizer(mcp_server)
optimizer.optimize_server()
logger.info("SSE transport optimizations applied for enhanced stability")
return optimizer
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/interfaces/stock_data.py:
--------------------------------------------------------------------------------
```python
"""
Stock data provider interfaces.
This module defines abstract interfaces for stock data fetching and screening operations.
These interfaces separate concerns between basic data retrieval and advanced screening logic,
following the Interface Segregation Principle.
"""
from abc import ABC, abstractmethod
from typing import Any, Protocol, runtime_checkable
import pandas as pd
@runtime_checkable
class IStockDataFetcher(Protocol):
"""
Interface for fetching basic stock data.
This interface defines the contract for retrieving historical price data,
real-time quotes, company information, and related financial data.
"""
async def get_stock_data(
self,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
period: str | None = None,
interval: str = "1d",
use_cache: bool = True,
) -> pd.DataFrame:
"""
Fetch historical stock data.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
period: Alternative to start/end dates (e.g., '1y', '6mo')
interval: Data interval ('1d', '1wk', '1mo', etc.)
use_cache: Whether to use cached data if available
Returns:
DataFrame with OHLCV data indexed by date
"""
...
async def get_realtime_data(self, symbol: str) -> dict[str, Any] | None:
"""
Get real-time stock data.
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with current price, change, volume, etc. or None if unavailable
"""
...
async def get_stock_info(self, symbol: str) -> dict[str, Any]:
"""
Get detailed stock information and fundamentals.
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with company info, financials, and market data
"""
...
async def get_news(self, symbol: str, limit: int = 10) -> pd.DataFrame:
"""
Get news articles for a stock.
Args:
symbol: Stock ticker symbol
limit: Maximum number of articles to return
Returns:
DataFrame with news articles
"""
...
async def get_earnings(self, symbol: str) -> dict[str, Any]:
"""
Get earnings information for a stock.
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with earnings data and dates
"""
...
async def get_recommendations(self, symbol: str) -> pd.DataFrame:
"""
Get analyst recommendations for a stock.
Args:
symbol: Stock ticker symbol
Returns:
DataFrame with analyst recommendations
"""
...
async def is_market_open(self) -> bool:
"""
Check if the stock market is currently open.
Returns:
True if market is open, False otherwise
"""
...
async def is_etf(self, symbol: str) -> bool:
"""
Check if a symbol represents an ETF.
Args:
symbol: Stock ticker symbol
Returns:
True if symbol is an ETF, False otherwise
"""
...
@runtime_checkable
class IStockScreener(Protocol):
"""
Interface for stock screening and recommendation operations.
This interface defines the contract for generating stock recommendations
based on various technical and fundamental criteria.
"""
async def get_maverick_recommendations(
self, limit: int = 20, min_score: int | None = None
) -> list[dict[str, Any]]:
"""
Get bullish Maverick stock recommendations.
Args:
limit: Maximum number of recommendations
min_score: Minimum combined score filter
Returns:
List of stock recommendations with technical analysis
"""
...
async def get_maverick_bear_recommendations(
self, limit: int = 20, min_score: int | None = None
) -> list[dict[str, Any]]:
"""
Get bearish Maverick stock recommendations.
Args:
limit: Maximum number of recommendations
min_score: Minimum score filter
Returns:
List of bear stock recommendations
"""
...
async def get_trending_recommendations(
self, limit: int = 20, min_momentum_score: float | None = None
) -> list[dict[str, Any]]:
"""
Get trending stock recommendations.
Args:
limit: Maximum number of recommendations
min_momentum_score: Minimum momentum score filter
Returns:
List of trending stock recommendations
"""
...
async def get_all_screening_recommendations(
self,
) -> dict[str, list[dict[str, Any]]]:
"""
Get all screening recommendations in one call.
Returns:
Dictionary with all screening types and their recommendations
"""
...
class StockDataProviderBase(ABC):
"""
Abstract base class for stock data providers.
This class provides a foundation for implementing both IStockDataFetcher
and IStockScreener interfaces, with common functionality and error handling.
"""
@abstractmethod
def _fetch_stock_data_from_source(
self,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
period: str | None = None,
interval: str = "1d",
) -> pd.DataFrame:
"""
Fetch stock data from the underlying data source.
This method must be implemented by concrete providers to define
how data is actually retrieved (e.g., from yfinance, Alpha Vantage, etc.)
"""
pass
def _validate_symbol(self, symbol: str) -> str:
"""
Validate and normalize a stock symbol.
Args:
symbol: Raw stock symbol
Returns:
Normalized symbol (uppercase, stripped)
Raises:
ValueError: If symbol is invalid
"""
if not symbol or not isinstance(symbol, str):
raise ValueError("Symbol must be a non-empty string")
normalized = symbol.strip().upper()
if not normalized:
raise ValueError("Symbol cannot be empty after normalization")
return normalized
def _validate_date_range(
self, start_date: str | None, end_date: str | None
) -> tuple[str | None, str | None]:
"""
Validate date range parameters.
Args:
start_date: Start date string
end_date: End date string
Returns:
Tuple of validated dates
Raises:
ValueError: If date format is invalid
"""
# Basic validation - can be extended with actual date parsing
if start_date is not None and not isinstance(start_date, str):
raise ValueError("start_date must be a string in YYYY-MM-DD format")
if end_date is not None and not isinstance(end_date, str):
raise ValueError("end_date must be a string in YYYY-MM-DD format")
return start_date, end_date
def _handle_provider_error(self, error: Exception, context: str) -> None:
"""
Handle provider-specific errors with consistent logging.
Args:
error: The exception that occurred
context: Context information for debugging
"""
# This would integrate with the logging system
# For now, we'll re-raise to maintain existing behavior
raise error
```
--------------------------------------------------------------------------------
/tools/templates/screening_strategy_template.py:
--------------------------------------------------------------------------------
```python
"""
Template for creating new stock screening strategies.
Copy this file and modify it to create new screening strategies quickly.
"""
from datetime import datetime, timedelta
from typing import Any
import pandas as pd
from maverick_mcp.core.technical_analysis import (
calculate_atr,
calculate_rsi,
calculate_sma,
)
from maverick_mcp.data.models import Stock, get_db
from maverick_mcp.providers.stock_data import StockDataProvider
from maverick_mcp.utils.logging import get_logger
logger = get_logger(__name__)
class YourScreeningStrategy:
"""
Your custom screening strategy.
This strategy identifies stocks that meet specific criteria
based on technical indicators and price action.
"""
def __init__(
self,
min_price: float = 10.0,
min_volume: int = 1_000_000,
lookback_days: int = 90,
):
"""
Initialize the screening strategy.
Args:
min_price: Minimum stock price to consider
min_volume: Minimum average daily volume
lookback_days: Number of days to analyze
"""
self.min_price = min_price
self.min_volume = min_volume
self.lookback_days = lookback_days
self.stock_provider = StockDataProvider()
def calculate_score(self, symbol: str, data: pd.DataFrame) -> float:
"""
Calculate a composite score for the stock.
Args:
symbol: Stock symbol
data: Historical price data
Returns:
Score between 0 and 100
"""
score = 0.0
try:
# Price above moving averages
sma_20 = calculate_sma(data, 20).iloc[-1]
sma_50 = calculate_sma(data, 50).iloc[-1]
current_price = data["Close"].iloc[-1]
if current_price > sma_20:
score += 20
if current_price > sma_50:
score += 15
# RSI in optimal range (not overbought/oversold)
rsi = calculate_rsi(data, 14).iloc[-1]
if 40 <= rsi <= 70:
score += 20
elif 30 <= rsi <= 80:
score += 10
# MACD bullish (using pandas_ta as alternative)
try:
import pandas_ta as ta
macd = ta.macd(data["close"])
if macd["MACD_12_26_9"].iloc[-1] > macd["MACDs_12_26_9"].iloc[-1]:
score += 15
except ImportError:
# Skip MACD if pandas_ta not available
pass
# Volume increasing
avg_volume_recent = data["Volume"].iloc[-5:].mean()
avg_volume_prior = data["Volume"].iloc[-20:-5].mean()
if avg_volume_recent > avg_volume_prior * 1.2:
score += 15
# Price momentum
price_change_1m = (current_price / data["Close"].iloc[-20] - 1) * 100
if price_change_1m > 10:
score += 15
elif price_change_1m > 5:
score += 10
logger.debug(
f"Score calculated for {symbol}: {score}",
extra={
"symbol": symbol,
"price": current_price,
"rsi": rsi,
"score": score,
},
)
except Exception as e:
logger.error(f"Error calculating score for {symbol}: {e}")
score = 0.0
return min(score, 100.0)
def screen_stocks(
self,
symbols: list[str] | None = None,
min_score: float = 70.0,
) -> list[dict[str, Any]]:
"""
Screen stocks based on the strategy criteria.
Args:
symbols: List of symbols to screen (None for all)
min_score: Minimum score to include in results
Returns:
List of stocks meeting criteria with scores
"""
results = []
end_date = datetime.now().strftime("%Y-%m-%d")
start_date = (datetime.now() - timedelta(days=self.lookback_days)).strftime(
"%Y-%m-%d"
)
# Get list of symbols to screen
if symbols is None:
# Get all active stocks from database
db = next(get_db())
try:
stocks = db.query(Stock).filter(Stock.is_active).all()
symbols = [stock.symbol for stock in stocks]
finally:
db.close()
logger.info(f"Screening {len(symbols)} stocks")
# Screen each stock
for symbol in symbols:
try:
# Get historical data
data = self.stock_provider.get_stock_data(symbol, start_date, end_date)
if len(data) < 50: # Need enough data for indicators
continue
# Check basic criteria
current_price = data["Close"].iloc[-1]
avg_volume = data["Volume"].iloc[-20:].mean()
if current_price < self.min_price or avg_volume < self.min_volume:
continue
# Calculate score
score = self.calculate_score(symbol, data)
if score >= min_score:
# Calculate additional metrics
atr = calculate_atr(data, 14).iloc[-1]
price_change_5d = (
data["Close"].iloc[-1] / data["Close"].iloc[-5] - 1
) * 100
result = {
"symbol": symbol,
"score": round(score, 2),
"price": round(current_price, 2),
"volume": int(avg_volume),
"atr": round(atr, 2),
"price_change_5d": round(price_change_5d, 2),
"rsi": round(calculate_rsi(data, 14).iloc[-1], 2),
"above_sma_20": current_price
> calculate_sma(data, 20).iloc[-1],
"above_sma_50": current_price
> calculate_sma(data, 50).iloc[-1],
}
results.append(result)
logger.info(f"Stock passed screening: {symbol} (score: {score})")
except Exception as e:
logger.error(f"Error screening {symbol}: {e}")
continue
# Sort by score descending
results.sort(key=lambda x: x["score"], reverse=True)
logger.info(f"Screening complete: {len(results)} stocks found")
return results
def get_entry_exit_levels(
self, symbol: str, data: pd.DataFrame
) -> dict[str, float]:
"""
Calculate entry, stop loss, and target levels.
Args:
symbol: Stock symbol
data: Historical price data
Returns:
Dictionary with entry, stop, and target levels
"""
current_price = data["Close"].iloc[-1]
atr = calculate_atr(data, 14).iloc[-1]
# Find recent support/resistance
recent_low = data["Low"].iloc[-20:].min()
# Calculate levels
entry = current_price
stop_loss = max(current_price - (2 * atr), recent_low * 0.98)
target1 = current_price + (2 * atr)
target2 = current_price + (3 * atr)
# Ensure minimum risk/reward
risk = entry - stop_loss
reward = target1 - entry
if reward / risk < 2:
target1 = entry + (2 * risk)
target2 = entry + (3 * risk)
return {
"entry": round(entry, 2),
"stop_loss": round(stop_loss, 2),
"target1": round(target1, 2),
"target2": round(target2, 2),
"risk_reward_ratio": round(reward / risk, 2),
}
```
--------------------------------------------------------------------------------
/tests/test_session_management.py:
--------------------------------------------------------------------------------
```python
"""
Tests for enhanced database session management.
Tests the new context managers and connection pool monitoring
introduced to fix Issue #55: Database Session Management.
"""
from unittest.mock import Mock, patch
import pytest
from maverick_mcp.data.session_management import (
check_connection_pool_health,
get_connection_pool_status,
get_db_session,
get_db_session_read_only,
)
class TestSessionManagement:
"""Test suite for database session management context managers."""
@patch("maverick_mcp.data.session_management.SessionLocal")
def test_get_db_session_success(self, mock_session_local):
"""Test successful database session with automatic commit."""
mock_session = Mock()
mock_session_local.return_value = mock_session
with get_db_session() as session:
assert session == mock_session
# Simulate some database operation
# Verify session lifecycle
mock_session.commit.assert_called_once()
mock_session.close.assert_called_once()
mock_session.rollback.assert_not_called()
@patch("maverick_mcp.data.session_management.SessionLocal")
def test_get_db_session_exception_rollback(self, mock_session_local):
"""Test database session rollback on exception."""
mock_session = Mock()
mock_session_local.return_value = mock_session
with pytest.raises(ValueError):
with get_db_session() as session:
assert session == mock_session
raise ValueError("Test exception")
# Verify rollback was called, but not commit
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
mock_session.close.assert_called_once()
@patch("maverick_mcp.data.session_management.SessionLocal")
def test_get_db_session_read_only_success(self, mock_session_local):
"""Test read-only database session (no commit)."""
mock_session = Mock()
mock_session_local.return_value = mock_session
with get_db_session_read_only() as session:
assert session == mock_session
# Simulate some read-only operation
# Verify no commit for read-only operations
mock_session.commit.assert_not_called()
mock_session.close.assert_called_once()
mock_session.rollback.assert_not_called()
@patch("maverick_mcp.data.session_management.SessionLocal")
def test_get_db_session_read_only_exception_rollback(self, mock_session_local):
"""Test read-only database session rollback on exception."""
mock_session = Mock()
mock_session_local.return_value = mock_session
with pytest.raises(RuntimeError):
with get_db_session_read_only() as session:
assert session == mock_session
raise RuntimeError("Read operation failed")
# Verify rollback was called, but not commit
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
mock_session.close.assert_called_once()
class TestConnectionPoolMonitoring:
"""Test suite for connection pool monitoring functionality."""
@patch("maverick_mcp.data.models.engine")
def test_get_connection_pool_status(self, mock_engine):
"""Test connection pool status reporting."""
mock_pool = Mock()
mock_pool.size.return_value = 10
mock_pool.checkedin.return_value = 5
mock_pool.checkedout.return_value = 3
mock_pool.overflow.return_value = 0
mock_pool.invalid.return_value = 0
mock_engine.pool = mock_pool
status = get_connection_pool_status()
expected = {
"pool_size": 10,
"checked_in": 5,
"checked_out": 3,
"overflow": 0,
"invalid": 0,
"pool_status": "healthy", # 3/10 = 30% < 80%
}
assert status == expected
@patch("maverick_mcp.data.models.engine")
def test_get_connection_pool_status_warning(self, mock_engine):
"""Test connection pool status with high utilization warning."""
mock_pool = Mock()
mock_pool.size.return_value = 10
mock_pool.checkedin.return_value = 1
mock_pool.checkedout.return_value = 9 # 90% utilization
mock_pool.overflow.return_value = 0
mock_pool.invalid.return_value = 0
mock_engine.pool = mock_pool
status = get_connection_pool_status()
assert status["pool_status"] == "warning"
assert status["checked_out"] == 9
@patch("maverick_mcp.data.session_management.get_connection_pool_status")
def test_check_connection_pool_health_healthy(self, mock_get_status):
"""Test connection pool health check - healthy scenario."""
mock_get_status.return_value = {
"pool_size": 10,
"checked_out": 5, # 50% utilization
"invalid": 0,
}
assert check_connection_pool_health() is True
@patch("maverick_mcp.data.session_management.get_connection_pool_status")
def test_check_connection_pool_health_high_utilization(self, mock_get_status):
"""Test connection pool health check - high utilization."""
mock_get_status.return_value = {
"pool_size": 10,
"checked_out": 9, # 90% utilization > 80% threshold
"invalid": 0,
}
assert check_connection_pool_health() is False
@patch("maverick_mcp.data.session_management.get_connection_pool_status")
def test_check_connection_pool_health_invalid_connections(self, mock_get_status):
"""Test connection pool health check - invalid connections detected."""
mock_get_status.return_value = {
"pool_size": 10,
"checked_out": 3, # Low utilization
"invalid": 2, # But has invalid connections
}
assert check_connection_pool_health() is False
@patch("maverick_mcp.data.session_management.get_connection_pool_status")
def test_check_connection_pool_health_exception(self, mock_get_status):
"""Test connection pool health check with exception handling."""
mock_get_status.side_effect = Exception("Pool access failed")
assert check_connection_pool_health() is False
class TestSessionManagementIntegration:
"""Integration tests for session management with real database."""
@pytest.mark.integration
def test_session_context_manager_real_db(self):
"""Test session context manager with real database connection."""
try:
with get_db_session_read_only() as session:
# Simple test query that should work on any PostgreSQL database
result = session.execute("SELECT 1 as test_value")
row = result.fetchone()
assert row[0] == 1
except Exception as e:
# If database is not available, skip this test
pytest.skip(f"Database not available for integration test: {e}")
@pytest.mark.integration
def test_connection_pool_status_real(self):
"""Test connection pool status with real database."""
try:
status = get_connection_pool_status()
# Verify the status has expected keys
required_keys = [
"pool_size",
"checked_in",
"checked_out",
"overflow",
"invalid",
"pool_status",
]
for key in required_keys:
assert key in status
# Verify status values are reasonable
assert isinstance(status["pool_size"], int)
assert status["pool_size"] > 0
assert status["pool_status"] in ["healthy", "warning"]
except Exception as e:
# If database is not available, skip this test
pytest.skip(f"Database not available for integration test: {e}")
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/mocks/mock_config.py:
--------------------------------------------------------------------------------
```python
"""
Mock configuration provider implementation for testing.
"""
from typing import Any
class MockConfigurationProvider:
"""
Mock implementation of IConfigurationProvider for testing.
This implementation provides safe test defaults and allows
easy configuration overrides for specific test scenarios.
"""
def __init__(self, overrides: dict[str, Any] | None = None):
"""
Initialize the mock configuration provider.
Args:
overrides: Optional dictionary of configuration overrides
"""
self._overrides = overrides or {}
self._defaults = {
"DATABASE_URL": "sqlite:///:memory:",
"REDIS_HOST": "localhost",
"REDIS_PORT": 6379,
"REDIS_DB": 1, # Use different DB for tests
"REDIS_PASSWORD": None,
"REDIS_SSL": False,
"CACHE_ENABLED": False, # Disable cache in tests by default
"CACHE_TTL_SECONDS": 300, # 5 minutes for tests
"FRED_API_KEY": "",
"CAPITAL_COMPANION_API_KEY": "",
"TIINGO_API_KEY": "",
"AUTH_ENABLED": False,
"JWT_SECRET_KEY": "test-secret-key",
"LOG_LEVEL": "DEBUG",
"ENVIRONMENT": "test",
"REQUEST_TIMEOUT": 5,
"MAX_RETRIES": 1,
"DB_POOL_SIZE": 1,
"DB_MAX_OVERFLOW": 0,
}
self._call_log: list[dict[str, Any]] = []
def get_database_url(self) -> str:
"""Get mock database URL."""
self._log_call("get_database_url", {})
return self._get_value("DATABASE_URL")
def get_redis_host(self) -> str:
"""Get mock Redis host."""
self._log_call("get_redis_host", {})
return self._get_value("REDIS_HOST")
def get_redis_port(self) -> int:
"""Get mock Redis port."""
self._log_call("get_redis_port", {})
return int(self._get_value("REDIS_PORT"))
def get_redis_db(self) -> int:
"""Get mock Redis database."""
self._log_call("get_redis_db", {})
return int(self._get_value("REDIS_DB"))
def get_redis_password(self) -> str | None:
"""Get mock Redis password."""
self._log_call("get_redis_password", {})
return self._get_value("REDIS_PASSWORD")
def get_redis_ssl(self) -> bool:
"""Get mock Redis SSL setting."""
self._log_call("get_redis_ssl", {})
return bool(self._get_value("REDIS_SSL"))
def is_cache_enabled(self) -> bool:
"""Check if mock caching is enabled."""
self._log_call("is_cache_enabled", {})
return bool(self._get_value("CACHE_ENABLED"))
def get_cache_ttl(self) -> int:
"""Get mock cache TTL."""
self._log_call("get_cache_ttl", {})
return int(self._get_value("CACHE_TTL_SECONDS"))
def get_fred_api_key(self) -> str:
"""Get mock FRED API key."""
self._log_call("get_fred_api_key", {})
return str(self._get_value("FRED_API_KEY"))
def get_external_api_key(self) -> str:
"""Get mock External API key."""
self._log_call("get_external_api_key", {})
return str(self._get_value("CAPITAL_COMPANION_API_KEY"))
def get_tiingo_api_key(self) -> str:
"""Get mock Tiingo API key."""
self._log_call("get_tiingo_api_key", {})
return str(self._get_value("TIINGO_API_KEY"))
def is_auth_enabled(self) -> bool:
"""Check if mock auth is enabled."""
self._log_call("is_auth_enabled", {})
return bool(self._get_value("AUTH_ENABLED"))
def get_jwt_secret_key(self) -> str:
"""Get mock JWT secret key."""
self._log_call("get_jwt_secret_key", {})
return str(self._get_value("JWT_SECRET_KEY"))
def get_log_level(self) -> str:
"""Get mock log level."""
self._log_call("get_log_level", {})
return str(self._get_value("LOG_LEVEL"))
def is_development_mode(self) -> bool:
"""Check if in mock development mode."""
self._log_call("is_development_mode", {})
env = str(self._get_value("ENVIRONMENT")).lower()
return env in ("development", "dev", "test")
def is_production_mode(self) -> bool:
"""Check if in mock production mode."""
self._log_call("is_production_mode", {})
env = str(self._get_value("ENVIRONMENT")).lower()
return env in ("production", "prod")
def get_request_timeout(self) -> int:
"""Get mock request timeout."""
self._log_call("get_request_timeout", {})
return int(self._get_value("REQUEST_TIMEOUT"))
def get_max_retries(self) -> int:
"""Get mock max retries."""
self._log_call("get_max_retries", {})
return int(self._get_value("MAX_RETRIES"))
def get_pool_size(self) -> int:
"""Get mock pool size."""
self._log_call("get_pool_size", {})
return int(self._get_value("DB_POOL_SIZE"))
def get_max_overflow(self) -> int:
"""Get mock max overflow."""
self._log_call("get_max_overflow", {})
return int(self._get_value("DB_MAX_OVERFLOW"))
def get_config_value(self, key: str, default: Any = None) -> Any:
"""Get mock configuration value."""
self._log_call("get_config_value", {"key": key, "default": default})
if key in self._overrides:
return self._overrides[key]
elif key in self._defaults:
return self._defaults[key]
else:
return default
def set_config_value(self, key: str, value: Any) -> None:
"""Set mock configuration value."""
self._log_call("set_config_value", {"key": key, "value": value})
self._overrides[key] = value
def get_all_config(self) -> dict[str, Any]:
"""Get all mock configuration."""
self._log_call("get_all_config", {})
config = self._defaults.copy()
config.update(self._overrides)
return config
def reload_config(self) -> None:
"""Reload mock configuration (no-op)."""
self._log_call("reload_config", {})
# No-op for mock implementation
def _get_value(self, key: str) -> Any:
"""Get a configuration value with override support."""
if key in self._overrides:
return self._overrides[key]
return self._defaults.get(key)
# Testing utilities
def _log_call(self, method: str, args: dict[str, Any]) -> None:
"""Log method calls for testing verification."""
self._call_log.append(
{
"method": method,
"args": args,
}
)
def get_call_log(self) -> list[dict[str, Any]]:
"""Get the log of method calls."""
return self._call_log.copy()
def clear_call_log(self) -> None:
"""Clear the method call log."""
self._call_log.clear()
def set_override(self, key: str, value: Any) -> None:
"""Set a configuration override for testing."""
self._overrides[key] = value
def clear_overrides(self) -> None:
"""Clear all configuration overrides."""
self._overrides.clear()
def enable_cache(self) -> None:
"""Enable caching for testing."""
self.set_override("CACHE_ENABLED", True)
def disable_cache(self) -> None:
"""Disable caching for testing."""
self.set_override("CACHE_ENABLED", False)
def enable_auth(self) -> None:
"""Enable authentication for testing."""
self.set_override("AUTH_ENABLED", True)
def disable_auth(self) -> None:
"""Disable authentication for testing."""
self.set_override("AUTH_ENABLED", False)
def set_production_mode(self) -> None:
"""Set production mode for testing."""
self.set_override("ENVIRONMENT", "production")
def set_development_mode(self) -> None:
"""Set development mode for testing."""
self.set_override("ENVIRONMENT", "development")
```
--------------------------------------------------------------------------------
/tests/domain/test_technical_analysis_service.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for the TechnicalAnalysisService domain service.
These tests demonstrate that the domain service can be tested
without any infrastructure dependencies (no mocks needed).
"""
import numpy as np
import pandas as pd
import pytest
from maverick_mcp.domain.services.technical_analysis_service import (
TechnicalAnalysisService,
)
from maverick_mcp.domain.value_objects.technical_indicators import (
Signal,
TrendDirection,
)
class TestTechnicalAnalysisService:
"""Test the technical analysis domain service."""
@pytest.fixture
def service(self):
"""Create a technical analysis service instance."""
return TechnicalAnalysisService()
@pytest.fixture
def sample_prices(self):
"""Create sample price data for testing."""
# Generate synthetic price data
dates = pd.date_range(start="2024-01-01", periods=100, freq="D")
prices = 100 + np.cumsum(np.random.randn(100) * 2)
return pd.Series(prices, index=dates)
@pytest.fixture
def sample_ohlc(self):
"""Create sample OHLC data for testing."""
dates = pd.date_range(start="2024-01-01", periods=100, freq="D")
close = 100 + np.cumsum(np.random.randn(100) * 2)
# Generate high/low based on close
high = close + np.abs(np.random.randn(100))
low = close - np.abs(np.random.randn(100))
return pd.DataFrame(
{
"high": high,
"low": low,
"close": close,
},
index=dates,
)
def test_calculate_rsi(self, service, sample_prices):
"""Test RSI calculation."""
rsi = service.calculate_rsi(sample_prices, period=14)
# RSI should be between 0 and 100
assert 0 <= rsi.value <= 100
assert rsi.period == 14
# Check signal logic
if rsi.value >= 70:
assert rsi.is_overbought
if rsi.value <= 30:
assert rsi.is_oversold
def test_calculate_rsi_insufficient_data(self, service):
"""Test RSI with insufficient data."""
prices = pd.Series([100, 101, 102]) # Only 3 prices
with pytest.raises(ValueError, match="Need at least 14 prices"):
service.calculate_rsi(prices, period=14)
def test_calculate_macd(self, service, sample_prices):
"""Test MACD calculation."""
macd = service.calculate_macd(sample_prices)
# Check structure
assert hasattr(macd, "macd_line")
assert hasattr(macd, "signal_line")
assert hasattr(macd, "histogram")
# Histogram should be difference between MACD and signal
assert abs(macd.histogram - (macd.macd_line - macd.signal_line)) < 0.01
# Check signal logic
if macd.macd_line > macd.signal_line and macd.histogram > 0:
assert macd.is_bullish_crossover
if macd.macd_line < macd.signal_line and macd.histogram < 0:
assert macd.is_bearish_crossover
def test_calculate_bollinger_bands(self, service, sample_prices):
"""Test Bollinger Bands calculation."""
bb = service.calculate_bollinger_bands(sample_prices)
# Check structure
assert bb.upper_band > bb.middle_band
assert bb.middle_band > bb.lower_band
assert bb.period == 20
assert bb.std_dev == 2
# Check bandwidth calculation
expected_bandwidth = (bb.upper_band - bb.lower_band) / bb.middle_band
assert abs(bb.bandwidth - expected_bandwidth) < 0.01
# Check %B calculation
expected_percent_b = (bb.current_price - bb.lower_band) / (
bb.upper_band - bb.lower_band
)
assert abs(bb.percent_b - expected_percent_b) < 0.01
def test_calculate_stochastic(self, service, sample_ohlc):
"""Test Stochastic Oscillator calculation."""
stoch = service.calculate_stochastic(
sample_ohlc["high"],
sample_ohlc["low"],
sample_ohlc["close"],
period=14,
)
# Values should be between 0 and 100
assert 0 <= stoch.k_value <= 100
assert 0 <= stoch.d_value <= 100
assert stoch.period == 14
# Check overbought/oversold logic
if stoch.k_value >= 80:
assert stoch.is_overbought
if stoch.k_value <= 20:
assert stoch.is_oversold
def test_identify_trend_uptrend(self, service):
"""Test trend identification for uptrend."""
# Create clear uptrend data
dates = pd.date_range(start="2024-01-01", periods=100, freq="D")
prices = pd.Series(range(100, 200), index=dates) # Linear uptrend
trend = service.identify_trend(prices, period=50)
assert trend in [TrendDirection.UPTREND, TrendDirection.STRONG_UPTREND]
def test_identify_trend_downtrend(self, service):
"""Test trend identification for downtrend."""
# Create clear downtrend data
dates = pd.date_range(start="2024-01-01", periods=100, freq="D")
prices = pd.Series(range(200, 100, -1), index=dates) # Linear downtrend
trend = service.identify_trend(prices, period=50)
assert trend in [TrendDirection.DOWNTREND, TrendDirection.STRONG_DOWNTREND]
def test_analyze_volume(self, service):
"""Test volume analysis."""
# Create volume data with spike
dates = pd.date_range(start="2024-01-01", periods=30, freq="D")
volume = pd.Series([1000000] * 29 + [3000000], index=dates) # Spike at end
volume_profile = service.analyze_volume(volume, period=20)
assert volume_profile.current_volume == 3000000
assert volume_profile.average_volume < 1500000
assert volume_profile.relative_volume > 2.0
assert volume_profile.unusual_activity # 3x average is unusual
def test_calculate_composite_signal_bullish(self, service):
"""Test composite signal calculation with bullish indicators."""
# Manually create bullish indicators for testing
from maverick_mcp.domain.value_objects.technical_indicators import (
MACDIndicator,
RSIIndicator,
)
bullish_rsi = RSIIndicator(value=25, period=14) # Oversold
bullish_macd = MACDIndicator(
macd_line=1.0,
signal_line=0.5,
histogram=0.5,
) # Bullish crossover
signal = service.calculate_composite_signal(
rsi=bullish_rsi,
macd=bullish_macd,
)
assert signal in [Signal.BUY, Signal.STRONG_BUY]
def test_calculate_composite_signal_mixed(self, service):
"""Test composite signal with mixed indicators."""
from maverick_mcp.domain.value_objects.technical_indicators import (
BollingerBands,
MACDIndicator,
RSIIndicator,
)
# Create mixed signals
neutral_rsi = RSIIndicator(value=50, period=14) # Neutral
bearish_macd = MACDIndicator(
macd_line=-0.5,
signal_line=0.0,
histogram=-0.5,
) # Bearish
neutral_bb = BollingerBands(
upper_band=110,
middle_band=100,
lower_band=90,
current_price=100,
) # Neutral
signal = service.calculate_composite_signal(
rsi=neutral_rsi,
macd=bearish_macd,
bollinger=neutral_bb,
)
# With mixed signals, should be neutral or slightly bearish
assert signal in [Signal.NEUTRAL, Signal.SELL]
def test_domain_service_has_no_infrastructure_dependencies(self, service):
"""Verify the domain service has no infrastructure dependencies."""
# Check that the service has no database, API, or cache attributes
assert not hasattr(service, "db")
assert not hasattr(service, "session")
assert not hasattr(service, "cache")
assert not hasattr(service, "api_client")
assert not hasattr(service, "http_client")
# Check that all methods are pure functions (no side effects)
# This is verified by the fact that all tests above work without mocks
```
--------------------------------------------------------------------------------
/tests/test_financial_search.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Test script for enhanced financial search capabilities in DeepResearchAgent.
This script demonstrates the improved Exa client usage for financial records search
with different strategies and optimizations.
"""
import asyncio
import os
import sys
from datetime import datetime
# Add the project root to the Python path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from maverick_mcp.agents.deep_research import DeepResearchAgent, ExaSearchProvider
async def test_financial_search_strategies():
"""Test different financial search strategies."""
# Initialize the search provider
exa_api_key = os.getenv("EXA_API_KEY")
if not exa_api_key:
print("❌ EXA_API_KEY environment variable not set")
return
print("🔍 Testing Enhanced Financial Search Capabilities")
print("=" * 60)
# Test queries for different financial scenarios
test_queries = [
("AAPL financial performance", "Apple stock analysis"),
("Tesla quarterly earnings 2024", "Tesla earnings report"),
("Microsoft revenue growth", "Microsoft financial growth"),
("S&P 500 market analysis", "Market index analysis"),
("Federal Reserve interest rates", "Fed policy analysis"),
]
# Test different search strategies
strategies = ["hybrid", "authoritative", "comprehensive", "auto"]
provider = ExaSearchProvider(exa_api_key)
for query, description in test_queries:
print(f"\n📊 Testing Query: {description}")
print(f" Query: '{query}'")
print("-" * 40)
for strategy in strategies:
try:
start_time = datetime.now()
# Test the enhanced financial search
results = await provider.search_financial(
query=query, num_results=5, strategy=strategy
)
duration = (datetime.now() - start_time).total_seconds()
print(f" 🎯 Strategy: {strategy.upper()}")
print(f" Results: {len(results)}")
print(f" Duration: {duration:.2f}s")
if results:
# Show top result with enhanced metadata
top_result = results[0]
print(" Top Result:")
print(f" Title: {top_result.get('title', 'N/A')[:80]}...")
print(f" Domain: {top_result.get('domain', 'N/A')}")
print(
f" Financial Relevance: {top_result.get('financial_relevance', 0):.2f}"
)
print(
f" Authoritative: {top_result.get('is_authoritative', False)}"
)
print(f" Score: {top_result.get('score', 0):.2f}")
print()
except Exception as e:
print(f" ❌ Strategy {strategy} failed: {str(e)}")
print()
async def test_query_enhancement():
"""Test the financial query enhancement feature."""
print("\n🔧 Testing Query Enhancement")
print("=" * 40)
exa_api_key = os.getenv("EXA_API_KEY")
if not exa_api_key:
print("❌ EXA_API_KEY environment variable not set")
return
provider = ExaSearchProvider(exa_api_key)
# Test queries that should be enhanced
test_queries = [
"AAPL", # Stock symbol
"Tesla company", # Company name
"Microsoft analysis", # Analysis request
"Amazon earnings financial", # Already has financial context
]
for query in test_queries:
enhanced = provider._enhance_financial_query(query)
print(f"Original: '{query}'")
print(f"Enhanced: '{enhanced}'")
print(f"Changed: {'Yes' if enhanced != query else 'No'}")
print()
async def test_financial_relevance_scoring():
"""Test the financial relevance scoring system."""
print("\n📈 Testing Financial Relevance Scoring")
print("=" * 45)
exa_api_key = os.getenv("EXA_API_KEY")
if not exa_api_key:
print("❌ EXA_API_KEY environment variable not set")
return
provider = ExaSearchProvider(exa_api_key)
# Mock result objects for testing
class MockResult:
def __init__(self, url, title, text, published_date=None):
self.url = url
self.title = title
self.text = text
self.published_date = published_date
test_results = [
MockResult(
"https://sec.gov/filing/aapl-10k-2024",
"Apple Inc. Annual Report (Form 10-K)",
"Apple Inc. reported quarterly earnings of $1.50 per share, with revenue of $95 billion for the quarter ending March 31, 2024.",
"2024-01-15T00:00:00Z",
),
MockResult(
"https://bloomberg.com/news/apple-stock-analysis",
"Apple Stock Analysis: Strong Financial Performance",
"Apple's financial performance continues to show strong growth with increased market cap and dividend distributions.",
"2024-01-10T00:00:00Z",
),
MockResult(
"https://example.com/random-article",
"Random Article About Technology",
"This is just a random article about technology trends without specific financial information.",
"2024-01-01T00:00:00Z",
),
]
for i, result in enumerate(test_results, 1):
relevance = provider._calculate_financial_relevance(result)
is_auth = provider._is_authoritative_source(result.url)
domain = provider._extract_domain(result.url)
print(f"Result {i}:")
print(f" URL: {result.url}")
print(f" Domain: {domain}")
print(f" Title: {result.title}")
print(f" Financial Relevance: {relevance:.2f}")
print(f" Authoritative: {is_auth}")
print()
async def test_deep_research_agent_integration():
"""Test the integration with DeepResearchAgent."""
print("\n🤖 Testing DeepResearchAgent Integration")
print("=" * 45)
exa_api_key = os.getenv("EXA_API_KEY")
if not exa_api_key:
print("❌ EXA_API_KEY environment variable not set")
return
try:
# Initialize the agent
agent = DeepResearchAgent(
llm=None, # Will be set by initialize if needed
persona="financial_analyst",
exa_api_key=exa_api_key,
)
await agent.initialize()
# Test the enhanced financial search tool
result = await agent._perform_financial_search(
query="Apple quarterly earnings Q4 2024",
num_results=3,
provider="exa",
strategy="authoritative",
)
print(f"Search Results: {result.get('total_results', 0)} found")
print(f"Strategy Used: {result.get('search_strategy', 'N/A')}")
print(f"Duration: {result.get('search_duration', 0):.2f}s")
print(f"Enhanced Search: {result.get('enhanced_search', False)}")
if result.get("results"):
print("\nTop Result:")
top = result["results"][0]
print(f" Title: {top.get('title', 'N/A')[:80]}...")
print(f" Financial Relevance: {top.get('financial_relevance', 0):.2f}")
print(f" Authoritative: {top.get('is_authoritative', False)}")
except Exception as e:
print(f"❌ Integration test failed: {str(e)}")
async def main():
"""Run all tests."""
print("🚀 Enhanced Financial Search Testing Suite")
print("=" * 60)
print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
await test_financial_search_strategies()
await test_query_enhancement()
await test_financial_relevance_scoring()
await test_deep_research_agent_integration()
print("\n✅ All tests completed successfully!")
except Exception as e:
print(f"\n❌ Test suite failed: {str(e)}")
import traceback
traceback.print_exc()
print(f"\nCompleted at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
if __name__ == "__main__":
asyncio.run(main())
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/shutdown.py:
--------------------------------------------------------------------------------
```python
"""
Graceful shutdown handler for MaverickMCP servers.
This module provides signal handling and graceful shutdown capabilities
for all server components to ensure safe deployments and prevent data loss.
"""
import asyncio
import signal
import sys
import time
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any
from maverick_mcp.utils.logging import get_logger
logger = get_logger(__name__)
class GracefulShutdownHandler:
"""Handles graceful shutdown for server components."""
def __init__(
self,
name: str,
shutdown_timeout: float = 30.0,
drain_timeout: float = 10.0,
):
"""
Initialize shutdown handler.
Args:
name: Name of the component for logging
shutdown_timeout: Maximum time to wait for shutdown (seconds)
drain_timeout: Time to wait for connection draining (seconds)
"""
self.name = name
self.shutdown_timeout = shutdown_timeout
self.drain_timeout = drain_timeout
self._shutdown_event = asyncio.Event()
self._cleanup_callbacks: list[Callable] = []
self._active_requests: set[asyncio.Task] = set()
self._original_handlers: dict[int, Any] = {}
self._shutdown_in_progress = False
self._start_time = time.time()
def register_cleanup(self, callback: Callable) -> None:
"""Register a cleanup callback to run during shutdown."""
self._cleanup_callbacks.append(callback)
logger.debug(f"Registered cleanup callback: {callback.__name__}")
def track_request(self, task: asyncio.Task) -> None:
"""Track an active request/task."""
self._active_requests.add(task)
task.add_done_callback(self._active_requests.discard)
@contextmanager
def track_sync_request(self):
"""Context manager to track synchronous requests."""
request_id = id(asyncio.current_task()) if asyncio.current_task() else None
try:
if request_id:
logger.debug(f"Tracking sync request: {request_id}")
yield
finally:
if request_id:
logger.debug(f"Completed sync request: {request_id}")
async def wait_for_shutdown(self) -> None:
"""Wait for shutdown signal."""
await self._shutdown_event.wait()
def is_shutting_down(self) -> bool:
"""Check if shutdown is in progress."""
return self._shutdown_in_progress
def install_signal_handlers(self) -> None:
"""Install signal handlers for graceful shutdown."""
# Store original handlers
for sig in (signal.SIGTERM, signal.SIGINT):
self._original_handlers[sig] = signal.signal(sig, self._signal_handler)
# Also handle SIGHUP for reload scenarios
if hasattr(signal, "SIGHUP"):
self._original_handlers[signal.SIGHUP] = signal.signal(
signal.SIGHUP, self._signal_handler
)
logger.info(f"{self.name}: Signal handlers installed")
def _signal_handler(self, signum: int, frame: Any) -> None:
"""Handle shutdown signals."""
signal_name = signal.Signals(signum).name
logger.info(f"{self.name}: Received {signal_name} signal")
if self._shutdown_in_progress:
logger.warning(
f"{self.name}: Shutdown already in progress, ignoring signal"
)
return
# Trigger async shutdown
if asyncio.get_event_loop().is_running():
asyncio.create_task(self._async_shutdown(signal_name))
else:
# Fallback for non-async context
self._sync_shutdown(signal_name)
async def _async_shutdown(self, signal_name: str) -> None:
"""Perform async graceful shutdown."""
if self._shutdown_in_progress:
return
self._shutdown_in_progress = True
shutdown_start = time.time()
logger.info(
f"{self.name}: Starting graceful shutdown (signal: {signal_name}, "
f"uptime: {shutdown_start - self._start_time:.1f}s)"
)
# Set shutdown event to notify waiting coroutines
self._shutdown_event.set()
# Phase 1: Stop accepting new requests
logger.info(f"{self.name}: Phase 1 - Stopping new requests")
# Phase 2: Drain active requests
if self._active_requests:
logger.info(
f"{self.name}: Phase 2 - Draining {len(self._active_requests)} "
f"active requests (timeout: {self.drain_timeout}s)"
)
try:
await asyncio.wait_for(
self._wait_for_requests(),
timeout=self.drain_timeout,
)
logger.info(f"{self.name}: All requests completed")
except TimeoutError:
remaining = len(self._active_requests)
logger.warning(
f"{self.name}: Drain timeout reached, {remaining} requests remaining"
)
# Cancel remaining requests
for task in self._active_requests:
task.cancel()
# Phase 3: Run cleanup callbacks
logger.info(f"{self.name}: Phase 3 - Running cleanup callbacks")
for callback in self._cleanup_callbacks:
try:
logger.debug(f"Running cleanup: {callback.__name__}")
if asyncio.iscoroutinefunction(callback):
await asyncio.wait_for(callback(), timeout=5.0)
else:
callback()
except Exception as e:
logger.error(f"Error in cleanup callback {callback.__name__}: {e}")
# Phase 4: Final shutdown
shutdown_duration = time.time() - shutdown_start
logger.info(
f"{self.name}: Graceful shutdown completed in {shutdown_duration:.1f}s"
)
# Exit the process
sys.exit(0)
def _sync_shutdown(self, signal_name: str) -> None:
"""Perform synchronous shutdown (fallback)."""
if self._shutdown_in_progress:
return
self._shutdown_in_progress = True
logger.info(f"{self.name}: Starting sync shutdown (signal: {signal_name})")
# Run sync cleanup callbacks
for callback in self._cleanup_callbacks:
if not asyncio.iscoroutinefunction(callback):
try:
callback()
except Exception as e:
logger.error(f"Error in cleanup callback: {e}")
logger.info(f"{self.name}: Sync shutdown completed")
sys.exit(0)
async def _wait_for_requests(self) -> None:
"""Wait for all active requests to complete."""
while self._active_requests:
# Wait a bit and check again
await asyncio.sleep(0.1)
# Log progress periodically
if int(time.time()) % 5 == 0:
logger.info(
f"{self.name}: Waiting for {len(self._active_requests)} requests"
)
def restore_signal_handlers(self) -> None:
"""Restore original signal handlers."""
for sig, handler in self._original_handlers.items():
signal.signal(sig, handler)
logger.debug(f"{self.name}: Signal handlers restored")
# Global shutdown handler instance
_shutdown_handler: GracefulShutdownHandler | None = None
def get_shutdown_handler(
name: str = "Server",
shutdown_timeout: float = 30.0,
drain_timeout: float = 10.0,
) -> GracefulShutdownHandler:
"""Get or create the global shutdown handler."""
global _shutdown_handler
if _shutdown_handler is None:
_shutdown_handler = GracefulShutdownHandler(
name, shutdown_timeout, drain_timeout
)
return _shutdown_handler
@contextmanager
def graceful_shutdown(
name: str = "Server",
shutdown_timeout: float = 30.0,
drain_timeout: float = 10.0,
):
"""Context manager for graceful shutdown handling."""
handler = get_shutdown_handler(name, shutdown_timeout, drain_timeout)
handler.install_signal_handlers()
try:
yield handler
finally:
handler.restore_signal_handlers()
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/implementations/cache_adapter.py:
--------------------------------------------------------------------------------
```python
"""
Cache manager adapter.
This module provides adapters that make the existing cache system
compatible with the new ICacheManager interface.
"""
import asyncio
import logging
from typing import Any
from maverick_mcp.data.cache import (
CacheManager as ExistingCacheManager,
)
from maverick_mcp.data.cache import (
clear_cache,
get_from_cache,
save_to_cache,
)
from maverick_mcp.providers.interfaces.cache import CacheConfig, ICacheManager
logger = logging.getLogger(__name__)
class RedisCacheAdapter(ICacheManager):
"""
Adapter that makes the existing cache system compatible with ICacheManager interface.
This adapter wraps the existing cache functions and CacheManager class
to provide the new interface while maintaining all existing functionality.
"""
def __init__(self, config: CacheConfig | None = None):
"""
Initialize the cache adapter.
Args:
config: Cache configuration (optional, defaults to environment)
"""
self._config = config
self._cache_manager = ExistingCacheManager()
logger.debug("RedisCacheAdapter initialized")
async def get(self, key: str) -> Any:
"""
Get data from cache (async wrapper).
Args:
key: Cache key to retrieve
Returns:
Cached data or None if not found or expired
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, get_from_cache, key)
async def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
"""
Store data in cache (async wrapper).
Args:
key: Cache key
value: Data to cache (must be JSON serializable)
ttl: Time-to-live in seconds (None for default TTL)
Returns:
True if successfully cached, False otherwise
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, save_to_cache, key, value, ttl)
async def delete(self, key: str) -> bool:
"""
Delete a key from cache.
Args:
key: Cache key to delete
Returns:
True if key was deleted, False if key didn't exist
"""
return await self._cache_manager.delete(key)
async def exists(self, key: str) -> bool:
"""
Check if a key exists in cache.
Args:
key: Cache key to check
Returns:
True if key exists and hasn't expired, False otherwise
"""
return await self._cache_manager.exists(key)
async def clear(self, pattern: str | None = None) -> int:
"""
Clear cache entries.
Args:
pattern: Pattern to match keys (e.g., "stock:*")
If None, clears all cache entries
Returns:
Number of entries cleared
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, clear_cache, pattern)
async def get_many(self, keys: list[str]) -> dict[str, Any]:
"""
Get multiple values at once for better performance.
Args:
keys: List of cache keys to retrieve
Returns:
Dictionary mapping keys to their cached values
(missing keys will not be in the result)
"""
return await self._cache_manager.get_many(keys)
async def set_many(self, items: list[tuple[str, Any, int | None]]) -> int:
"""
Set multiple values at once for better performance.
Args:
items: List of tuples (key, value, ttl)
Returns:
Number of items successfully cached
"""
return await self._cache_manager.batch_save(items)
async def delete_many(self, keys: list[str]) -> int:
"""
Delete multiple keys for better performance.
Args:
keys: List of keys to delete
Returns:
Number of keys successfully deleted
"""
return await self._cache_manager.batch_delete(keys)
async def exists_many(self, keys: list[str]) -> dict[str, bool]:
"""
Check existence of multiple keys for better performance.
Args:
keys: List of keys to check
Returns:
Dictionary mapping keys to their existence status
"""
return await self._cache_manager.batch_exists(keys)
async def count_keys(self, pattern: str) -> int:
"""
Count keys matching a pattern.
Args:
pattern: Pattern to match (e.g., "stock:*")
Returns:
Number of matching keys
"""
return await self._cache_manager.count_keys(pattern)
async def get_or_set(
self, key: str, default_value: Any, ttl: int | None = None
) -> Any:
"""
Get value from cache, setting it if it doesn't exist.
Args:
key: Cache key
default_value: Value to set if key doesn't exist
ttl: Time-to-live for the default value
Returns:
Either the existing cached value or the default value
"""
# Check if key exists
existing_value = await self.get(key)
if existing_value is not None:
return existing_value
# Set default value and return it
await self.set(key, default_value, ttl)
return default_value
async def increment(self, key: str, amount: int = 1) -> int:
"""
Increment a numeric value in cache.
Args:
key: Cache key
amount: Amount to increment by
Returns:
New value after increment
Raises:
ValueError: If the key exists but doesn't contain a numeric value
"""
# Get current value
current = await self.get(key)
if current is None:
# Key doesn't exist, start from 0
new_value = amount
else:
# Try to convert to int and increment
try:
current_int = int(current)
new_value = current_int + amount
except (ValueError, TypeError):
raise ValueError(f"Key {key} contains non-numeric value: {current}")
# Set the new value
await self.set(key, new_value)
return new_value
async def set_if_not_exists(
self, key: str, value: Any, ttl: int | None = None
) -> bool:
"""
Set a value only if the key doesn't already exist.
Args:
key: Cache key
value: Value to set
ttl: Time-to-live in seconds
Returns:
True if the value was set, False if key already existed
"""
# Check if key already exists
if await self.exists(key):
return False
# Key doesn't exist, set the value
return await self.set(key, value, ttl)
async def get_ttl(self, key: str) -> int | None:
"""
Get the remaining time-to-live for a key.
Args:
key: Cache key
Returns:
Remaining TTL in seconds, None if key doesn't exist or has no TTL
"""
# This would need to be implemented in the underlying cache manager
# For now, return None as we don't have TTL introspection in the existing system
logger.warning(f"TTL introspection not implemented for key: {key}")
return None
async def expire(self, key: str, ttl: int) -> bool:
"""
Set expiration time for an existing key.
Args:
key: Cache key
ttl: Time-to-live in seconds
Returns:
True if expiration was set, False if key doesn't exist
"""
# Check if key exists
if not await self.exists(key):
return False
# Get current value and re-set with new TTL
current_value = await self.get(key)
if current_value is not None:
return await self.set(key, current_value, ttl)
return False
def get_sync_cache_manager(self) -> ExistingCacheManager:
"""
Get the underlying synchronous cache manager for backward compatibility.
Returns:
The wrapped CacheManager instance
"""
return self._cache_manager
```
--------------------------------------------------------------------------------
/maverick_mcp/config/security_utils.py:
--------------------------------------------------------------------------------
```python
"""
Security utilities for applying centralized security configuration.
This module provides utility functions to apply the SecurityConfig
across different server implementations consistently.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware as StarletteCORSMiddleware
from starlette.requests import Request
from maverick_mcp.config.security import get_security_config, validate_security_config
from maverick_mcp.utils.logging import get_logger
logger = get_logger(__name__)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Middleware to add security headers based on SecurityConfig."""
def __init__(self, app, security_config=None):
super().__init__(app)
self.security_config = security_config or get_security_config()
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Add security headers
headers = self.security_config.get_security_headers()
for name, value in headers.items():
response.headers[name] = value
return response
def apply_cors_to_fastapi(app: FastAPI, security_config=None) -> None:
"""Apply CORS configuration to FastAPI app using SecurityConfig."""
config = security_config or get_security_config()
# Validate security before applying
validation = validate_security_config()
if not validation["valid"]:
logger.error(f"Security validation failed: {validation['issues']}")
for issue in validation["issues"]:
logger.error(f"SECURITY ISSUE: {issue}")
raise ValueError(f"Security configuration is invalid: {validation['issues']}")
if validation["warnings"]:
for warning in validation["warnings"]:
logger.warning(f"SECURITY WARNING: {warning}")
# Apply CORS middleware
cors_config = config.get_cors_middleware_config()
app.add_middleware(CORSMiddleware, **cors_config)
logger.info(
f"CORS configured for {config.environment} environment: "
f"origins={cors_config['allow_origins']}, "
f"credentials={cors_config['allow_credentials']}"
)
def apply_cors_to_starlette(app: Starlette, security_config=None) -> list[Middleware]:
"""Get CORS middleware configuration for Starlette app using SecurityConfig."""
config = security_config or get_security_config()
# Validate security before applying
validation = validate_security_config()
if not validation["valid"]:
logger.error(f"Security validation failed: {validation['issues']}")
for issue in validation["issues"]:
logger.error(f"SECURITY ISSUE: {issue}")
raise ValueError(f"Security configuration is invalid: {validation['issues']}")
if validation["warnings"]:
for warning in validation["warnings"]:
logger.warning(f"SECURITY WARNING: {warning}")
# Create middleware configuration
cors_config = config.get_cors_middleware_config()
middleware_list = [
Middleware(StarletteCORSMiddleware, **cors_config),
Middleware(SecurityHeadersMiddleware, security_config=config),
]
logger.info(
f"Starlette CORS configured for {config.environment} environment: "
f"origins={cors_config['allow_origins']}, "
f"credentials={cors_config['allow_credentials']}"
)
return middleware_list
def apply_trusted_hosts_to_fastapi(app: FastAPI, security_config=None) -> None:
"""Apply trusted hosts configuration to FastAPI app."""
config = security_config or get_security_config()
# Only enforce in production or when strict security is enabled
if config.is_production() or config.strict_security:
app.add_middleware(
TrustedHostMiddleware, allowed_hosts=config.trusted_hosts.allowed_hosts
)
logger.info(f"Trusted hosts configured: {config.trusted_hosts.allowed_hosts}")
elif config.trusted_hosts.enforce_in_development:
app.add_middleware(
TrustedHostMiddleware, allowed_hosts=config.trusted_hosts.allowed_hosts
)
logger.info(
f"Trusted hosts configured for development: {config.trusted_hosts.allowed_hosts}"
)
else:
logger.info("Trusted hosts validation disabled for development")
def apply_security_headers_to_fastapi(app: FastAPI, security_config=None) -> None:
"""Apply security headers middleware to FastAPI app."""
config = security_config or get_security_config()
app.add_middleware(SecurityHeadersMiddleware, security_config=config)
logger.info("Security headers middleware applied")
def get_safe_cors_config() -> dict:
"""Get a safe CORS configuration that prevents common vulnerabilities."""
config = get_security_config()
# Validate the configuration
validation = validate_security_config()
if not validation["valid"]:
logger.error("Using fallback safe CORS configuration due to validation errors")
# Return a safe fallback configuration
if config.is_production():
return {
"allow_origins": ["https://maverick-mcp.com"],
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["Authorization", "Content-Type"],
"expose_headers": [],
"max_age": 86400,
}
else:
return {
"allow_origins": ["http://localhost:3000"],
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["Authorization", "Content-Type"],
"expose_headers": [],
"max_age": 86400,
}
return config.get_cors_middleware_config()
def log_security_status() -> None:
"""Log current security configuration status."""
config = get_security_config()
validation = validate_security_config()
logger.info("=== Security Configuration Status ===")
logger.info(f"Environment: {config.environment}")
logger.info(f"Force HTTPS: {config.force_https}")
logger.info(f"Strict Security: {config.strict_security}")
logger.info(f"CORS Origins: {config.cors.allowed_origins}")
logger.info(f"CORS Credentials: {config.cors.allow_credentials}")
logger.info(f"Rate Limiting: {config.rate_limiting.enabled}")
logger.info(f"Trusted Hosts: {config.trusted_hosts.allowed_hosts}")
if validation["valid"]:
logger.info("✅ Security validation: PASSED")
else:
logger.error("❌ Security validation: FAILED")
for issue in validation["issues"]:
logger.error(f" - {issue}")
if validation["warnings"]:
logger.warning("⚠️ Security warnings:")
for warning in validation["warnings"]:
logger.warning(f" - {warning}")
logger.info("=====================================")
def create_secure_fastapi_app(
title: str = "Maverick MCP API",
description: str = "Secure API with centralized security configuration",
version: str = "1.0.0",
**kwargs,
) -> FastAPI:
"""Create a FastAPI app with security configuration applied."""
app = FastAPI(title=title, description=description, version=version, **kwargs)
# Apply security configuration
apply_trusted_hosts_to_fastapi(app)
apply_cors_to_fastapi(app)
apply_security_headers_to_fastapi(app)
# Log security status
log_security_status()
return app
def create_secure_starlette_middleware() -> list[Middleware]:
"""Create Starlette middleware list with security configuration."""
config = get_security_config()
# Start with CORS and security headers
middleware_list = apply_cors_to_starlette(None, config)
# Log security status
log_security_status()
return middleware_list
# Export validation function for easy access
def check_security_config() -> bool:
"""Check if security configuration is valid."""
validation = validate_security_config()
return validation["valid"]
```
--------------------------------------------------------------------------------
/scripts/dev.sh:
--------------------------------------------------------------------------------
```bash
#!/bin/bash
# Maverick-MCP Development Script
# This script starts the backend MCP server for personal stock analysis
set -e
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo -e "${GREEN}Starting Maverick-MCP Development Environment${NC}"
# Kill any existing processes on port 8003 to avoid conflicts
echo -e "${YELLOW}Checking for existing processes on port 8003...${NC}"
EXISTING_PID=$(lsof -ti:8003 2>/dev/null || true)
if [ ! -z "$EXISTING_PID" ]; then
echo -e "${YELLOW}Found existing process(es) on port 8003: $EXISTING_PID${NC}"
echo -e "${YELLOW}Killing existing processes...${NC}"
kill -9 $EXISTING_PID 2>/dev/null || true
sleep 1
else
echo -e "${GREEN}No existing processes found on port 8003${NC}"
fi
# Check if Redis is running
if ! pgrep -x "redis-server" > /dev/null; then
echo -e "${YELLOW}Starting Redis...${NC}"
if command -v brew &> /dev/null; then
brew services start redis
else
redis-server --daemonize yes
fi
else
echo -e "${GREEN}Redis is already running${NC}"
fi
# Function to cleanup on exit
cleanup() {
echo -e "\n${YELLOW}Shutting down services...${NC}"
# Kill backend process
if [ ! -z "$BACKEND_PID" ]; then
kill $BACKEND_PID 2>/dev/null || true
fi
echo -e "${GREEN}Development environment stopped${NC}"
exit 0
}
# Set trap to cleanup on script exit
trap cleanup EXIT INT TERM
# Start backend
echo -e "${YELLOW}Starting backend MCP server...${NC}"
cd "$(dirname "$0")/.."
echo -e "${YELLOW}Current directory: $(pwd)${NC}"
# Source .env if it exists
if [ -f .env ]; then
source .env
fi
# Check if uv is available (more relevant than python since we use uv run)
if ! command -v uv &> /dev/null; then
echo -e "${RED}uv not found! Please install uv: curl -LsSf https://astral.sh/uv/install.sh | sh${NC}"
exit 1
fi
# Validate critical environment variables
echo -e "${YELLOW}Validating environment...${NC}"
if [ -z "$TIINGO_API_KEY" ]; then
echo -e "${RED}Warning: TIINGO_API_KEY not set - stock data tools may not work${NC}"
fi
if [ -z "$EXA_API_KEY" ] && [ -z "$TAVILY_API_KEY" ]; then
echo -e "${RED}Warning: Neither EXA_API_KEY nor TAVILY_API_KEY set - research tools may be limited${NC}"
fi
# Choose transport based on environment variable or default to SSE for reliability
TRANSPORT=${MAVERICK_TRANSPORT:-sse}
echo -e "${YELLOW}Starting backend with: uv run python -m maverick_mcp.api.server --transport ${TRANSPORT} --host 0.0.0.0 --port 8003${NC}"
echo -e "${YELLOW}Transport: ${TRANSPORT} (recommended for Claude Desktop stability)${NC}"
# Run backend with FastMCP in development mode (show real-time output)
echo -e "${YELLOW}Starting server with real-time output...${NC}"
# Set PYTHONWARNINGS to suppress websockets deprecation warnings from uvicorn
PYTHONWARNINGS="ignore::DeprecationWarning:websockets.*,ignore::DeprecationWarning:uvicorn.*" \
uv run python -m maverick_mcp.api.server --transport ${TRANSPORT} --host 0.0.0.0 --port 8003 2>&1 | tee backend.log &
BACKEND_PID=$!
echo -e "${YELLOW}Backend PID: $BACKEND_PID${NC}"
# Wait for backend to start
echo -e "${YELLOW}Waiting for backend to start...${NC}"
# Wait up to 45 seconds for the backend to start and tools to register
TOOLS_REGISTERED=false
for i in {1..45}; do
# Check if backend process is still running first
if ! kill -0 $BACKEND_PID 2>/dev/null; then
echo -e "${RED}Backend process died! Check output above for errors.${NC}"
exit 1
fi
# Check if port is open
if nc -z localhost 8003 2>/dev/null || curl -s http://localhost:8003/health >/dev/null 2>&1; then
if [ "$TOOLS_REGISTERED" = false ]; then
echo -e "${GREEN}Backend port is open, checking for tool registration...${NC}"
# Check backend.log for tool registration messages
if grep -q "Research tools registered successfully" backend.log 2>/dev/null ||
grep -q "Tool registration process completed" backend.log 2>/dev/null ||
grep -q "Tools registered successfully" backend.log 2>/dev/null; then
echo -e "${GREEN}Research tools successfully registered!${NC}"
TOOLS_REGISTERED=true
break
else
echo -e "${YELLOW}Backend running but tools not yet registered... ($i/45)${NC}"
fi
fi
else
echo -e "${YELLOW}Still waiting for backend to start... ($i/45)${NC}"
fi
if [ $i -eq 45 ]; then
echo -e "${RED}Backend failed to fully initialize after 45 seconds!${NC}"
echo -e "${RED}Server may be running but tools not registered. Check output above.${NC}"
# Don't exit - let it continue in case tools load later
fi
sleep 1
done
if [ "$TOOLS_REGISTERED" = true ]; then
echo -e "${GREEN}Backend is ready with tools registered!${NC}"
else
echo -e "${YELLOW}Backend appears to be running but tool registration status unclear${NC}"
fi
echo -e "${GREEN}Backend started successfully on http://localhost:8003${NC}"
# Show information
echo -e "\n${GREEN}Development environment is running!${NC}"
echo -e "${YELLOW}MCP Server:${NC} http://localhost:8003"
echo -e "${YELLOW}Health Check:${NC} http://localhost:8003/health"
# Show endpoint based on transport type
if [ "$TRANSPORT" = "sse" ]; then
echo -e "${YELLOW}MCP SSE Endpoint:${NC} http://localhost:8003/sse/"
elif [ "$TRANSPORT" = "streamable-http" ]; then
echo -e "${YELLOW}MCP HTTP Endpoint:${NC} http://localhost:8003/mcp"
echo -e "${YELLOW}Test with curl:${NC} curl -X POST http://localhost:8003/mcp"
elif [ "$TRANSPORT" = "stdio" ]; then
echo -e "${YELLOW}MCP Transport:${NC} STDIO (no HTTP endpoint)"
fi
echo -e "${YELLOW}Logs:${NC} tail -f backend.log"
if [ "$TOOLS_REGISTERED" = true ]; then
echo -e "\n${GREEN}✓ Research tools are registered and ready${NC}"
else
echo -e "\n${YELLOW}⚠ Tool registration status unclear${NC}"
echo -e "${YELLOW}Debug: Check backend.log for tool registration messages${NC}"
echo -e "${YELLOW}Debug: Look for 'Successfully registered' or 'research tools' in logs${NC}"
fi
echo -e "\n${YELLOW}Claude Desktop Configuration:${NC}"
if [ "$TRANSPORT" = "sse" ]; then
echo -e "${GREEN}SSE Transport (tested and stable):${NC}"
echo -e '{"mcpServers": {"maverick-mcp": {"command": "npx", "args": ["-y", "mcp-remote", "http://localhost:8003/sse/"]}}}'
elif [ "$TRANSPORT" = "stdio" ]; then
echo -e "${GREEN}STDIO Transport (direct connection):${NC}"
echo -e '{"mcpServers": {"maverick-mcp": {"command": "uv", "args": ["run", "python", "-m", "maverick_mcp.api.server", "--transport", "stdio"], "cwd": "'$(pwd)'"}}}'
elif [ "$TRANSPORT" = "streamable-http" ]; then
echo -e "${GREEN}Streamable-HTTP Transport (for testing):${NC}"
echo -e '{"mcpServers": {"maverick-mcp": {"command": "npx", "args": ["-y", "mcp-remote", "http://localhost:8003/mcp"]}}}'
else
echo -e '{"mcpServers": {"maverick-mcp": {"command": "npx", "args": ["-y", "mcp-remote", "http://localhost:8003/mcp"]}}}'
fi
echo -e "\n${YELLOW}Connection Stability Features:${NC}"
if [ "$TRANSPORT" = "sse" ]; then
echo -e " • SSE transport (tested and stable for Claude Desktop)"
echo -e " • Uses mcp-remote bridge for reliable connection"
echo -e " • Prevents tools from disappearing"
echo -e " • Persistent connection with session management"
echo -e " • Adaptive timeout system for research tools"
elif [ "$TRANSPORT" = "stdio" ]; then
echo -e " • Direct STDIO transport (no network layer)"
echo -e " • No mcp-remote needed (direct Claude Desktop integration)"
echo -e " • No session management issues"
echo -e " • No timeout problems"
elif [ "$TRANSPORT" = "streamable-http" ]; then
echo -e " • Streamable-HTTP transport (FastMCP 2.0 standard)"
echo -e " • Uses mcp-remote bridge for Claude Desktop"
echo -e " • Ideal for testing with curl/Postman/REST clients"
echo -e " • Good for debugging transport-specific issues"
echo -e " • Alternative to SSE for compatibility testing"
else
echo -e " • HTTP transport with mcp-remote bridge"
echo -e " • Alternative to SSE for compatibility"
echo -e " • Single process management"
fi
echo -e "\nPress Ctrl+C to stop the server"
# Wait for process
wait
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/screening.py:
--------------------------------------------------------------------------------
```python
"""
Stock screening router for Maverick-MCP.
This module contains all stock screening related tools including
Maverick, supply/demand breakouts, and other screening strategies.
"""
import logging
from typing import Any
from fastmcp import FastMCP
logger = logging.getLogger(__name__)
# Create the screening router
screening_router: FastMCP = FastMCP("Stock_Screening")
def get_maverick_stocks(limit: int = 20) -> dict[str, Any]:
"""
Get top Maverick stocks from the screening results.
DISCLAIMER: Stock screening results are for educational and research purposes only.
This is not investment advice. Past performance does not guarantee future results.
Always conduct thorough research and consult financial professionals before investing.
The Maverick screening strategy identifies stocks with:
- High momentum strength
- Technical patterns (Cup & Handle, consolidation, etc.)
- Momentum characteristics
- Strong combined scores
Args:
limit: Maximum number of stocks to return (default: 20)
Returns:
Dictionary containing Maverick stock screening results
"""
try:
from maverick_mcp.data.models import MaverickStocks, SessionLocal
with SessionLocal() as session:
stocks = MaverickStocks.get_top_stocks(session, limit=limit)
return {
"status": "success",
"count": len(stocks),
"stocks": [stock.to_dict() for stock in stocks],
"screening_type": "maverick_bullish",
"description": "High momentum stocks with bullish technical setups",
}
except Exception as e:
logger.error(f"Error fetching Maverick stocks: {str(e)}")
return {"error": str(e), "status": "error"}
def get_maverick_bear_stocks(limit: int = 20) -> dict[str, Any]:
"""
Get top Maverick Bear stocks from the screening results.
DISCLAIMER: Bearish screening results are for educational purposes only.
This is not advice to sell short or make bearish trades. Short selling involves
unlimited risk potential. Always consult financial professionals before trading.
The Maverick Bear screening identifies stocks with:
- Weak momentum strength
- Bearish technical patterns
- Distribution characteristics
- High bear scores
Args:
limit: Maximum number of stocks to return (default: 20)
Returns:
Dictionary containing Maverick Bear stock screening results
"""
try:
from maverick_mcp.data.models import MaverickBearStocks, SessionLocal
with SessionLocal() as session:
stocks = MaverickBearStocks.get_top_stocks(session, limit=limit)
return {
"status": "success",
"count": len(stocks),
"stocks": [stock.to_dict() for stock in stocks],
"screening_type": "maverick_bearish",
"description": "Weak stocks with bearish technical setups",
}
except Exception as e:
logger.error(f"Error fetching Maverick Bear stocks: {str(e)}")
return {"error": str(e), "status": "error"}
def get_supply_demand_breakouts(
limit: int = 20, filter_moving_averages: bool = False
) -> dict[str, Any]:
"""
Get stocks showing supply/demand breakout patterns from accumulation.
This screening identifies stocks in the demand expansion phase with:
- Price above all major moving averages (demand zone)
- Moving averages in proper alignment indicating accumulation (50 > 150 > 200)
- Strong momentum strength showing institutional interest
- Market structure indicating supply absorption and demand dominance
Args:
limit: Maximum number of stocks to return (default: 20)
filter_moving_averages: If True, only return stocks above all moving averages
Returns:
Dictionary containing supply/demand breakout screening results
"""
try:
from maverick_mcp.data.models import SessionLocal, SupplyDemandBreakoutStocks
with SessionLocal() as session:
if filter_moving_averages:
stocks = SupplyDemandBreakoutStocks.get_stocks_above_moving_averages(
session
)[:limit]
else:
stocks = SupplyDemandBreakoutStocks.get_top_stocks(session, limit=limit)
return {
"status": "success",
"count": len(stocks),
"stocks": [stock.to_dict() for stock in stocks],
"screening_type": "supply_demand_breakout",
"description": "Stocks breaking out from accumulation with strong demand dynamics",
}
except Exception as e:
logger.error(f"Error fetching supply/demand breakout stocks: {str(e)}")
return {"error": str(e), "status": "error"}
def get_all_screening_recommendations() -> dict[str, Any]:
"""
Get comprehensive screening results from all strategies.
This tool returns the top stocks from each screening strategy:
- Maverick Bullish: High momentum growth stocks
- Maverick Bearish: Weak stocks for short opportunities
- Supply/Demand Breakouts: Stocks breaking out from accumulation phases
Returns:
Dictionary containing all screening results organized by strategy
"""
try:
from maverick_mcp.providers.stock_data import StockDataProvider
provider = StockDataProvider()
return provider.get_all_screening_recommendations()
except Exception as e:
logger.error(f"Error getting all screening recommendations: {e}")
return {
"error": str(e),
"status": "error",
"maverick_stocks": [],
"maverick_bear_stocks": [],
"supply_demand_breakouts": [],
}
def get_screening_by_criteria(
min_momentum_score: float | str | None = None,
min_volume: int | str | None = None,
max_price: float | str | None = None,
sector: str | None = None,
limit: int | str = 20,
) -> dict[str, Any]:
"""
Get stocks filtered by specific screening criteria.
This tool allows custom filtering across all screening results based on:
- Momentum score rating
- Volume requirements
- Price constraints
- Sector preferences
Args:
min_momentum_score: Minimum momentum score rating (0-100)
min_volume: Minimum average daily volume
max_price: Maximum stock price
sector: Specific sector to filter (e.g., "Technology")
limit: Maximum number of results
Returns:
Dictionary containing filtered screening results
"""
try:
from maverick_mcp.data.models import MaverickStocks, SessionLocal
# Convert string inputs to appropriate numeric types
if min_momentum_score is not None:
min_momentum_score = float(min_momentum_score)
if min_volume is not None:
min_volume = int(min_volume)
if max_price is not None:
max_price = float(max_price)
if isinstance(limit, str):
limit = int(limit)
with SessionLocal() as session:
query = session.query(MaverickStocks)
if min_momentum_score:
query = query.filter(
MaverickStocks.momentum_score >= min_momentum_score
)
if min_volume:
query = query.filter(MaverickStocks.avg_vol_30d >= min_volume)
if max_price:
query = query.filter(MaverickStocks.close_price <= max_price)
# Note: Sector filtering would require joining with Stock table
# This is a simplified version
stocks = (
query.order_by(MaverickStocks.combined_score.desc()).limit(limit).all()
)
return {
"status": "success",
"count": len(stocks),
"stocks": [stock.to_dict() for stock in stocks],
"criteria": {
"min_momentum_score": min_momentum_score,
"min_volume": min_volume,
"max_price": max_price,
"sector": sector,
},
}
except Exception as e:
logger.error(f"Error in custom screening: {str(e)}")
return {"error": str(e), "status": "error"}
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/interfaces/config.py:
--------------------------------------------------------------------------------
```python
"""
Configuration provider interface.
This module defines the abstract interface for configuration management,
enabling different configuration sources (environment variables, files, etc.)
to be used interchangeably throughout the application.
"""
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class IConfigurationProvider(Protocol):
"""
Interface for configuration management.
This interface abstracts configuration access to enable different
sources (environment variables, config files, etc.) to be used interchangeably.
"""
def get_database_url(self) -> str:
"""
Get database connection URL.
Returns:
Database connection URL string
"""
...
def get_redis_host(self) -> str:
"""Get Redis server host."""
...
def get_redis_port(self) -> int:
"""Get Redis server port."""
...
def get_redis_db(self) -> int:
"""Get Redis database number."""
...
def get_redis_password(self) -> str | None:
"""Get Redis password."""
...
def get_redis_ssl(self) -> bool:
"""Get Redis SSL setting."""
...
def is_cache_enabled(self) -> bool:
"""Check if caching is enabled."""
...
def get_cache_ttl(self) -> int:
"""Get default cache TTL in seconds."""
...
def get_fred_api_key(self) -> str:
"""Get FRED API key for macroeconomic data."""
...
def get_external_api_key(self) -> str:
"""Get External API key for market data."""
...
def get_tiingo_api_key(self) -> str:
"""Get Tiingo API key for market data."""
...
def get_log_level(self) -> str:
"""Get logging level."""
...
def is_development_mode(self) -> bool:
"""Check if running in development mode."""
...
def is_production_mode(self) -> bool:
"""Check if running in production mode."""
...
def get_request_timeout(self) -> int:
"""Get default request timeout in seconds."""
...
def get_max_retries(self) -> int:
"""Get maximum retry attempts for API calls."""
...
def get_pool_size(self) -> int:
"""Get database connection pool size."""
...
def get_max_overflow(self) -> int:
"""Get database connection pool overflow."""
...
def get_config_value(self, key: str, default: Any = None) -> Any:
"""
Get a configuration value by key.
Args:
key: Configuration key
default: Default value if key not found
Returns:
Configuration value or default
"""
...
def set_config_value(self, key: str, value: Any) -> None:
"""
Set a configuration value.
Args:
key: Configuration key
value: Value to set
"""
...
def get_all_config(self) -> dict[str, Any]:
"""
Get all configuration as a dictionary.
Returns:
Dictionary of all configuration values
"""
...
def reload_config(self) -> None:
"""Reload configuration from source."""
...
class ConfigurationError(Exception):
"""Base exception for configuration-related errors."""
pass
class MissingConfigurationError(ConfigurationError):
"""Raised when required configuration is missing."""
def __init__(self, key: str, message: str | None = None):
self.key = key
super().__init__(message or f"Missing required configuration: {key}")
class InvalidConfigurationError(ConfigurationError):
"""Raised when configuration value is invalid."""
def __init__(self, key: str, value: Any, message: str | None = None):
self.key = key
self.value = value
super().__init__(message or f"Invalid configuration value for {key}: {value}")
class EnvironmentConfigurationProvider:
"""
Environment-based configuration provider.
This is a concrete implementation that can be used as a default
or reference implementation for the IConfigurationProvider interface.
"""
def __init__(self):
"""Initialize with environment variables."""
import os
self._env = os.environ
self._cache: dict[str, Any] = {}
def get_database_url(self) -> str:
"""Get database URL from DATABASE_URL environment variable."""
return self._env.get("DATABASE_URL", "sqlite:///maverick_mcp.db")
def get_redis_host(self) -> str:
"""Get Redis host from REDIS_HOST environment variable."""
return self._env.get("REDIS_HOST", "localhost")
def get_redis_port(self) -> int:
"""Get Redis port from REDIS_PORT environment variable."""
return int(self._env.get("REDIS_PORT", "6379"))
def get_redis_db(self) -> int:
"""Get Redis database from REDIS_DB environment variable."""
return int(self._env.get("REDIS_DB", "0"))
def get_redis_password(self) -> str | None:
"""Get Redis password from REDIS_PASSWORD environment variable."""
password = self._env.get("REDIS_PASSWORD", "")
return password if password else None
def get_redis_ssl(self) -> bool:
"""Get Redis SSL setting from REDIS_SSL environment variable."""
return self._env.get("REDIS_SSL", "False").lower() == "true"
def is_cache_enabled(self) -> bool:
"""Check if caching is enabled from CACHE_ENABLED environment variable."""
return self._env.get("CACHE_ENABLED", "True").lower() == "true"
def get_cache_ttl(self) -> int:
"""Get cache TTL from CACHE_TTL_SECONDS environment variable."""
return int(self._env.get("CACHE_TTL_SECONDS", "604800"))
def get_fred_api_key(self) -> str:
"""Get FRED API key from FRED_API_KEY environment variable."""
return self._env.get("FRED_API_KEY", "")
def get_external_api_key(self) -> str:
"""Get External API key from CAPITAL_COMPANION_API_KEY environment variable."""
return self._env.get("CAPITAL_COMPANION_API_KEY", "")
def get_tiingo_api_key(self) -> str:
"""Get Tiingo API key from TIINGO_API_KEY environment variable."""
return self._env.get("TIINGO_API_KEY", "")
def get_log_level(self) -> str:
"""Get log level from LOG_LEVEL environment variable."""
return self._env.get("LOG_LEVEL", "INFO")
def is_development_mode(self) -> bool:
"""Check if in development mode from ENVIRONMENT environment variable."""
env = self._env.get("ENVIRONMENT", "development").lower()
return env in ("development", "dev", "test")
def is_production_mode(self) -> bool:
"""Check if in production mode from ENVIRONMENT environment variable."""
env = self._env.get("ENVIRONMENT", "development").lower()
return env in ("production", "prod")
def get_request_timeout(self) -> int:
"""Get request timeout from REQUEST_TIMEOUT environment variable."""
return int(self._env.get("REQUEST_TIMEOUT", "30"))
def get_max_retries(self) -> int:
"""Get max retries from MAX_RETRIES environment variable."""
return int(self._env.get("MAX_RETRIES", "3"))
def get_pool_size(self) -> int:
"""Get pool size from DB_POOL_SIZE environment variable."""
return int(self._env.get("DB_POOL_SIZE", "5"))
def get_max_overflow(self) -> int:
"""Get max overflow from DB_MAX_OVERFLOW environment variable."""
return int(self._env.get("DB_MAX_OVERFLOW", "10"))
def get_config_value(self, key: str, default: Any = None) -> Any:
"""Get configuration value from environment variables."""
if key in self._cache:
return self._cache[key]
value = self._env.get(key, default)
self._cache[key] = value
return value
def set_config_value(self, key: str, value: Any) -> None:
"""Set configuration value (updates cache, not environment)."""
self._cache[key] = value
def get_all_config(self) -> dict[str, Any]:
"""Get all configuration as dictionary."""
config = {}
config.update(self._env)
config.update(self._cache)
return config
def reload_config(self) -> None:
"""Clear cache to force reload from environment."""
self._cache.clear()
```
--------------------------------------------------------------------------------
/tests/integration/base.py:
--------------------------------------------------------------------------------
```python
"""Base classes and utilities for integration testing."""
from __future__ import annotations
import asyncio
import fnmatch
import time
from collections import defaultdict
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
class InMemoryPubSub:
"""Lightweight pub/sub implementation for the in-memory Redis stub."""
def __init__(self, redis: InMemoryRedis) -> None:
self._redis = redis
self._queues: dict[str, asyncio.Queue[dict[str, Any]]] = {}
self._active = True
async def subscribe(self, channel: str) -> None:
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._queues[channel] = queue
self._redis.register_subscriber(channel, queue)
async def unsubscribe(self, channel: str) -> None:
queue = self._queues.pop(channel, None)
if queue is not None:
self._redis.unregister_subscriber(channel, queue)
async def close(self) -> None:
self._active = False
for channel, _queue in list(self._queues.items()):
await self.unsubscribe(channel)
async def listen(self): # pragma: no cover - simple async generator
while self._active:
tasks = [
asyncio.create_task(queue.get()) for queue in self._queues.values()
]
if not tasks:
await asyncio.sleep(0.01)
continue
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
for task in done:
message = task.result()
yield message
class InMemoryRedis:
"""A minimal asynchronous Redis replacement used in tests."""
def __init__(self) -> None:
self._data: dict[str, bytes] = {}
self._hashes: dict[str, dict[str, str]] = defaultdict(dict)
self._expiry: dict[str, float] = {}
self._pubsub_channels: dict[str, list[asyncio.Queue[dict[str, Any]]]] = (
defaultdict(list)
)
def _is_expired(self, key: str) -> bool:
expiry = self._expiry.get(key)
if expiry is None:
return False
if expiry < time.time():
self._data.pop(key, None)
self._hashes.pop(key, None)
self._expiry.pop(key, None)
return True
return False
def register_subscriber(
self, channel: str, queue: asyncio.Queue[dict[str, Any]]
) -> None:
self._pubsub_channels[channel].append(queue)
def unregister_subscriber(
self, channel: str, queue: asyncio.Queue[dict[str, Any]]
) -> None:
if channel in self._pubsub_channels:
try:
self._pubsub_channels[channel].remove(queue)
except ValueError:
pass
if not self._pubsub_channels[channel]:
del self._pubsub_channels[channel]
async def setex(self, key: str, ttl: int, value: Any) -> None:
self._data[key] = self._encode(value)
self._expiry[key] = time.time() + ttl
async def set(
self,
key: str,
value: Any,
*,
nx: bool = False,
ex: int | None = None,
) -> str | None:
if nx and key in self._data and not self._is_expired(key):
return None
self._data[key] = self._encode(value)
if ex is not None:
self._expiry[key] = time.time() + ex
return "OK"
async def get(self, key: str) -> bytes | None:
if self._is_expired(key):
return None
return self._data.get(key)
async def delete(self, *keys: str) -> int:
removed = 0
for key in keys:
if key in self._data and not self._is_expired(key):
removed += 1
self._data.pop(key, None)
self._hashes.pop(key, None)
self._expiry.pop(key, None)
return removed
async def scan(
self, cursor: int, match: str | None = None, count: int = 100
) -> tuple[int, list[str]]:
keys = [key for key in self._data.keys() if not self._is_expired(key)]
if match:
keys = [key for key in keys if fnmatch.fnmatch(key, match)]
return 0, keys[:count]
async def mget(self, keys: list[str]) -> list[bytes | None]:
return [await self.get(key) for key in keys]
async def hincrby(self, key: str, field: str, amount: int) -> int:
current = int(self._hashes[key].get(field, "0"))
current += amount
self._hashes[key][field] = str(current)
return current
async def hgetall(self, key: str) -> dict[bytes, bytes]:
if self._is_expired(key):
return {}
mapping = self._hashes.get(key, {})
return {
field.encode("utf-8"): value.encode("utf-8")
for field, value in mapping.items()
}
async def hset(self, key: str, mapping: dict[str, Any]) -> None:
for field, value in mapping.items():
self._hashes[key][field] = str(value)
async def eval(self, script: str, keys: list[str], args: list[str]) -> int:
if not keys:
return 0
key = keys[0]
expected = args[0] if args else ""
stored = await self.get(key)
if stored is not None and stored.decode("utf-8") == expected:
await self.delete(key)
return 1
return 0
async def publish(self, channel: str, message: Any) -> None:
encoded = self._encode(message)
for queue in self._pubsub_channels.get(channel, []):
await queue.put(
{"type": "message", "channel": channel, "data": encoded.decode("utf-8")}
)
def pubsub(self) -> InMemoryPubSub:
return InMemoryPubSub(self)
def _encode(self, value: Any) -> bytes:
if isinstance(value, bytes):
return value
if isinstance(value, str):
return value.encode("utf-8")
return str(value).encode("utf-8")
async def close(self) -> None:
self._data.clear()
self._hashes.clear()
self._expiry.clear()
self._pubsub_channels.clear()
class BaseIntegrationTest:
"""Base class for integration tests with common utilities."""
def setup_test(self):
"""Set up test environment for each test."""
return None
def assert_response_success(self, response, expected_status: int = 200):
"""Assert that a response is successful."""
if hasattr(response, "status_code"):
assert response.status_code == expected_status, (
f"Expected status {expected_status}, got {response.status_code}. "
f"Response: {response.json() if hasattr(response, 'content') and response.content else 'No content'}"
)
class RedisIntegrationTest(BaseIntegrationTest):
"""Integration tests that rely on a Redis-like backend."""
redis_client: InMemoryRedis
@pytest.fixture(autouse=True)
async def _setup_redis(self):
self.redis_client = InMemoryRedis()
yield
await self.redis_client.close()
async def assert_cache_exists(self, key: str) -> None:
value = await self.redis_client.get(key)
assert value is not None, f"Expected cache key {key} to exist"
async def assert_cache_not_exists(self, key: str) -> None:
value = await self.redis_client.get(key)
assert value is None, f"Expected cache key {key} to be absent"
class MockLLMBase:
"""Base mock LLM for consistent testing."""
def __init__(self):
self.ainvoke = AsyncMock()
self.bind_tools = MagicMock(return_value=self)
self.invoke = MagicMock()
mock_response = MagicMock()
mock_response.content = '{"insights": ["Test insight"], "sentiment": {"direction": "neutral", "confidence": 0.5}}'
self.ainvoke.return_value = mock_response
class MockCacheManager:
"""Mock cache manager for testing."""
def __init__(self):
self.get = AsyncMock(return_value=None)
self.set = AsyncMock()
self._cache: dict[str, Any] = {}
async def get_cached(self, key: str) -> Any:
"""Get value from mock cache."""
return self._cache.get(key)
async def set_cached(self, key: str, value: Any) -> None:
"""Set value in mock cache."""
self._cache[key] = value
```
--------------------------------------------------------------------------------
/maverick_mcp/data/health.py:
--------------------------------------------------------------------------------
```python
"""
Database health monitoring and connection pool management.
This module provides utilities for monitoring database health,
connection pool statistics, and performance metrics.
"""
import logging
import time
from contextlib import contextmanager
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import event, text
from sqlalchemy import pool as sql_pool
from sqlalchemy.engine import Engine
from maverick_mcp.data.models import SessionLocal, engine
logger = logging.getLogger(__name__)
class DatabaseHealthMonitor:
"""Monitor database health and connection pool statistics."""
def __init__(self, engine: Engine):
self.engine = engine
self.connection_times: list[float] = []
self.query_times: list[float] = []
self.active_connections = 0
self.total_connections = 0
self.failed_connections = 0
# Register event listeners
self._register_events()
def _register_events(self):
"""Register SQLAlchemy event listeners for monitoring."""
@event.listens_for(self.engine, "connect")
def receive_connect(dbapi_conn, connection_record):
"""Track successful connections."""
self.total_connections += 1
self.active_connections += 1
connection_record.info["connect_time"] = time.time()
@event.listens_for(self.engine, "close")
def receive_close(dbapi_conn, connection_record):
"""Track connection closures."""
self.active_connections -= 1
if "connect_time" in connection_record.info:
duration = time.time() - connection_record.info["connect_time"]
self.connection_times.append(duration)
# Keep only last 100 measurements
if len(self.connection_times) > 100:
self.connection_times.pop(0)
# Only register connect_error for databases that support it
# SQLite doesn't support connect_error event
if not self.engine.url.drivername.startswith("sqlite"):
@event.listens_for(self.engine, "connect_error")
def receive_connect_error(dbapi_conn, connection_record, exception):
"""Track connection failures."""
self.failed_connections += 1
logger.error(f"Database connection failed: {exception}")
def get_pool_status(self) -> dict[str, Any]:
"""Get current connection pool status."""
pool = self.engine.pool
if isinstance(pool, sql_pool.QueuePool):
return {
"type": "QueuePool",
"size": pool.size(),
"checked_in": pool.checkedin(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"total": pool.size() + pool.overflow(),
}
elif isinstance(pool, sql_pool.NullPool):
return {
"type": "NullPool",
"message": "No connection pooling (each request creates new connection)",
}
else:
return {
"type": type(pool).__name__,
"message": "Pool statistics not available",
}
def check_database_health(self) -> dict[str, Any]:
"""Perform comprehensive database health check."""
health_status: dict[str, Any] = {
"status": "unknown",
"timestamp": datetime.now(UTC).isoformat(),
"checks": {},
}
# Check 1: Basic connectivity
try:
start_time = time.time()
with SessionLocal() as session:
result = session.execute(text("SELECT 1"))
result.fetchone()
connect_time = (time.time() - start_time) * 1000 # Convert to ms
health_status["checks"]["connectivity"] = {
"status": "healthy",
"response_time_ms": round(connect_time, 2),
"message": "Database is reachable",
}
except Exception as e:
health_status["checks"]["connectivity"] = {
"status": "unhealthy",
"error": str(e),
"message": "Cannot connect to database",
}
health_status["status"] = "unhealthy"
return health_status
# Check 2: Connection pool
pool_status = self.get_pool_status()
health_status["checks"]["connection_pool"] = {
"status": "healthy",
"details": pool_status,
}
# Check 3: Query performance
try:
start_time = time.time()
with SessionLocal() as session:
# Test a simple query on a core table
result = session.execute(text("SELECT COUNT(*) FROM stocks_stock"))
count = result.scalar()
query_time = (time.time() - start_time) * 1000
self.query_times.append(query_time)
if len(self.query_times) > 100:
self.query_times.pop(0)
avg_query_time = (
sum(self.query_times) / len(self.query_times) if self.query_times else 0
)
health_status["checks"]["query_performance"] = {
"status": "healthy" if query_time < 1000 else "degraded",
"last_query_ms": round(query_time, 2),
"avg_query_ms": round(avg_query_time, 2),
"stock_count": count,
}
except Exception as e:
health_status["checks"]["query_performance"] = {
"status": "unhealthy",
"error": str(e),
}
# Check 4: Connection statistics
health_status["checks"]["connection_stats"] = {
"total_connections": self.total_connections,
"active_connections": self.active_connections,
"failed_connections": self.failed_connections,
"failure_rate": round(
self.failed_connections / max(self.total_connections, 1) * 100, 2
),
}
# Determine overall status
if all(
check.get("status") == "healthy"
for check in health_status["checks"].values()
if isinstance(check, dict) and "status" in check
):
health_status["status"] = "healthy"
elif any(
check.get("status") == "unhealthy"
for check in health_status["checks"].values()
if isinstance(check, dict) and "status" in check
):
health_status["status"] = "unhealthy"
else:
health_status["status"] = "degraded"
return health_status
def reset_statistics(self):
"""Reset all collected statistics."""
self.connection_times.clear()
self.query_times.clear()
self.total_connections = 0
self.failed_connections = 0
logger.info("Database health statistics reset")
# Global health monitor instance
db_health_monitor = DatabaseHealthMonitor(engine)
@contextmanager
def timed_query(name: str):
"""Context manager for timing database queries."""
start_time = time.time()
try:
yield
finally:
duration = (time.time() - start_time) * 1000
logger.debug(f"Query '{name}' completed in {duration:.2f}ms")
def get_database_health() -> dict[str, Any]:
"""Get current database health status."""
return db_health_monitor.check_database_health()
def get_pool_statistics() -> dict[str, Any]:
"""Get current connection pool statistics."""
return db_health_monitor.get_pool_status()
def warmup_connection_pool(num_connections: int = 5):
"""
Warm up the connection pool by pre-establishing connections.
This is useful after server startup to avoid cold start latency.
"""
logger.info(f"Warming up connection pool with {num_connections} connections")
connections = []
try:
for _ in range(num_connections):
conn = engine.connect()
conn.execute(text("SELECT 1"))
connections.append(conn)
# Close all connections to return them to the pool
for conn in connections:
conn.close()
logger.info("Connection pool warmup completed")
except Exception as e:
logger.error(f"Error during connection pool warmup: {e}")
# Clean up any established connections
for conn in connections:
try:
conn.close()
except Exception:
pass
```
--------------------------------------------------------------------------------
/maverick_mcp/core/visualization.py:
--------------------------------------------------------------------------------
```python
"""
Visualization utilities for Maverick-MCP.
This module contains functions for generating charts and visualizations
for financial data, including technical analysis charts.
"""
import base64
import logging
import os
import tempfile
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
from maverick_mcp.config.plotly_config import setup_plotly
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("maverick_mcp.visualization")
# Configure Plotly to use modern defaults and suppress warnings
setup_plotly()
def plotly_fig_to_base64(fig: go.Figure, format: str = "png") -> str:
"""
Convert a Plotly figure to a base64 encoded data URI string.
Args:
fig: The Plotly figure to convert
format: Image format (default: 'png')
Returns:
Base64 encoded data URI string of the figure
"""
img_bytes = None
with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmpfile:
try:
fig.write_image(tmpfile.name)
tmpfile.seek(0)
img_bytes = tmpfile.read()
except Exception as e:
logger.error(f"Error writing image: {e}")
raise
os.remove(tmpfile.name)
if not img_bytes:
logger.error("No image bytes were written. Is kaleido installed?")
raise RuntimeError(
"Plotly failed to write image. Ensure 'kaleido' is installed."
)
base64_str = base64.b64encode(img_bytes).decode("utf-8")
return f"data:image/{format};base64,{base64_str}"
def create_plotly_technical_chart(
df: pd.DataFrame, ticker: str, height: int = 400, width: int = 600
) -> go.Figure:
"""
Generate a Plotly technical analysis chart for financial data visualization.
Args:
df: DataFrame with price and technical indicator data
ticker: The ticker symbol to display in the chart title
height: Chart height
width: Chart width
Returns:
A Plotly figure with the technical analysis chart
"""
df = df.copy()
df.columns = [col.lower() for col in df.columns]
df = df.iloc[-126:].copy() # Ensure we keep DataFrame structure
fig = sp.make_subplots(
rows=4,
cols=1,
shared_xaxes=True,
vertical_spacing=0.03,
subplot_titles=("", "", "", ""),
row_heights=[0.6, 0.15, 0.15, 0.1],
)
bg_color = "#FFFFFF"
text_color = "#000000"
grid_color = "rgba(0, 0, 0, 0.35)"
colors = {
"green": "#00796B",
"red": "#D32F2F",
"blue": "#1565C0",
"orange": "#E65100",
"purple": "#6A1B9A",
"gray": "#424242",
"black": "#000000",
}
line_width = 1
# Candlestick chart
fig.add_trace(
go.Candlestick(
x=df.index,
name="Price",
open=df["open"],
high=df["high"],
low=df["low"],
close=df["close"],
increasing_line_color=colors["green"],
decreasing_line_color=colors["red"],
line={"width": line_width},
),
row=1,
col=1,
)
# Moving averages
for i, (col, name) in enumerate(
[("ema_21", "EMA 21"), ("sma_50", "SMA 50"), ("sma_200", "SMA 200")]
):
color = [colors["blue"], colors["green"], colors["red"]][i]
fig.add_trace(
go.Scatter(
x=df.index,
y=df[col],
mode="lines",
name=name,
line={"color": color, "width": line_width},
),
row=1,
col=1,
)
# Bollinger Bands
light_blue = "rgba(21, 101, 192, 0.6)"
fill_color = "rgba(21, 101, 192, 0.1)"
fig.add_trace(
go.Scatter(
x=df.index,
y=df["bbu_20_2.0"],
mode="lines",
line={"color": light_blue, "width": line_width},
name="Upper BB",
legendgroup="bollinger",
showlegend=True,
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=df.index,
y=df["bbl_20_2.0"],
mode="lines",
line={"color": light_blue, "width": line_width},
name="Lower BB",
legendgroup="bollinger",
showlegend=False,
fill="tonexty",
fillcolor=fill_color,
),
row=1,
col=1,
)
# Volume
volume_colors = np.where(df["close"] >= df["open"], colors["green"], colors["red"])
fig.add_trace(
go.Bar(
x=df.index,
y=df["volume"],
name="Volume",
marker={"color": volume_colors},
opacity=0.75,
showlegend=False,
),
row=2,
col=1,
)
# RSI
fig.add_trace(
go.Scatter(
x=df.index,
y=df["rsi"],
mode="lines",
name="RSI",
line={"color": colors["blue"], "width": line_width},
),
row=3,
col=1,
)
fig.add_hline(
y=70,
line_dash="dash",
line_color=colors["red"],
line_width=line_width,
row=3,
col=1,
)
fig.add_hline(
y=30,
line_dash="dash",
line_color=colors["green"],
line_width=line_width,
row=3,
col=1,
)
# MACD
fig.add_trace(
go.Scatter(
x=df.index,
y=df["macd_12_26_9"],
mode="lines",
name="MACD",
line={"color": colors["blue"], "width": line_width},
),
row=4,
col=1,
)
fig.add_trace(
go.Scatter(
x=df.index,
y=df["macds_12_26_9"],
mode="lines",
name="Signal",
line={"color": colors["orange"], "width": line_width},
),
row=4,
col=1,
)
fig.add_trace(
go.Bar(
x=df.index,
y=df["macdh_12_26_9"],
name="Histogram",
showlegend=False,
marker={"color": df["macdh_12_26_9"], "colorscale": "RdYlGn"},
),
row=4,
col=1,
)
# Layout
import datetime
now = datetime.datetime.now(datetime.UTC).strftime("%m/%d/%Y")
fig.update_layout(
height=height,
width=width,
title={
"text": f"<b>{ticker.upper()} | {now} | Technical Analysis | Maverick-MCP</b>",
"font": {"size": 12, "color": text_color, "family": "Arial, sans-serif"},
"y": 0.98,
},
plot_bgcolor=bg_color,
paper_bgcolor=bg_color,
xaxis_rangeslider_visible=False,
legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1,
"xanchor": "left",
"x": 0,
"font": {"size": 10, "color": text_color, "family": "Arial, sans-serif"},
"itemwidth": 30,
"itemsizing": "constant",
"borderwidth": 0,
"tracegroupgap": 1,
},
font={"size": 10, "color": text_color, "family": "Arial, sans-serif"},
margin={"r": 20, "l": 40, "t": 80, "b": 0},
)
fig.update_xaxes(
gridcolor=grid_color,
zerolinecolor=grid_color,
zerolinewidth=line_width,
gridwidth=1,
griddash="dot",
)
fig.update_yaxes(
gridcolor=grid_color,
zerolinecolor=grid_color,
zerolinewidth=line_width,
gridwidth=1,
griddash="dot",
)
y_axis_titles = ["Price", "Volume", "RSI", "MACD"]
for i, title in enumerate(y_axis_titles, start=1):
if title:
fig.update_yaxes(
title={
"text": f"<b>{title}</b>",
"font": {"size": 8, "color": text_color},
"standoff": 0,
},
side="left",
position=0,
automargin=True,
row=i,
col=1,
tickfont={"size": 8},
)
fig.update_xaxes(showticklabels=False, row=1, col=1)
fig.update_xaxes(showticklabels=False, row=2, col=1)
fig.update_xaxes(showticklabels=False, row=3, col=1)
fig.update_xaxes(
title={"text": "Date", "font": {"size": 8, "color": text_color}, "standoff": 5},
row=4,
col=1,
tickfont={"size": 8},
showticklabels=True,
tickangle=45,
tickformat="%Y-%m-%d",
)
return fig
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/yfinance_pool.py:
--------------------------------------------------------------------------------
```python
"""
Optimized yfinance connection pooling and caching.
Provides thread-safe connection pooling and request optimization for yfinance.
"""
import logging
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from typing import Any
import pandas as pd
import yfinance as yf
from requests import Session
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
logger = logging.getLogger(__name__)
class YFinancePool:
"""Thread-safe yfinance connection pool with optimized session management."""
_instance = None
_lock = threading.Lock()
def __new__(cls):
"""Singleton pattern to ensure single connection pool."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Initialize the connection pool once."""
if self._initialized:
return
# Create optimized session with connection pooling
self.session = self._create_optimized_session()
# Thread pool for parallel requests
self.executor = ThreadPoolExecutor(
max_workers=10, thread_name_prefix="yfinance_pool"
)
# Request cache (simple TTL cache)
self._request_cache: dict[str, tuple[Any, float]] = {}
self._cache_lock = threading.Lock()
self._cache_ttl = 60 # 1 minute cache for quotes
self._initialized = True
logger.info("YFinance connection pool initialized")
def _create_optimized_session(self) -> Session:
"""Create an optimized requests session with retry logic and connection pooling."""
session = Session()
# Configure retry strategy
retry_strategy = Retry(
total=3,
backoff_factor=0.3,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["GET", "POST"],
)
# Configure adapter with connection pooling
adapter = HTTPAdapter(
pool_connections=10, # Number of connection pools
pool_maxsize=50, # Max connections per pool
max_retries=retry_strategy,
pool_block=False, # Don't block when pool is full
)
# Mount adapter for HTTP and HTTPS
session.mount("http://", adapter)
session.mount("https://", adapter)
# Set headers to avoid rate limiting
session.headers.update(
{
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.5",
"Accept-Encoding": "gzip, deflate",
"DNT": "1",
"Connection": "keep-alive",
"Upgrade-Insecure-Requests": "1",
}
)
return session
def get_ticker(self, symbol: str) -> yf.Ticker:
"""Get a ticker object - let yfinance handle session for compatibility."""
# Check cache first
cache_key = f"ticker_{symbol}"
cached = self._get_from_cache(cache_key)
if cached:
return cached
# Create ticker without custom session (yfinance now requires curl_cffi)
ticker = yf.Ticker(symbol)
# Cache for short duration
self._add_to_cache(cache_key, ticker, ttl=300) # 5 minutes
return ticker
def get_history(
self,
symbol: str,
start: str | None = None,
end: str | None = None,
period: str | None = None,
interval: str = "1d",
) -> pd.DataFrame:
"""Get historical data with connection pooling."""
# Create cache key
cache_key = f"history_{symbol}_{start}_{end}_{period}_{interval}"
# Check cache
cached = self._get_from_cache(cache_key)
if cached is not None and not cached.empty:
return cached
# Get ticker with optimized session
ticker = self.get_ticker(symbol)
# Fetch data
if period:
df = ticker.history(period=period, interval=interval)
else:
if start is None:
start = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
if end is None:
end = datetime.now().strftime("%Y-%m-%d")
df = ticker.history(start=start, end=end, interval=interval)
# Cache the result (longer TTL for historical data)
if not df.empty:
ttl = (
3600 if interval == "1d" else 300
) # 1 hour for daily, 5 min for intraday
self._add_to_cache(cache_key, df, ttl=ttl)
return df
def get_info(self, symbol: str) -> dict:
"""Get stock info with caching."""
cache_key = f"info_{symbol}"
# Check cache
cached = self._get_from_cache(cache_key)
if cached:
return cached
# Get ticker and info
ticker = self.get_ticker(symbol)
info = ticker.info
# Cache for longer duration (info doesn't change often)
self._add_to_cache(cache_key, info, ttl=3600) # 1 hour
return info
def batch_download(
self,
symbols: list[str],
start: str | None = None,
end: str | None = None,
period: str | None = None,
interval: str = "1d",
group_by: str = "ticker",
threads: bool = True,
) -> pd.DataFrame:
"""Download data for multiple symbols efficiently."""
# Use yfinance's batch download without custom session
if period:
data = yf.download(
tickers=symbols,
period=period,
interval=interval,
group_by=group_by,
threads=threads,
progress=False,
)
else:
if start is None:
start = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
if end is None:
end = datetime.now().strftime("%Y-%m-%d")
data = yf.download(
tickers=symbols,
start=start,
end=end,
interval=interval,
group_by=group_by,
threads=threads,
progress=False,
)
return data
def _get_from_cache(self, key: str) -> Any | None:
"""Get item from cache if not expired."""
with self._cache_lock:
if key in self._request_cache:
value, expiry = self._request_cache[key]
if datetime.now().timestamp() < expiry:
logger.debug(f"Cache hit for {key}")
return value
else:
del self._request_cache[key]
return None
def _add_to_cache(self, key: str, value: Any, ttl: int = 60):
"""Add item to cache with TTL."""
with self._cache_lock:
expiry = datetime.now().timestamp() + ttl
self._request_cache[key] = (value, expiry)
# Clean up old entries if cache is too large
if len(self._request_cache) > 1000:
self._cleanup_cache()
def _cleanup_cache(self):
"""Remove expired entries from cache."""
current_time = datetime.now().timestamp()
expired_keys = [
k for k, (_, expiry) in self._request_cache.items() if expiry < current_time
]
for key in expired_keys:
del self._request_cache[key]
# If still too large, remove oldest entries
if len(self._request_cache) > 800:
sorted_items = sorted(
self._request_cache.items(),
key=lambda x: x[1][1], # Sort by expiry time
)
# Keep only the newest 600 entries
self._request_cache = dict(sorted_items[-600:])
def close(self):
"""Clean up resources."""
try:
self.session.close()
self.executor.shutdown(wait=False)
logger.info("YFinance connection pool closed")
except Exception as e:
logger.warning(f"Error closing connection pool: {e}")
# Global instance
_yfinance_pool: YFinancePool | None = None
def get_yfinance_pool() -> YFinancePool:
"""Get or create the global yfinance connection pool."""
global _yfinance_pool
if _yfinance_pool is None:
_yfinance_pool = YFinancePool()
return _yfinance_pool
def cleanup_yfinance_pool():
"""Clean up the global connection pool."""
global _yfinance_pool
if _yfinance_pool:
_yfinance_pool.close()
_yfinance_pool = None
```
--------------------------------------------------------------------------------
/maverick_mcp/validation/middleware.py:
--------------------------------------------------------------------------------
```python
"""
Validation middleware for FastAPI to standardize error handling.
This module provides middleware to catch validation errors and
return standardized error responses.
"""
import logging
import time
import traceback
import uuid
from fastapi import Request, Response, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from starlette.middleware.base import BaseHTTPMiddleware
from maverick_mcp.exceptions import MaverickException
from .responses import error_response, validation_error_response
logger = logging.getLogger(__name__)
class ValidationMiddleware(BaseHTTPMiddleware):
"""Middleware to handle validation errors and API exceptions."""
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request and handle exceptions."""
# Generate trace ID for request tracking
trace_id = str(uuid.uuid4())
request.state.trace_id = trace_id
try:
response = await call_next(request)
return response
except MaverickException as e:
logger.warning(
f"API error: {e.error_code} - {e.message}",
extra={
"trace_id": trace_id,
"path": request.url.path,
"method": request.method,
"error_code": e.error_code,
},
)
return JSONResponse(
status_code=e.status_code,
content=error_response(
code=e.error_code,
message=e.message,
status_code=e.status_code,
field=e.field,
context=e.context,
trace_id=trace_id,
),
)
except RequestValidationError as e:
logger.warning(
f"Request validation error: {str(e)}",
extra={
"trace_id": trace_id,
"path": request.url.path,
"method": request.method,
},
)
# Convert Pydantic validation errors to our format
errors = []
for error in e.errors():
errors.append(
{
"code": "VALIDATION_ERROR",
"field": ".".join(str(x) for x in error["loc"]),
"message": error["msg"],
"context": {"input": error.get("input"), "type": error["type"]},
}
)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=validation_error_response(errors=errors, trace_id=trace_id),
)
except ValidationError as e:
logger.warning(
f"Pydantic validation error: {str(e)}",
extra={
"trace_id": trace_id,
"path": request.url.path,
"method": request.method,
},
)
# Convert Pydantic validation errors
errors = []
for error in e.errors():
errors.append(
{
"code": "VALIDATION_ERROR",
"field": ".".join(str(x) for x in error["loc"]),
"message": error["msg"],
"context": {"input": error.get("input"), "type": error["type"]},
}
)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=validation_error_response(errors=errors, trace_id=trace_id),
)
except Exception as e:
logger.error(
f"Unexpected error: {str(e)}",
extra={
"trace_id": trace_id,
"path": request.url.path,
"method": request.method,
"traceback": traceback.format_exc(),
},
)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=error_response(
code="INTERNAL_ERROR",
message="An unexpected error occurred",
status_code=500,
trace_id=trace_id,
),
)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware for rate limiting based on API keys."""
def __init__(self, app, rate_limit_store=None):
super().__init__(app)
self.rate_limit_store = rate_limit_store or {}
async def dispatch(self, request: Request, call_next) -> Response:
"""Check rate limits before processing request."""
# Skip rate limiting for health checks and internal endpoints
if request.url.path in ["/health", "/metrics", "/docs", "/openapi.json"]:
return await call_next(request)
# Extract API key from headers
api_key = None
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
elif "x-api-key" in request.headers:
api_key = request.headers["x-api-key"]
if api_key:
# Check rate limit (simplified implementation)
# In production, use Redis or similar for distributed rate limiting
current_time = int(time.time())
window_start = current_time - 60 # 1-minute window
# Clean old entries
key_requests = self.rate_limit_store.get(api_key, [])
key_requests = [ts for ts in key_requests if ts > window_start]
# Check limit (default 60 requests per minute)
if len(key_requests) >= 60:
trace_id = getattr(request.state, "trace_id", str(uuid.uuid4()))
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content=error_response(
code="RATE_LIMIT_EXCEEDED",
message="Rate limit exceeded",
status_code=429,
context={
"limit": 60,
"window": "1 minute",
"retry_after": 60 - (current_time % 60),
},
trace_id=trace_id,
),
headers={"Retry-After": "60"},
)
# Add current request
key_requests.append(current_time)
self.rate_limit_store[api_key] = key_requests
return await call_next(request)
class SecurityMiddleware(BaseHTTPMiddleware):
"""Security middleware for headers and request validation."""
async def dispatch(self, request: Request, call_next) -> Response:
"""Add security headers and validate requests."""
# Validate content type for POST/PUT requests
if request.method in ["POST", "PUT", "PATCH"]:
content_type = request.headers.get("content-type", "")
if not content_type.startswith("application/json"):
trace_id = getattr(request.state, "trace_id", str(uuid.uuid4()))
return JSONResponse(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
content=error_response(
code="UNSUPPORTED_MEDIA_TYPE",
message="Content-Type must be application/json",
status_code=415,
trace_id=trace_id,
),
)
# Validate request size (10MB limit)
content_length = request.headers.get("content-length")
if content_length and int(content_length) > 10 * 1024 * 1024:
trace_id = getattr(request.state, "trace_id", str(uuid.uuid4()))
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content=error_response(
code="REQUEST_TOO_LARGE",
message="Request entity too large (max 10MB)",
status_code=413,
trace_id=trace_id,
),
)
response = await call_next(request)
# Add security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
return response
```
--------------------------------------------------------------------------------
/maverick_mcp/agents/circuit_breaker.py:
--------------------------------------------------------------------------------
```python
"""
Circuit Breaker pattern for resilient external API calls.
"""
import asyncio
import logging
import time
from collections.abc import Callable
from enum import Enum
from typing import Any
from maverick_mcp.config.settings import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class CircuitState(Enum):
"""Circuit breaker states."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, reject calls
HALF_OPEN = "half_open" # Testing if service recovered
class CircuitBreaker:
"""
Circuit breaker for protecting against cascading failures.
Implements the circuit breaker pattern to prevent repeated calls
to failing services and allow them time to recover.
"""
def __init__(
self,
failure_threshold: int | None = None,
recovery_timeout: int | None = None,
expected_exception: type[Exception] = Exception,
name: str = "CircuitBreaker",
):
"""
Initialize circuit breaker.
Args:
failure_threshold: Number of failures before opening circuit (uses config default if None)
recovery_timeout: Seconds to wait before testing recovery (uses config default if None)
expected_exception: Exception type to catch
name: Name for logging
"""
self.failure_threshold = (
failure_threshold or settings.agent.circuit_breaker_failure_threshold
)
self.recovery_timeout = (
recovery_timeout or settings.agent.circuit_breaker_recovery_timeout
)
self.expected_exception = expected_exception
self.name = name
self._failure_count = 0
self._last_failure_time: float | None = None
self._state = CircuitState.CLOSED
self._lock = asyncio.Lock()
@property
def state(self) -> CircuitState:
"""Get current circuit state."""
return self._state
@property
def failure_count(self) -> int:
"""Get current failure count."""
return self._failure_count
async def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Call function through circuit breaker.
Args:
func: Function to call
*args: Function arguments
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
Exception: If circuit is open or function fails
"""
async with self._lock:
if self._state == CircuitState.OPEN:
if self._should_attempt_reset():
self._state = CircuitState.HALF_OPEN
logger.info(f"{self.name}: Attempting reset (half-open)")
else:
raise Exception(f"{self.name}: Circuit breaker is OPEN")
try:
# Execute the function
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
# Success - reset on half-open or reduce failure count
await self._on_success()
return result
except self.expected_exception as e:
# Failure - increment counter and possibly open circuit
await self._on_failure()
raise e
async def _on_success(self):
"""Handle successful call."""
async with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._state = CircuitState.CLOSED
self._failure_count = 0
logger.info(f"{self.name}: Circuit breaker CLOSED after recovery")
elif self._failure_count > 0:
self._failure_count = max(0, self._failure_count - 1)
async def _on_failure(self):
"""Handle failed call."""
async with self._lock:
self._failure_count += 1
self._last_failure_time = time.time()
if self._failure_count >= self.failure_threshold:
self._state = CircuitState.OPEN
logger.warning(
f"{self.name}: Circuit breaker OPEN after {self._failure_count} failures"
)
elif self._state == CircuitState.HALF_OPEN:
self._state = CircuitState.OPEN
logger.warning(
f"{self.name}: Circuit breaker OPEN after half-open test failed"
)
def _should_attempt_reset(self) -> bool:
"""Check if enough time has passed to attempt reset."""
if self._last_failure_time is None:
return False
return (time.time() - self._last_failure_time) >= self.recovery_timeout
async def reset(self):
"""Manually reset the circuit breaker."""
async with self._lock:
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = None
logger.info(f"{self.name}: Circuit breaker manually RESET")
def get_status(self) -> dict[str, Any]:
"""Get circuit breaker status."""
return {
"name": self.name,
"state": self._state.value,
"failure_count": self._failure_count,
"failure_threshold": self.failure_threshold,
"recovery_timeout": self.recovery_timeout,
"time_until_retry": self._get_time_until_retry(),
}
def _get_time_until_retry(self) -> float | None:
"""Get seconds until retry is allowed."""
if self._state != CircuitState.OPEN or self._last_failure_time is None:
return None
elapsed = time.time() - self._last_failure_time
remaining = self.recovery_timeout - elapsed
return max(0, remaining)
class CircuitBreakerManager:
"""Manage multiple circuit breakers."""
def __init__(self):
"""Initialize circuit breaker manager."""
self._breakers: dict[str, CircuitBreaker] = {}
self._lock = asyncio.Lock()
async def get_or_create(
self,
name: str,
failure_threshold: int = 5,
recovery_timeout: int = 60,
expected_exception: type[Exception] = Exception,
) -> CircuitBreaker:
"""Get existing or create new circuit breaker."""
async with self._lock:
if name not in self._breakers:
self._breakers[name] = CircuitBreaker(
failure_threshold=failure_threshold,
recovery_timeout=recovery_timeout,
expected_exception=expected_exception,
name=name,
)
return self._breakers[name]
def get_all_status(self) -> dict[str, dict[str, Any]]:
"""Get status of all circuit breakers."""
return {name: breaker.get_status() for name, breaker in self._breakers.items()}
async def reset_all(self):
"""Reset all circuit breakers."""
for breaker in self._breakers.values():
await breaker.reset()
# Global circuit breaker manager
circuit_manager = CircuitBreakerManager()
def circuit_breaker(
name: str | None = None,
failure_threshold: int = 5,
recovery_timeout: int = 60,
expected_exception: type[Exception] = Exception,
):
"""
Decorator to wrap functions with circuit breaker protection.
Args:
name: Circuit breaker name (uses function name if None)
failure_threshold: Number of failures before opening circuit
recovery_timeout: Seconds to wait before testing recovery
expected_exception: Exception type to catch
Example:
@circuit_breaker("api_call", failure_threshold=3, recovery_timeout=30)
async def call_external_api():
# API call logic
pass
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
breaker_name = (
name or f"{func.__module__}.{getattr(func, '__name__', 'unknown')}"
)
if asyncio.iscoroutinefunction(func):
async def async_wrapper(*args, **kwargs):
breaker = await circuit_manager.get_or_create(
breaker_name,
failure_threshold=failure_threshold,
recovery_timeout=recovery_timeout,
expected_exception=expected_exception,
)
return await breaker.call(func, *args, **kwargs)
return async_wrapper
else:
def sync_wrapper(*args, **kwargs):
# For sync functions, we need to handle async breaker differently
# This is a simplified version - in production you'd want proper async handling
try:
return func(*args, **kwargs)
except expected_exception as e:
logger.warning(f"Circuit breaker {breaker_name}: {e}")
raise
return sync_wrapper
return decorator
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/implementations/stock_data_adapter.py:
--------------------------------------------------------------------------------
```python
"""
Stock data provider adapter.
This module provides adapters that make the existing StockDataProvider
compatible with the new interface-based architecture while maintaining
all existing functionality.
"""
import asyncio
import logging
from typing import Any
import pandas as pd
from sqlalchemy.orm import Session
from maverick_mcp.providers.interfaces.cache import ICacheManager
from maverick_mcp.providers.interfaces.config import IConfigurationProvider
from maverick_mcp.providers.interfaces.persistence import IDataPersistence
from maverick_mcp.providers.interfaces.stock_data import (
IStockDataFetcher,
IStockScreener,
)
from maverick_mcp.providers.stock_data import StockDataProvider
logger = logging.getLogger(__name__)
class StockDataAdapter(IStockDataFetcher, IStockScreener):
"""
Adapter that makes the existing StockDataProvider compatible with new interfaces.
This adapter wraps the existing provider and exposes it through the new
interface contracts, enabling gradual migration to the new architecture.
"""
def __init__(
self,
cache_manager: ICacheManager | None = None,
persistence: IDataPersistence | None = None,
config: IConfigurationProvider | None = None,
db_session: Session | None = None,
):
"""
Initialize the stock data adapter.
Args:
cache_manager: Cache manager for data caching
persistence: Persistence layer for database operations
config: Configuration provider
db_session: Optional database session for dependency injection
"""
self._cache_manager = cache_manager
self._persistence = persistence
self._config = config
self._db_session = db_session
# Initialize the existing provider
self._provider = StockDataProvider(db_session=db_session)
logger.debug("StockDataAdapter initialized")
async def get_stock_data(
self,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
period: str | None = None,
interval: str = "1d",
use_cache: bool = True,
) -> pd.DataFrame:
"""
Fetch historical stock data (async wrapper).
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
period: Alternative to start/end dates (e.g., '1y', '6mo')
interval: Data interval ('1d', '1wk', '1mo', etc.)
use_cache: Whether to use cached data if available
Returns:
DataFrame with OHLCV data indexed by date
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
self._provider.get_stock_data,
symbol,
start_date,
end_date,
period,
interval,
use_cache,
)
async def get_realtime_data(self, symbol: str) -> dict[str, Any] | None:
"""
Get real-time stock data (async wrapper).
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with current price, change, volume, etc. or None if unavailable
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self._provider.get_realtime_data, symbol
)
async def get_stock_info(self, symbol: str) -> dict[str, Any]:
"""
Get detailed stock information and fundamentals (async wrapper).
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with company info, financials, and market data
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._provider.get_stock_info, symbol)
async def get_news(self, symbol: str, limit: int = 10) -> pd.DataFrame:
"""
Get news articles for a stock (async wrapper).
Args:
symbol: Stock ticker symbol
limit: Maximum number of articles to return
Returns:
DataFrame with news articles
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._provider.get_news, symbol, limit)
async def get_earnings(self, symbol: str) -> dict[str, Any]:
"""
Get earnings information for a stock (async wrapper).
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with earnings data and dates
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._provider.get_earnings, symbol)
async def get_recommendations(self, symbol: str) -> pd.DataFrame:
"""
Get analyst recommendations for a stock (async wrapper).
Args:
symbol: Stock ticker symbol
Returns:
DataFrame with analyst recommendations
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self._provider.get_recommendations, symbol
)
async def is_market_open(self) -> bool:
"""
Check if the stock market is currently open (async wrapper).
Returns:
True if market is open, False otherwise
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._provider.is_market_open)
async def is_etf(self, symbol: str) -> bool:
"""
Check if a symbol represents an ETF (async wrapper).
Args:
symbol: Stock ticker symbol
Returns:
True if symbol is an ETF, False otherwise
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._provider.is_etf, symbol)
# IStockScreener implementation
async def get_maverick_recommendations(
self, limit: int = 20, min_score: int | None = None
) -> list[dict[str, Any]]:
"""
Get bullish Maverick stock recommendations (async wrapper).
Args:
limit: Maximum number of recommendations
min_score: Minimum combined score filter
Returns:
List of stock recommendations with technical analysis
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self._provider.get_maverick_recommendations, limit, min_score
)
async def get_maverick_bear_recommendations(
self, limit: int = 20, min_score: int | None = None
) -> list[dict[str, Any]]:
"""
Get bearish Maverick stock recommendations (async wrapper).
Args:
limit: Maximum number of recommendations
min_score: Minimum score filter
Returns:
List of bear stock recommendations
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self._provider.get_maverick_bear_recommendations, limit, min_score
)
async def get_trending_recommendations(
self, limit: int = 20, min_momentum_score: float | None = None
) -> list[dict[str, Any]]:
"""
Get trending stock recommendations (async wrapper).
Args:
limit: Maximum number of recommendations
min_momentum_score: Minimum momentum score filter
Returns:
List of trending stock recommendations
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
self._provider.get_supply_demand_breakout_recommendations,
limit,
min_momentum_score,
)
async def get_all_screening_recommendations(
self,
) -> dict[str, list[dict[str, Any]]]:
"""
Get all screening recommendations in one call (async wrapper).
Returns:
Dictionary with all screening types and their recommendations
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self._provider.get_all_screening_recommendations
)
# Additional methods to expose provider functionality
def get_sync_provider(self) -> StockDataProvider:
"""
Get the underlying synchronous provider for backward compatibility.
Returns:
The wrapped StockDataProvider instance
"""
return self._provider
async def get_all_realtime_data(self, symbols: list[str]) -> dict[str, Any]:
"""
Get real-time data for multiple symbols (async wrapper).
Args:
symbols: List of stock ticker symbols
Returns:
Dictionary mapping symbols to their real-time data
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self._provider.get_all_realtime_data, symbols
)
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/parallel_screening.py:
--------------------------------------------------------------------------------
```python
"""
Parallel stock screening utilities for Maverick-MCP.
This module provides utilities for running stock screening operations
in parallel using ProcessPoolExecutor for significant performance gains.
"""
import asyncio
import time
from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Any
from maverick_mcp.utils.logging import get_logger
logger = get_logger(__name__)
class ParallelScreener:
"""
Parallel stock screening executor.
This class provides methods to run screening functions in parallel
across multiple processes for better performance.
"""
def __init__(self, max_workers: int | None = None):
"""
Initialize the parallel screener.
Args:
max_workers: Maximum number of worker processes.
Defaults to CPU count.
"""
self.max_workers = max_workers
self._executor: ProcessPoolExecutor | None = None
def __enter__(self):
"""Context manager entry."""
self._executor = ProcessPoolExecutor(max_workers=self.max_workers)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
if self._executor:
self._executor.shutdown(wait=True)
self._executor = None
def screen_batch(
self,
symbols: list[str],
screening_func: Callable[[str], dict[str, Any]],
batch_size: int = 10,
timeout: float = 30.0,
) -> list[dict[str, Any]]:
"""
Screen a batch of symbols in parallel.
Args:
symbols: List of stock symbols to screen
screening_func: Function that takes a symbol and returns screening results
batch_size: Number of symbols to process per worker
timeout: Timeout for each screening operation
Returns:
List of screening results for symbols that passed
"""
if not self._executor:
raise RuntimeError("ParallelScreener must be used as context manager")
start_time = time.time()
results = []
failed_symbols = []
# Create batches
batches = [
symbols[i : i + batch_size] for i in range(0, len(symbols), batch_size)
]
logger.info(
f"Starting parallel screening of {len(symbols)} symbols "
f"in {len(batches)} batches"
)
# Submit batch processing jobs
future_to_batch = {
self._executor.submit(self._process_batch, batch, screening_func): batch
for batch in batches
}
# Collect results as they complete
for future in as_completed(future_to_batch, timeout=timeout * len(batches)):
batch = future_to_batch[future]
try:
batch_results = future.result()
results.extend(batch_results)
except Exception as e:
logger.error(f"Batch processing failed: {e}")
failed_symbols.extend(batch)
elapsed = time.time() - start_time
success_rate = (len(results) / len(symbols)) * 100 if symbols else 0
logger.info(
f"Parallel screening completed in {elapsed:.2f}s "
f"({len(results)}/{len(symbols)} succeeded, "
f"{success_rate:.1f}% success rate)"
)
if failed_symbols:
logger.warning(f"Failed to screen symbols: {failed_symbols[:10]}...")
return results
@staticmethod
def _process_batch(
symbols: list[str], screening_func: Callable[[str], dict[str, Any]]
) -> list[dict[str, Any]]:
"""
Process a batch of symbols.
This runs in a separate process.
"""
results = []
for symbol in symbols:
try:
result = screening_func(symbol)
if result and result.get("passed", False):
results.append(result)
except Exception as e:
# Log errors but continue processing
logger.debug(f"Screening failed for {symbol}: {e}")
return results
async def parallel_screen_async(
symbols: list[str],
screening_func: Callable[[str], dict[str, Any]],
max_workers: int | None = None,
batch_size: int = 10,
) -> list[dict[str, Any]]:
"""
Async wrapper for parallel screening.
Args:
symbols: List of stock symbols to screen
screening_func: Screening function (must be picklable)
max_workers: Maximum number of worker processes
batch_size: Number of symbols per batch
Returns:
List of screening results
"""
loop = asyncio.get_event_loop()
# Run screening in thread pool to avoid blocking
def run_screening():
with ParallelScreener(max_workers=max_workers) as screener:
return screener.screen_batch(symbols, screening_func, batch_size)
results = await loop.run_in_executor(None, run_screening)
return results
# Example screening function (must be at module level for pickling)
def example_momentum_screen(symbol: str) -> dict[str, Any]:
"""
Example momentum screening function.
This must be defined at module level to be picklable for multiprocessing.
"""
from maverick_mcp.core.technical_analysis import calculate_rsi, calculate_sma
from maverick_mcp.providers.stock_data import StockDataProvider
try:
# Get stock data
provider = StockDataProvider(use_cache=False)
data = provider.get_stock_data(
symbol, start_date="2023-01-01", end_date="2024-01-01"
)
if len(data) < 50:
return {"symbol": symbol, "passed": False, "reason": "Insufficient data"}
# Calculate indicators
current_price = data["Close"].iloc[-1]
sma_50 = calculate_sma(data, 50).iloc[-1]
rsi = calculate_rsi(data, 14).iloc[-1]
# Momentum criteria
passed = (
current_price > sma_50 # Price above 50-day SMA
and 40 <= rsi <= 70 # RSI in healthy range
)
return {
"symbol": symbol,
"passed": passed,
"price": round(current_price, 2),
"sma_50": round(sma_50, 2),
"rsi": round(rsi, 2),
"above_sma": current_price > sma_50,
}
except Exception as e:
return {"symbol": symbol, "passed": False, "error": str(e)}
# Decorator for making functions parallel-friendly
def make_parallel_safe(func: Callable) -> Callable:
"""
Decorator to make a function safe for parallel execution.
This ensures the function:
1. Doesn't rely on shared state
2. Handles its own database connections
3. Returns picklable results
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
# Ensure clean execution environment
import os
os.environ["AUTH_ENABLED"] = "false"
try:
result = func(*args, **kwargs)
# Ensure result is serializable
import json
json.dumps(result) # Test serializability
return result
except Exception as e:
logger.error(f"Parallel execution error in {func.__name__}: {e}")
return {"error": str(e), "passed": False}
return wrapper
# Batch screening with progress tracking
class BatchScreener:
"""Enhanced batch screener with progress tracking."""
def __init__(self, screening_func: Callable, max_workers: int = 4):
self.screening_func = screening_func
self.max_workers = max_workers
self.results = []
self.progress = 0
self.total = 0
def screen_with_progress(
self,
symbols: list[str],
progress_callback: Callable[[int, int], None] | None = None,
) -> list[dict[str, Any]]:
"""
Screen symbols with progress tracking.
Args:
symbols: List of symbols to screen
progress_callback: Optional callback for progress updates
Returns:
List of screening results
"""
self.total = len(symbols)
self.progress = 0
with ParallelScreener(max_workers=self.max_workers) as screener:
# Process in smaller batches for better progress tracking
batch_size = max(1, len(symbols) // (self.max_workers * 4))
for i in range(0, len(symbols), batch_size):
batch = symbols[i : i + batch_size]
batch_results = screener.screen_batch(
batch,
self.screening_func,
batch_size=1, # Process one at a time within batch
)
self.results.extend(batch_results)
self.progress = min(i + batch_size, self.total)
if progress_callback:
progress_callback(self.progress, self.total)
return self.results
```
--------------------------------------------------------------------------------
/PLANS.md:
--------------------------------------------------------------------------------
```markdown
# PLANS.md
The detailed Execution Plan (`PLANS.md`) is a **living document** and the **memory** that helps Codex steer toward a completed project. Fel mentioned his actual `plans.md` file was about **160 lines** in length, expanded to approximate the detail required for a major project, such as the 15,000-line change to the JSON parser for streaming tool calls.
## 1. Big Picture / Goal
- **Objective:** To execute a core refactor of the existing streaming JSON parser architecture to seamlessly integrate the specialized `ToolCall_V2` library, enabling advanced, concurrent tool call processing and maintaining robust performance characteristics suitable for the "AI age". This refactor must minimize latency introduced during intermediate stream buffering.
- **Architectural Goal:** Transition the core tokenization and parsing logic from synchronous, block-based handling to a fully asynchronous, state-machine-driven model, specifically targeting non-blocking tool call detection within the stream.
- **Success Criteria (Mandatory):**
- All existing unit, property, and fuzzing tests must pass successfully post-refactor.
- New comprehensive integration tests must be written and passed to fully validate `ToolCall_V2` library functionality and streaming integration.
- Performance benchmarks must demonstrate no more than a 5% regression in parsing speed under high-concurrency streaming loads.
- The `plans.md` document must be fully updated upon completion, serving as the executive summary of the work accomplished.
- A high-quality summary and documentation updates (e.g., Readme, API guides) reflecting the new architecture must be generated and committed.
## 2. To-Do List (High Level)
- [ ] **Spike 1:** Comprehensive research and PoC for `ToolCall_V2` integration points.
- [ ] **Refactor Core:** Implement the new asynchronous state machine for streaming tokenization.
- [ ] **Feature A:** Implement the parsing hook necessary to detect `ToolCall_V2` structures mid-stream.
- [ ] **Feature B:** Develop the compatibility layer (shim) for backward support of legacy tool call formats.
- [ ] **Testing:** Write extensive property tests specifically targeting concurrency and error handling around tool calls.
- [ ] **Documentation:** Update all internal and external documentation, including `README.md` and inline comments.
## 3. Plan Details (Spikes & Features)
### Spike 1: Research `ToolCall_V2` Integration
- **Action:** Investigate the API signature of the `ToolCall_V2` library, focusing on its memory allocation strategies and compatibility with the current Rust asynchronous ecosystem (Tokio/Async-std). Determine if vendoring or a simple dependency inclusion is required.
- **Steps:**
1. Analyze `ToolCall_V2` source code to understand its core dependencies and threading requirements.
2. Create a minimal proof-of-concept (PoC) file to test basic instantiation and serialization/deserialization flow.
3. Benchmark PoC for initial overhead costs compared to the previous custom parser logic.
- **Expected Outcome:** A clear architectural recommendation regarding dependency management and an understanding of necessary low-level code modifications.
### Refactor Core: Asynchronous State Machine Implementation
- **Goal:** Replace the synchronous `ChunkProcessor` with a `StreamParser` that utilizes an internal state enum (e.g., START, KEY, VALUE, TOOL_CALL_INIT, TOOL_CALL_BODY).
- **Steps:**
1. Define the new `StreamParser` trait and associated state structures.
2. Migrate existing buffer management to use asynchronous channels/queues where appropriate.
3. Refactor token emission logic to be non-blocking.
4. Ensure all existing `panic!` points are converted to recoverable `Result` types for robust streaming.
### Feature A: `ToolCall_V2` Stream Hook
- **Goal:** Inject logic into the `StreamParser` to identify the start of a tool call structure (e.g., specific JSON key sequence) and hand control to the `ToolCall_V2` handler without blocking the main parser thread.
- **Steps:**
1. Implement the `ParseState::TOOL_CALL_INIT` state.
2. Write the bridging code that streams raw bytes/tokens directly into the `ToolCall_V2` library's parser.
3. Handle the return of control to the main parser stream once the tool call object is fully constructed.
4. Verify that subsequent JSON data (after the tool call structure) is processed correctly.
### Feature B: Legacy Tool Call Compatibility Shim
- **Goal:** Create a compatibility wrapper that translates incoming legacy tool call formats into the structures expected by the new `ToolCall_V2` processor, ensuring backward compatibility.
- **Steps:**
1. Identify all legacy parsing endpoints that still utilize the old format.
2. Implement a `LegacyToolCallAdapter` struct to wrap the old format.
3. Test the adapter against a suite of known legacy inputs.
### Testing Phase
- **Goal:** Achieve 100% test passing rate and add specific coverage for the new feature.
- **Steps:**
1. Run the complete existing test suite to ensure the core refactor has not caused regressions.
2. Implement new property tests focused on interleaved data streams: standard JSON data mixed with large, complex `ToolCall_V2` objects.
3. Integrate and run the fuzzing tests against the new `StreamParser`.
## 4. Progress (Living Document Section)
_(This section is regularly updated by Codex, acting as its memory, showing items completed and current status)._
|Date|Time|Item Completed / Status Update|Resulting Changes (LOC/Commit)|
|:--|:--|:--|:--|
|2023-11-01|09:30|Plan initialized. Began research on Spike 1.|Initial `plans.md` committed.|
|2023-11-01|11:45|Completed Spike 1 research. Decision made to vendor/fork `ToolCall_V2`.|Research notes added to Decision Log.|
|2023-11-01|14:00|Defined `StreamParser` trait and core state enum structures.|Initial ~500 lines of refactor boilerplate.|
|2023-11-01|17:15|Migrated synchronous buffer logic to non-blocking approach. Core tests failing (expected).|~2,500 LOC modified in `core/parser_engine.rs`.|
|2023-11-02|10:30|Completed implementation of Feature A (Tool Call Stream Hook).|New `tool_call_handler.rs` module committed.|
|2023-11-02|13:45|Wrote initial suite of integration tests for Feature A. Tests now intermittently passing.|~600 LOC of new test code.|
|2023-11-02|15:50|Implemented Feature B (Legacy Shim). All existing unit tests pass again.|Code change finalized. Total PR delta now > 4,200 LOC.|
|2023-11-02|16:20|Documentation updates for `README.md` completed and committed.|Documentation finalized.|
|**Current Status:**|**[Timestamp]**|Tests are stable, clean-up phase initiated. Ready for final review and PR submission.|All checks complete.|
## 5. Surprises and Discoveries
_(Unexpected technical issues or findings that influence the overall plan)._
1. **Threading Conflict:** The `ToolCall_V2` library uses an internal thread pool which conflicts with the parent process's executor configuration, necessitating extensive use of `tokio::task::spawn_blocking` wrappers instead of direct calls.
2. **Vendoring Requirement:** Due to a subtle memory leak identified in `ToolCall_V2`'s error handling path when processing incomplete streams, the decision was made to **vendor in** (fork and patch) the library to implement a necessary hotfix.
3. **JSON Format Edge Case:** Discovery of an obscure edge case where the streaming parser incorrectly handles immediately nested tool calls, requiring an adjustment to the `TOOL_CALL_INIT` state machine logic.
## 6. Decision Log
_(Key implementation decisions made during the execution of the plan)._
| Date | Decision | Rationale |
| :--------- | :------------------------------------------------------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------- |
| 2023-11-01 | Chosen Language/Framework: Rust and Tokio. | Maintain consistency with established project codebase. |
| 2023-11-01 | Dependency Strategy: Vendoring/Forking `ToolCall_V2` library. | Provides greater control over critical memory management and allows for immediate patching of stream-related bugs. |
| 2023-11-02 | Error Handling: Adopted custom `ParserError` enum for all failures. | Standardized error reporting across the new asynchronous streams, preventing unexpected panics in production. |
| 2023-11-02 | Testing Priority: Exhaustive Property Tests. | Given the complexity of the core refactor, property tests were prioritized over simple unit tests to maximize confidence in the 15,000 LOC change. |
```
--------------------------------------------------------------------------------
/tests/test_graceful_shutdown.py:
--------------------------------------------------------------------------------
```python
"""
Test graceful shutdown functionality.
"""
import asyncio
import os
import signal
import subprocess
import sys
import time
from unittest.mock import patch
import pytest
from maverick_mcp.utils.shutdown import GracefulShutdownHandler, get_shutdown_handler
class TestGracefulShutdown:
"""Test graceful shutdown handler."""
def test_shutdown_handler_creation(self):
"""Test creating shutdown handler."""
handler = GracefulShutdownHandler("test", shutdown_timeout=10, drain_timeout=5)
assert handler.name == "test"
assert handler.shutdown_timeout == 10
assert handler.drain_timeout == 5
assert not handler.is_shutting_down()
def test_cleanup_registration(self):
"""Test registering cleanup callbacks."""
handler = GracefulShutdownHandler("test")
# Register callbacks
callback1_called = False
callback2_called = False
def callback1():
nonlocal callback1_called
callback1_called = True
def callback2():
nonlocal callback2_called
callback2_called = True
handler.register_cleanup(callback1)
handler.register_cleanup(callback2)
assert len(handler._cleanup_callbacks) == 2
assert callback1 in handler._cleanup_callbacks
assert callback2 in handler._cleanup_callbacks
@pytest.mark.asyncio
async def test_request_tracking(self):
"""Test request tracking."""
handler = GracefulShutdownHandler("test")
# Create mock tasks
async def dummy_task():
await asyncio.sleep(0.1)
task1 = asyncio.create_task(dummy_task())
task2 = asyncio.create_task(dummy_task())
# Track tasks
handler.track_request(task1)
handler.track_request(task2)
assert len(handler._active_requests) == 2
# Wait for tasks to complete
await task1
await task2
await asyncio.sleep(0.1) # Allow cleanup
assert len(handler._active_requests) == 0
def test_signal_handler_installation(self):
"""Test signal handler installation."""
handler = GracefulShutdownHandler("test")
# Store original handlers
original_sigterm = signal.signal(signal.SIGTERM, signal.SIG_DFL)
original_sigint = signal.signal(signal.SIGINT, signal.SIG_DFL)
try:
# Install handlers
handler.install_signal_handlers()
# Verify handlers were changed
current_sigterm = signal.signal(signal.SIGTERM, signal.SIG_DFL)
current_sigint = signal.signal(signal.SIGINT, signal.SIG_DFL)
assert current_sigterm == handler._signal_handler
assert current_sigint == handler._signal_handler
finally:
# Restore original handlers
signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGINT, original_sigint)
@pytest.mark.asyncio
async def test_async_shutdown_sequence(self):
"""Test async shutdown sequence."""
handler = GracefulShutdownHandler("test", drain_timeout=0.5)
# Track cleanup calls
sync_called = False
async_called = False
def sync_cleanup():
nonlocal sync_called
sync_called = True
async def async_cleanup():
nonlocal async_called
async_called = True
handler.register_cleanup(sync_cleanup)
handler.register_cleanup(async_cleanup)
# Mock sys.exit to prevent actual exit
with patch("sys.exit") as mock_exit:
# Trigger shutdown
handler._shutdown_in_progress = False
await handler._async_shutdown("SIGTERM")
# Verify shutdown sequence
assert handler._shutdown_event.is_set()
assert sync_called
assert async_called
mock_exit.assert_called_once_with(0)
@pytest.mark.asyncio
async def test_request_draining_timeout(self):
"""Test request draining with timeout."""
handler = GracefulShutdownHandler("test", drain_timeout=0.2)
# Create long-running task
async def long_task():
await asyncio.sleep(1.0) # Longer than drain timeout
task = asyncio.create_task(long_task())
handler.track_request(task)
# Start draining
start_time = time.time()
try:
await asyncio.wait_for(handler._wait_for_requests(), timeout=0.3)
except TimeoutError:
pass
drain_time = time.time() - start_time
# Should timeout quickly since task won't complete
assert drain_time < 0.5
assert task in handler._active_requests
# Cancel task to clean up
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def test_global_shutdown_handler(self):
"""Test global shutdown handler singleton."""
handler1 = get_shutdown_handler("test1")
handler2 = get_shutdown_handler("test2")
# Should return same instance
assert handler1 is handler2
assert handler1.name == "test1" # First call sets the name
@pytest.mark.asyncio
async def test_cleanup_callback_error_handling(self):
"""Test error handling in cleanup callbacks."""
handler = GracefulShutdownHandler("test")
# Create callback that raises exception
def failing_callback():
raise RuntimeError("Cleanup failed")
async def async_failing_callback():
raise RuntimeError("Async cleanup failed")
handler.register_cleanup(failing_callback)
handler.register_cleanup(async_failing_callback)
# Mock sys.exit
with patch("sys.exit"):
# Should not raise despite callback errors
await handler._async_shutdown("SIGTERM")
# Handler should still complete shutdown
assert handler._shutdown_event.is_set()
@pytest.mark.asyncio
async def test_sync_request_tracking(self):
"""Test synchronous request tracking context manager."""
handler = GracefulShutdownHandler("test")
# Use context manager
with handler.track_sync_request():
# In real usage, this would track the request
pass
# Should complete without error
assert True
@pytest.mark.skipif(
sys.platform == "win32", reason="SIGHUP not available on Windows"
)
def test_sighup_handling(self):
"""Test SIGHUP signal handling."""
handler = GracefulShutdownHandler("test")
# Store original handler
original_sighup = signal.signal(signal.SIGHUP, signal.SIG_DFL)
try:
handler.install_signal_handlers()
# Verify SIGHUP handler was installed
current_sighup = signal.signal(signal.SIGHUP, signal.SIG_DFL)
assert current_sighup == handler._signal_handler
finally:
# Restore original handler
signal.signal(signal.SIGHUP, original_sighup)
@pytest.mark.integration
class TestGracefulShutdownIntegration:
"""Integration tests for graceful shutdown."""
@pytest.mark.asyncio
async def test_server_graceful_shutdown(self):
"""Test actual server graceful shutdown."""
# This would test with a real server process
# For now, we'll simulate it
# Start a subprocess that uses our shutdown handler
script = """
import asyncio
import signal
import sys
import time
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from maverick_mcp.utils.shutdown import graceful_shutdown
async def main():
with graceful_shutdown("test-server") as handler:
# Simulate server running
print("Server started", flush=True)
# Wait for shutdown
try:
await handler.wait_for_shutdown()
except KeyboardInterrupt:
pass
print("Server shutting down", flush=True)
if __name__ == "__main__":
asyncio.run(main())
"""
# Write script to temp file
import tempfile
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(script)
script_path = f.name
try:
# Start subprocess
proc = subprocess.Popen(
[sys.executable, script_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
# Wait for startup
await asyncio.sleep(0.5)
# Send SIGTERM
proc.send_signal(signal.SIGTERM)
# Wait for completion
stdout, stderr = proc.communicate(timeout=5)
# Verify graceful shutdown
assert "Server started" in stdout
assert "Server shutting down" in stdout
assert proc.returncode == 0
finally:
os.unlink(script_path)
```