This is page 6 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/data/health.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Database health monitoring and connection pool management.
3 |
4 | This module provides utilities for monitoring database health,
5 | connection pool statistics, and performance metrics.
6 | """
7 |
8 | import logging
9 | import time
10 | from contextlib import contextmanager
11 | from datetime import UTC, datetime
12 | from typing import Any
13 |
14 | from sqlalchemy import event, text
15 | from sqlalchemy import pool as sql_pool
16 | from sqlalchemy.engine import Engine
17 |
18 | from maverick_mcp.data.models import SessionLocal, engine
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class DatabaseHealthMonitor:
24 | """Monitor database health and connection pool statistics."""
25 |
26 | def __init__(self, engine: Engine):
27 | self.engine = engine
28 | self.connection_times: list[float] = []
29 | self.query_times: list[float] = []
30 | self.active_connections = 0
31 | self.total_connections = 0
32 | self.failed_connections = 0
33 |
34 | # Register event listeners
35 | self._register_events()
36 |
37 | def _register_events(self):
38 | """Register SQLAlchemy event listeners for monitoring."""
39 |
40 | @event.listens_for(self.engine, "connect")
41 | def receive_connect(dbapi_conn, connection_record):
42 | """Track successful connections."""
43 | self.total_connections += 1
44 | self.active_connections += 1
45 | connection_record.info["connect_time"] = time.time()
46 |
47 | @event.listens_for(self.engine, "close")
48 | def receive_close(dbapi_conn, connection_record):
49 | """Track connection closures."""
50 | self.active_connections -= 1
51 | if "connect_time" in connection_record.info:
52 | duration = time.time() - connection_record.info["connect_time"]
53 | self.connection_times.append(duration)
54 | # Keep only last 100 measurements
55 | if len(self.connection_times) > 100:
56 | self.connection_times.pop(0)
57 |
58 | # Only register connect_error for databases that support it
59 | # SQLite doesn't support connect_error event
60 | if not self.engine.url.drivername.startswith("sqlite"):
61 |
62 | @event.listens_for(self.engine, "connect_error")
63 | def receive_connect_error(dbapi_conn, connection_record, exception):
64 | """Track connection failures."""
65 | self.failed_connections += 1
66 | logger.error(f"Database connection failed: {exception}")
67 |
68 | def get_pool_status(self) -> dict[str, Any]:
69 | """Get current connection pool status."""
70 | pool = self.engine.pool
71 |
72 | if isinstance(pool, sql_pool.QueuePool):
73 | return {
74 | "type": "QueuePool",
75 | "size": pool.size(),
76 | "checked_in": pool.checkedin(),
77 | "checked_out": pool.checkedout(),
78 | "overflow": pool.overflow(),
79 | "total": pool.size() + pool.overflow(),
80 | }
81 | elif isinstance(pool, sql_pool.NullPool):
82 | return {
83 | "type": "NullPool",
84 | "message": "No connection pooling (each request creates new connection)",
85 | }
86 | else:
87 | return {
88 | "type": type(pool).__name__,
89 | "message": "Pool statistics not available",
90 | }
91 |
92 | def check_database_health(self) -> dict[str, Any]:
93 | """Perform comprehensive database health check."""
94 | health_status: dict[str, Any] = {
95 | "status": "unknown",
96 | "timestamp": datetime.now(UTC).isoformat(),
97 | "checks": {},
98 | }
99 |
100 | # Check 1: Basic connectivity
101 | try:
102 | start_time = time.time()
103 | with SessionLocal() as session:
104 | result = session.execute(text("SELECT 1"))
105 | result.fetchone()
106 |
107 | connect_time = (time.time() - start_time) * 1000 # Convert to ms
108 | health_status["checks"]["connectivity"] = {
109 | "status": "healthy",
110 | "response_time_ms": round(connect_time, 2),
111 | "message": "Database is reachable",
112 | }
113 | except Exception as e:
114 | health_status["checks"]["connectivity"] = {
115 | "status": "unhealthy",
116 | "error": str(e),
117 | "message": "Cannot connect to database",
118 | }
119 | health_status["status"] = "unhealthy"
120 | return health_status
121 |
122 | # Check 2: Connection pool
123 | pool_status = self.get_pool_status()
124 | health_status["checks"]["connection_pool"] = {
125 | "status": "healthy",
126 | "details": pool_status,
127 | }
128 |
129 | # Check 3: Query performance
130 | try:
131 | start_time = time.time()
132 | with SessionLocal() as session:
133 | # Test a simple query on a core table
134 | result = session.execute(text("SELECT COUNT(*) FROM stocks_stock"))
135 | count = result.scalar()
136 |
137 | query_time = (time.time() - start_time) * 1000
138 | self.query_times.append(query_time)
139 | if len(self.query_times) > 100:
140 | self.query_times.pop(0)
141 |
142 | avg_query_time = (
143 | sum(self.query_times) / len(self.query_times) if self.query_times else 0
144 | )
145 |
146 | health_status["checks"]["query_performance"] = {
147 | "status": "healthy" if query_time < 1000 else "degraded",
148 | "last_query_ms": round(query_time, 2),
149 | "avg_query_ms": round(avg_query_time, 2),
150 | "stock_count": count,
151 | }
152 | except Exception as e:
153 | health_status["checks"]["query_performance"] = {
154 | "status": "unhealthy",
155 | "error": str(e),
156 | }
157 |
158 | # Check 4: Connection statistics
159 | health_status["checks"]["connection_stats"] = {
160 | "total_connections": self.total_connections,
161 | "active_connections": self.active_connections,
162 | "failed_connections": self.failed_connections,
163 | "failure_rate": round(
164 | self.failed_connections / max(self.total_connections, 1) * 100, 2
165 | ),
166 | }
167 |
168 | # Determine overall status
169 | if all(
170 | check.get("status") == "healthy"
171 | for check in health_status["checks"].values()
172 | if isinstance(check, dict) and "status" in check
173 | ):
174 | health_status["status"] = "healthy"
175 | elif any(
176 | check.get("status") == "unhealthy"
177 | for check in health_status["checks"].values()
178 | if isinstance(check, dict) and "status" in check
179 | ):
180 | health_status["status"] = "unhealthy"
181 | else:
182 | health_status["status"] = "degraded"
183 |
184 | return health_status
185 |
186 | def reset_statistics(self):
187 | """Reset all collected statistics."""
188 | self.connection_times.clear()
189 | self.query_times.clear()
190 | self.total_connections = 0
191 | self.failed_connections = 0
192 | logger.info("Database health statistics reset")
193 |
194 |
195 | # Global health monitor instance
196 | db_health_monitor = DatabaseHealthMonitor(engine)
197 |
198 |
199 | @contextmanager
200 | def timed_query(name: str):
201 | """Context manager for timing database queries."""
202 | start_time = time.time()
203 | try:
204 | yield
205 | finally:
206 | duration = (time.time() - start_time) * 1000
207 | logger.debug(f"Query '{name}' completed in {duration:.2f}ms")
208 |
209 |
210 | def get_database_health() -> dict[str, Any]:
211 | """Get current database health status."""
212 | return db_health_monitor.check_database_health()
213 |
214 |
215 | def get_pool_statistics() -> dict[str, Any]:
216 | """Get current connection pool statistics."""
217 | return db_health_monitor.get_pool_status()
218 |
219 |
220 | def warmup_connection_pool(num_connections: int = 5):
221 | """
222 | Warm up the connection pool by pre-establishing connections.
223 |
224 | This is useful after server startup to avoid cold start latency.
225 | """
226 | logger.info(f"Warming up connection pool with {num_connections} connections")
227 |
228 | connections = []
229 | try:
230 | for _ in range(num_connections):
231 | conn = engine.connect()
232 | conn.execute(text("SELECT 1"))
233 | connections.append(conn)
234 |
235 | # Close all connections to return them to the pool
236 | for conn in connections:
237 | conn.close()
238 |
239 | logger.info("Connection pool warmup completed")
240 | except Exception as e:
241 | logger.error(f"Error during connection pool warmup: {e}")
242 | # Clean up any established connections
243 | for conn in connections:
244 | try:
245 | conn.close()
246 | except Exception:
247 | pass
248 |
```
--------------------------------------------------------------------------------
/maverick_mcp/core/visualization.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Visualization utilities for Maverick-MCP.
3 |
4 | This module contains functions for generating charts and visualizations
5 | for financial data, including technical analysis charts.
6 | """
7 |
8 | import base64
9 | import logging
10 | import os
11 | import tempfile
12 |
13 | import numpy as np
14 | import pandas as pd
15 | import plotly.graph_objects as go
16 | import plotly.subplots as sp
17 |
18 | from maverick_mcp.config.plotly_config import setup_plotly
19 |
20 | # Set up logging
21 | logging.basicConfig(
22 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
23 | )
24 | logger = logging.getLogger("maverick_mcp.visualization")
25 |
26 | # Configure Plotly to use modern defaults and suppress warnings
27 | setup_plotly()
28 |
29 |
30 | def plotly_fig_to_base64(fig: go.Figure, format: str = "png") -> str:
31 | """
32 | Convert a Plotly figure to a base64 encoded data URI string.
33 |
34 | Args:
35 | fig: The Plotly figure to convert
36 | format: Image format (default: 'png')
37 |
38 | Returns:
39 | Base64 encoded data URI string of the figure
40 | """
41 | img_bytes = None
42 | with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmpfile:
43 | try:
44 | fig.write_image(tmpfile.name)
45 | tmpfile.seek(0)
46 | img_bytes = tmpfile.read()
47 | except Exception as e:
48 | logger.error(f"Error writing image: {e}")
49 | raise
50 | os.remove(tmpfile.name)
51 | if not img_bytes:
52 | logger.error("No image bytes were written. Is kaleido installed?")
53 | raise RuntimeError(
54 | "Plotly failed to write image. Ensure 'kaleido' is installed."
55 | )
56 | base64_str = base64.b64encode(img_bytes).decode("utf-8")
57 | return f"data:image/{format};base64,{base64_str}"
58 |
59 |
60 | def create_plotly_technical_chart(
61 | df: pd.DataFrame, ticker: str, height: int = 400, width: int = 600
62 | ) -> go.Figure:
63 | """
64 | Generate a Plotly technical analysis chart for financial data visualization.
65 |
66 | Args:
67 | df: DataFrame with price and technical indicator data
68 | ticker: The ticker symbol to display in the chart title
69 | height: Chart height
70 | width: Chart width
71 |
72 | Returns:
73 | A Plotly figure with the technical analysis chart
74 | """
75 | df = df.copy()
76 | df.columns = [col.lower() for col in df.columns]
77 | df = df.iloc[-126:].copy() # Ensure we keep DataFrame structure
78 |
79 | fig = sp.make_subplots(
80 | rows=4,
81 | cols=1,
82 | shared_xaxes=True,
83 | vertical_spacing=0.03,
84 | subplot_titles=("", "", "", ""),
85 | row_heights=[0.6, 0.15, 0.15, 0.1],
86 | )
87 |
88 | bg_color = "#FFFFFF"
89 | text_color = "#000000"
90 | grid_color = "rgba(0, 0, 0, 0.35)"
91 | colors = {
92 | "green": "#00796B",
93 | "red": "#D32F2F",
94 | "blue": "#1565C0",
95 | "orange": "#E65100",
96 | "purple": "#6A1B9A",
97 | "gray": "#424242",
98 | "black": "#000000",
99 | }
100 | line_width = 1
101 |
102 | # Candlestick chart
103 | fig.add_trace(
104 | go.Candlestick(
105 | x=df.index,
106 | name="Price",
107 | open=df["open"],
108 | high=df["high"],
109 | low=df["low"],
110 | close=df["close"],
111 | increasing_line_color=colors["green"],
112 | decreasing_line_color=colors["red"],
113 | line={"width": line_width},
114 | ),
115 | row=1,
116 | col=1,
117 | )
118 |
119 | # Moving averages
120 | for i, (col, name) in enumerate(
121 | [("ema_21", "EMA 21"), ("sma_50", "SMA 50"), ("sma_200", "SMA 200")]
122 | ):
123 | color = [colors["blue"], colors["green"], colors["red"]][i]
124 | fig.add_trace(
125 | go.Scatter(
126 | x=df.index,
127 | y=df[col],
128 | mode="lines",
129 | name=name,
130 | line={"color": color, "width": line_width},
131 | ),
132 | row=1,
133 | col=1,
134 | )
135 |
136 | # Bollinger Bands
137 | light_blue = "rgba(21, 101, 192, 0.6)"
138 | fill_color = "rgba(21, 101, 192, 0.1)"
139 | fig.add_trace(
140 | go.Scatter(
141 | x=df.index,
142 | y=df["bbu_20_2.0"],
143 | mode="lines",
144 | line={"color": light_blue, "width": line_width},
145 | name="Upper BB",
146 | legendgroup="bollinger",
147 | showlegend=True,
148 | ),
149 | row=1,
150 | col=1,
151 | )
152 | fig.add_trace(
153 | go.Scatter(
154 | x=df.index,
155 | y=df["bbl_20_2.0"],
156 | mode="lines",
157 | line={"color": light_blue, "width": line_width},
158 | name="Lower BB",
159 | legendgroup="bollinger",
160 | showlegend=False,
161 | fill="tonexty",
162 | fillcolor=fill_color,
163 | ),
164 | row=1,
165 | col=1,
166 | )
167 |
168 | # Volume
169 | volume_colors = np.where(df["close"] >= df["open"], colors["green"], colors["red"])
170 | fig.add_trace(
171 | go.Bar(
172 | x=df.index,
173 | y=df["volume"],
174 | name="Volume",
175 | marker={"color": volume_colors},
176 | opacity=0.75,
177 | showlegend=False,
178 | ),
179 | row=2,
180 | col=1,
181 | )
182 |
183 | # RSI
184 | fig.add_trace(
185 | go.Scatter(
186 | x=df.index,
187 | y=df["rsi"],
188 | mode="lines",
189 | name="RSI",
190 | line={"color": colors["blue"], "width": line_width},
191 | ),
192 | row=3,
193 | col=1,
194 | )
195 | fig.add_hline(
196 | y=70,
197 | line_dash="dash",
198 | line_color=colors["red"],
199 | line_width=line_width,
200 | row=3,
201 | col=1,
202 | )
203 | fig.add_hline(
204 | y=30,
205 | line_dash="dash",
206 | line_color=colors["green"],
207 | line_width=line_width,
208 | row=3,
209 | col=1,
210 | )
211 |
212 | # MACD
213 | fig.add_trace(
214 | go.Scatter(
215 | x=df.index,
216 | y=df["macd_12_26_9"],
217 | mode="lines",
218 | name="MACD",
219 | line={"color": colors["blue"], "width": line_width},
220 | ),
221 | row=4,
222 | col=1,
223 | )
224 | fig.add_trace(
225 | go.Scatter(
226 | x=df.index,
227 | y=df["macds_12_26_9"],
228 | mode="lines",
229 | name="Signal",
230 | line={"color": colors["orange"], "width": line_width},
231 | ),
232 | row=4,
233 | col=1,
234 | )
235 | fig.add_trace(
236 | go.Bar(
237 | x=df.index,
238 | y=df["macdh_12_26_9"],
239 | name="Histogram",
240 | showlegend=False,
241 | marker={"color": df["macdh_12_26_9"], "colorscale": "RdYlGn"},
242 | ),
243 | row=4,
244 | col=1,
245 | )
246 |
247 | # Layout
248 | import datetime
249 |
250 | now = datetime.datetime.now(datetime.UTC).strftime("%m/%d/%Y")
251 | fig.update_layout(
252 | height=height,
253 | width=width,
254 | title={
255 | "text": f"<b>{ticker.upper()} | {now} | Technical Analysis | Maverick-MCP</b>",
256 | "font": {"size": 12, "color": text_color, "family": "Arial, sans-serif"},
257 | "y": 0.98,
258 | },
259 | plot_bgcolor=bg_color,
260 | paper_bgcolor=bg_color,
261 | xaxis_rangeslider_visible=False,
262 | legend={
263 | "orientation": "h",
264 | "yanchor": "bottom",
265 | "y": 1,
266 | "xanchor": "left",
267 | "x": 0,
268 | "font": {"size": 10, "color": text_color, "family": "Arial, sans-serif"},
269 | "itemwidth": 30,
270 | "itemsizing": "constant",
271 | "borderwidth": 0,
272 | "tracegroupgap": 1,
273 | },
274 | font={"size": 10, "color": text_color, "family": "Arial, sans-serif"},
275 | margin={"r": 20, "l": 40, "t": 80, "b": 0},
276 | )
277 |
278 | fig.update_xaxes(
279 | gridcolor=grid_color,
280 | zerolinecolor=grid_color,
281 | zerolinewidth=line_width,
282 | gridwidth=1,
283 | griddash="dot",
284 | )
285 | fig.update_yaxes(
286 | gridcolor=grid_color,
287 | zerolinecolor=grid_color,
288 | zerolinewidth=line_width,
289 | gridwidth=1,
290 | griddash="dot",
291 | )
292 |
293 | y_axis_titles = ["Price", "Volume", "RSI", "MACD"]
294 | for i, title in enumerate(y_axis_titles, start=1):
295 | if title:
296 | fig.update_yaxes(
297 | title={
298 | "text": f"<b>{title}</b>",
299 | "font": {"size": 8, "color": text_color},
300 | "standoff": 0,
301 | },
302 | side="left",
303 | position=0,
304 | automargin=True,
305 | row=i,
306 | col=1,
307 | tickfont={"size": 8},
308 | )
309 |
310 | fig.update_xaxes(showticklabels=False, row=1, col=1)
311 | fig.update_xaxes(showticklabels=False, row=2, col=1)
312 | fig.update_xaxes(showticklabels=False, row=3, col=1)
313 | fig.update_xaxes(
314 | title={"text": "Date", "font": {"size": 8, "color": text_color}, "standoff": 5},
315 | row=4,
316 | col=1,
317 | tickfont={"size": 8},
318 | showticklabels=True,
319 | tickangle=45,
320 | tickformat="%Y-%m-%d",
321 | )
322 |
323 | return fig
324 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/yfinance_pool.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Optimized yfinance connection pooling and caching.
3 | Provides thread-safe connection pooling and request optimization for yfinance.
4 | """
5 |
6 | import logging
7 | import threading
8 | from concurrent.futures import ThreadPoolExecutor
9 | from datetime import datetime, timedelta
10 | from typing import Any
11 |
12 | import pandas as pd
13 | import yfinance as yf
14 | from requests import Session
15 | from requests.adapters import HTTPAdapter
16 | from urllib3.util.retry import Retry
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class YFinancePool:
22 | """Thread-safe yfinance connection pool with optimized session management."""
23 |
24 | _instance = None
25 | _lock = threading.Lock()
26 |
27 | def __new__(cls):
28 | """Singleton pattern to ensure single connection pool."""
29 | if cls._instance is None:
30 | with cls._lock:
31 | if cls._instance is None:
32 | cls._instance = super().__new__(cls)
33 | cls._instance._initialized = False
34 | return cls._instance
35 |
36 | def __init__(self):
37 | """Initialize the connection pool once."""
38 | if self._initialized:
39 | return
40 |
41 | # Create optimized session with connection pooling
42 | self.session = self._create_optimized_session()
43 |
44 | # Thread pool for parallel requests
45 | self.executor = ThreadPoolExecutor(
46 | max_workers=10, thread_name_prefix="yfinance_pool"
47 | )
48 |
49 | # Request cache (simple TTL cache)
50 | self._request_cache: dict[str, tuple[Any, float]] = {}
51 | self._cache_lock = threading.Lock()
52 | self._cache_ttl = 60 # 1 minute cache for quotes
53 |
54 | self._initialized = True
55 | logger.info("YFinance connection pool initialized")
56 |
57 | def _create_optimized_session(self) -> Session:
58 | """Create an optimized requests session with retry logic and connection pooling."""
59 | session = Session()
60 |
61 | # Configure retry strategy
62 | retry_strategy = Retry(
63 | total=3,
64 | backoff_factor=0.3,
65 | status_forcelist=[429, 500, 502, 503, 504],
66 | allowed_methods=["GET", "POST"],
67 | )
68 |
69 | # Configure adapter with connection pooling
70 | adapter = HTTPAdapter(
71 | pool_connections=10, # Number of connection pools
72 | pool_maxsize=50, # Max connections per pool
73 | max_retries=retry_strategy,
74 | pool_block=False, # Don't block when pool is full
75 | )
76 |
77 | # Mount adapter for HTTP and HTTPS
78 | session.mount("http://", adapter)
79 | session.mount("https://", adapter)
80 |
81 | # Set headers to avoid rate limiting
82 | session.headers.update(
83 | {
84 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
85 | "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
86 | "Accept-Language": "en-US,en;q=0.5",
87 | "Accept-Encoding": "gzip, deflate",
88 | "DNT": "1",
89 | "Connection": "keep-alive",
90 | "Upgrade-Insecure-Requests": "1",
91 | }
92 | )
93 |
94 | return session
95 |
96 | def get_ticker(self, symbol: str) -> yf.Ticker:
97 | """Get a ticker object - let yfinance handle session for compatibility."""
98 | # Check cache first
99 | cache_key = f"ticker_{symbol}"
100 | cached = self._get_from_cache(cache_key)
101 | if cached:
102 | return cached
103 |
104 | # Create ticker without custom session (yfinance now requires curl_cffi)
105 | ticker = yf.Ticker(symbol)
106 |
107 | # Cache for short duration
108 | self._add_to_cache(cache_key, ticker, ttl=300) # 5 minutes
109 |
110 | return ticker
111 |
112 | def get_history(
113 | self,
114 | symbol: str,
115 | start: str | None = None,
116 | end: str | None = None,
117 | period: str | None = None,
118 | interval: str = "1d",
119 | ) -> pd.DataFrame:
120 | """Get historical data with connection pooling."""
121 | # Create cache key
122 | cache_key = f"history_{symbol}_{start}_{end}_{period}_{interval}"
123 |
124 | # Check cache
125 | cached = self._get_from_cache(cache_key)
126 | if cached is not None and not cached.empty:
127 | return cached
128 |
129 | # Get ticker with optimized session
130 | ticker = self.get_ticker(symbol)
131 |
132 | # Fetch data
133 | if period:
134 | df = ticker.history(period=period, interval=interval)
135 | else:
136 | if start is None:
137 | start = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
138 | if end is None:
139 | end = datetime.now().strftime("%Y-%m-%d")
140 | df = ticker.history(start=start, end=end, interval=interval)
141 |
142 | # Cache the result (longer TTL for historical data)
143 | if not df.empty:
144 | ttl = (
145 | 3600 if interval == "1d" else 300
146 | ) # 1 hour for daily, 5 min for intraday
147 | self._add_to_cache(cache_key, df, ttl=ttl)
148 |
149 | return df
150 |
151 | def get_info(self, symbol: str) -> dict:
152 | """Get stock info with caching."""
153 | cache_key = f"info_{symbol}"
154 |
155 | # Check cache
156 | cached = self._get_from_cache(cache_key)
157 | if cached:
158 | return cached
159 |
160 | # Get ticker and info
161 | ticker = self.get_ticker(symbol)
162 | info = ticker.info
163 |
164 | # Cache for longer duration (info doesn't change often)
165 | self._add_to_cache(cache_key, info, ttl=3600) # 1 hour
166 |
167 | return info
168 |
169 | def batch_download(
170 | self,
171 | symbols: list[str],
172 | start: str | None = None,
173 | end: str | None = None,
174 | period: str | None = None,
175 | interval: str = "1d",
176 | group_by: str = "ticker",
177 | threads: bool = True,
178 | ) -> pd.DataFrame:
179 | """Download data for multiple symbols efficiently."""
180 | # Use yfinance's batch download without custom session
181 | if period:
182 | data = yf.download(
183 | tickers=symbols,
184 | period=period,
185 | interval=interval,
186 | group_by=group_by,
187 | threads=threads,
188 | progress=False,
189 | )
190 | else:
191 | if start is None:
192 | start = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
193 | if end is None:
194 | end = datetime.now().strftime("%Y-%m-%d")
195 |
196 | data = yf.download(
197 | tickers=symbols,
198 | start=start,
199 | end=end,
200 | interval=interval,
201 | group_by=group_by,
202 | threads=threads,
203 | progress=False,
204 | )
205 |
206 | return data
207 |
208 | def _get_from_cache(self, key: str) -> Any | None:
209 | """Get item from cache if not expired."""
210 | with self._cache_lock:
211 | if key in self._request_cache:
212 | value, expiry = self._request_cache[key]
213 | if datetime.now().timestamp() < expiry:
214 | logger.debug(f"Cache hit for {key}")
215 | return value
216 | else:
217 | del self._request_cache[key]
218 | return None
219 |
220 | def _add_to_cache(self, key: str, value: Any, ttl: int = 60):
221 | """Add item to cache with TTL."""
222 | with self._cache_lock:
223 | expiry = datetime.now().timestamp() + ttl
224 | self._request_cache[key] = (value, expiry)
225 |
226 | # Clean up old entries if cache is too large
227 | if len(self._request_cache) > 1000:
228 | self._cleanup_cache()
229 |
230 | def _cleanup_cache(self):
231 | """Remove expired entries from cache."""
232 | current_time = datetime.now().timestamp()
233 | expired_keys = [
234 | k for k, (_, expiry) in self._request_cache.items() if expiry < current_time
235 | ]
236 | for key in expired_keys:
237 | del self._request_cache[key]
238 |
239 | # If still too large, remove oldest entries
240 | if len(self._request_cache) > 800:
241 | sorted_items = sorted(
242 | self._request_cache.items(),
243 | key=lambda x: x[1][1], # Sort by expiry time
244 | )
245 | # Keep only the newest 600 entries
246 | self._request_cache = dict(sorted_items[-600:])
247 |
248 | def close(self):
249 | """Clean up resources."""
250 | try:
251 | self.session.close()
252 | self.executor.shutdown(wait=False)
253 | logger.info("YFinance connection pool closed")
254 | except Exception as e:
255 | logger.warning(f"Error closing connection pool: {e}")
256 |
257 |
258 | # Global instance
259 | _yfinance_pool: YFinancePool | None = None
260 |
261 |
262 | def get_yfinance_pool() -> YFinancePool:
263 | """Get or create the global yfinance connection pool."""
264 | global _yfinance_pool
265 | if _yfinance_pool is None:
266 | _yfinance_pool = YFinancePool()
267 | return _yfinance_pool
268 |
269 |
270 | def cleanup_yfinance_pool():
271 | """Clean up the global connection pool."""
272 | global _yfinance_pool
273 | if _yfinance_pool:
274 | _yfinance_pool.close()
275 | _yfinance_pool = None
276 |
```
--------------------------------------------------------------------------------
/maverick_mcp/validation/middleware.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Validation middleware for FastAPI to standardize error handling.
3 |
4 | This module provides middleware to catch validation errors and
5 | return standardized error responses.
6 | """
7 |
8 | import logging
9 | import time
10 | import traceback
11 | import uuid
12 |
13 | from fastapi import Request, Response, status
14 | from fastapi.exceptions import RequestValidationError
15 | from fastapi.responses import JSONResponse
16 | from pydantic import ValidationError
17 | from starlette.middleware.base import BaseHTTPMiddleware
18 |
19 | from maverick_mcp.exceptions import MaverickException
20 |
21 | from .responses import error_response, validation_error_response
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class ValidationMiddleware(BaseHTTPMiddleware):
27 | """Middleware to handle validation errors and API exceptions."""
28 |
29 | async def dispatch(self, request: Request, call_next) -> Response:
30 | """Process request and handle exceptions."""
31 | # Generate trace ID for request tracking
32 | trace_id = str(uuid.uuid4())
33 | request.state.trace_id = trace_id
34 |
35 | try:
36 | response = await call_next(request)
37 | return response
38 |
39 | except MaverickException as e:
40 | logger.warning(
41 | f"API error: {e.error_code} - {e.message}",
42 | extra={
43 | "trace_id": trace_id,
44 | "path": request.url.path,
45 | "method": request.method,
46 | "error_code": e.error_code,
47 | },
48 | )
49 | return JSONResponse(
50 | status_code=e.status_code,
51 | content=error_response(
52 | code=e.error_code,
53 | message=e.message,
54 | status_code=e.status_code,
55 | field=e.field,
56 | context=e.context,
57 | trace_id=trace_id,
58 | ),
59 | )
60 |
61 | except RequestValidationError as e:
62 | logger.warning(
63 | f"Request validation error: {str(e)}",
64 | extra={
65 | "trace_id": trace_id,
66 | "path": request.url.path,
67 | "method": request.method,
68 | },
69 | )
70 |
71 | # Convert Pydantic validation errors to our format
72 | errors = []
73 | for error in e.errors():
74 | errors.append(
75 | {
76 | "code": "VALIDATION_ERROR",
77 | "field": ".".join(str(x) for x in error["loc"]),
78 | "message": error["msg"],
79 | "context": {"input": error.get("input"), "type": error["type"]},
80 | }
81 | )
82 |
83 | return JSONResponse(
84 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
85 | content=validation_error_response(errors=errors, trace_id=trace_id),
86 | )
87 |
88 | except ValidationError as e:
89 | logger.warning(
90 | f"Pydantic validation error: {str(e)}",
91 | extra={
92 | "trace_id": trace_id,
93 | "path": request.url.path,
94 | "method": request.method,
95 | },
96 | )
97 |
98 | # Convert Pydantic validation errors
99 | errors = []
100 | for error in e.errors():
101 | errors.append(
102 | {
103 | "code": "VALIDATION_ERROR",
104 | "field": ".".join(str(x) for x in error["loc"]),
105 | "message": error["msg"],
106 | "context": {"input": error.get("input"), "type": error["type"]},
107 | }
108 | )
109 |
110 | return JSONResponse(
111 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
112 | content=validation_error_response(errors=errors, trace_id=trace_id),
113 | )
114 |
115 | except Exception as e:
116 | logger.error(
117 | f"Unexpected error: {str(e)}",
118 | extra={
119 | "trace_id": trace_id,
120 | "path": request.url.path,
121 | "method": request.method,
122 | "traceback": traceback.format_exc(),
123 | },
124 | )
125 |
126 | return JSONResponse(
127 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
128 | content=error_response(
129 | code="INTERNAL_ERROR",
130 | message="An unexpected error occurred",
131 | status_code=500,
132 | trace_id=trace_id,
133 | ),
134 | )
135 |
136 |
137 | class RateLimitMiddleware(BaseHTTPMiddleware):
138 | """Middleware for rate limiting based on API keys."""
139 |
140 | def __init__(self, app, rate_limit_store=None):
141 | super().__init__(app)
142 | self.rate_limit_store = rate_limit_store or {}
143 |
144 | async def dispatch(self, request: Request, call_next) -> Response:
145 | """Check rate limits before processing request."""
146 | # Skip rate limiting for health checks and internal endpoints
147 | if request.url.path in ["/health", "/metrics", "/docs", "/openapi.json"]:
148 | return await call_next(request)
149 |
150 | # Extract API key from headers
151 | api_key = None
152 | auth_header = request.headers.get("authorization")
153 | if auth_header and auth_header.startswith("Bearer "):
154 | api_key = auth_header[7:]
155 | elif "x-api-key" in request.headers:
156 | api_key = request.headers["x-api-key"]
157 |
158 | if api_key:
159 | # Check rate limit (simplified implementation)
160 | # In production, use Redis or similar for distributed rate limiting
161 | current_time = int(time.time())
162 | window_start = current_time - 60 # 1-minute window
163 |
164 | # Clean old entries
165 | key_requests = self.rate_limit_store.get(api_key, [])
166 | key_requests = [ts for ts in key_requests if ts > window_start]
167 |
168 | # Check limit (default 60 requests per minute)
169 | if len(key_requests) >= 60:
170 | trace_id = getattr(request.state, "trace_id", str(uuid.uuid4()))
171 | return JSONResponse(
172 | status_code=status.HTTP_429_TOO_MANY_REQUESTS,
173 | content=error_response(
174 | code="RATE_LIMIT_EXCEEDED",
175 | message="Rate limit exceeded",
176 | status_code=429,
177 | context={
178 | "limit": 60,
179 | "window": "1 minute",
180 | "retry_after": 60 - (current_time % 60),
181 | },
182 | trace_id=trace_id,
183 | ),
184 | headers={"Retry-After": "60"},
185 | )
186 |
187 | # Add current request
188 | key_requests.append(current_time)
189 | self.rate_limit_store[api_key] = key_requests
190 |
191 | return await call_next(request)
192 |
193 |
194 | class SecurityMiddleware(BaseHTTPMiddleware):
195 | """Security middleware for headers and request validation."""
196 |
197 | async def dispatch(self, request: Request, call_next) -> Response:
198 | """Add security headers and validate requests."""
199 | # Validate content type for POST/PUT requests
200 | if request.method in ["POST", "PUT", "PATCH"]:
201 | content_type = request.headers.get("content-type", "")
202 | if not content_type.startswith("application/json"):
203 | trace_id = getattr(request.state, "trace_id", str(uuid.uuid4()))
204 | return JSONResponse(
205 | status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
206 | content=error_response(
207 | code="UNSUPPORTED_MEDIA_TYPE",
208 | message="Content-Type must be application/json",
209 | status_code=415,
210 | trace_id=trace_id,
211 | ),
212 | )
213 |
214 | # Validate request size (10MB limit)
215 | content_length = request.headers.get("content-length")
216 | if content_length and int(content_length) > 10 * 1024 * 1024:
217 | trace_id = getattr(request.state, "trace_id", str(uuid.uuid4()))
218 | return JSONResponse(
219 | status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
220 | content=error_response(
221 | code="REQUEST_TOO_LARGE",
222 | message="Request entity too large (max 10MB)",
223 | status_code=413,
224 | trace_id=trace_id,
225 | ),
226 | )
227 |
228 | response = await call_next(request)
229 |
230 | # Add security headers
231 | response.headers["X-Content-Type-Options"] = "nosniff"
232 | response.headers["X-Frame-Options"] = "DENY"
233 | response.headers["X-XSS-Protection"] = "1; mode=block"
234 | response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
235 | response.headers["Strict-Transport-Security"] = (
236 | "max-age=31536000; includeSubDomains"
237 | )
238 |
239 | return response
240 |
```
--------------------------------------------------------------------------------
/maverick_mcp/agents/circuit_breaker.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Circuit Breaker pattern for resilient external API calls.
3 | """
4 |
5 | import asyncio
6 | import logging
7 | import time
8 | from collections.abc import Callable
9 | from enum import Enum
10 | from typing import Any
11 |
12 | from maverick_mcp.config.settings import get_settings
13 |
14 | logger = logging.getLogger(__name__)
15 | settings = get_settings()
16 |
17 |
18 | class CircuitState(Enum):
19 | """Circuit breaker states."""
20 |
21 | CLOSED = "closed" # Normal operation
22 | OPEN = "open" # Failing, reject calls
23 | HALF_OPEN = "half_open" # Testing if service recovered
24 |
25 |
26 | class CircuitBreaker:
27 | """
28 | Circuit breaker for protecting against cascading failures.
29 |
30 | Implements the circuit breaker pattern to prevent repeated calls
31 | to failing services and allow them time to recover.
32 | """
33 |
34 | def __init__(
35 | self,
36 | failure_threshold: int | None = None,
37 | recovery_timeout: int | None = None,
38 | expected_exception: type[Exception] = Exception,
39 | name: str = "CircuitBreaker",
40 | ):
41 | """
42 | Initialize circuit breaker.
43 |
44 | Args:
45 | failure_threshold: Number of failures before opening circuit (uses config default if None)
46 | recovery_timeout: Seconds to wait before testing recovery (uses config default if None)
47 | expected_exception: Exception type to catch
48 | name: Name for logging
49 | """
50 | self.failure_threshold = (
51 | failure_threshold or settings.agent.circuit_breaker_failure_threshold
52 | )
53 | self.recovery_timeout = (
54 | recovery_timeout or settings.agent.circuit_breaker_recovery_timeout
55 | )
56 | self.expected_exception = expected_exception
57 | self.name = name
58 |
59 | self._failure_count = 0
60 | self._last_failure_time: float | None = None
61 | self._state = CircuitState.CLOSED
62 | self._lock = asyncio.Lock()
63 |
64 | @property
65 | def state(self) -> CircuitState:
66 | """Get current circuit state."""
67 | return self._state
68 |
69 | @property
70 | def failure_count(self) -> int:
71 | """Get current failure count."""
72 | return self._failure_count
73 |
74 | async def call(self, func: Callable, *args, **kwargs) -> Any:
75 | """
76 | Call function through circuit breaker.
77 |
78 | Args:
79 | func: Function to call
80 | *args: Function arguments
81 | **kwargs: Function keyword arguments
82 |
83 | Returns:
84 | Function result
85 |
86 | Raises:
87 | Exception: If circuit is open or function fails
88 | """
89 | async with self._lock:
90 | if self._state == CircuitState.OPEN:
91 | if self._should_attempt_reset():
92 | self._state = CircuitState.HALF_OPEN
93 | logger.info(f"{self.name}: Attempting reset (half-open)")
94 | else:
95 | raise Exception(f"{self.name}: Circuit breaker is OPEN")
96 |
97 | try:
98 | # Execute the function
99 | if asyncio.iscoroutinefunction(func):
100 | result = await func(*args, **kwargs)
101 | else:
102 | result = func(*args, **kwargs)
103 |
104 | # Success - reset on half-open or reduce failure count
105 | await self._on_success()
106 | return result
107 |
108 | except self.expected_exception as e:
109 | # Failure - increment counter and possibly open circuit
110 | await self._on_failure()
111 | raise e
112 |
113 | async def _on_success(self):
114 | """Handle successful call."""
115 | async with self._lock:
116 | if self._state == CircuitState.HALF_OPEN:
117 | self._state = CircuitState.CLOSED
118 | self._failure_count = 0
119 | logger.info(f"{self.name}: Circuit breaker CLOSED after recovery")
120 | elif self._failure_count > 0:
121 | self._failure_count = max(0, self._failure_count - 1)
122 |
123 | async def _on_failure(self):
124 | """Handle failed call."""
125 | async with self._lock:
126 | self._failure_count += 1
127 | self._last_failure_time = time.time()
128 |
129 | if self._failure_count >= self.failure_threshold:
130 | self._state = CircuitState.OPEN
131 | logger.warning(
132 | f"{self.name}: Circuit breaker OPEN after {self._failure_count} failures"
133 | )
134 | elif self._state == CircuitState.HALF_OPEN:
135 | self._state = CircuitState.OPEN
136 | logger.warning(
137 | f"{self.name}: Circuit breaker OPEN after half-open test failed"
138 | )
139 |
140 | def _should_attempt_reset(self) -> bool:
141 | """Check if enough time has passed to attempt reset."""
142 | if self._last_failure_time is None:
143 | return False
144 |
145 | return (time.time() - self._last_failure_time) >= self.recovery_timeout
146 |
147 | async def reset(self):
148 | """Manually reset the circuit breaker."""
149 | async with self._lock:
150 | self._state = CircuitState.CLOSED
151 | self._failure_count = 0
152 | self._last_failure_time = None
153 | logger.info(f"{self.name}: Circuit breaker manually RESET")
154 |
155 | def get_status(self) -> dict[str, Any]:
156 | """Get circuit breaker status."""
157 | return {
158 | "name": self.name,
159 | "state": self._state.value,
160 | "failure_count": self._failure_count,
161 | "failure_threshold": self.failure_threshold,
162 | "recovery_timeout": self.recovery_timeout,
163 | "time_until_retry": self._get_time_until_retry(),
164 | }
165 |
166 | def _get_time_until_retry(self) -> float | None:
167 | """Get seconds until retry is allowed."""
168 | if self._state != CircuitState.OPEN or self._last_failure_time is None:
169 | return None
170 |
171 | elapsed = time.time() - self._last_failure_time
172 | remaining = self.recovery_timeout - elapsed
173 | return max(0, remaining)
174 |
175 |
176 | class CircuitBreakerManager:
177 | """Manage multiple circuit breakers."""
178 |
179 | def __init__(self):
180 | """Initialize circuit breaker manager."""
181 | self._breakers: dict[str, CircuitBreaker] = {}
182 | self._lock = asyncio.Lock()
183 |
184 | async def get_or_create(
185 | self,
186 | name: str,
187 | failure_threshold: int = 5,
188 | recovery_timeout: int = 60,
189 | expected_exception: type[Exception] = Exception,
190 | ) -> CircuitBreaker:
191 | """Get existing or create new circuit breaker."""
192 | async with self._lock:
193 | if name not in self._breakers:
194 | self._breakers[name] = CircuitBreaker(
195 | failure_threshold=failure_threshold,
196 | recovery_timeout=recovery_timeout,
197 | expected_exception=expected_exception,
198 | name=name,
199 | )
200 | return self._breakers[name]
201 |
202 | def get_all_status(self) -> dict[str, dict[str, Any]]:
203 | """Get status of all circuit breakers."""
204 | return {name: breaker.get_status() for name, breaker in self._breakers.items()}
205 |
206 | async def reset_all(self):
207 | """Reset all circuit breakers."""
208 | for breaker in self._breakers.values():
209 | await breaker.reset()
210 |
211 |
212 | # Global circuit breaker manager
213 | circuit_manager = CircuitBreakerManager()
214 |
215 |
216 | def circuit_breaker(
217 | name: str | None = None,
218 | failure_threshold: int = 5,
219 | recovery_timeout: int = 60,
220 | expected_exception: type[Exception] = Exception,
221 | ):
222 | """
223 | Decorator to wrap functions with circuit breaker protection.
224 |
225 | Args:
226 | name: Circuit breaker name (uses function name if None)
227 | failure_threshold: Number of failures before opening circuit
228 | recovery_timeout: Seconds to wait before testing recovery
229 | expected_exception: Exception type to catch
230 |
231 | Example:
232 | @circuit_breaker("api_call", failure_threshold=3, recovery_timeout=30)
233 | async def call_external_api():
234 | # API call logic
235 | pass
236 | """
237 |
238 | def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
239 | breaker_name = (
240 | name or f"{func.__module__}.{getattr(func, '__name__', 'unknown')}"
241 | )
242 |
243 | if asyncio.iscoroutinefunction(func):
244 |
245 | async def async_wrapper(*args, **kwargs):
246 | breaker = await circuit_manager.get_or_create(
247 | breaker_name,
248 | failure_threshold=failure_threshold,
249 | recovery_timeout=recovery_timeout,
250 | expected_exception=expected_exception,
251 | )
252 | return await breaker.call(func, *args, **kwargs)
253 |
254 | return async_wrapper
255 | else:
256 |
257 | def sync_wrapper(*args, **kwargs):
258 | # For sync functions, we need to handle async breaker differently
259 | # This is a simplified version - in production you'd want proper async handling
260 | try:
261 | return func(*args, **kwargs)
262 | except expected_exception as e:
263 | logger.warning(f"Circuit breaker {breaker_name}: {e}")
264 | raise
265 |
266 | return sync_wrapper
267 |
268 | return decorator
269 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/implementations/stock_data_adapter.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Stock data provider adapter.
3 |
4 | This module provides adapters that make the existing StockDataProvider
5 | compatible with the new interface-based architecture while maintaining
6 | all existing functionality.
7 | """
8 |
9 | import asyncio
10 | import logging
11 | from typing import Any
12 |
13 | import pandas as pd
14 | from sqlalchemy.orm import Session
15 |
16 | from maverick_mcp.providers.interfaces.cache import ICacheManager
17 | from maverick_mcp.providers.interfaces.config import IConfigurationProvider
18 | from maverick_mcp.providers.interfaces.persistence import IDataPersistence
19 | from maverick_mcp.providers.interfaces.stock_data import (
20 | IStockDataFetcher,
21 | IStockScreener,
22 | )
23 | from maverick_mcp.providers.stock_data import StockDataProvider
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | class StockDataAdapter(IStockDataFetcher, IStockScreener):
29 | """
30 | Adapter that makes the existing StockDataProvider compatible with new interfaces.
31 |
32 | This adapter wraps the existing provider and exposes it through the new
33 | interface contracts, enabling gradual migration to the new architecture.
34 | """
35 |
36 | def __init__(
37 | self,
38 | cache_manager: ICacheManager | None = None,
39 | persistence: IDataPersistence | None = None,
40 | config: IConfigurationProvider | None = None,
41 | db_session: Session | None = None,
42 | ):
43 | """
44 | Initialize the stock data adapter.
45 |
46 | Args:
47 | cache_manager: Cache manager for data caching
48 | persistence: Persistence layer for database operations
49 | config: Configuration provider
50 | db_session: Optional database session for dependency injection
51 | """
52 | self._cache_manager = cache_manager
53 | self._persistence = persistence
54 | self._config = config
55 | self._db_session = db_session
56 |
57 | # Initialize the existing provider
58 | self._provider = StockDataProvider(db_session=db_session)
59 |
60 | logger.debug("StockDataAdapter initialized")
61 |
62 | async def get_stock_data(
63 | self,
64 | symbol: str,
65 | start_date: str | None = None,
66 | end_date: str | None = None,
67 | period: str | None = None,
68 | interval: str = "1d",
69 | use_cache: bool = True,
70 | ) -> pd.DataFrame:
71 | """
72 | Fetch historical stock data (async wrapper).
73 |
74 | Args:
75 | symbol: Stock ticker symbol
76 | start_date: Start date in YYYY-MM-DD format
77 | end_date: End date in YYYY-MM-DD format
78 | period: Alternative to start/end dates (e.g., '1y', '6mo')
79 | interval: Data interval ('1d', '1wk', '1mo', etc.)
80 | use_cache: Whether to use cached data if available
81 |
82 | Returns:
83 | DataFrame with OHLCV data indexed by date
84 | """
85 | loop = asyncio.get_event_loop()
86 | return await loop.run_in_executor(
87 | None,
88 | self._provider.get_stock_data,
89 | symbol,
90 | start_date,
91 | end_date,
92 | period,
93 | interval,
94 | use_cache,
95 | )
96 |
97 | async def get_realtime_data(self, symbol: str) -> dict[str, Any] | None:
98 | """
99 | Get real-time stock data (async wrapper).
100 |
101 | Args:
102 | symbol: Stock ticker symbol
103 |
104 | Returns:
105 | Dictionary with current price, change, volume, etc. or None if unavailable
106 | """
107 | loop = asyncio.get_event_loop()
108 | return await loop.run_in_executor(
109 | None, self._provider.get_realtime_data, symbol
110 | )
111 |
112 | async def get_stock_info(self, symbol: str) -> dict[str, Any]:
113 | """
114 | Get detailed stock information and fundamentals (async wrapper).
115 |
116 | Args:
117 | symbol: Stock ticker symbol
118 |
119 | Returns:
120 | Dictionary with company info, financials, and market data
121 | """
122 | loop = asyncio.get_event_loop()
123 | return await loop.run_in_executor(None, self._provider.get_stock_info, symbol)
124 |
125 | async def get_news(self, symbol: str, limit: int = 10) -> pd.DataFrame:
126 | """
127 | Get news articles for a stock (async wrapper).
128 |
129 | Args:
130 | symbol: Stock ticker symbol
131 | limit: Maximum number of articles to return
132 |
133 | Returns:
134 | DataFrame with news articles
135 | """
136 | loop = asyncio.get_event_loop()
137 | return await loop.run_in_executor(None, self._provider.get_news, symbol, limit)
138 |
139 | async def get_earnings(self, symbol: str) -> dict[str, Any]:
140 | """
141 | Get earnings information for a stock (async wrapper).
142 |
143 | Args:
144 | symbol: Stock ticker symbol
145 |
146 | Returns:
147 | Dictionary with earnings data and dates
148 | """
149 | loop = asyncio.get_event_loop()
150 | return await loop.run_in_executor(None, self._provider.get_earnings, symbol)
151 |
152 | async def get_recommendations(self, symbol: str) -> pd.DataFrame:
153 | """
154 | Get analyst recommendations for a stock (async wrapper).
155 |
156 | Args:
157 | symbol: Stock ticker symbol
158 |
159 | Returns:
160 | DataFrame with analyst recommendations
161 | """
162 | loop = asyncio.get_event_loop()
163 | return await loop.run_in_executor(
164 | None, self._provider.get_recommendations, symbol
165 | )
166 |
167 | async def is_market_open(self) -> bool:
168 | """
169 | Check if the stock market is currently open (async wrapper).
170 |
171 | Returns:
172 | True if market is open, False otherwise
173 | """
174 | loop = asyncio.get_event_loop()
175 | return await loop.run_in_executor(None, self._provider.is_market_open)
176 |
177 | async def is_etf(self, symbol: str) -> bool:
178 | """
179 | Check if a symbol represents an ETF (async wrapper).
180 |
181 | Args:
182 | symbol: Stock ticker symbol
183 |
184 | Returns:
185 | True if symbol is an ETF, False otherwise
186 | """
187 | loop = asyncio.get_event_loop()
188 | return await loop.run_in_executor(None, self._provider.is_etf, symbol)
189 |
190 | # IStockScreener implementation
191 | async def get_maverick_recommendations(
192 | self, limit: int = 20, min_score: int | None = None
193 | ) -> list[dict[str, Any]]:
194 | """
195 | Get bullish Maverick stock recommendations (async wrapper).
196 |
197 | Args:
198 | limit: Maximum number of recommendations
199 | min_score: Minimum combined score filter
200 |
201 | Returns:
202 | List of stock recommendations with technical analysis
203 | """
204 | loop = asyncio.get_event_loop()
205 | return await loop.run_in_executor(
206 | None, self._provider.get_maverick_recommendations, limit, min_score
207 | )
208 |
209 | async def get_maverick_bear_recommendations(
210 | self, limit: int = 20, min_score: int | None = None
211 | ) -> list[dict[str, Any]]:
212 | """
213 | Get bearish Maverick stock recommendations (async wrapper).
214 |
215 | Args:
216 | limit: Maximum number of recommendations
217 | min_score: Minimum score filter
218 |
219 | Returns:
220 | List of bear stock recommendations
221 | """
222 | loop = asyncio.get_event_loop()
223 | return await loop.run_in_executor(
224 | None, self._provider.get_maverick_bear_recommendations, limit, min_score
225 | )
226 |
227 | async def get_trending_recommendations(
228 | self, limit: int = 20, min_momentum_score: float | None = None
229 | ) -> list[dict[str, Any]]:
230 | """
231 | Get trending stock recommendations (async wrapper).
232 |
233 | Args:
234 | limit: Maximum number of recommendations
235 | min_momentum_score: Minimum momentum score filter
236 |
237 | Returns:
238 | List of trending stock recommendations
239 | """
240 | loop = asyncio.get_event_loop()
241 | return await loop.run_in_executor(
242 | None,
243 | self._provider.get_supply_demand_breakout_recommendations,
244 | limit,
245 | min_momentum_score,
246 | )
247 |
248 | async def get_all_screening_recommendations(
249 | self,
250 | ) -> dict[str, list[dict[str, Any]]]:
251 | """
252 | Get all screening recommendations in one call (async wrapper).
253 |
254 | Returns:
255 | Dictionary with all screening types and their recommendations
256 | """
257 | loop = asyncio.get_event_loop()
258 | return await loop.run_in_executor(
259 | None, self._provider.get_all_screening_recommendations
260 | )
261 |
262 | # Additional methods to expose provider functionality
263 | def get_sync_provider(self) -> StockDataProvider:
264 | """
265 | Get the underlying synchronous provider for backward compatibility.
266 |
267 | Returns:
268 | The wrapped StockDataProvider instance
269 | """
270 | return self._provider
271 |
272 | async def get_all_realtime_data(self, symbols: list[str]) -> dict[str, Any]:
273 | """
274 | Get real-time data for multiple symbols (async wrapper).
275 |
276 | Args:
277 | symbols: List of stock ticker symbols
278 |
279 | Returns:
280 | Dictionary mapping symbols to their real-time data
281 | """
282 | loop = asyncio.get_event_loop()
283 | return await loop.run_in_executor(
284 | None, self._provider.get_all_realtime_data, symbols
285 | )
286 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/parallel_screening.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Parallel stock screening utilities for Maverick-MCP.
3 |
4 | This module provides utilities for running stock screening operations
5 | in parallel using ProcessPoolExecutor for significant performance gains.
6 | """
7 |
8 | import asyncio
9 | import time
10 | from collections.abc import Callable
11 | from concurrent.futures import ProcessPoolExecutor, as_completed
12 | from typing import Any
13 |
14 | from maverick_mcp.utils.logging import get_logger
15 |
16 | logger = get_logger(__name__)
17 |
18 |
19 | class ParallelScreener:
20 | """
21 | Parallel stock screening executor.
22 |
23 | This class provides methods to run screening functions in parallel
24 | across multiple processes for better performance.
25 | """
26 |
27 | def __init__(self, max_workers: int | None = None):
28 | """
29 | Initialize the parallel screener.
30 |
31 | Args:
32 | max_workers: Maximum number of worker processes.
33 | Defaults to CPU count.
34 | """
35 | self.max_workers = max_workers
36 | self._executor: ProcessPoolExecutor | None = None
37 |
38 | def __enter__(self):
39 | """Context manager entry."""
40 | self._executor = ProcessPoolExecutor(max_workers=self.max_workers)
41 | return self
42 |
43 | def __exit__(self, exc_type, exc_val, exc_tb):
44 | """Context manager exit."""
45 | if self._executor:
46 | self._executor.shutdown(wait=True)
47 | self._executor = None
48 |
49 | def screen_batch(
50 | self,
51 | symbols: list[str],
52 | screening_func: Callable[[str], dict[str, Any]],
53 | batch_size: int = 10,
54 | timeout: float = 30.0,
55 | ) -> list[dict[str, Any]]:
56 | """
57 | Screen a batch of symbols in parallel.
58 |
59 | Args:
60 | symbols: List of stock symbols to screen
61 | screening_func: Function that takes a symbol and returns screening results
62 | batch_size: Number of symbols to process per worker
63 | timeout: Timeout for each screening operation
64 |
65 | Returns:
66 | List of screening results for symbols that passed
67 | """
68 | if not self._executor:
69 | raise RuntimeError("ParallelScreener must be used as context manager")
70 |
71 | start_time = time.time()
72 | results = []
73 | failed_symbols = []
74 |
75 | # Create batches
76 | batches = [
77 | symbols[i : i + batch_size] for i in range(0, len(symbols), batch_size)
78 | ]
79 |
80 | logger.info(
81 | f"Starting parallel screening of {len(symbols)} symbols "
82 | f"in {len(batches)} batches"
83 | )
84 |
85 | # Submit batch processing jobs
86 | future_to_batch = {
87 | self._executor.submit(self._process_batch, batch, screening_func): batch
88 | for batch in batches
89 | }
90 |
91 | # Collect results as they complete
92 | for future in as_completed(future_to_batch, timeout=timeout * len(batches)):
93 | batch = future_to_batch[future]
94 | try:
95 | batch_results = future.result()
96 | results.extend(batch_results)
97 | except Exception as e:
98 | logger.error(f"Batch processing failed: {e}")
99 | failed_symbols.extend(batch)
100 |
101 | elapsed = time.time() - start_time
102 | success_rate = (len(results) / len(symbols)) * 100 if symbols else 0
103 |
104 | logger.info(
105 | f"Parallel screening completed in {elapsed:.2f}s "
106 | f"({len(results)}/{len(symbols)} succeeded, "
107 | f"{success_rate:.1f}% success rate)"
108 | )
109 |
110 | if failed_symbols:
111 | logger.warning(f"Failed to screen symbols: {failed_symbols[:10]}...")
112 |
113 | return results
114 |
115 | @staticmethod
116 | def _process_batch(
117 | symbols: list[str], screening_func: Callable[[str], dict[str, Any]]
118 | ) -> list[dict[str, Any]]:
119 | """
120 | Process a batch of symbols.
121 |
122 | This runs in a separate process.
123 | """
124 | results = []
125 |
126 | for symbol in symbols:
127 | try:
128 | result = screening_func(symbol)
129 | if result and result.get("passed", False):
130 | results.append(result)
131 | except Exception as e:
132 | # Log errors but continue processing
133 | logger.debug(f"Screening failed for {symbol}: {e}")
134 |
135 | return results
136 |
137 |
138 | async def parallel_screen_async(
139 | symbols: list[str],
140 | screening_func: Callable[[str], dict[str, Any]],
141 | max_workers: int | None = None,
142 | batch_size: int = 10,
143 | ) -> list[dict[str, Any]]:
144 | """
145 | Async wrapper for parallel screening.
146 |
147 | Args:
148 | symbols: List of stock symbols to screen
149 | screening_func: Screening function (must be picklable)
150 | max_workers: Maximum number of worker processes
151 | batch_size: Number of symbols per batch
152 |
153 | Returns:
154 | List of screening results
155 | """
156 | loop = asyncio.get_event_loop()
157 |
158 | # Run screening in thread pool to avoid blocking
159 | def run_screening():
160 | with ParallelScreener(max_workers=max_workers) as screener:
161 | return screener.screen_batch(symbols, screening_func, batch_size)
162 |
163 | results = await loop.run_in_executor(None, run_screening)
164 | return results
165 |
166 |
167 | # Example screening function (must be at module level for pickling)
168 | def example_momentum_screen(symbol: str) -> dict[str, Any]:
169 | """
170 | Example momentum screening function.
171 |
172 | This must be defined at module level to be picklable for multiprocessing.
173 | """
174 | from maverick_mcp.core.technical_analysis import calculate_rsi, calculate_sma
175 | from maverick_mcp.providers.stock_data import StockDataProvider
176 |
177 | try:
178 | # Get stock data
179 | provider = StockDataProvider(use_cache=False)
180 | data = provider.get_stock_data(
181 | symbol, start_date="2023-01-01", end_date="2024-01-01"
182 | )
183 |
184 | if len(data) < 50:
185 | return {"symbol": symbol, "passed": False, "reason": "Insufficient data"}
186 |
187 | # Calculate indicators
188 | current_price = data["Close"].iloc[-1]
189 | sma_50 = calculate_sma(data, 50).iloc[-1]
190 | rsi = calculate_rsi(data, 14).iloc[-1]
191 |
192 | # Momentum criteria
193 | passed = (
194 | current_price > sma_50 # Price above 50-day SMA
195 | and 40 <= rsi <= 70 # RSI in healthy range
196 | )
197 |
198 | return {
199 | "symbol": symbol,
200 | "passed": passed,
201 | "price": round(current_price, 2),
202 | "sma_50": round(sma_50, 2),
203 | "rsi": round(rsi, 2),
204 | "above_sma": current_price > sma_50,
205 | }
206 |
207 | except Exception as e:
208 | return {"symbol": symbol, "passed": False, "error": str(e)}
209 |
210 |
211 | # Decorator for making functions parallel-friendly
212 | def make_parallel_safe(func: Callable) -> Callable:
213 | """
214 | Decorator to make a function safe for parallel execution.
215 |
216 | This ensures the function:
217 | 1. Doesn't rely on shared state
218 | 2. Handles its own database connections
219 | 3. Returns picklable results
220 | """
221 | from functools import wraps
222 |
223 | @wraps(func)
224 | def wrapper(*args, **kwargs):
225 | # Ensure clean execution environment
226 | import os
227 |
228 | os.environ["AUTH_ENABLED"] = "false"
229 |
230 | try:
231 | result = func(*args, **kwargs)
232 | # Ensure result is serializable
233 | import json
234 |
235 | json.dumps(result) # Test serializability
236 | return result
237 | except Exception as e:
238 | logger.error(f"Parallel execution error in {func.__name__}: {e}")
239 | return {"error": str(e), "passed": False}
240 |
241 | return wrapper
242 |
243 |
244 | # Batch screening with progress tracking
245 | class BatchScreener:
246 | """Enhanced batch screener with progress tracking."""
247 |
248 | def __init__(self, screening_func: Callable, max_workers: int = 4):
249 | self.screening_func = screening_func
250 | self.max_workers = max_workers
251 | self.results = []
252 | self.progress = 0
253 | self.total = 0
254 |
255 | def screen_with_progress(
256 | self,
257 | symbols: list[str],
258 | progress_callback: Callable[[int, int], None] | None = None,
259 | ) -> list[dict[str, Any]]:
260 | """
261 | Screen symbols with progress tracking.
262 |
263 | Args:
264 | symbols: List of symbols to screen
265 | progress_callback: Optional callback for progress updates
266 |
267 | Returns:
268 | List of screening results
269 | """
270 | self.total = len(symbols)
271 | self.progress = 0
272 |
273 | with ParallelScreener(max_workers=self.max_workers) as screener:
274 | # Process in smaller batches for better progress tracking
275 | batch_size = max(1, len(symbols) // (self.max_workers * 4))
276 |
277 | for i in range(0, len(symbols), batch_size):
278 | batch = symbols[i : i + batch_size]
279 | batch_results = screener.screen_batch(
280 | batch,
281 | self.screening_func,
282 | batch_size=1, # Process one at a time within batch
283 | )
284 |
285 | self.results.extend(batch_results)
286 | self.progress = min(i + batch_size, self.total)
287 |
288 | if progress_callback:
289 | progress_callback(self.progress, self.total)
290 |
291 | return self.results
292 |
```
--------------------------------------------------------------------------------
/PLANS.md:
--------------------------------------------------------------------------------
```markdown
1 | # PLANS.md
2 |
3 | The detailed Execution Plan (`PLANS.md`) is a **living document** and the **memory** that helps Codex steer toward a completed project. Fel mentioned his actual `plans.md` file was about **160 lines** in length, expanded to approximate the detail required for a major project, such as the 15,000-line change to the JSON parser for streaming tool calls.
4 |
5 | ## 1. Big Picture / Goal
6 |
7 | - **Objective:** To execute a core refactor of the existing streaming JSON parser architecture to seamlessly integrate the specialized `ToolCall_V2` library, enabling advanced, concurrent tool call processing and maintaining robust performance characteristics suitable for the "AI age". This refactor must minimize latency introduced during intermediate stream buffering.
8 | - **Architectural Goal:** Transition the core tokenization and parsing logic from synchronous, block-based handling to a fully asynchronous, state-machine-driven model, specifically targeting non-blocking tool call detection within the stream.
9 | - **Success Criteria (Mandatory):**
10 | - All existing unit, property, and fuzzing tests must pass successfully post-refactor.
11 | - New comprehensive integration tests must be written and passed to fully validate `ToolCall_V2` library functionality and streaming integration.
12 | - Performance benchmarks must demonstrate no more than a 5% regression in parsing speed under high-concurrency streaming loads.
13 | - The `plans.md` document must be fully updated upon completion, serving as the executive summary of the work accomplished.
14 | - A high-quality summary and documentation updates (e.g., Readme, API guides) reflecting the new architecture must be generated and committed.
15 |
16 | ## 2. To-Do List (High Level)
17 |
18 | - [ ] **Spike 1:** Comprehensive research and PoC for `ToolCall_V2` integration points.
19 | - [ ] **Refactor Core:** Implement the new asynchronous state machine for streaming tokenization.
20 | - [ ] **Feature A:** Implement the parsing hook necessary to detect `ToolCall_V2` structures mid-stream.
21 | - [ ] **Feature B:** Develop the compatibility layer (shim) for backward support of legacy tool call formats.
22 | - [ ] **Testing:** Write extensive property tests specifically targeting concurrency and error handling around tool calls.
23 | - [ ] **Documentation:** Update all internal and external documentation, including `README.md` and inline comments.
24 |
25 | ## 3. Plan Details (Spikes & Features)
26 |
27 | ### Spike 1: Research `ToolCall_V2` Integration
28 |
29 | - **Action:** Investigate the API signature of the `ToolCall_V2` library, focusing on its memory allocation strategies and compatibility with the current Rust asynchronous ecosystem (Tokio/Async-std). Determine if vendoring or a simple dependency inclusion is required.
30 | - **Steps:**
31 | 1. Analyze `ToolCall_V2` source code to understand its core dependencies and threading requirements.
32 | 2. Create a minimal proof-of-concept (PoC) file to test basic instantiation and serialization/deserialization flow.
33 | 3. Benchmark PoC for initial overhead costs compared to the previous custom parser logic.
34 | - **Expected Outcome:** A clear architectural recommendation regarding dependency management and an understanding of necessary low-level code modifications.
35 |
36 | ### Refactor Core: Asynchronous State Machine Implementation
37 |
38 | - **Goal:** Replace the synchronous `ChunkProcessor` with a `StreamParser` that utilizes an internal state enum (e.g., START, KEY, VALUE, TOOL_CALL_INIT, TOOL_CALL_BODY).
39 | - **Steps:**
40 | 1. Define the new `StreamParser` trait and associated state structures.
41 | 2. Migrate existing buffer management to use asynchronous channels/queues where appropriate.
42 | 3. Refactor token emission logic to be non-blocking.
43 | 4. Ensure all existing `panic!` points are converted to recoverable `Result` types for robust streaming.
44 |
45 | ### Feature A: `ToolCall_V2` Stream Hook
46 |
47 | - **Goal:** Inject logic into the `StreamParser` to identify the start of a tool call structure (e.g., specific JSON key sequence) and hand control to the `ToolCall_V2` handler without blocking the main parser thread.
48 | - **Steps:**
49 | 1. Implement the `ParseState::TOOL_CALL_INIT` state.
50 | 2. Write the bridging code that streams raw bytes/tokens directly into the `ToolCall_V2` library's parser.
51 | 3. Handle the return of control to the main parser stream once the tool call object is fully constructed.
52 | 4. Verify that subsequent JSON data (after the tool call structure) is processed correctly.
53 |
54 | ### Feature B: Legacy Tool Call Compatibility Shim
55 |
56 | - **Goal:** Create a compatibility wrapper that translates incoming legacy tool call formats into the structures expected by the new `ToolCall_V2` processor, ensuring backward compatibility.
57 | - **Steps:**
58 | 1. Identify all legacy parsing endpoints that still utilize the old format.
59 | 2. Implement a `LegacyToolCallAdapter` struct to wrap the old format.
60 | 3. Test the adapter against a suite of known legacy inputs.
61 |
62 | ### Testing Phase
63 |
64 | - **Goal:** Achieve 100% test passing rate and add specific coverage for the new feature.
65 | - **Steps:**
66 | 1. Run the complete existing test suite to ensure the core refactor has not caused regressions.
67 | 2. Implement new property tests focused on interleaved data streams: standard JSON data mixed with large, complex `ToolCall_V2` objects.
68 | 3. Integrate and run the fuzzing tests against the new `StreamParser`.
69 |
70 | ## 4. Progress (Living Document Section)
71 |
72 | _(This section is regularly updated by Codex, acting as its memory, showing items completed and current status)._
73 |
74 | |Date|Time|Item Completed / Status Update|Resulting Changes (LOC/Commit)|
75 | |:--|:--|:--|:--|
76 | |2023-11-01|09:30|Plan initialized. Began research on Spike 1.|Initial `plans.md` committed.|
77 | |2023-11-01|11:45|Completed Spike 1 research. Decision made to vendor/fork `ToolCall_V2`.|Research notes added to Decision Log.|
78 | |2023-11-01|14:00|Defined `StreamParser` trait and core state enum structures.|Initial ~500 lines of refactor boilerplate.|
79 | |2023-11-01|17:15|Migrated synchronous buffer logic to non-blocking approach. Core tests failing (expected).|~2,500 LOC modified in `core/parser_engine.rs`.|
80 | |2023-11-02|10:30|Completed implementation of Feature A (Tool Call Stream Hook).|New `tool_call_handler.rs` module committed.|
81 | |2023-11-02|13:45|Wrote initial suite of integration tests for Feature A. Tests now intermittently passing.|~600 LOC of new test code.|
82 | |2023-11-02|15:50|Implemented Feature B (Legacy Shim). All existing unit tests pass again.|Code change finalized. Total PR delta now > 4,200 LOC.|
83 | |2023-11-02|16:20|Documentation updates for `README.md` completed and committed.|Documentation finalized.|
84 | |**Current Status:**|**[Timestamp]**|Tests are stable, clean-up phase initiated. Ready for final review and PR submission.|All checks complete.|
85 |
86 | ## 5. Surprises and Discoveries
87 |
88 | _(Unexpected technical issues or findings that influence the overall plan)._
89 |
90 | 1. **Threading Conflict:** The `ToolCall_V2` library uses an internal thread pool which conflicts with the parent process's executor configuration, necessitating extensive use of `tokio::task::spawn_blocking` wrappers instead of direct calls.
91 | 2. **Vendoring Requirement:** Due to a subtle memory leak identified in `ToolCall_V2`'s error handling path when processing incomplete streams, the decision was made to **vendor in** (fork and patch) the library to implement a necessary hotfix.
92 | 3. **JSON Format Edge Case:** Discovery of an obscure edge case where the streaming parser incorrectly handles immediately nested tool calls, requiring an adjustment to the `TOOL_CALL_INIT` state machine logic.
93 |
94 | ## 6. Decision Log
95 |
96 | _(Key implementation decisions made during the execution of the plan)._
97 |
98 | | Date | Decision | Rationale |
99 | | :--------- | :------------------------------------------------------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------- |
100 | | 2023-11-01 | Chosen Language/Framework: Rust and Tokio. | Maintain consistency with established project codebase. |
101 | | 2023-11-01 | Dependency Strategy: Vendoring/Forking `ToolCall_V2` library. | Provides greater control over critical memory management and allows for immediate patching of stream-related bugs. |
102 | | 2023-11-02 | Error Handling: Adopted custom `ParserError` enum for all failures. | Standardized error reporting across the new asynchronous streams, preventing unexpected panics in production. |
103 | | 2023-11-02 | Testing Priority: Exhaustive Property Tests. | Given the complexity of the core refactor, property tests were prioritized over simple unit tests to maximize confidence in the 15,000 LOC change. |
```
--------------------------------------------------------------------------------
/tests/test_graceful_shutdown.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Test graceful shutdown functionality.
3 | """
4 |
5 | import asyncio
6 | import os
7 | import signal
8 | import subprocess
9 | import sys
10 | import time
11 | from unittest.mock import patch
12 |
13 | import pytest
14 |
15 | from maverick_mcp.utils.shutdown import GracefulShutdownHandler, get_shutdown_handler
16 |
17 |
18 | class TestGracefulShutdown:
19 | """Test graceful shutdown handler."""
20 |
21 | def test_shutdown_handler_creation(self):
22 | """Test creating shutdown handler."""
23 | handler = GracefulShutdownHandler("test", shutdown_timeout=10, drain_timeout=5)
24 |
25 | assert handler.name == "test"
26 | assert handler.shutdown_timeout == 10
27 | assert handler.drain_timeout == 5
28 | assert not handler.is_shutting_down()
29 |
30 | def test_cleanup_registration(self):
31 | """Test registering cleanup callbacks."""
32 | handler = GracefulShutdownHandler("test")
33 |
34 | # Register callbacks
35 | callback1_called = False
36 | callback2_called = False
37 |
38 | def callback1():
39 | nonlocal callback1_called
40 | callback1_called = True
41 |
42 | def callback2():
43 | nonlocal callback2_called
44 | callback2_called = True
45 |
46 | handler.register_cleanup(callback1)
47 | handler.register_cleanup(callback2)
48 |
49 | assert len(handler._cleanup_callbacks) == 2
50 | assert callback1 in handler._cleanup_callbacks
51 | assert callback2 in handler._cleanup_callbacks
52 |
53 | @pytest.mark.asyncio
54 | async def test_request_tracking(self):
55 | """Test request tracking."""
56 | handler = GracefulShutdownHandler("test")
57 |
58 | # Create mock tasks
59 | async def dummy_task():
60 | await asyncio.sleep(0.1)
61 |
62 | task1 = asyncio.create_task(dummy_task())
63 | task2 = asyncio.create_task(dummy_task())
64 |
65 | # Track tasks
66 | handler.track_request(task1)
67 | handler.track_request(task2)
68 |
69 | assert len(handler._active_requests) == 2
70 |
71 | # Wait for tasks to complete
72 | await task1
73 | await task2
74 | await asyncio.sleep(0.1) # Allow cleanup
75 |
76 | assert len(handler._active_requests) == 0
77 |
78 | def test_signal_handler_installation(self):
79 | """Test signal handler installation."""
80 | handler = GracefulShutdownHandler("test")
81 |
82 | # Store original handlers
83 | original_sigterm = signal.signal(signal.SIGTERM, signal.SIG_DFL)
84 | original_sigint = signal.signal(signal.SIGINT, signal.SIG_DFL)
85 |
86 | try:
87 | # Install handlers
88 | handler.install_signal_handlers()
89 |
90 | # Verify handlers were changed
91 | current_sigterm = signal.signal(signal.SIGTERM, signal.SIG_DFL)
92 | current_sigint = signal.signal(signal.SIGINT, signal.SIG_DFL)
93 |
94 | assert current_sigterm == handler._signal_handler
95 | assert current_sigint == handler._signal_handler
96 |
97 | finally:
98 | # Restore original handlers
99 | signal.signal(signal.SIGTERM, original_sigterm)
100 | signal.signal(signal.SIGINT, original_sigint)
101 |
102 | @pytest.mark.asyncio
103 | async def test_async_shutdown_sequence(self):
104 | """Test async shutdown sequence."""
105 | handler = GracefulShutdownHandler("test", drain_timeout=0.5)
106 |
107 | # Track cleanup calls
108 | sync_called = False
109 | async_called = False
110 |
111 | def sync_cleanup():
112 | nonlocal sync_called
113 | sync_called = True
114 |
115 | async def async_cleanup():
116 | nonlocal async_called
117 | async_called = True
118 |
119 | handler.register_cleanup(sync_cleanup)
120 | handler.register_cleanup(async_cleanup)
121 |
122 | # Mock sys.exit to prevent actual exit
123 | with patch("sys.exit") as mock_exit:
124 | # Trigger shutdown
125 | handler._shutdown_in_progress = False
126 | await handler._async_shutdown("SIGTERM")
127 |
128 | # Verify shutdown sequence
129 | assert handler._shutdown_event.is_set()
130 | assert sync_called
131 | assert async_called
132 | mock_exit.assert_called_once_with(0)
133 |
134 | @pytest.mark.asyncio
135 | async def test_request_draining_timeout(self):
136 | """Test request draining with timeout."""
137 | handler = GracefulShutdownHandler("test", drain_timeout=0.2)
138 |
139 | # Create long-running task
140 | async def long_task():
141 | await asyncio.sleep(1.0) # Longer than drain timeout
142 |
143 | task = asyncio.create_task(long_task())
144 | handler.track_request(task)
145 |
146 | # Start draining
147 | start_time = time.time()
148 | try:
149 | await asyncio.wait_for(handler._wait_for_requests(), timeout=0.3)
150 | except TimeoutError:
151 | pass
152 |
153 | drain_time = time.time() - start_time
154 |
155 | # Should timeout quickly since task won't complete
156 | assert drain_time < 0.5
157 | assert task in handler._active_requests
158 |
159 | # Cancel task to clean up
160 | task.cancel()
161 | try:
162 | await task
163 | except asyncio.CancelledError:
164 | pass
165 |
166 | def test_global_shutdown_handler(self):
167 | """Test global shutdown handler singleton."""
168 | handler1 = get_shutdown_handler("test1")
169 | handler2 = get_shutdown_handler("test2")
170 |
171 | # Should return same instance
172 | assert handler1 is handler2
173 | assert handler1.name == "test1" # First call sets the name
174 |
175 | @pytest.mark.asyncio
176 | async def test_cleanup_callback_error_handling(self):
177 | """Test error handling in cleanup callbacks."""
178 | handler = GracefulShutdownHandler("test")
179 |
180 | # Create callback that raises exception
181 | def failing_callback():
182 | raise RuntimeError("Cleanup failed")
183 |
184 | async def async_failing_callback():
185 | raise RuntimeError("Async cleanup failed")
186 |
187 | handler.register_cleanup(failing_callback)
188 | handler.register_cleanup(async_failing_callback)
189 |
190 | # Mock sys.exit
191 | with patch("sys.exit"):
192 | # Should not raise despite callback errors
193 | await handler._async_shutdown("SIGTERM")
194 |
195 | # Handler should still complete shutdown
196 | assert handler._shutdown_event.is_set()
197 |
198 | @pytest.mark.asyncio
199 | async def test_sync_request_tracking(self):
200 | """Test synchronous request tracking context manager."""
201 | handler = GracefulShutdownHandler("test")
202 |
203 | # Use context manager
204 | with handler.track_sync_request():
205 | # In real usage, this would track the request
206 | pass
207 |
208 | # Should complete without error
209 | assert True
210 |
211 | @pytest.mark.skipif(
212 | sys.platform == "win32", reason="SIGHUP not available on Windows"
213 | )
214 | def test_sighup_handling(self):
215 | """Test SIGHUP signal handling."""
216 | handler = GracefulShutdownHandler("test")
217 |
218 | # Store original handler
219 | original_sighup = signal.signal(signal.SIGHUP, signal.SIG_DFL)
220 |
221 | try:
222 | handler.install_signal_handlers()
223 |
224 | # Verify SIGHUP handler was installed
225 | current_sighup = signal.signal(signal.SIGHUP, signal.SIG_DFL)
226 | assert current_sighup == handler._signal_handler
227 |
228 | finally:
229 | # Restore original handler
230 | signal.signal(signal.SIGHUP, original_sighup)
231 |
232 |
233 | @pytest.mark.integration
234 | class TestGracefulShutdownIntegration:
235 | """Integration tests for graceful shutdown."""
236 |
237 | @pytest.mark.asyncio
238 | async def test_server_graceful_shutdown(self):
239 | """Test actual server graceful shutdown."""
240 | # This would test with a real server process
241 | # For now, we'll simulate it
242 |
243 | # Start a subprocess that uses our shutdown handler
244 | script = """
245 | import asyncio
246 | import signal
247 | import sys
248 | import time
249 | import os
250 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
251 | from maverick_mcp.utils.shutdown import graceful_shutdown
252 |
253 | async def main():
254 | with graceful_shutdown("test-server") as handler:
255 | # Simulate server running
256 | print("Server started", flush=True)
257 |
258 | # Wait for shutdown
259 | try:
260 | await handler.wait_for_shutdown()
261 | except KeyboardInterrupt:
262 | pass
263 |
264 | print("Server shutting down", flush=True)
265 |
266 | if __name__ == "__main__":
267 | asyncio.run(main())
268 | """
269 |
270 | # Write script to temp file
271 | import tempfile
272 |
273 | with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
274 | f.write(script)
275 | script_path = f.name
276 |
277 | try:
278 | # Start subprocess
279 | proc = subprocess.Popen(
280 | [sys.executable, script_path],
281 | stdout=subprocess.PIPE,
282 | stderr=subprocess.PIPE,
283 | text=True,
284 | )
285 |
286 | # Wait for startup
287 | await asyncio.sleep(0.5)
288 |
289 | # Send SIGTERM
290 | proc.send_signal(signal.SIGTERM)
291 |
292 | # Wait for completion
293 | stdout, stderr = proc.communicate(timeout=5)
294 |
295 | # Verify graceful shutdown
296 | assert "Server started" in stdout
297 | assert "Server shutting down" in stdout
298 | assert proc.returncode == 0
299 |
300 | finally:
301 | os.unlink(script_path)
302 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/dependencies.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Dependency injection utilities for Maverick-MCP.
3 |
4 | This module provides dependency injection support for routers and other components,
5 | enabling clean separation of concerns and improved testability.
6 | """
7 |
8 | import logging
9 | from functools import lru_cache
10 |
11 | from maverick_mcp.providers.factories.config_factory import ConfigurationFactory
12 | from maverick_mcp.providers.factories.provider_factory import ProviderFactory
13 | from maverick_mcp.providers.interfaces.cache import ICacheManager
14 | from maverick_mcp.providers.interfaces.config import IConfigurationProvider
15 | from maverick_mcp.providers.interfaces.macro_data import IMacroDataProvider
16 | from maverick_mcp.providers.interfaces.market_data import IMarketDataProvider
17 | from maverick_mcp.providers.interfaces.persistence import IDataPersistence
18 | from maverick_mcp.providers.interfaces.stock_data import (
19 | IStockDataFetcher,
20 | IStockScreener,
21 | )
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | # Global provider factory instance
26 | _provider_factory: ProviderFactory | None = None
27 |
28 |
29 | def get_provider_factory() -> ProviderFactory:
30 | """
31 | Get the global provider factory instance.
32 |
33 | This function implements the singleton pattern to ensure a single
34 | factory instance is used throughout the application.
35 |
36 | Returns:
37 | ProviderFactory instance
38 | """
39 | global _provider_factory
40 |
41 | if _provider_factory is None:
42 | config = ConfigurationFactory.auto_detect_config()
43 | _provider_factory = ProviderFactory(config)
44 | logger.debug("Global provider factory initialized")
45 |
46 | return _provider_factory
47 |
48 |
49 | def set_provider_factory(factory: ProviderFactory) -> None:
50 | """
51 | Set the global provider factory instance.
52 |
53 | This is primarily used for testing to inject a custom factory.
54 |
55 | Args:
56 | factory: ProviderFactory instance to use globally
57 | """
58 | global _provider_factory
59 | _provider_factory = factory
60 | logger.debug("Global provider factory overridden")
61 |
62 |
63 | def reset_provider_factory() -> None:
64 | """
65 | Reset the global provider factory to None.
66 |
67 | This forces re-initialization on the next access, which is useful
68 | for testing or configuration changes.
69 | """
70 | global _provider_factory
71 | _provider_factory = None
72 | logger.debug("Global provider factory reset")
73 |
74 |
75 | # Dependency injection functions for use with FastAPI Depends() or similar
76 |
77 |
78 | def get_configuration() -> IConfigurationProvider:
79 | """
80 | Get configuration provider dependency.
81 |
82 | Returns:
83 | IConfigurationProvider instance
84 | """
85 | return get_provider_factory()._config
86 |
87 |
88 | def get_cache_manager() -> ICacheManager:
89 | """
90 | Get cache manager dependency.
91 |
92 | Returns:
93 | ICacheManager instance
94 | """
95 | return get_provider_factory().get_cache_manager()
96 |
97 |
98 | def get_persistence() -> IDataPersistence:
99 | """
100 | Get persistence layer dependency.
101 |
102 | Returns:
103 | IDataPersistence instance
104 | """
105 | return get_provider_factory().get_persistence()
106 |
107 |
108 | def get_stock_data_fetcher() -> IStockDataFetcher:
109 | """
110 | Get stock data fetcher dependency.
111 |
112 | Returns:
113 | IStockDataFetcher instance
114 | """
115 | return get_provider_factory().get_stock_data_fetcher()
116 |
117 |
118 | def get_stock_screener() -> IStockScreener:
119 | """
120 | Get stock screener dependency.
121 |
122 | Returns:
123 | IStockScreener instance
124 | """
125 | return get_provider_factory().get_stock_screener()
126 |
127 |
128 | def get_market_data_provider() -> IMarketDataProvider:
129 | """
130 | Get market data provider dependency.
131 |
132 | Returns:
133 | IMarketDataProvider instance
134 | """
135 | return get_provider_factory().get_market_data_provider()
136 |
137 |
138 | def get_macro_data_provider() -> IMacroDataProvider:
139 | """
140 | Get macro data provider dependency.
141 |
142 | Returns:
143 | IMacroDataProvider instance
144 | """
145 | return get_provider_factory().get_macro_data_provider()
146 |
147 |
148 | # Context manager for dependency overrides (useful for testing)
149 |
150 |
151 | class DependencyOverride:
152 | """
153 | Context manager for temporarily overriding dependencies.
154 |
155 | This is primarily useful for testing where you want to inject
156 | mock implementations for specific test cases.
157 | """
158 |
159 | def __init__(self, **overrides):
160 | """
161 | Initialize dependency override context.
162 |
163 | Args:
164 | **overrides: Keyword arguments mapping dependency names to override instances
165 | """
166 | self.overrides = overrides
167 | self.original_factory = None
168 | self.original_providers = {}
169 |
170 | def __enter__(self):
171 | """Enter the context and apply overrides."""
172 | global _provider_factory
173 |
174 | # Save original state
175 | self.original_factory = _provider_factory
176 |
177 | if _provider_factory is not None:
178 | # Save original provider instances
179 | for key in self.overrides:
180 | attr_name = f"_{key}"
181 | if hasattr(_provider_factory, attr_name):
182 | self.original_providers[key] = getattr(_provider_factory, attr_name)
183 |
184 | # Apply overrides
185 | for key, override in self.overrides.items():
186 | attr_name = f"_{key}"
187 | if hasattr(_provider_factory, attr_name):
188 | setattr(_provider_factory, attr_name, override)
189 | else:
190 | logger.warning(f"Unknown dependency override: {key}")
191 |
192 | return self
193 |
194 | def __exit__(self, exc_type, exc_val, exc_tb):
195 | """Exit the context and restore original dependencies."""
196 | global _provider_factory
197 |
198 | if _provider_factory is not None:
199 | # Restore original provider instances
200 | for key, original in self.original_providers.items():
201 | attr_name = f"_{key}"
202 | setattr(_provider_factory, attr_name, original)
203 |
204 | # Restore original factory
205 | _provider_factory = self.original_factory
206 |
207 |
208 | # Utility functions for testing
209 |
210 |
211 | def create_test_dependencies(**overrides) -> dict:
212 | """
213 | Create a dictionary of test dependencies with optional overrides.
214 |
215 | This is useful for creating dependencies for testing without
216 | affecting the global state.
217 |
218 | Args:
219 | **overrides: Keyword arguments for dependency overrides
220 |
221 | Returns:
222 | Dictionary mapping dependency names to instances
223 | """
224 | config = ConfigurationFactory.create_test_config()
225 | factory = ProviderFactory(config)
226 |
227 | dependencies = {
228 | "configuration": config,
229 | "cache_manager": factory.get_cache_manager(),
230 | "persistence": factory.get_persistence(),
231 | "stock_data_fetcher": factory.get_stock_data_fetcher(),
232 | "stock_screener": factory.get_stock_screener(),
233 | "market_data_provider": factory.get_market_data_provider(),
234 | "macro_data_provider": factory.get_macro_data_provider(),
235 | }
236 |
237 | # Apply any overrides
238 | dependencies.update(overrides)
239 |
240 | return dependencies
241 |
242 |
243 | def validate_dependencies() -> list[str]:
244 | """
245 | Validate that all dependencies are properly configured.
246 |
247 | Returns:
248 | List of validation errors (empty if valid)
249 | """
250 | try:
251 | factory = get_provider_factory()
252 | return factory.validate_configuration()
253 | except Exception as e:
254 | return [f"Failed to validate dependencies: {e}"]
255 |
256 |
257 | # Caching decorators for expensive dependency creation
258 |
259 |
260 | @lru_cache(maxsize=1)
261 | def get_cached_configuration() -> IConfigurationProvider:
262 | """Get cached configuration provider (singleton)."""
263 | return get_configuration()
264 |
265 |
266 | @lru_cache(maxsize=1)
267 | def get_cached_cache_manager() -> ICacheManager:
268 | """Get cached cache manager (singleton)."""
269 | return get_cache_manager()
270 |
271 |
272 | @lru_cache(maxsize=1)
273 | def get_cached_persistence() -> IDataPersistence:
274 | """Get cached persistence layer (singleton)."""
275 | return get_persistence()
276 |
277 |
278 | # Helper functions for router integration
279 |
280 |
281 | def inject_dependencies(**dependency_overrides):
282 | """
283 | Decorator for injecting dependencies into router functions.
284 |
285 | This decorator can be used to automatically inject dependencies
286 | into router functions without requiring explicit Depends() calls.
287 |
288 | Args:
289 | **dependency_overrides: Optional dependency overrides
290 |
291 | Returns:
292 | Decorator function
293 | """
294 |
295 | def decorator(func):
296 | def wrapper(*args, **kwargs):
297 | # Inject dependencies as keyword arguments
298 | if "stock_data_fetcher" not in kwargs:
299 | kwargs["stock_data_fetcher"] = dependency_overrides.get(
300 | "stock_data_fetcher", get_stock_data_fetcher()
301 | )
302 |
303 | if "cache_manager" not in kwargs:
304 | kwargs["cache_manager"] = dependency_overrides.get(
305 | "cache_manager", get_cache_manager()
306 | )
307 |
308 | if "config" not in kwargs:
309 | kwargs["config"] = dependency_overrides.get(
310 | "config", get_configuration()
311 | )
312 |
313 | return func(*args, **kwargs)
314 |
315 | return wrapper
316 |
317 | return decorator
318 |
319 |
320 | def get_dependencies_for_testing() -> dict:
321 | """
322 | Get a set of dependencies configured for testing.
323 |
324 | Returns:
325 | Dictionary of test-configured dependencies
326 | """
327 | return create_test_dependencies()
328 |
```
--------------------------------------------------------------------------------
/maverick_mcp/data/session_management.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Enhanced database session management with context managers.
3 |
4 | This module provides robust context managers for database session management
5 | that guarantee proper cleanup, automatic rollback on errors, and connection
6 | pool monitoring to prevent connection leaks.
7 |
8 | Addresses Issue #55: Implement Proper Database Session Management with Context Managers
9 | """
10 |
11 | import logging
12 | from collections.abc import AsyncGenerator, Generator
13 | from contextlib import asynccontextmanager, contextmanager
14 | from typing import Any
15 |
16 | from sqlalchemy.ext.asyncio import AsyncSession
17 | from sqlalchemy.orm import Session
18 |
19 | from maverick_mcp.data.models import (
20 | SessionLocal,
21 | _get_async_session_factory,
22 | )
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | @contextmanager
28 | def get_db_session() -> Generator[Session, None, None]:
29 | """
30 | Enhanced sync database session context manager.
31 |
32 | Provides:
33 | - Automatic session cleanup
34 | - Auto-commit on success
35 | - Auto-rollback on exceptions
36 | - Guaranteed session.close() even if commit/rollback fails
37 |
38 | Usage:
39 | with get_db_session() as session:
40 | # Perform database operations
41 | result = session.query(Model).all()
42 | # Session is automatically committed and closed
43 |
44 | Returns:
45 | Database session that will be properly managed
46 |
47 | Raises:
48 | Exception: Re-raises any database exceptions after rollback
49 | """
50 | session = SessionLocal()
51 | try:
52 | yield session
53 | session.commit()
54 | logger.debug("Database session committed successfully")
55 | except Exception as e:
56 | session.rollback()
57 | logger.warning(f"Database session rolled back due to error: {e}")
58 | raise
59 | finally:
60 | session.close()
61 | logger.debug("Database session closed")
62 |
63 |
64 | @asynccontextmanager
65 | async def get_async_db_session() -> AsyncGenerator[AsyncSession, None]:
66 | """
67 | Enhanced async database session context manager.
68 |
69 | Provides:
70 | - Automatic session cleanup for async operations
71 | - Auto-commit on success
72 | - Auto-rollback on exceptions
73 | - Guaranteed session.close() even if commit/rollback fails
74 |
75 | Usage:
76 | async with get_async_db_session() as session:
77 | # Perform async database operations
78 | result = await session.execute(query)
79 | # Session is automatically committed and closed
80 |
81 | Returns:
82 | Async database session that will be properly managed
83 |
84 | Raises:
85 | Exception: Re-raises any database exceptions after rollback
86 | """
87 | async_session_factory = _get_async_session_factory()
88 |
89 | async with async_session_factory() as session:
90 | try:
91 | yield session
92 | await session.commit()
93 | logger.debug("Async database session committed successfully")
94 | except Exception as e:
95 | await session.rollback()
96 | logger.warning(f"Async database session rolled back due to error: {e}")
97 | raise
98 | finally:
99 | await session.close()
100 | logger.debug("Async database session closed")
101 |
102 |
103 | @contextmanager
104 | def get_db_session_read_only() -> Generator[Session, None, None]:
105 | """
106 | Enhanced sync database session context manager for read-only operations.
107 |
108 | Optimized for read-only operations:
109 | - No auto-commit (read-only)
110 | - Rollback on any exception
111 | - Guaranteed session cleanup
112 |
113 | Usage:
114 | with get_db_session_read_only() as session:
115 | # Perform read-only database operations
116 | result = session.query(Model).all()
117 | # Session is automatically closed (no commit)
118 |
119 | Returns:
120 | Database session configured for read-only operations
121 |
122 | Raises:
123 | Exception: Re-raises any database exceptions after rollback
124 | """
125 | session = SessionLocal()
126 | try:
127 | yield session
128 | # No commit for read-only operations
129 | logger.debug("Read-only database session completed successfully")
130 | except Exception as e:
131 | session.rollback()
132 | logger.warning(f"Read-only database session rolled back due to error: {e}")
133 | raise
134 | finally:
135 | session.close()
136 | logger.debug("Read-only database session closed")
137 |
138 |
139 | @asynccontextmanager
140 | async def get_async_db_session_read_only() -> AsyncGenerator[AsyncSession, None]:
141 | """
142 | Enhanced async database session context manager for read-only operations.
143 |
144 | Optimized for read-only operations:
145 | - No auto-commit (read-only)
146 | - Rollback on any exception
147 | - Guaranteed session cleanup
148 |
149 | Usage:
150 | async with get_async_db_session_read_only() as session:
151 | # Perform read-only async database operations
152 | result = await session.execute(query)
153 | # Session is automatically closed (no commit)
154 |
155 | Returns:
156 | Async database session configured for read-only operations
157 |
158 | Raises:
159 | Exception: Re-raises any database exceptions after rollback
160 | """
161 | async_session_factory = _get_async_session_factory()
162 |
163 | async with async_session_factory() as session:
164 | try:
165 | yield session
166 | # No commit for read-only operations
167 | logger.debug("Read-only async database session completed successfully")
168 | except Exception as e:
169 | await session.rollback()
170 | logger.warning(
171 | f"Read-only async database session rolled back due to error: {e}"
172 | )
173 | raise
174 | finally:
175 | await session.close()
176 | logger.debug("Read-only async database session closed")
177 |
178 |
179 | def get_connection_pool_status() -> dict[str, Any]:
180 | """
181 | Get current connection pool status for monitoring.
182 |
183 | Returns:
184 | Dictionary containing pool metrics:
185 | - pool_size: Current pool size
186 | - checked_in: Number of connections currently checked in
187 | - checked_out: Number of connections currently checked out
188 | - overflow: Number of connections beyond pool_size
189 | - invalid: Number of invalid connections
190 | """
191 | from maverick_mcp.data.models import engine
192 |
193 | pool = engine.pool
194 |
195 | return {
196 | "pool_size": getattr(pool, "size", lambda: 0)(),
197 | "checked_in": getattr(pool, "checkedin", lambda: 0)(),
198 | "checked_out": getattr(pool, "checkedout", lambda: 0)(),
199 | "overflow": getattr(pool, "overflow", lambda: 0)(),
200 | "invalid": getattr(pool, "invalid", lambda: 0)(),
201 | "pool_status": "healthy"
202 | if getattr(pool, "checkedout", lambda: 0)()
203 | < getattr(pool, "size", lambda: 10)() * 0.8
204 | else "warning",
205 | }
206 |
207 |
208 | async def get_async_connection_pool_status() -> dict[str, Any]:
209 | """
210 | Get current async connection pool status for monitoring.
211 |
212 | Returns:
213 | Dictionary containing async pool metrics
214 | """
215 | from maverick_mcp.data.models import _get_async_engine
216 |
217 | engine = _get_async_engine()
218 | pool = engine.pool
219 |
220 | return {
221 | "pool_size": getattr(pool, "size", lambda: 0)(),
222 | "checked_in": getattr(pool, "checkedin", lambda: 0)(),
223 | "checked_out": getattr(pool, "checkedout", lambda: 0)(),
224 | "overflow": getattr(pool, "overflow", lambda: 0)(),
225 | "invalid": getattr(pool, "invalid", lambda: 0)(),
226 | "pool_status": "healthy"
227 | if getattr(pool, "checkedout", lambda: 0)()
228 | < getattr(pool, "size", lambda: 10)() * 0.8
229 | else "warning",
230 | }
231 |
232 |
233 | def check_connection_pool_health() -> bool:
234 | """
235 | Check if connection pool is healthy.
236 |
237 | Returns:
238 | True if pool is healthy, False if approaching limits
239 | """
240 | try:
241 | status = get_connection_pool_status()
242 | pool_utilization = (
243 | status["checked_out"] / status["pool_size"]
244 | if status["pool_size"] > 0
245 | else 0
246 | )
247 |
248 | # Consider unhealthy if > 80% utilization
249 | if pool_utilization > 0.8:
250 | logger.warning(f"High connection pool utilization: {pool_utilization:.2%}")
251 | return False
252 |
253 | # Check for invalid connections
254 | if status["invalid"] > 0:
255 | logger.warning(f"Invalid connections detected: {status['invalid']}")
256 | return False
257 |
258 | return True
259 |
260 | except Exception as e:
261 | logger.error(f"Failed to check connection pool health: {e}")
262 | return False
263 |
264 |
265 | async def check_async_connection_pool_health() -> bool:
266 | """
267 | Check if async connection pool is healthy.
268 |
269 | Returns:
270 | True if pool is healthy, False if approaching limits
271 | """
272 | try:
273 | status = await get_async_connection_pool_status()
274 | pool_utilization = (
275 | status["checked_out"] / status["pool_size"]
276 | if status["pool_size"] > 0
277 | else 0
278 | )
279 |
280 | # Consider unhealthy if > 80% utilization
281 | if pool_utilization > 0.8:
282 | logger.warning(
283 | f"High async connection pool utilization: {pool_utilization:.2%}"
284 | )
285 | return False
286 |
287 | # Check for invalid connections
288 | if status["invalid"] > 0:
289 | logger.warning(f"Invalid async connections detected: {status['invalid']}")
290 | return False
291 |
292 | return True
293 |
294 | except Exception as e:
295 | logger.error(f"Failed to check async connection pool health: {e}")
296 | return False
297 |
```
--------------------------------------------------------------------------------
/maverick_mcp/tests/test_mcp_tool_fixes.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Test suite to verify the three critical MCP tool fixes are working properly.
4 |
5 | This test validates that the fixes for:
6 | 1. Research returning empty results (API keys not passed to DeepResearchAgent)
7 | 2. Portfolio risk analysis cryptic "'high'" error (DataFrame validation and column case)
8 | 3. External API key hard dependency (graceful degradation)
9 |
10 | All continue to work correctly after code changes.
11 |
12 | ## Issues Fixed
13 |
14 | ### Issue #1: Research Returning Empty Results
15 | - **Root Cause**: API keys weren't passed from settings to DeepResearchAgent constructor
16 | - **Files Modified**:
17 | - `maverick_mcp/api/routers/research.py:line 35-40` - Added API key parameters
18 | - `maverick_mcp/providers/llm_factory.py:line 30` - Fixed temperature and streaming
19 | - **Fix**: Pass exa_api_key and tavily_api_key to DeepResearchAgent, fix LLM config
20 |
21 | ### Issue #2: Portfolio Risk Analysis "'high'" Error
22 | - **Root Cause**: DataFrame column name case mismatch and date range problems
23 | - **Files Modified**: `maverick_mcp/api/routers/portfolio.py:line 66-84`
24 | - **Fixes**:
25 | - Added DataFrame validation before column access
26 | - Fixed column name case sensitivity (High/Low/Close vs high/low/close)
27 | - Used explicit date range to avoid weekend/holiday data fetch issues
28 |
29 | ### Issue #3: External API Key Hard Dependency
30 | - **Root Cause**: Hard failure when EXTERNAL_DATA_API_KEY not configured
31 | - **Files Modified**: `maverick_mcp/api/routers/data.py:line 244-253`
32 | - **Fix**: Graceful degradation with informative fallback message
33 |
34 | ## Running This Test
35 |
36 | ```bash
37 | # Via Makefile (recommended)
38 | make test-fixes
39 |
40 | # Direct execution
41 | uv run python maverick_mcp/tests/test_mcp_tool_fixes.py
42 |
43 | # Via pytest (if environment allows)
44 | pytest maverick_mcp/tests/test_fixes_validation.py
45 | ```
46 |
47 | This test should be run after any changes to ensure the MCP tool fixes remain intact.
48 | """
49 |
50 | import asyncio
51 | import os
52 |
53 | from maverick_mcp.api.routers.data import get_stock_info
54 | from maverick_mcp.api.routers.portfolio import risk_adjusted_analysis
55 | from maverick_mcp.validation.data import GetStockInfoRequest
56 |
57 |
58 | def test_portfolio_risk_analysis():
59 | """
60 | Test Issue #2: Portfolio risk analysis (formerly returned cryptic 'high' error).
61 |
62 | This test validates:
63 | - DataFrame is properly retrieved with correct columns
64 | - Column name case sensitivity is handled correctly
65 | - Date range calculation avoids weekend/holiday issues
66 | - Risk calculations complete successfully
67 | """
68 | print("🧪 Testing portfolio risk analysis (Issue #2)...")
69 | try:
70 | # First test what data we actually get from the provider
71 | from datetime import UTC, datetime, timedelta
72 |
73 | from maverick_mcp.api.routers.portfolio import stock_provider
74 |
75 | print(" Debugging: Testing data provider directly...")
76 | end_date = (datetime.now(UTC) - timedelta(days=7)).strftime("%Y-%m-%d")
77 | start_date = (datetime.now(UTC) - timedelta(days=365)).strftime("%Y-%m-%d")
78 | df = stock_provider.get_stock_data(
79 | "MSFT", start_date=start_date, end_date=end_date
80 | )
81 |
82 | print(f" DataFrame shape: {df.shape}")
83 | print(f" DataFrame columns: {list(df.columns)}")
84 | print(f" DataFrame empty: {df.empty}")
85 | if not df.empty:
86 | print(f" Sample data (last 3 rows):\n{df.tail(3)}")
87 |
88 | # Now test the actual function
89 | result = risk_adjusted_analysis("MSFT", 75.0)
90 | if "error" in result:
91 | # If still error, try string conversion
92 | result = risk_adjusted_analysis("MSFT", "75")
93 | if "error" in result:
94 | print(f"❌ Still has error: {result}")
95 | return False
96 |
97 | print(
98 | f"✅ Success! Current price: ${result.get('current_price')}, Risk level: {result.get('risk_level')}"
99 | )
100 | print(
101 | f" Position sizing: ${result.get('position_sizing', {}).get('suggested_position_size')}"
102 | )
103 | print(f" Strategy type: {result.get('analysis', {}).get('strategy_type')}")
104 | return True
105 | except Exception as e:
106 | print(f"❌ Exception: {e}")
107 | return False
108 |
109 |
110 | def test_stock_info_external_api():
111 | """
112 | Test Issue #3: Stock info requiring EXTERNAL_DATA_API_KEY.
113 |
114 | This test validates:
115 | - External API dependency is optional
116 | - Graceful fallback when EXTERNAL_DATA_API_KEY not configured
117 | - Core stock info functionality still works
118 | """
119 | print("\n🧪 Testing stock info external API handling (Issue #3)...")
120 | try:
121 | request = GetStockInfoRequest(ticker="MSFT")
122 | result = get_stock_info(request)
123 | if "error" in result and "Invalid API key" in str(result.get("error")):
124 | print(f"❌ Still failing on external API: {result}")
125 | return False
126 | else:
127 | print(f"✅ Success! Company: {result.get('company', {}).get('name')}")
128 | print(
129 | f" Current price: ${result.get('market_data', {}).get('current_price')}"
130 | )
131 | return True
132 | except Exception as e:
133 | print(f"❌ Exception: {e}")
134 | return False
135 |
136 |
137 | async def test_research_empty_results():
138 | """
139 | Test Issue #1: Research returning empty results.
140 |
141 | This test validates:
142 | - DeepResearchAgent is created with API keys from settings
143 | - Search providers are properly initialized
144 | - API keys are correctly passed through the configuration chain
145 | """
146 | print("\n🧪 Testing research functionality (Issue #1)...")
147 | try:
148 | # Import the research function
149 | from maverick_mcp.api.routers.research import get_research_agent
150 |
151 | # Test that the research agent can be created with API keys
152 | agent = get_research_agent()
153 |
154 | # Check if API keys are available in environment
155 | exa_key = os.getenv("EXA_API_KEY")
156 | tavily_key = os.getenv("TAVILY_API_KEY")
157 |
158 | print(f" API keys available: EXA={bool(exa_key)}, TAVILY={bool(tavily_key)}")
159 |
160 | # Check if the agent has search providers (indicates API keys were passed correctly)
161 | if hasattr(agent, "search_providers") and len(agent.search_providers) > 0:
162 | print(
163 | f"✅ Research agent created with {len(agent.search_providers)} search providers!"
164 | )
165 |
166 | # Try to access the provider API keys to verify they're configured
167 | providers_configured = 0
168 | for provider in agent.search_providers:
169 | if hasattr(provider, "api_key") and provider.api_key:
170 | providers_configured += 1
171 |
172 | if providers_configured > 0:
173 | print(
174 | f"✅ {providers_configured} search providers have API keys configured"
175 | )
176 | return True
177 | else:
178 | print("❌ Search providers missing API keys")
179 | return False
180 | else:
181 | print("❌ Research agent has no search providers configured")
182 | return False
183 | except Exception as e:
184 | print(f"❌ Exception: {e}")
185 | return False
186 |
187 |
188 | def test_llm_configuration():
189 | """
190 | Test LLM configuration fixes.
191 |
192 | This test validates:
193 | - LLM can be created successfully
194 | - Temperature and streaming settings are compatible with gpt-5-mini
195 | - LLM can handle basic queries without errors
196 | """
197 | print("\n🧪 Testing LLM configuration...")
198 | try:
199 | from maverick_mcp.providers.llm_factory import get_llm
200 |
201 | print(" Creating LLM instance...")
202 | llm = get_llm()
203 | print(f" LLM created: {type(llm).__name__}")
204 |
205 | # Test a simple query to ensure it works
206 | print(" Testing LLM query...")
207 | response = llm.invoke("What is 2+2?")
208 | print(f"✅ LLM response: {response.content}")
209 | return True
210 | except Exception as e:
211 | print(f"❌ LLM test failed: {e}")
212 | return False
213 |
214 |
215 | def main():
216 | """Run comprehensive test suite for MCP tool fixes."""
217 | print("🚀 Testing MCP Tool Fixes")
218 | print("=" * 50)
219 |
220 | results = []
221 |
222 | # Test portfolio risk analysis
223 | results.append(test_portfolio_risk_analysis())
224 |
225 | # Test stock info external API handling
226 | results.append(test_stock_info_external_api())
227 |
228 | # Test research functionality
229 | results.append(asyncio.run(test_research_empty_results()))
230 |
231 | # Test LLM configuration
232 | results.append(test_llm_configuration())
233 |
234 | print("\n" + "=" * 50)
235 | print("📊 Test Results Summary:")
236 | print(f"✅ Passed: {sum(results)}/{len(results)}")
237 | print(f"❌ Failed: {len(results) - sum(results)}/{len(results)}")
238 |
239 | if all(results):
240 | print("\n🎉 All MCP tool fixes are working correctly!")
241 | print("\nFixed Issues:")
242 | print("1. ✅ Research tools return actual content (API keys properly passed)")
243 | print(
244 | "2. ✅ Portfolio risk analysis works (DataFrame validation & column case)"
245 | )
246 | print("3. ✅ Stock info graceful fallback (external API optional)")
247 | print("4. ✅ LLM configuration compatible (temperature & streaming)")
248 | else:
249 | print("\n⚠️ Some issues remain to be fixed.")
250 | print("Please check the individual test results above.")
251 |
252 | return all(results)
253 |
254 |
255 | if __name__ == "__main__":
256 | import sys
257 |
258 | success = main()
259 | sys.exit(0 if success else 1)
260 |
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/visualization.py:
--------------------------------------------------------------------------------
```python
1 | import base64
2 | import io
3 | import logging
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import pandas as pd
8 | import seaborn as sns
9 | from matplotlib.colors import LinearSegmentedColormap
10 | from matplotlib.figure import Figure
11 |
12 | # Configure logging
13 | logging.basicConfig(
14 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
15 | )
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def set_chart_style(theme: str = "light") -> None:
20 | """
21 | Set matplotlib style based on theme.
22 |
23 | Args:
24 | theme (str): Chart theme, either 'light' or 'dark'
25 | """
26 | plt.style.use("seaborn")
27 |
28 | if theme == "dark":
29 | plt.style.use("dark_background")
30 | plt.rcParams["axes.facecolor"] = "#1E1E1E"
31 | plt.rcParams["figure.facecolor"] = "#121212"
32 | text_color = "white"
33 | else:
34 | plt.rcParams["axes.facecolor"] = "white"
35 | plt.rcParams["figure.facecolor"] = "white"
36 | text_color = "black"
37 |
38 | plt.rcParams["font.size"] = 10
39 | plt.rcParams["axes.labelcolor"] = text_color
40 | plt.rcParams["xtick.color"] = text_color
41 | plt.rcParams["ytick.color"] = text_color
42 | plt.rcParams["text.color"] = text_color
43 |
44 |
45 | def image_to_base64(fig: Figure, dpi: int = 100, max_width: int = 800) -> str:
46 | """
47 | Convert matplotlib figure to base64 encoded PNG.
48 |
49 | Args:
50 | fig (Figure): Matplotlib figure to convert
51 | dpi (int): Dots per inch for resolution
52 | max_width (int): Maximum width in pixels
53 |
54 | Returns:
55 | str: Base64 encoded image
56 | """
57 | try:
58 | # Adjust figure size to maintain aspect ratio
59 | width, height = fig.get_size_inches()
60 | aspect_ratio = height / width
61 |
62 | # Resize if wider than max_width
63 | if width * dpi > max_width:
64 | width = max_width / dpi
65 | height = width * aspect_ratio
66 | fig.set_size_inches(width, height)
67 |
68 | buf = io.BytesIO()
69 | fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
70 | buf.seek(0)
71 | base64_image = base64.b64encode(buf.getvalue()).decode("utf-8")
72 | plt.close(fig)
73 | return base64_image
74 | except Exception as e:
75 | logger.error(f"Error converting image to base64: {e}")
76 | return ""
77 |
78 |
79 | def generate_equity_curve(
80 | returns: pd.Series,
81 | drawdown: pd.Series | None = None,
82 | title: str = "Equity Curve",
83 | theme: str = "light",
84 | ) -> str:
85 | """
86 | Generate equity curve with optional drawdown subplot.
87 |
88 | Args:
89 | returns (pd.Series): Cumulative returns series
90 | drawdown (pd.Series, optional): Drawdown series
91 | title (str): Chart title
92 | theme (str): Chart theme
93 |
94 | Returns:
95 | str: Base64 encoded image
96 | """
97 | set_chart_style(theme)
98 |
99 | try:
100 | fig, (ax1, ax2) = plt.subplots(
101 | 2, 1, figsize=(10, 6), gridspec_kw={"height_ratios": [3, 1]}
102 | )
103 |
104 | # Equity curve
105 | returns.plot(ax=ax1, linewidth=2, color="blue")
106 | ax1.set_title(title)
107 | ax1.set_xlabel("")
108 | ax1.set_ylabel("Cumulative Returns")
109 | ax1.grid(True, linestyle="--", alpha=0.7)
110 |
111 | # Drawdown subplot
112 | if drawdown is not None:
113 | drawdown.plot(ax=ax2, linewidth=2, color="red")
114 | ax2.set_title("Maximum Drawdown")
115 | ax2.set_ylabel("Drawdown (%)")
116 | ax2.grid(True, linestyle="--", alpha=0.7)
117 |
118 | plt.tight_layout()
119 | return image_to_base64(fig)
120 | except Exception as e:
121 | logger.error(f"Error generating equity curve: {e}")
122 | return ""
123 |
124 |
125 | def generate_trade_scatter(
126 | prices: pd.Series,
127 | trades: pd.DataFrame,
128 | title: str = "Trade Scatter Plot",
129 | theme: str = "light",
130 | ) -> str:
131 | """
132 | Generate trade scatter plot on price chart.
133 |
134 | Args:
135 | prices (pd.Series): Price series
136 | trades (pd.DataFrame): Trades DataFrame with entry/exit points
137 | title (str): Chart title
138 | theme (str): Chart theme
139 |
140 | Returns:
141 | str: Base64 encoded image
142 | """
143 | set_chart_style(theme)
144 |
145 | try:
146 | fig, ax = plt.subplots(figsize=(10, 6))
147 |
148 | # Plot price
149 | prices.plot(ax=ax, linewidth=1, label="Price", color="blue")
150 |
151 | # Plot entry/exit points
152 | entry_trades = trades[trades["type"] == "entry"]
153 | exit_trades = trades[trades["type"] == "exit"]
154 |
155 | ax.scatter(
156 | entry_trades.index,
157 | entry_trades["price"],
158 | color="green",
159 | marker="^",
160 | label="Entry",
161 | s=100,
162 | )
163 | ax.scatter(
164 | exit_trades.index,
165 | exit_trades["price"],
166 | color="red",
167 | marker="v",
168 | label="Exit",
169 | s=100,
170 | )
171 |
172 | ax.set_title(title)
173 | ax.set_xlabel("Date")
174 | ax.set_ylabel("Price")
175 | ax.legend()
176 | ax.grid(True, linestyle="--", alpha=0.7)
177 |
178 | plt.tight_layout()
179 | return image_to_base64(fig)
180 | except Exception as e:
181 | logger.error(f"Error generating trade scatter plot: {e}")
182 | return ""
183 |
184 |
185 | def generate_optimization_heatmap(
186 | param_results: dict[str, dict[str, float]],
187 | title: str = "Parameter Optimization",
188 | theme: str = "light",
189 | ) -> str:
190 | """
191 | Generate heatmap for parameter optimization results.
192 |
193 | Args:
194 | param_results (Dict): Dictionary of parameter combinations and performance
195 | title (str): Chart title
196 | theme (str): Chart theme
197 |
198 | Returns:
199 | str: Base64 encoded image
200 | """
201 | set_chart_style(theme)
202 |
203 | try:
204 | # Prepare data for heatmap
205 | params = list(param_results.keys())
206 | results = [list(result.values()) for result in param_results.values()]
207 |
208 | fig, ax = plt.subplots(figsize=(10, 8))
209 |
210 | # Custom colormap
211 | cmap = LinearSegmentedColormap.from_list(
212 | "performance", ["red", "yellow", "green"]
213 | )
214 |
215 | sns.heatmap(
216 | results,
217 | annot=True,
218 | cmap=cmap,
219 | xticklabels=params,
220 | yticklabels=params,
221 | ax=ax,
222 | fmt=".2f",
223 | )
224 |
225 | ax.set_title(title)
226 | plt.tight_layout()
227 | return image_to_base64(fig)
228 | except Exception as e:
229 | logger.error(f"Error generating optimization heatmap: {e}")
230 | return ""
231 |
232 |
233 | def generate_portfolio_allocation(
234 | allocations: dict[str, float],
235 | title: str = "Portfolio Allocation",
236 | theme: str = "light",
237 | ) -> str:
238 | """
239 | Generate portfolio allocation pie chart.
240 |
241 | Args:
242 | allocations (Dict): Dictionary of symbol allocations
243 | title (str): Chart title
244 | theme (str): Chart theme
245 |
246 | Returns:
247 | str: Base64 encoded image
248 | """
249 | set_chart_style(theme)
250 |
251 | try:
252 | fig, ax = plt.subplots(figsize=(8, 8))
253 |
254 | symbols = list(allocations.keys())
255 | weights = list(allocations.values())
256 |
257 | # Color palette
258 | colors = plt.cm.Pastel1(np.linspace(0, 1, len(symbols)))
259 |
260 | ax.pie(
261 | weights,
262 | labels=symbols,
263 | colors=colors,
264 | autopct="%1.1f%%",
265 | startangle=90,
266 | pctdistance=0.85,
267 | )
268 | ax.set_title(title)
269 |
270 | plt.tight_layout()
271 | return image_to_base64(fig)
272 | except Exception as e:
273 | logger.error(f"Error generating portfolio allocation chart: {e}")
274 | return ""
275 |
276 |
277 | def generate_strategy_comparison(
278 | strategies: dict[str, pd.Series],
279 | title: str = "Strategy Comparison",
280 | theme: str = "light",
281 | ) -> str:
282 | """
283 | Generate strategy comparison chart.
284 |
285 | Args:
286 | strategies (Dict): Dictionary of strategy returns
287 | title (str): Chart title
288 | theme (str): Chart theme
289 |
290 | Returns:
291 | str: Base64 encoded image
292 | """
293 | set_chart_style(theme)
294 |
295 | try:
296 | fig, ax = plt.subplots(figsize=(10, 6))
297 |
298 | for name, returns in strategies.items():
299 | returns.plot(ax=ax, label=name, linewidth=2)
300 |
301 | ax.set_title(title)
302 | ax.set_xlabel("Date")
303 | ax.set_ylabel("Cumulative Returns")
304 | ax.legend()
305 | ax.grid(True, linestyle="--", alpha=0.7)
306 |
307 | plt.tight_layout()
308 | return image_to_base64(fig)
309 | except Exception as e:
310 | logger.error(f"Error generating strategy comparison chart: {e}")
311 | return ""
312 |
313 |
314 | def generate_performance_dashboard(
315 | metrics: dict[str, float | str],
316 | title: str = "Performance Dashboard",
317 | theme: str = "light",
318 | ) -> str:
319 | """
320 | Generate performance metrics dashboard as a table image.
321 |
322 | Args:
323 | metrics (Dict): Dictionary of performance metrics
324 | title (str): Dashboard title
325 | theme (str): Chart theme
326 |
327 | Returns:
328 | str: Base64 encoded image
329 | """
330 | set_chart_style(theme)
331 |
332 | try:
333 | fig, ax = plt.subplots(figsize=(8, 6))
334 | ax.axis("off")
335 |
336 | # Prepare table data
337 | metric_names = list(metrics.keys())
338 | metric_values = [str(val) for val in metrics.values()]
339 |
340 | table = ax.table(
341 | cellText=[metric_names, metric_values], loc="center", cellLoc="center"
342 | )
343 | table.auto_set_font_size(False)
344 | table.set_fontsize(10)
345 |
346 | ax.set_title(title)
347 | plt.tight_layout()
348 | return image_to_base64(fig)
349 | except Exception as e:
350 | logger.error(f"Error generating performance dashboard: {e}")
351 | return ""
352 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/mcp_logging.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Integration of structured logging with FastMCP.
3 |
4 | This module provides:
5 | - Automatic request context capture from MCP
6 | - Tool execution logging
7 | - Performance monitoring
8 | - Error tracking
9 | """
10 |
11 | import functools
12 | import time
13 | from collections.abc import Callable
14 | from typing import Any
15 |
16 | from fastmcp import Context
17 | from fastmcp.exceptions import ToolError
18 |
19 | from .logging import (
20 | PerformanceMonitor,
21 | get_logger,
22 | log_cache_operation,
23 | log_database_query,
24 | log_external_api_call,
25 | request_id_var,
26 | request_start_var,
27 | tool_name_var,
28 | user_id_var,
29 | )
30 |
31 |
32 | def with_logging(tool_name: str | None = None):
33 | """
34 | Decorator for FastMCP tools that adds structured logging.
35 |
36 | Automatically logs:
37 | - Tool invocation with parameters
38 | - Execution time
39 | - Success/failure status
40 | - Context information (request ID, user)
41 |
42 | Example:
43 | @mcp.tool()
44 | @with_logging()
45 | async def fetch_stock_data(context: Context, ticker: str) -> dict:
46 | # Tool implementation
47 | pass
48 | """
49 |
50 | def decorator(func: Callable) -> Callable:
51 | @functools.wraps(func)
52 | async def wrapper(*args, **kwargs):
53 | # Extract context
54 | context = None
55 | for arg in args:
56 | if isinstance(arg, Context):
57 | context = arg
58 | break
59 |
60 | # Get tool name
61 | actual_tool_name = tool_name or func.__name__
62 |
63 | # Set context variables
64 | if context:
65 | # Extract request ID from context metadata if available
66 | request_id = getattr(context, "request_id", None) or str(time.time())
67 | request_id_var.set(request_id)
68 |
69 | # Extract user info if available
70 | user_id = getattr(context, "user_id", None)
71 | if user_id:
72 | user_id_var.set(user_id)
73 |
74 | tool_name_var.set(actual_tool_name)
75 | request_start_var.set(time.time())
76 |
77 | # Get logger
78 | logger = get_logger(f"maverick_mcp.tools.{actual_tool_name}")
79 |
80 | # Log tool invocation
81 | logger.info(
82 | f"Tool invoked: {actual_tool_name}",
83 | extra={
84 | "tool_name": actual_tool_name,
85 | "has_context": context is not None,
86 | "args_count": len(args),
87 | "kwargs_keys": list(kwargs.keys()),
88 | },
89 | )
90 |
91 | try:
92 | # Use context's progress callback if available
93 | if context and hasattr(context, "report_progress"):
94 | await context.report_progress(
95 | progress=0, total=100, message=f"Starting {actual_tool_name}"
96 | )
97 |
98 | # Execute the tool
99 | with PerformanceMonitor(f"tool_{actual_tool_name}"):
100 | result = await func(*args, **kwargs)
101 |
102 | # Log success
103 | logger.info(
104 | f"Tool completed: {actual_tool_name}",
105 | extra={"tool_name": actual_tool_name, "status": "success"},
106 | )
107 |
108 | # Report completion
109 | if context and hasattr(context, "report_progress"):
110 | await context.report_progress(
111 | progress=100, total=100, message=f"Completed {actual_tool_name}"
112 | )
113 |
114 | return result
115 |
116 | except ToolError as e:
117 | # Log tool-specific error
118 | logger.warning(
119 | f"Tool error in {actual_tool_name}: {str(e)}",
120 | extra={
121 | "tool_name": actual_tool_name,
122 | "status": "tool_error",
123 | "error_message": str(e),
124 | },
125 | )
126 | raise
127 |
128 | except Exception as e:
129 | # Log unexpected error
130 | logger.error(
131 | f"Unexpected error in {actual_tool_name}: {str(e)}",
132 | exc_info=True,
133 | extra={
134 | "tool_name": actual_tool_name,
135 | "status": "error",
136 | "error_type": type(e).__name__,
137 | },
138 | )
139 | raise
140 |
141 | finally:
142 | # Clear context vars
143 | request_id_var.set(None)
144 | tool_name_var.set(None)
145 | user_id_var.set(None)
146 | request_start_var.set(None)
147 |
148 | return wrapper
149 |
150 | return decorator
151 |
152 |
153 | def log_mcp_context(context: Context, operation: str, **extra):
154 | """
155 | Log information from MCP context.
156 |
157 | Args:
158 | context: FastMCP context object
159 | operation: Description of the operation
160 | **extra: Additional fields to log
161 | """
162 | logger = get_logger("maverick_mcp.context")
163 |
164 | log_data = {
165 | "operation": operation,
166 | "has_request_id": hasattr(context, "request_id"),
167 | "can_report_progress": hasattr(context, "report_progress"),
168 | "can_log": hasattr(context, "info"),
169 | }
170 |
171 | # Add any extra fields
172 | log_data.update(extra)
173 |
174 | logger.info(f"MCP Context: {operation}", extra=log_data)
175 |
176 |
177 | class LoggingStockDataProvider:
178 | """
179 | Wrapper for StockDataProvider that adds logging.
180 |
181 | This demonstrates how to add logging to existing classes.
182 | """
183 |
184 | def __init__(self, provider):
185 | self.provider = provider
186 | self.logger = get_logger("maverick_mcp.providers.stock_data")
187 |
188 | async def get_stock_data(
189 | self, ticker: str, start_date: str, end_date: str, **kwargs
190 | ):
191 | """Get stock data with logging."""
192 | with PerformanceMonitor(f"fetch_stock_data_{ticker}"):
193 | # Check cache first
194 | cache_key = f"stock:{ticker}:{start_date}:{end_date}"
195 |
196 | # Log cache check
197 | start = time.time()
198 | cached_data = await self._check_cache(cache_key)
199 | cache_duration = int((time.time() - start) * 1000)
200 |
201 | if cached_data:
202 | log_cache_operation(
203 | "get", cache_key, hit=True, duration_ms=cache_duration
204 | )
205 | return cached_data
206 | else:
207 | log_cache_operation(
208 | "get", cache_key, hit=False, duration_ms=cache_duration
209 | )
210 |
211 | # Fetch from provider
212 | try:
213 | start = time.time()
214 | data = await self.provider.get_stock_data(
215 | ticker, start_date, end_date, **kwargs
216 | )
217 | api_duration = int((time.time() - start) * 1000)
218 |
219 | log_external_api_call(
220 | service="yfinance",
221 | endpoint=f"/quote/{ticker}",
222 | method="GET",
223 | status_code=200,
224 | duration_ms=api_duration,
225 | )
226 |
227 | # Cache the result
228 | await self._set_cache(cache_key, data)
229 |
230 | return data
231 |
232 | except Exception as e:
233 | log_external_api_call(
234 | service="yfinance",
235 | endpoint=f"/quote/{ticker}",
236 | method="GET",
237 | error=str(e),
238 | )
239 | raise
240 |
241 | async def _check_cache(self, key: str):
242 | """Check cache (placeholder)."""
243 | # This would integrate with actual cache
244 | return None
245 |
246 | async def _set_cache(self, key: str, data: Any):
247 | """Set cache (placeholder)."""
248 | # This would integrate with actual cache
249 | pass
250 |
251 |
252 | # SQL query logging wrapper
253 | class LoggingSession:
254 | """Wrapper for SQLAlchemy session that logs queries."""
255 |
256 | def __init__(self, session):
257 | self.session = session
258 | self.logger = get_logger("maverick_mcp.database")
259 |
260 | def execute(self, query, params=None):
261 | """Execute query with logging."""
262 | start = time.time()
263 | try:
264 | result = self.session.execute(query, params)
265 | duration = int((time.time() - start) * 1000)
266 | log_database_query(str(query), params, duration)
267 | return result
268 | except Exception as e:
269 | duration = int((time.time() - start) * 1000)
270 | log_database_query(str(query), params, duration)
271 | self.logger.error(
272 | f"Database query failed: {str(e)}",
273 | extra={"query": str(query)[:200], "error_type": type(e).__name__},
274 | )
275 | raise
276 |
277 | def __getattr__(self, name):
278 | """Proxy other methods to the wrapped session."""
279 | return getattr(self.session, name)
280 |
281 |
282 | # Example usage in routers
283 | def setup_router_logging(router):
284 | """
285 | Add logging middleware to a FastMCP router.
286 |
287 | This should be called when setting up routers.
288 | """
289 | logger = get_logger(f"maverick_mcp.routers.{router.__class__.__name__}")
290 |
291 | # Log router initialization
292 | logger.info(
293 | "Router initialized",
294 | extra={
295 | "router_class": router.__class__.__name__,
296 | "tool_count": len(getattr(router, "tools", [])),
297 | },
298 | )
299 |
300 | # Add middleware to log all requests (if supported by FastMCP)
301 | # This is a placeholder for when FastMCP supports middleware
302 | pass
303 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/performance.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Performance monitoring router for Maverick-MCP.
3 |
4 | This router provides endpoints for monitoring system performance,
5 | including Redis connection health, cache performance, query optimization,
6 | and database index analysis.
7 | """
8 |
9 | import logging
10 | from typing import Any
11 |
12 | from fastmcp import FastMCP
13 | from pydantic import Field
14 |
15 | from maverick_mcp.tools.performance_monitoring import (
16 | analyze_database_indexes,
17 | clear_performance_caches,
18 | get_cache_performance_metrics,
19 | get_comprehensive_performance_report,
20 | get_query_performance_metrics,
21 | get_redis_connection_health,
22 | optimize_cache_settings,
23 | )
24 | from maverick_mcp.validation.base import BaseRequest, BaseResponse
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | # Create router
29 | performance_router = FastMCP("Performance_Monitoring")
30 |
31 |
32 | # Request/Response Models
33 | class PerformanceHealthRequest(BaseRequest):
34 | """Request model for performance health check."""
35 |
36 | include_detailed_metrics: bool = Field(
37 | default=False, description="Include detailed metrics in the response"
38 | )
39 |
40 |
41 | class CacheClearRequest(BaseRequest):
42 | """Request model for cache clearing operations."""
43 |
44 | cache_types: list[str] | None = Field(
45 | default=None,
46 | description="Types of caches to clear: stock_data, screening, market_data, all",
47 | )
48 |
49 |
50 | class PerformanceMetricsResponse(BaseResponse):
51 | """Response model for performance metrics."""
52 |
53 | metrics: dict[str, Any] = Field(description="Performance metrics data")
54 |
55 |
56 | class PerformanceReportResponse(BaseResponse):
57 | """Response model for comprehensive performance report."""
58 |
59 | overall_health_score: float = Field(
60 | description="Overall system health score (0-100)"
61 | )
62 | component_scores: dict[str, float] = Field(
63 | description="Individual component scores"
64 | )
65 | recommendations: list[str] = Field(
66 | description="Performance improvement recommendations"
67 | )
68 | detailed_metrics: dict[str, Any] | None = Field(description="Detailed metrics data")
69 |
70 |
71 | async def get_system_performance_health(
72 | request: PerformanceHealthRequest,
73 | ) -> PerformanceReportResponse:
74 | """
75 | Get comprehensive system performance health report.
76 |
77 | This tool provides an overall health assessment of the MaverickMCP system,
78 | including Redis connectivity, cache performance, database query metrics,
79 | and index usage analysis. Use this for general system health monitoring.
80 |
81 | Args:
82 | request: Performance health check request
83 |
84 | Returns:
85 | Comprehensive performance health report with scores and recommendations
86 | """
87 | try:
88 | logger.info("Generating comprehensive performance health report")
89 |
90 | # Get comprehensive performance report
91 | report = await get_comprehensive_performance_report()
92 |
93 | if "error" in report:
94 | return PerformanceReportResponse(
95 | overall_health_score=0.0,
96 | component_scores={},
97 | recommendations=[f"System health check failed: {report['error']}"],
98 | detailed_metrics=None,
99 | )
100 |
101 | # Extract main components
102 | overall_score = report.get("overall_health_score", 0.0)
103 | component_scores = report.get("component_scores", {})
104 | recommendations = report.get("recommendations", [])
105 | detailed_metrics = (
106 | report.get("detailed_metrics") if request.include_detailed_metrics else None
107 | )
108 |
109 | logger.info(
110 | f"Performance health report generated: overall score {overall_score}"
111 | )
112 |
113 | return PerformanceReportResponse(
114 | overall_health_score=overall_score,
115 | component_scores=component_scores,
116 | recommendations=recommendations,
117 | detailed_metrics=detailed_metrics,
118 | )
119 |
120 | except Exception as e:
121 | logger.error(f"Error getting system performance health: {e}")
122 | return PerformanceReportResponse(
123 | overall_health_score=0.0,
124 | component_scores={},
125 | recommendations=[f"Failed to assess system health: {str(e)}"],
126 | detailed_metrics=None,
127 | )
128 |
129 |
130 | async def get_redis_health_status() -> PerformanceMetricsResponse:
131 | """
132 | Get Redis connection pool health and performance metrics.
133 |
134 | This tool provides detailed information about Redis connectivity,
135 | connection pool status, operation latency, and basic health tests.
136 | Use this when diagnosing Redis-related performance issues.
137 |
138 | Returns:
139 | Redis health status and connection metrics
140 | """
141 | try:
142 | logger.info("Checking Redis connection health")
143 |
144 | redis_health = await get_redis_connection_health()
145 |
146 | return PerformanceMetricsResponse(metrics=redis_health)
147 |
148 | except Exception as e:
149 | logger.error(f"Error getting Redis health status: {e}")
150 | return PerformanceMetricsResponse(metrics={"error": str(e)})
151 |
152 |
153 | async def get_cache_performance_status() -> PerformanceMetricsResponse:
154 | """
155 | Get cache performance metrics and optimization suggestions.
156 |
157 | This tool provides cache hit/miss ratios, operation latencies,
158 | Redis memory usage, and performance test results. Use this
159 | to optimize caching strategies and identify cache bottlenecks.
160 |
161 | Returns:
162 | Cache performance metrics and test results
163 | """
164 | try:
165 | logger.info("Getting cache performance metrics")
166 |
167 | cache_metrics = await get_cache_performance_metrics()
168 |
169 | return PerformanceMetricsResponse(metrics=cache_metrics)
170 |
171 | except Exception as e:
172 | logger.error(f"Error getting cache performance status: {e}")
173 | return PerformanceMetricsResponse(metrics={"error": str(e)})
174 |
175 |
176 | async def get_database_performance_status() -> PerformanceMetricsResponse:
177 | """
178 | Get database query performance metrics and connection pool status.
179 |
180 | This tool provides database query statistics, slow query detection,
181 | connection pool metrics, and database health tests. Use this to
182 | identify database performance bottlenecks and optimization opportunities.
183 |
184 | Returns:
185 | Database performance metrics and query statistics
186 | """
187 | try:
188 | logger.info("Getting database performance metrics")
189 |
190 | query_metrics = await get_query_performance_metrics()
191 |
192 | return PerformanceMetricsResponse(metrics=query_metrics)
193 |
194 | except Exception as e:
195 | logger.error(f"Error getting database performance status: {e}")
196 | return PerformanceMetricsResponse(metrics={"error": str(e)})
197 |
198 |
199 | async def analyze_database_index_usage() -> PerformanceMetricsResponse:
200 | """
201 | Analyze database index usage and provide optimization recommendations.
202 |
203 | This tool examines database index usage statistics, identifies missing
204 | indexes, analyzes table scan patterns, and provides specific recommendations
205 | for database performance optimization. Use this for database tuning.
206 |
207 | Returns:
208 | Database index analysis and optimization recommendations
209 | """
210 | try:
211 | logger.info("Analyzing database index usage")
212 |
213 | index_analysis = await analyze_database_indexes()
214 |
215 | return PerformanceMetricsResponse(metrics=index_analysis)
216 |
217 | except Exception as e:
218 | logger.error(f"Error analyzing database index usage: {e}")
219 | return PerformanceMetricsResponse(metrics={"error": str(e)})
220 |
221 |
222 | async def optimize_cache_configuration() -> PerformanceMetricsResponse:
223 | """
224 | Analyze cache usage patterns and recommend optimal configuration.
225 |
226 | This tool analyzes current cache hit rates, memory usage, and access
227 | patterns to recommend optimal TTL values, cache sizes, and configuration
228 | settings for maximum performance. Use this for cache tuning.
229 |
230 | Returns:
231 | Cache optimization analysis and recommended settings
232 | """
233 | try:
234 | logger.info("Optimizing cache configuration")
235 |
236 | optimization_analysis = await optimize_cache_settings()
237 |
238 | return PerformanceMetricsResponse(metrics=optimization_analysis)
239 |
240 | except Exception as e:
241 | logger.error(f"Error optimizing cache configuration: {e}")
242 | return PerformanceMetricsResponse(metrics={"error": str(e)})
243 |
244 |
245 | async def clear_system_caches(
246 | request: CacheClearRequest,
247 | ) -> PerformanceMetricsResponse:
248 | """
249 | Clear specific performance caches for maintenance or testing.
250 |
251 | This tool allows selective clearing of different cache types:
252 | - stock_data: Stock price and company information caches
253 | - screening: Maverick and trending stock screening caches
254 | - market_data: High volume and market analysis caches
255 | - all: Clear all performance caches
256 |
257 | Use this for cache maintenance, testing, or when stale data is suspected.
258 |
259 | Args:
260 | request: Cache clearing request with specific cache types
261 |
262 | Returns:
263 | Cache clearing results and statistics
264 | """
265 | try:
266 | cache_types = request.cache_types or ["all"]
267 | logger.info(f"Clearing performance caches: {cache_types}")
268 |
269 | clear_results = await clear_performance_caches(cache_types)
270 |
271 | return PerformanceMetricsResponse(metrics=clear_results)
272 |
273 | except Exception as e:
274 | logger.error(f"Error clearing system caches: {e}")
275 | return PerformanceMetricsResponse(metrics={"error": str(e)})
276 |
277 |
278 | # Router configuration
279 | def get_performance_router():
280 | """Get the configured performance monitoring router."""
281 | return performance_router
282 |
```
--------------------------------------------------------------------------------
/scripts/test_tiingo_loader.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Test script for the Tiingo data loader.
4 |
5 | This script performs basic validation that the loader components work correctly
6 | without requiring an actual API call or database connection.
7 | """
8 |
9 | import sys
10 | import unittest
11 | from pathlib import Path
12 | from unittest.mock import MagicMock, patch
13 |
14 | # Add parent directory to path
15 | sys.path.insert(0, str(Path(__file__).parent.parent))
16 |
17 | from scripts.load_tiingo_data import (
18 | SP500_SYMBOLS,
19 | ProgressTracker,
20 | RateLimiter,
21 | TiingoDataLoader,
22 | )
23 | from scripts.tiingo_config import (
24 | SCREENING_CONFIGS,
25 | SYMBOL_LISTS,
26 | TiingoConfig,
27 | get_config_for_environment,
28 | )
29 |
30 |
31 | class TestProgressTracker(unittest.TestCase):
32 | """Test the progress tracking functionality."""
33 |
34 | def setUp(self):
35 | self.tracker = ProgressTracker("test_progress.json")
36 |
37 | def test_initialization(self):
38 | """Test that progress tracker initializes correctly."""
39 | self.assertEqual(self.tracker.processed_symbols, 0)
40 | self.assertEqual(self.tracker.successful_symbols, 0)
41 | self.assertEqual(len(self.tracker.failed_symbols), 0)
42 | self.assertEqual(len(self.tracker.completed_symbols), 0)
43 |
44 | def test_update_progress_success(self):
45 | """Test updating progress for successful symbol."""
46 | self.tracker.total_symbols = 5
47 | self.tracker.update_progress("AAPL", True)
48 |
49 | self.assertEqual(self.tracker.processed_symbols, 1)
50 | self.assertEqual(self.tracker.successful_symbols, 1)
51 | self.assertIn("AAPL", self.tracker.completed_symbols)
52 | self.assertEqual(len(self.tracker.failed_symbols), 0)
53 |
54 | def test_update_progress_failure(self):
55 | """Test updating progress for failed symbol."""
56 | self.tracker.total_symbols = 5
57 | self.tracker.update_progress("BADSTOCK", False, "Not found")
58 |
59 | self.assertEqual(self.tracker.processed_symbols, 1)
60 | self.assertEqual(self.tracker.successful_symbols, 0)
61 | self.assertIn("BADSTOCK", self.tracker.failed_symbols)
62 | self.assertEqual(len(self.tracker.errors), 1)
63 |
64 |
65 | class TestRateLimiter(unittest.TestCase):
66 | """Test the rate limiting functionality."""
67 |
68 | def test_initialization(self):
69 | """Test rate limiter initialization."""
70 | limiter = RateLimiter(3600) # 1 request per second
71 | self.assertEqual(limiter.max_requests, 3600)
72 | self.assertEqual(limiter.min_interval, 1.0)
73 |
74 | def test_tiingo_rate_limit(self):
75 | """Test Tiingo-specific rate limit calculation."""
76 | limiter = RateLimiter(2400) # Tiingo free tier
77 | expected_interval = 3600.0 / 2400 # 1.5 seconds
78 | self.assertEqual(limiter.min_interval, expected_interval)
79 |
80 |
81 | class TestTiingoConfig(unittest.TestCase):
82 | """Test configuration management."""
83 |
84 | def test_default_config(self):
85 | """Test default configuration values."""
86 | config = TiingoConfig()
87 |
88 | self.assertEqual(config.rate_limit_per_hour, 2400)
89 | self.assertEqual(config.max_retries, 3)
90 | self.assertEqual(config.default_batch_size, 50)
91 | self.assertEqual(config.rsi_period, 14)
92 | self.assertIsInstance(config.sma_periods, list)
93 | self.assertIn(50, config.sma_periods)
94 | self.assertIn(200, config.sma_periods)
95 |
96 | def test_environment_configs(self):
97 | """Test environment-specific configurations."""
98 | dev_config = get_config_for_environment("development")
99 | prod_config = get_config_for_environment("production")
100 | test_config = get_config_for_environment("testing")
101 |
102 | # Production should have higher limits
103 | self.assertGreaterEqual(
104 | prod_config.max_concurrent_requests, dev_config.max_concurrent_requests
105 | )
106 | self.assertGreaterEqual(
107 | prod_config.default_batch_size, dev_config.default_batch_size
108 | )
109 |
110 | # Test should have lower limits
111 | self.assertLessEqual(
112 | test_config.max_concurrent_requests, dev_config.max_concurrent_requests
113 | )
114 | self.assertLessEqual(
115 | test_config.default_batch_size, dev_config.default_batch_size
116 | )
117 |
118 | def test_symbol_lists(self):
119 | """Test that symbol lists are properly configured."""
120 | self.assertIn("sp500_top_100", SYMBOL_LISTS)
121 | self.assertIn("nasdaq_100", SYMBOL_LISTS)
122 | self.assertIn("dow_30", SYMBOL_LISTS)
123 |
124 | # Check that lists have reasonable sizes
125 | self.assertGreater(len(SYMBOL_LISTS["sp500_top_100"]), 50)
126 | self.assertLess(len(SYMBOL_LISTS["dow_30"]), 35)
127 |
128 | def test_screening_configs(self):
129 | """Test screening algorithm configurations."""
130 | maverick_config = SCREENING_CONFIGS["maverick_momentum"]
131 |
132 | self.assertIn("min_momentum_score", maverick_config)
133 | self.assertIn("scoring_weights", maverick_config)
134 | self.assertIsInstance(maverick_config["scoring_weights"], dict)
135 |
136 |
137 | class TestTiingoDataLoader(unittest.TestCase):
138 | """Test the main TiingoDataLoader class."""
139 |
140 | @patch.dict("os.environ", {"TIINGO_API_TOKEN": "test_token"})
141 | def test_initialization(self):
142 | """Test loader initialization."""
143 | loader = TiingoDataLoader(batch_size=25, max_concurrent=3)
144 |
145 | self.assertEqual(loader.batch_size, 25)
146 | self.assertEqual(loader.max_concurrent, 3)
147 | self.assertEqual(loader.api_token, "test_token")
148 | self.assertIsNotNone(loader.rate_limiter)
149 |
150 | def test_initialization_without_token(self):
151 | """Test that loader fails without API token."""
152 | with patch.dict("os.environ", {}, clear=True):
153 | with self.assertRaises(ValueError):
154 | TiingoDataLoader()
155 |
156 | @patch("aiohttp.ClientSession")
157 | async def test_context_manager(self, mock_session_class):
158 | """Test async context manager functionality."""
159 | mock_session = MagicMock()
160 | mock_session_class.return_value = mock_session
161 |
162 | with patch.dict("os.environ", {"TIINGO_API_TOKEN": "test_token"}):
163 | async with TiingoDataLoader() as loader:
164 | self.assertIsNotNone(loader.session)
165 |
166 | # Session should be closed after context exit
167 | mock_session.close.assert_called_once()
168 |
169 |
170 | class TestSymbolValidation(unittest.TestCase):
171 | """Test symbol validation and processing."""
172 |
173 | def test_sp500_symbols(self):
174 | """Test that S&P 500 symbols are valid."""
175 | self.assertIsInstance(SP500_SYMBOLS, list)
176 | self.assertGreater(len(SP500_SYMBOLS), 90) # Should have at least 90 symbols
177 |
178 | # Check that symbols are uppercase strings
179 | for symbol in SP500_SYMBOLS[:10]: # Check first 10
180 | self.assertIsInstance(symbol, str)
181 | self.assertEqual(symbol, symbol.upper())
182 | self.assertGreater(len(symbol), 0)
183 | self.assertLess(len(symbol), 10) # Reasonable symbol length
184 |
185 |
186 | class TestUtilityFunctions(unittest.TestCase):
187 | """Test utility functions."""
188 |
189 | def test_symbol_file_content(self):
190 | """Test the format that would be expected in symbol files."""
191 | # Test comma-separated format
192 | test_content = "AAPL,MSFT,GOOGL\nTSLA,NVDA\n# Comment\nAMZN"
193 | lines = test_content.split("\n")
194 |
195 | symbols = []
196 | for line in lines:
197 | line = line.strip()
198 | if line and not line.startswith("#"):
199 | line_symbols = [s.strip().upper() for s in line.split(",")]
200 | symbols.extend(line_symbols)
201 |
202 | expected = ["AAPL", "MSFT", "GOOGL", "TSLA", "NVDA", "AMZN"]
203 | self.assertEqual(symbols, expected)
204 |
205 |
206 | def run_basic_validation():
207 | """Run basic validation without external dependencies."""
208 | print("🧪 Running basic validation tests...")
209 |
210 | # Test imports
211 | try:
212 | from scripts.load_tiingo_data import ProgressTracker
213 | from scripts.tiingo_config import SYMBOL_LISTS, TiingoConfig
214 |
215 | print("✅ All imports successful")
216 | except ImportError as e:
217 | print(f"❌ Import error: {e}")
218 | return False
219 |
220 | # Test configuration
221 | try:
222 | config = TiingoConfig()
223 | assert config.rate_limit_per_hour == 2400
224 | assert len(config.sma_periods) > 0
225 | print("✅ Configuration validation passed")
226 | except Exception as e:
227 | print(f"❌ Configuration error: {e}")
228 | return False
229 |
230 | # Test symbol lists
231 | try:
232 | assert len(SP500_SYMBOLS) > 90
233 | assert len(SYMBOL_LISTS["sp500_top_100"]) > 90
234 | assert all(isinstance(s, str) for s in SP500_SYMBOLS[:10])
235 | print("✅ Symbol list validation passed")
236 | except Exception as e:
237 | print(f"❌ Symbol list error: {e}")
238 | return False
239 |
240 | # Test progress tracker
241 | try:
242 | tracker = ProgressTracker("test.json")
243 | tracker.update_progress("TEST", True)
244 | assert tracker.successful_symbols == 1
245 | assert "TEST" in tracker.completed_symbols
246 | print("✅ Progress tracker validation passed")
247 | except Exception as e:
248 | print(f"❌ Progress tracker error: {e}")
249 | return False
250 |
251 | print("🎉 All basic validations passed!")
252 | return True
253 |
254 |
255 | if __name__ == "__main__":
256 | print("Tiingo Data Loader Test Suite")
257 | print("=" * 40)
258 |
259 | # Run basic validation first
260 | if not run_basic_validation():
261 | sys.exit(1)
262 |
263 | # Run unit tests
264 | print("\n🧪 Running unit tests...")
265 | unittest.main(verbosity=2, exit=False)
266 |
267 | print("\n✅ Test suite completed!")
268 |
```