This is page 11 of 29. Use http://codebase.md/wshobson/maverick-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/providers/mocks/mock_stock_data.py:
--------------------------------------------------------------------------------
```python
"""
Mock stock data provider implementations for testing.
"""
from datetime import datetime, timedelta
from typing import Any
import numpy as np
import pandas as pd
class MockStockDataFetcher:
"""
Mock implementation of IStockDataFetcher for testing.
This implementation provides predictable test data without requiring
external API calls or database access.
"""
def __init__(self, test_data: dict[str, pd.DataFrame] | None = None):
"""
Initialize the mock stock data fetcher.
Args:
test_data: Optional dictionary mapping symbols to DataFrames
"""
self._test_data = test_data or {}
self._call_log: list[dict[str, Any]] = []
async def get_stock_data(
self,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
period: str | None = None,
interval: str = "1d",
use_cache: bool = True,
) -> pd.DataFrame:
"""Get mock stock data."""
self._log_call(
"get_stock_data",
{
"symbol": symbol,
"start_date": start_date,
"end_date": end_date,
"period": period,
"interval": interval,
"use_cache": use_cache,
},
)
symbol = symbol.upper()
# Return test data if available
if symbol in self._test_data:
df = self._test_data[symbol].copy()
# Filter by date range if specified
if start_date or end_date:
if start_date:
df = df[df.index >= start_date]
if end_date:
df = df[df.index <= end_date]
return df
# Generate synthetic data
return self._generate_synthetic_data(symbol, start_date, end_date, period)
async def get_realtime_data(self, symbol: str) -> dict[str, Any] | None:
"""Get mock real-time stock data."""
self._log_call("get_realtime_data", {"symbol": symbol})
# Return predictable mock data
return {
"symbol": symbol.upper(),
"price": 150.25,
"change": 2.15,
"change_percent": 1.45,
"volume": 1234567,
"timestamp": datetime.now(),
"timestamp_display": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"is_real_time": False,
}
async def get_stock_info(self, symbol: str) -> dict[str, Any]:
"""Get mock stock information."""
self._log_call("get_stock_info", {"symbol": symbol})
return {
"symbol": symbol.upper(),
"longName": f"{symbol.upper()} Inc.",
"sector": "Technology",
"industry": "Software",
"marketCap": 1000000000,
"previousClose": 148.10,
"beta": 1.2,
"dividendYield": 0.02,
"peRatio": 25.5,
}
async def get_news(self, symbol: str, limit: int = 10) -> pd.DataFrame:
"""Get mock news data."""
self._log_call("get_news", {"symbol": symbol, "limit": limit})
# Generate mock news data
news_data = []
for i in range(min(limit, 5)): # Generate up to 5 mock articles
news_data.append(
{
"title": f"Mock news article {i + 1} for {symbol}",
"publisher": f"Mock Publisher {i + 1}",
"link": f"https://example.com/news/{symbol.lower()}/{i + 1}",
"providerPublishTime": datetime.now() - timedelta(hours=i),
"type": "STORY",
}
)
return pd.DataFrame(news_data)
async def get_earnings(self, symbol: str) -> dict[str, Any]:
"""Get mock earnings data."""
self._log_call("get_earnings", {"symbol": symbol})
return {
"earnings": {
"2023": 5.25,
"2022": 4.80,
"2021": 4.35,
},
"earnings_dates": {
"next_date": "2024-01-25",
"eps_estimate": 1.35,
},
"earnings_trend": {
"current_quarter": 1.30,
"next_quarter": 1.35,
"current_year": 5.40,
"next_year": 5.85,
},
}
async def get_recommendations(self, symbol: str) -> pd.DataFrame:
"""Get mock analyst recommendations."""
self._log_call("get_recommendations", {"symbol": symbol})
recommendations_data = [
{
"firm": "Mock Investment Bank",
"toGrade": "Buy",
"fromGrade": "Hold",
"action": "up",
},
{
"firm": "Another Mock Firm",
"toGrade": "Hold",
"fromGrade": "Hold",
"action": "main",
},
]
return pd.DataFrame(recommendations_data)
async def is_market_open(self) -> bool:
"""Check if market is open (mock)."""
self._log_call("is_market_open", {})
# Return True for testing by default
return True
async def is_etf(self, symbol: str) -> bool:
"""Check if symbol is an ETF (mock)."""
self._log_call("is_etf", {"symbol": symbol})
# Mock ETF detection based on common ETF symbols
etf_symbols = {"SPY", "QQQ", "IWM", "VTI", "VEA", "VWO", "XLK", "XLF"}
return symbol.upper() in etf_symbols
def _generate_synthetic_data(
self,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
period: str | None = None,
) -> pd.DataFrame:
"""Generate synthetic stock data for testing."""
# Determine date range
if period:
days = {"1d": 1, "5d": 5, "1mo": 30, "3mo": 90, "1y": 365}.get(period, 30)
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=days)
else:
end_dt = pd.to_datetime(end_date) if end_date else datetime.now()
start_dt = (
pd.to_datetime(start_date)
if start_date
else end_dt - timedelta(days=30)
)
# Generate date range (business days only)
dates = pd.bdate_range(start=start_dt, end=end_dt)
if len(dates) == 0:
# Return empty DataFrame with proper columns
return pd.DataFrame(
columns=[
"Open",
"High",
"Low",
"Close",
"Volume",
"Dividends",
"Stock Splits",
]
)
# Generate synthetic price data
np.random.seed(hash(symbol) % 2**32) # Consistent data per symbol
base_price = 100.0
returns = np.random.normal(
0.001, 0.02, len(dates)
) # 0.1% mean return, 2% volatility
prices = [base_price]
for ret in returns[1:]:
prices.append(prices[-1] * (1 + ret))
# Generate OHLCV data
data = []
for _i, (_date, close_price) in enumerate(zip(dates, prices, strict=False)):
# Generate Open, High, Low based on Close
volatility = close_price * 0.02 # 2% intraday volatility
open_price = close_price + np.random.normal(0, volatility * 0.5)
high_price = max(open_price, close_price) + abs(
np.random.normal(0, volatility * 0.3)
)
low_price = min(open_price, close_price) - abs(
np.random.normal(0, volatility * 0.3)
)
# Ensure High >= Low and prices are positive
high_price = max(high_price, low_price + 0.01, 0.01)
low_price = max(low_price, 0.01)
volume = int(
np.random.lognormal(15, 0.5)
) # Log-normal distribution for volume
data.append(
{
"Open": round(open_price, 2),
"High": round(high_price, 2),
"Low": round(low_price, 2),
"Close": round(close_price, 2),
"Volume": volume,
"Dividends": 0.0,
"Stock Splits": 0.0,
}
)
df = pd.DataFrame(data, index=dates)
df.index.name = "Date"
return df
# Testing utilities
def _log_call(self, method: str, args: dict[str, Any]) -> None:
"""Log method calls for testing verification."""
self._call_log.append(
{
"method": method,
"args": args,
"timestamp": datetime.now(),
}
)
def get_call_log(self) -> list[dict[str, Any]]:
"""Get the log of method calls."""
return self._call_log.copy()
def clear_call_log(self) -> None:
"""Clear the method call log."""
self._call_log.clear()
def set_test_data(self, symbol: str, data: pd.DataFrame) -> None:
"""Set test data for a specific symbol."""
self._test_data[symbol.upper()] = data
def clear_test_data(self) -> None:
"""Clear all test data."""
self._test_data.clear()
class MockStockScreener:
"""
Mock implementation of IStockScreener for testing.
"""
def __init__(
self, test_recommendations: dict[str, list[dict[str, Any]]] | None = None
):
"""
Initialize the mock stock screener.
Args:
test_recommendations: Optional dictionary of test screening results
"""
self._test_recommendations = test_recommendations or {}
self._call_log: list[dict[str, Any]] = []
async def get_maverick_recommendations(
self, limit: int = 20, min_score: int | None = None
) -> list[dict[str, Any]]:
"""Get mock maverick recommendations."""
self._log_call(
"get_maverick_recommendations", {"limit": limit, "min_score": min_score}
)
if "maverick" in self._test_recommendations:
results = self._test_recommendations["maverick"]
else:
results = self._generate_mock_maverick_recommendations()
# Apply filters
if min_score:
results = [r for r in results if r.get("combined_score", 0) >= min_score]
return results[:limit]
async def get_maverick_bear_recommendations(
self, limit: int = 20, min_score: int | None = None
) -> list[dict[str, Any]]:
"""Get mock maverick bear recommendations."""
self._log_call(
"get_maverick_bear_recommendations",
{"limit": limit, "min_score": min_score},
)
if "bear" in self._test_recommendations:
results = self._test_recommendations["bear"]
else:
results = self._generate_mock_bear_recommendations()
# Apply filters
if min_score:
results = [r for r in results if r.get("score", 0) >= min_score]
return results[:limit]
async def get_trending_recommendations(
self, limit: int = 20, min_momentum_score: float | None = None
) -> list[dict[str, Any]]:
"""Get mock trending recommendations."""
self._log_call(
"get_trending_recommendations",
{"limit": limit, "min_momentum_score": min_momentum_score},
)
if "trending" in self._test_recommendations:
results = self._test_recommendations["trending"]
else:
results = self._generate_mock_trending_recommendations()
# Apply filters
if min_momentum_score:
results = [
r for r in results if r.get("momentum_score", 0) >= min_momentum_score
]
return results[:limit]
async def get_all_screening_recommendations(
self,
) -> dict[str, list[dict[str, Any]]]:
"""Get all mock screening recommendations."""
self._log_call("get_all_screening_recommendations", {})
return {
"maverick_stocks": await self.get_maverick_recommendations(),
"maverick_bear_stocks": await self.get_maverick_bear_recommendations(),
"supply_demand_breakouts": await self.get_trending_recommendations(),
}
def _generate_mock_maverick_recommendations(self) -> list[dict[str, Any]]:
"""Generate mock maverick recommendations."""
return [
{
"symbol": "AAPL",
"combined_score": 95,
"momentum_score": 92,
"pattern": "Cup with Handle",
"consolidation": "yes",
"squeeze": "firing",
"recommendation_type": "maverick_bullish",
"reason": "Exceptional combined score with outstanding relative strength",
},
{
"symbol": "MSFT",
"combined_score": 88,
"momentum_score": 85,
"pattern": "Flat Base",
"consolidation": "no",
"squeeze": "setup",
"recommendation_type": "maverick_bullish",
"reason": "Strong combined score with strong relative strength",
},
]
def _generate_mock_bear_recommendations(self) -> list[dict[str, Any]]:
"""Generate mock bear recommendations."""
return [
{
"symbol": "BEAR1",
"score": 92,
"momentum_score": 25,
"rsi_14": 28,
"atr_contraction": True,
"big_down_vol": True,
"recommendation_type": "maverick_bearish",
"reason": "Exceptional bear score with weak relative strength, oversold RSI",
},
{
"symbol": "BEAR2",
"score": 85,
"momentum_score": 30,
"rsi_14": 35,
"atr_contraction": False,
"big_down_vol": True,
"recommendation_type": "maverick_bearish",
"reason": "Strong bear score with weak relative strength",
},
]
def _generate_mock_trending_recommendations(self) -> list[dict[str, Any]]:
"""Generate mock trending recommendations."""
return [
{
"symbol": "TREND1",
"momentum_score": 95,
"close": 150.25,
"sma_50": 145.50,
"sma_150": 140.25,
"sma_200": 135.75,
"pattern": "Breakout",
"recommendation_type": "trending_stage2",
"reason": "Uptrend with exceptional momentum strength",
},
{
"symbol": "TREND2",
"momentum_score": 88,
"close": 85.30,
"sma_50": 82.15,
"sma_150": 79.80,
"sma_200": 76.45,
"pattern": "Higher Lows",
"recommendation_type": "trending_stage2",
"reason": "Uptrend with strong momentum strength",
},
]
# Testing utilities
def _log_call(self, method: str, args: dict[str, Any]) -> None:
"""Log method calls for testing verification."""
self._call_log.append(
{
"method": method,
"args": args,
"timestamp": datetime.now(),
}
)
def get_call_log(self) -> list[dict[str, Any]]:
"""Get the log of method calls."""
return self._call_log.copy()
def clear_call_log(self) -> None:
"""Clear the method call log."""
self._call_log.clear()
def set_test_recommendations(
self, screening_type: str, recommendations: list[dict[str, Any]]
) -> None:
"""Set test recommendations for a specific screening type."""
self._test_recommendations[screening_type] = recommendations
def clear_test_recommendations(self) -> None:
"""Clear all test recommendations."""
self._test_recommendations.clear()
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/database_monitoring.py:
--------------------------------------------------------------------------------
```python
"""
Database and Redis monitoring utilities for MaverickMCP.
This module provides comprehensive monitoring for:
- SQLAlchemy database connection pools
- Database query performance
- Redis connection pools and operations
- Cache hit rates and performance metrics
"""
import asyncio
import time
from contextlib import asynccontextmanager, contextmanager
from typing import Any
from sqlalchemy.event import listen
from sqlalchemy.pool import Pool
from maverick_mcp.utils.logging import get_logger
from maverick_mcp.utils.monitoring import (
redis_connections,
redis_memory_usage,
track_cache_operation,
track_database_connection_event,
track_database_query,
track_redis_operation,
update_database_metrics,
update_redis_metrics,
)
from maverick_mcp.utils.tracing import trace_cache_operation, trace_database_query
logger = get_logger(__name__)
class DatabaseMonitor:
"""Monitor for SQLAlchemy database operations and connection pools."""
def __init__(self, engine=None):
self.engine = engine
self.query_stats = {}
self._setup_event_listeners()
def _setup_event_listeners(self):
"""Set up SQLAlchemy event listeners for monitoring."""
if not self.engine:
return
# Connection pool events
listen(Pool, "connect", self._on_connection_created)
listen(Pool, "checkout", self._on_connection_checkout)
listen(Pool, "checkin", self._on_connection_checkin)
listen(Pool, "close", self._on_connection_closed)
# Query execution events
listen(self.engine, "before_cursor_execute", self._on_before_query)
listen(self.engine, "after_cursor_execute", self._on_after_query)
def _on_connection_created(self, dbapi_connection, connection_record):
"""Handle new database connection creation."""
track_database_connection_event("created")
logger.debug("Database connection created")
def _on_connection_checkout(
self, dbapi_connection, connection_record, connection_proxy
):
"""Handle connection checkout from pool."""
# Update connection metrics
pool = self.engine.pool
self._update_pool_metrics(pool)
def _on_connection_checkin(self, dbapi_connection, connection_record):
"""Handle connection checkin to pool."""
# Update connection metrics
pool = self.engine.pool
self._update_pool_metrics(pool)
def _on_connection_closed(self, dbapi_connection, connection_record):
"""Handle connection closure."""
track_database_connection_event("closed", "normal")
logger.debug("Database connection closed")
def _on_before_query(
self, conn, cursor, statement, parameters, context, executemany
):
"""Handle query execution start."""
context._query_start_time = time.time()
context._query_statement = statement
def _on_after_query(
self, conn, cursor, statement, parameters, context, executemany
):
"""Handle query execution completion."""
if hasattr(context, "_query_start_time"):
duration = time.time() - context._query_start_time
query_type = self._extract_query_type(statement)
table = self._extract_table_name(statement)
# Track metrics
track_database_query(query_type, table, duration, "success")
# Log slow queries
if duration > 1.0: # Queries over 1 second
logger.warning(
"Slow database query detected",
extra={
"query_type": query_type,
"table": table,
"duration_seconds": duration,
"statement": statement[:200] + "..."
if len(statement) > 200
else statement,
},
)
def _update_pool_metrics(self, pool):
"""Update connection pool metrics."""
try:
pool_size = pool.size()
checked_out = pool.checkedout()
checked_in = pool.checkedin()
update_database_metrics(
pool_size=pool_size,
active_connections=checked_out,
idle_connections=checked_in,
)
except Exception as e:
logger.warning(f"Failed to update pool metrics: {e}")
def _extract_query_type(self, statement: str) -> str:
"""Extract query type from SQL statement."""
statement_upper = statement.strip().upper()
if statement_upper.startswith("SELECT"):
return "SELECT"
elif statement_upper.startswith("INSERT"):
return "INSERT"
elif statement_upper.startswith("UPDATE"):
return "UPDATE"
elif statement_upper.startswith("DELETE"):
return "DELETE"
elif statement_upper.startswith("CREATE"):
return "CREATE"
elif statement_upper.startswith("DROP"):
return "DROP"
elif statement_upper.startswith("ALTER"):
return "ALTER"
else:
return "OTHER"
def _extract_table_name(self, statement: str) -> str | None:
"""Extract table name from SQL statement."""
import re
# Simple regex to extract table names
patterns = [
r"FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)", # SELECT FROM table
r"INTO\s+([a-zA-Z_][a-zA-Z0-9_]*)", # INSERT INTO table
r"UPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)", # UPDATE table
r"DELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)", # DELETE FROM table
]
for pattern in patterns:
match = re.search(pattern, statement.upper())
if match:
return match.group(1).lower()
return "unknown"
@contextmanager
def trace_query(self, query_type: str, table: str | None = None):
"""Context manager for tracing database queries."""
with trace_database_query(query_type, table) as span:
start_time = time.time()
try:
yield span
duration = time.time() - start_time
track_database_query(
query_type, table or "unknown", duration, "success"
)
except Exception:
duration = time.time() - start_time
track_database_query(query_type, table or "unknown", duration, "error")
raise
def get_pool_status(self) -> dict[str, Any]:
"""Get current database pool status."""
if not self.engine:
return {}
try:
pool = self.engine.pool
return {
"pool_size": pool.size(),
"checked_out": pool.checkedout(),
"checked_in": pool.checkedin(),
"overflow": pool.overflow(),
"invalid": pool.invalid(),
}
except Exception as e:
logger.error(f"Failed to get pool status: {e}")
return {}
class RedisMonitor:
"""Monitor for Redis operations and connection pools."""
def __init__(self, redis_client=None):
self.redis_client = redis_client
self.operation_stats = {}
@asynccontextmanager
async def trace_operation(self, operation: str, key: str | None = None):
"""Context manager for tracing Redis operations."""
with trace_cache_operation(operation, "redis") as span:
start_time = time.time()
if span and key:
span.set_attribute("redis.key", key)
try:
yield span
duration = time.time() - start_time
track_redis_operation(operation, duration, "success")
except Exception as e:
duration = time.time() - start_time
track_redis_operation(operation, duration, "error")
if span:
span.record_exception(e)
logger.error(
f"Redis operation failed: {operation}",
extra={
"operation": operation,
"key": key,
"duration_seconds": duration,
"error": str(e),
},
)
raise
async def monitor_get(self, key: str):
"""Monitor Redis GET operation."""
async with self.trace_operation("get", key):
try:
result = await self.redis_client.get(key)
hit = result is not None
track_cache_operation("redis", "get", hit, self._get_key_prefix(key))
return result
except Exception:
track_cache_operation("redis", "get", False, self._get_key_prefix(key))
raise
async def monitor_set(self, key: str, value: Any, **kwargs):
"""Monitor Redis SET operation."""
async with self.trace_operation("set", key):
return await self.redis_client.set(key, value, **kwargs)
async def monitor_delete(self, key: str):
"""Monitor Redis DELETE operation."""
async with self.trace_operation("delete", key):
return await self.redis_client.delete(key)
async def monitor_exists(self, key: str):
"""Monitor Redis EXISTS operation."""
async with self.trace_operation("exists", key):
return await self.redis_client.exists(key)
async def update_redis_metrics(self):
"""Update Redis metrics from server info."""
if not self.redis_client:
return
try:
info = await self.redis_client.info()
# Connection metrics
connected_clients = info.get("connected_clients", 0)
redis_connections.set(connected_clients)
# Memory metrics
used_memory = info.get("used_memory", 0)
redis_memory_usage.set(used_memory)
# Keyspace metrics
keyspace_hits = info.get("keyspace_hits", 0)
keyspace_misses = info.get("keyspace_misses", 0)
# Update counters (these are cumulative, so we track the difference)
update_redis_metrics(
connections=connected_clients,
memory_bytes=used_memory,
hits=0, # Will be updated by individual operations
misses=0, # Will be updated by individual operations
)
logger.debug(
"Redis metrics updated",
extra={
"connected_clients": connected_clients,
"used_memory_mb": used_memory / 1024 / 1024,
"keyspace_hits": keyspace_hits,
"keyspace_misses": keyspace_misses,
},
)
except Exception as e:
logger.error(f"Failed to update Redis metrics: {e}")
def _get_key_prefix(self, key: str) -> str:
"""Extract key prefix for metrics grouping."""
if ":" in key:
return key.split(":")[0]
return "other"
async def get_redis_info(self) -> dict[str, Any]:
"""Get Redis server information."""
if not self.redis_client:
return {}
try:
info = await self.redis_client.info()
return {
"connected_clients": info.get("connected_clients", 0),
"used_memory": info.get("used_memory", 0),
"used_memory_human": info.get("used_memory_human", "0B"),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"total_commands_processed": info.get("total_commands_processed", 0),
"uptime_in_seconds": info.get("uptime_in_seconds", 0),
}
except Exception as e:
logger.error(f"Failed to get Redis info: {e}")
return {}
class CacheMonitor:
"""High-level cache monitoring that supports multiple backends."""
def __init__(self, redis_monitor: RedisMonitor | None = None):
self.redis_monitor = redis_monitor
@contextmanager
def monitor_operation(self, cache_type: str, operation: str, key: str):
"""Monitor cache operation across different backends."""
start_time = time.time()
hit = False
try:
yield
hit = True # If no exception, assume it was a hit for GET operations
except Exception as e:
logger.error(
f"Cache operation failed: {cache_type} {operation}",
extra={
"cache_type": cache_type,
"operation": operation,
"key": key,
"error": str(e),
},
)
raise
finally:
duration = time.time() - start_time
# Track metrics based on operation
if operation in ["get", "exists"]:
track_cache_operation(
cache_type, operation, hit, self._get_key_prefix(key)
)
# Log slow cache operations
if duration > 0.1: # Operations over 100ms
logger.warning(
f"Slow cache operation: {cache_type} {operation}",
extra={
"cache_type": cache_type,
"operation": operation,
"key": key,
"duration_seconds": duration,
},
)
def _get_key_prefix(self, key: str) -> str:
"""Extract key prefix for metrics grouping."""
if ":" in key:
return key.split(":")[0]
return "other"
async def update_all_metrics(self):
"""Update metrics for all monitored cache backends."""
tasks = []
if self.redis_monitor:
tasks.append(self.redis_monitor.update_redis_metrics())
if tasks:
try:
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
logger.error(f"Failed to update cache metrics: {e}")
# Global monitor instances
_database_monitor: DatabaseMonitor | None = None
_redis_monitor: RedisMonitor | None = None
_cache_monitor: CacheMonitor | None = None
def get_database_monitor(engine=None) -> DatabaseMonitor:
"""Get or create the global database monitor."""
global _database_monitor
if _database_monitor is None:
_database_monitor = DatabaseMonitor(engine)
return _database_monitor
def get_redis_monitor(redis_client=None) -> RedisMonitor:
"""Get or create the global Redis monitor."""
global _redis_monitor
if _redis_monitor is None:
_redis_monitor = RedisMonitor(redis_client)
return _redis_monitor
def get_cache_monitor() -> CacheMonitor:
"""Get or create the global cache monitor."""
global _cache_monitor
if _cache_monitor is None:
redis_monitor = get_redis_monitor()
_cache_monitor = CacheMonitor(redis_monitor)
return _cache_monitor
def initialize_database_monitoring(engine):
"""Initialize database monitoring with the given engine."""
logger.info("Initializing database monitoring...")
monitor = get_database_monitor(engine)
logger.info("Database monitoring initialized")
return monitor
def initialize_redis_monitoring(redis_client):
"""Initialize Redis monitoring with the given client."""
logger.info("Initializing Redis monitoring...")
monitor = get_redis_monitor(redis_client)
logger.info("Redis monitoring initialized")
return monitor
async def start_periodic_metrics_collection(interval: int = 30):
"""Start periodic collection of database and cache metrics."""
logger.info(f"Starting periodic metrics collection (interval: {interval}s)")
cache_monitor = get_cache_monitor()
while True:
try:
await cache_monitor.update_all_metrics()
except Exception as e:
logger.error(f"Error in periodic metrics collection: {e}")
await asyncio.sleep(interval)
```
--------------------------------------------------------------------------------
/maverick_mcp/monitoring/middleware.py:
--------------------------------------------------------------------------------
```python
"""
Monitoring middleware for automatic metrics collection.
This module provides middleware components that automatically track:
- API calls and response times
- Strategy execution performance
- Resource usage during operations
- Anomaly detection triggers
"""
import asyncio
import time
from collections.abc import Callable
from contextlib import asynccontextmanager
from functools import wraps
from typing import Any
from maverick_mcp.monitoring.metrics import get_backtesting_metrics
from maverick_mcp.utils.logging import get_logger
logger = get_logger(__name__)
class MetricsMiddleware:
"""
Middleware for automatic metrics collection during backtesting operations.
Provides decorators and context managers for seamless metrics integration.
"""
def __init__(self):
self.collector = get_backtesting_metrics()
self.logger = get_logger(f"{__name__}.MetricsMiddleware")
def track_api_call(self, provider: str, endpoint: str, method: str = "GET"):
"""
Decorator to automatically track API call metrics.
Usage:
@middleware.track_api_call("tiingo", "/daily/{symbol}")
async def get_stock_data(symbol: str):
# API call logic here
pass
"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
start_time = time.time()
status_code = 200
error_type = None
try:
result = await func(*args, **kwargs)
return result
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_type = type(e).__name__
raise
finally:
duration = time.time() - start_time
self.collector.track_api_call(
provider=provider,
endpoint=endpoint,
method=method,
status_code=status_code,
duration=duration,
error_type=error_type,
)
@wraps(func)
def sync_wrapper(*args, **kwargs):
start_time = time.time()
status_code = 200
error_type = None
try:
result = func(*args, **kwargs)
return result
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_type = type(e).__name__
raise
finally:
duration = time.time() - start_time
self.collector.track_api_call(
provider=provider,
endpoint=endpoint,
method=method,
status_code=status_code,
duration=duration,
error_type=error_type,
)
# Return appropriate wrapper based on function type
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
def track_strategy_execution(
self, strategy_name: str, symbol: str, timeframe: str = "1D"
):
"""
Decorator to automatically track strategy execution metrics.
Usage:
@middleware.track_strategy_execution("RSI_Strategy", "AAPL")
def run_backtest(data):
# Strategy execution logic here
return results
"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
with self.collector.track_backtest_execution(
strategy_name=strategy_name,
symbol=symbol,
timeframe=timeframe,
data_points=kwargs.get("data_points", 0),
):
result = await func(*args, **kwargs)
# Extract performance metrics from result if available
if isinstance(result, dict):
self._extract_and_track_performance(
result, strategy_name, symbol, timeframe
)
return result
@wraps(func)
def sync_wrapper(*args, **kwargs):
with self.collector.track_backtest_execution(
strategy_name=strategy_name,
symbol=symbol,
timeframe=timeframe,
data_points=kwargs.get("data_points", 0),
):
result = func(*args, **kwargs)
# Extract performance metrics from result if available
if isinstance(result, dict):
self._extract_and_track_performance(
result, strategy_name, symbol, timeframe
)
return result
# Return appropriate wrapper based on function type
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
def track_resource_usage(self, operation_type: str):
"""
Decorator to automatically track resource usage for operations.
Usage:
@middleware.track_resource_usage("vectorbt_backtest")
def run_vectorbt_analysis(data):
# VectorBT analysis logic here
pass
"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
import psutil
process = psutil.Process()
start_memory = process.memory_info().rss / 1024 / 1024
start_time = time.time()
try:
result = await func(*args, **kwargs)
return result
finally:
end_memory = process.memory_info().rss / 1024 / 1024
duration = time.time() - start_time
memory_used = max(0, end_memory - start_memory)
# Determine data size category
data_size = "unknown"
if "data" in kwargs:
data_length = (
len(kwargs["data"])
if hasattr(kwargs["data"], "__len__")
else 0
)
data_size = self.collector._categorize_data_size(data_length)
self.collector.track_resource_usage(
operation_type=operation_type,
memory_mb=memory_used,
computation_time=duration,
data_size=data_size,
)
@wraps(func)
def sync_wrapper(*args, **kwargs):
import psutil
process = psutil.Process()
start_memory = process.memory_info().rss / 1024 / 1024
start_time = time.time()
try:
result = func(*args, **kwargs)
return result
finally:
end_memory = process.memory_info().rss / 1024 / 1024
duration = time.time() - start_time
memory_used = max(0, end_memory - start_memory)
# Determine data size category
data_size = "unknown"
if "data" in kwargs:
data_length = (
len(kwargs["data"])
if hasattr(kwargs["data"], "__len__")
else 0
)
data_size = self.collector._categorize_data_size(data_length)
self.collector.track_resource_usage(
operation_type=operation_type,
memory_mb=memory_used,
computation_time=duration,
data_size=data_size,
)
# Return appropriate wrapper based on function type
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
@asynccontextmanager
async def track_database_operation(
self, query_type: str, table_name: str, operation: str
):
"""
Context manager to track database operation performance.
Usage:
async with middleware.track_database_operation("SELECT", "stocks", "fetch"):
result = await db.execute(query)
"""
start_time = time.time()
try:
yield
finally:
duration = time.time() - start_time
self.collector.track_database_operation(
query_type=query_type,
table_name=table_name,
operation=operation,
duration=duration,
)
def _extract_and_track_performance(
self, result: dict[str, Any], strategy_name: str, symbol: str, timeframe: str
):
"""Extract and track strategy performance metrics from results."""
try:
# Extract common performance metrics from result dictionary
returns = result.get("total_return", result.get("returns", 0.0))
sharpe_ratio = result.get("sharpe_ratio", 0.0)
max_drawdown = result.get("max_drawdown", result.get("max_dd", 0.0))
win_rate = result.get("win_rate", result.get("win_ratio", 0.0))
total_trades = result.get("total_trades", result.get("num_trades", 0))
winning_trades = result.get("winning_trades", 0)
# Convert win rate to percentage if it's in decimal form
if win_rate <= 1.0:
win_rate *= 100
# Convert max drawdown to positive percentage if negative
if max_drawdown < 0:
max_drawdown = abs(max_drawdown) * 100
# Extract winning trades from win rate if not provided directly
if winning_trades == 0 and total_trades > 0:
winning_trades = int(total_trades * (win_rate / 100))
# Determine period from timeframe or use default
period_mapping = {"1D": "1Y", "1H": "3M", "5m": "1M", "1m": "1W"}
period = period_mapping.get(timeframe, "1Y")
# Track the performance metrics
self.collector.track_strategy_performance(
strategy_name=strategy_name,
symbol=symbol,
period=period,
returns=returns,
sharpe_ratio=sharpe_ratio,
max_drawdown=max_drawdown,
win_rate=win_rate,
total_trades=total_trades,
winning_trades=winning_trades,
)
self.logger.debug(
f"Tracked strategy performance for {strategy_name}",
extra={
"strategy": strategy_name,
"symbol": symbol,
"returns": returns,
"sharpe_ratio": sharpe_ratio,
"max_drawdown": max_drawdown,
"win_rate": win_rate,
"total_trades": total_trades,
},
)
except Exception as e:
self.logger.warning(
f"Failed to extract performance metrics from result: {e}",
extra={
"result_keys": list(result.keys())
if isinstance(result, dict)
else "not_dict"
},
)
# Global middleware instance
_middleware_instance: MetricsMiddleware | None = None
def get_metrics_middleware() -> MetricsMiddleware:
"""Get or create the global metrics middleware instance."""
global _middleware_instance
if _middleware_instance is None:
_middleware_instance = MetricsMiddleware()
return _middleware_instance
# Convenience decorators using global middleware instance
def track_api_call(provider: str, endpoint: str, method: str = "GET"):
"""Convenience decorator for API call tracking."""
return get_metrics_middleware().track_api_call(provider, endpoint, method)
def track_strategy_execution(strategy_name: str, symbol: str, timeframe: str = "1D"):
"""Convenience decorator for strategy execution tracking."""
return get_metrics_middleware().track_strategy_execution(
strategy_name, symbol, timeframe
)
def track_resource_usage(operation_type: str):
"""Convenience decorator for resource usage tracking."""
return get_metrics_middleware().track_resource_usage(operation_type)
def track_database_operation(query_type: str, table_name: str, operation: str):
"""Convenience context manager for database operation tracking."""
return get_metrics_middleware().track_database_operation(
query_type, table_name, operation
)
# Example circuit breaker with metrics
class MetricsCircuitBreaker:
"""
Circuit breaker with integrated metrics tracking.
Automatically tracks circuit breaker state changes and failures.
"""
def __init__(
self,
provider: str,
endpoint: str,
failure_threshold: int = 5,
recovery_timeout: int = 60,
expected_exception: type = Exception,
):
self.provider = provider
self.endpoint = endpoint
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.failure_count = 0
self.last_failure_time = 0
self.state = "closed" # closed, open, half-open
self.collector = get_backtesting_metrics()
self.logger = get_logger(f"{__name__}.MetricsCircuitBreaker")
async def call(self, func: Callable, *args, **kwargs):
"""Execute function with circuit breaker protection and metrics tracking."""
if self.state == "open":
if time.time() - self.last_failure_time > self.recovery_timeout:
self.state = "half-open"
self.collector.track_circuit_breaker(
self.provider, self.endpoint, self.state, 0
)
else:
raise Exception(
f"Circuit breaker is open for {self.provider}/{self.endpoint}"
)
try:
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
# Success - reset failure count and close circuit if half-open
if self.state == "half-open":
self.state = "closed"
self.failure_count = 0
self.collector.track_circuit_breaker(
self.provider, self.endpoint, self.state, 0
)
self.logger.info(
f"Circuit breaker closed for {self.provider}/{self.endpoint}"
)
return result
except self.expected_exception as e:
self.failure_count += 1
self.last_failure_time = time.time()
# Track failure
self.collector.track_circuit_breaker(
self.provider, self.endpoint, self.state, 1
)
# Open circuit if threshold reached
if self.failure_count >= self.failure_threshold:
self.state = "open"
self.collector.track_circuit_breaker(
self.provider, self.endpoint, self.state, 0
)
self.logger.warning(
f"Circuit breaker opened for {self.provider}/{self.endpoint} "
f"after {self.failure_count} failures"
)
raise e
```
--------------------------------------------------------------------------------
/tests/test_stock_analysis_service.py:
--------------------------------------------------------------------------------
```python
"""
Tests for StockAnalysisService.
"""
from unittest.mock import Mock, patch
import pandas as pd
from maverick_mcp.domain.stock_analysis import StockAnalysisService
from maverick_mcp.infrastructure.caching import CacheManagementService
from maverick_mcp.infrastructure.data_fetching import StockDataFetchingService
class TestStockAnalysisService:
"""Test cases for StockAnalysisService."""
def setup_method(self):
"""Set up test fixtures."""
self.mock_data_fetching_service = Mock(spec=StockDataFetchingService)
self.mock_cache_service = Mock(spec=CacheManagementService)
self.mock_db_session = Mock()
self.service = StockAnalysisService(
data_fetching_service=self.mock_data_fetching_service,
cache_service=self.mock_cache_service,
db_session=self.mock_db_session,
)
def test_init(self):
"""Test service initialization."""
assert self.service.data_fetching_service == self.mock_data_fetching_service
assert self.service.cache_service == self.mock_cache_service
assert self.service.db_session == self.mock_db_session
def test_get_stock_data_non_daily_interval(self):
"""Test get_stock_data with non-daily interval bypasses cache."""
mock_data = pd.DataFrame(
{"Open": [150.0], "Close": [151.0]},
index=pd.date_range("2024-01-01", periods=1),
)
self.mock_data_fetching_service.fetch_stock_data.return_value = mock_data
# Test with 1-hour interval
result = self.service.get_stock_data("AAPL", interval="1h")
# Assertions
assert not result.empty
self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
self.mock_cache_service.get_cached_data.assert_not_called()
def test_get_stock_data_with_period(self):
"""Test get_stock_data with period parameter bypasses cache."""
mock_data = pd.DataFrame(
{"Open": [150.0], "Close": [151.0]},
index=pd.date_range("2024-01-01", periods=1),
)
self.mock_data_fetching_service.fetch_stock_data.return_value = mock_data
# Test with period
result = self.service.get_stock_data("AAPL", period="1mo")
# Assertions
assert not result.empty
self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
self.mock_cache_service.get_cached_data.assert_not_called()
def test_get_stock_data_cache_disabled(self):
"""Test get_stock_data with cache disabled."""
mock_data = pd.DataFrame(
{"Open": [150.0], "Close": [151.0]},
index=pd.date_range("2024-01-01", periods=1),
)
self.mock_data_fetching_service.fetch_stock_data.return_value = mock_data
# Test with cache disabled
result = self.service.get_stock_data("AAPL", use_cache=False)
# Assertions
assert not result.empty
self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
self.mock_cache_service.get_cached_data.assert_not_called()
def test_get_stock_data_cache_hit(self):
"""Test get_stock_data with complete cache hit."""
# Mock cached data
mock_cached_data = pd.DataFrame(
{
"Open": [150.0, 151.0, 152.0],
"High": [151.0, 152.0, 153.0],
"Low": [149.0, 150.0, 151.0],
"Close": [150.5, 151.5, 152.5],
"Volume": [1000000, 1100000, 1200000],
},
index=pd.date_range("2024-01-01", periods=3),
)
self.mock_cache_service.get_cached_data.return_value = mock_cached_data
# Test
result = self.service.get_stock_data(
"AAPL", start_date="2024-01-01", end_date="2024-01-03"
)
# Assertions
assert not result.empty
assert len(result) == 3
self.mock_cache_service.get_cached_data.assert_called_once()
self.mock_data_fetching_service.fetch_stock_data.assert_not_called()
def test_get_stock_data_cache_miss(self):
"""Test get_stock_data with complete cache miss."""
# Mock no cached data
self.mock_cache_service.get_cached_data.return_value = None
# Mock market calendar
with patch.object(self.service, "_get_trading_days") as mock_trading_days:
mock_trading_days.return_value = pd.DatetimeIndex(
["2024-01-01", "2024-01-02"]
)
# Mock fetched data
mock_fetched_data = pd.DataFrame(
{
"Open": [150.0, 151.0],
"Close": [150.5, 151.5],
"Volume": [1000000, 1100000],
},
index=pd.date_range("2024-01-01", periods=2),
)
self.mock_data_fetching_service.fetch_stock_data.return_value = (
mock_fetched_data
)
# Test
result = self.service.get_stock_data(
"AAPL", start_date="2024-01-01", end_date="2024-01-02"
)
# Assertions
assert not result.empty
self.mock_cache_service.get_cached_data.assert_called_once()
self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
self.mock_cache_service.cache_data.assert_called_once()
def test_get_stock_data_partial_cache_hit(self):
"""Test get_stock_data with partial cache hit requiring additional data."""
# Mock partial cached data (missing recent data)
mock_cached_data = pd.DataFrame(
{"Open": [150.0], "Close": [150.5], "Volume": [1000000]},
index=pd.date_range("2024-01-01", periods=1),
)
self.mock_cache_service.get_cached_data.return_value = mock_cached_data
# Mock missing data fetch
mock_missing_data = pd.DataFrame(
{"Open": [151.0], "Close": [151.5], "Volume": [1100000]},
index=pd.date_range("2024-01-02", periods=1),
)
self.mock_data_fetching_service.fetch_stock_data.return_value = (
mock_missing_data
)
# Mock helper methods
with (
patch.object(self.service, "_get_trading_days") as mock_trading_days,
patch.object(
self.service, "_is_trading_day_between"
) as mock_trading_between,
):
mock_trading_days.return_value = pd.DatetimeIndex(["2024-01-02"])
mock_trading_between.return_value = True
# Test
result = self.service.get_stock_data(
"AAPL", start_date="2024-01-01", end_date="2024-01-02"
)
# Assertions
assert not result.empty
assert len(result) == 2 # Combined cached + fetched data
self.mock_cache_service.get_cached_data.assert_called_once()
self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
self.mock_cache_service.cache_data.assert_called_once()
def test_get_stock_data_smart_cache_fallback(self):
"""Test get_stock_data fallback when smart cache fails."""
# Mock cache service to raise exception
self.mock_cache_service.get_cached_data.side_effect = Exception("Cache error")
# Mock fallback data
mock_fallback_data = pd.DataFrame(
{"Open": [150.0], "Close": [150.5]},
index=pd.date_range("2024-01-01", periods=1),
)
self.mock_data_fetching_service.fetch_stock_data.return_value = (
mock_fallback_data
)
# Test
result = self.service.get_stock_data("AAPL")
# Assertions
assert not result.empty
self.mock_data_fetching_service.fetch_stock_data.assert_called()
def test_get_stock_info(self):
"""Test get_stock_info delegation."""
mock_info = {"longName": "Apple Inc."}
self.mock_data_fetching_service.fetch_stock_info.return_value = mock_info
# Test
result = self.service.get_stock_info("AAPL")
# Assertions
assert result == mock_info
self.mock_data_fetching_service.fetch_stock_info.assert_called_once_with("AAPL")
def test_get_realtime_data(self):
"""Test get_realtime_data delegation."""
mock_data = {"symbol": "AAPL", "price": 150.0}
self.mock_data_fetching_service.fetch_realtime_data.return_value = mock_data
# Test
result = self.service.get_realtime_data("AAPL")
# Assertions
assert result == mock_data
self.mock_data_fetching_service.fetch_realtime_data.assert_called_once_with(
"AAPL"
)
def test_get_multiple_realtime_data(self):
"""Test get_multiple_realtime_data delegation."""
mock_data = {"AAPL": {"price": 150.0}, "MSFT": {"price": 300.0}}
self.mock_data_fetching_service.fetch_multiple_realtime_data.return_value = (
mock_data
)
# Test
result = self.service.get_multiple_realtime_data(["AAPL", "MSFT"])
# Assertions
assert result == mock_data
self.mock_data_fetching_service.fetch_multiple_realtime_data.assert_called_once_with(
["AAPL", "MSFT"]
)
@patch("maverick_mcp.domain.stock_analysis.stock_analysis_service.datetime")
@patch("maverick_mcp.domain.stock_analysis.stock_analysis_service.pytz")
def test_is_market_open_weekday_during_hours(self, mock_pytz, mock_datetime):
"""Test market open check during trading hours on weekday."""
# Mock current time: Wednesday 10:00 AM ET
mock_now = Mock()
mock_now.weekday.return_value = 2 # Wednesday
mock_now.replace.return_value = mock_now
mock_now.__le__ = lambda self, other: True
mock_now.__ge__ = lambda self, other: True
mock_datetime.now.return_value = mock_now
mock_pytz.timezone.return_value.localize = lambda x: x
# Test
result = self.service.is_market_open()
# Assertions
assert result is True
@patch("maverick_mcp.domain.stock_analysis.stock_analysis_service.datetime")
def test_is_market_open_weekend(self, mock_datetime):
"""Test market open check on weekend."""
# Mock current time: Saturday
mock_now = Mock()
mock_now.weekday.return_value = 5 # Saturday
mock_datetime.now.return_value = mock_now
# Test
result = self.service.is_market_open()
# Assertions
assert result is False
def test_get_news(self):
"""Test get_news delegation."""
mock_news = pd.DataFrame({"title": ["Apple News"]})
self.mock_data_fetching_service.fetch_news.return_value = mock_news
# Test
result = self.service.get_news("AAPL", limit=5)
# Assertions
assert not result.empty
self.mock_data_fetching_service.fetch_news.assert_called_once_with("AAPL", 5)
def test_get_earnings(self):
"""Test get_earnings delegation."""
mock_earnings = {"earnings": {}}
self.mock_data_fetching_service.fetch_earnings.return_value = mock_earnings
# Test
result = self.service.get_earnings("AAPL")
# Assertions
assert result == mock_earnings
self.mock_data_fetching_service.fetch_earnings.assert_called_once_with("AAPL")
def test_get_recommendations(self):
"""Test get_recommendations delegation."""
mock_recs = pd.DataFrame({"firm": ["Goldman Sachs"]})
self.mock_data_fetching_service.fetch_recommendations.return_value = mock_recs
# Test
result = self.service.get_recommendations("AAPL")
# Assertions
assert not result.empty
self.mock_data_fetching_service.fetch_recommendations.assert_called_once_with(
"AAPL"
)
def test_is_etf(self):
"""Test is_etf delegation."""
self.mock_data_fetching_service.check_if_etf.return_value = True
# Test
result = self.service.is_etf("SPY")
# Assertions
assert result is True
self.mock_data_fetching_service.check_if_etf.assert_called_once_with("SPY")
def test_invalidate_cache(self):
"""Test invalidate_cache delegation."""
self.mock_cache_service.invalidate_cache.return_value = True
# Test
result = self.service.invalidate_cache("AAPL", "2024-01-01", "2024-01-02")
# Assertions
assert result is True
self.mock_cache_service.invalidate_cache.assert_called_once_with(
"AAPL", "2024-01-01", "2024-01-02"
)
def test_get_cache_stats(self):
"""Test get_cache_stats delegation."""
mock_stats = {"symbol": "AAPL", "total_records": 100}
self.mock_cache_service.get_cache_stats.return_value = mock_stats
# Test
result = self.service.get_cache_stats("AAPL")
# Assertions
assert result == mock_stats
self.mock_cache_service.get_cache_stats.assert_called_once_with("AAPL")
def test_get_trading_days(self):
"""Test get_trading_days helper method."""
with patch.object(self.service.market_calendar, "schedule") as mock_schedule:
# Mock schedule response
mock_df = Mock()
mock_df.index = pd.DatetimeIndex(["2024-01-01", "2024-01-02"])
mock_schedule.return_value = mock_df
# Test
result = self.service._get_trading_days("2024-01-01", "2024-01-02")
# Assertions
assert len(result) == 2
assert result[0] == pd.Timestamp("2024-01-01")
def test_is_trading_day(self):
"""Test is_trading_day helper method."""
with patch.object(self.service.market_calendar, "schedule") as mock_schedule:
# Mock schedule response with trading session
mock_df = Mock()
mock_df.__len__ = Mock(return_value=1) # Has trading session
mock_schedule.return_value = mock_df
# Test
result = self.service._is_trading_day("2024-01-01")
# Assertions
assert result is True
def test_get_last_trading_day_is_trading_day(self):
"""Test get_last_trading_day when date is already a trading day."""
with patch.object(self.service, "_is_trading_day") as mock_is_trading:
mock_is_trading.return_value = True
# Test
result = self.service._get_last_trading_day("2024-01-01")
# Assertions
assert result == pd.Timestamp("2024-01-01")
def test_get_last_trading_day_find_previous(self):
"""Test get_last_trading_day finding previous trading day."""
with patch.object(self.service, "_is_trading_day") as mock_is_trading:
# First call (date itself) returns False, second call (previous day) returns True
mock_is_trading.side_effect = [False, True]
# Test
result = self.service._get_last_trading_day("2024-01-01")
# Assertions
assert result == pd.Timestamp("2023-12-31")
def test_is_trading_day_between_true(self):
"""Test is_trading_day_between when there are trading days between dates."""
with patch.object(self.service, "_get_trading_days") as mock_trading_days:
mock_trading_days.return_value = pd.DatetimeIndex(["2024-01-02"])
# Test
start_date = pd.Timestamp("2024-01-01")
end_date = pd.Timestamp("2024-01-03")
result = self.service._is_trading_day_between(start_date, end_date)
# Assertions
assert result is True
def test_is_trading_day_between_false(self):
"""Test is_trading_day_between when there are no trading days between dates."""
with patch.object(self.service, "_get_trading_days") as mock_trading_days:
mock_trading_days.return_value = pd.DatetimeIndex([])
# Test
start_date = pd.Timestamp("2024-01-01")
end_date = pd.Timestamp("2024-01-02")
result = self.service._is_trading_day_between(start_date, end_date)
# Assertions
assert result is False
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/health_tools.py:
--------------------------------------------------------------------------------
```python
"""
MCP tools for health monitoring and system status.
These tools expose health monitoring functionality through the MCP interface,
allowing Claude to check system health, monitor component status, and get
real-time metrics about the backtesting system.
"""
import logging
from datetime import UTC, datetime
from typing import Any
from fastmcp import FastMCP
from maverick_mcp.config.settings import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
def register_health_tools(mcp: FastMCP):
"""Register all health monitoring tools with the MCP server."""
@mcp.tool()
async def get_system_health() -> dict[str, Any]:
"""
Get comprehensive system health status.
Returns detailed information about all system components including:
- Overall health status
- Component-by-component status
- Resource utilization
- Circuit breaker states
- Performance metrics
Returns:
Dictionary containing complete system health information
"""
try:
from maverick_mcp.api.routers.health_enhanced import (
_get_detailed_health_status,
)
health_status = await _get_detailed_health_status()
return {
"status": "success",
"data": health_status,
"timestamp": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Failed to get system health: {e}")
return {
"status": "error",
"error": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
@mcp.tool()
async def get_component_status(component_name: str = None) -> dict[str, Any]:
"""
Get status of a specific component or all components.
Args:
component_name: Name of the component to check (optional).
If not provided, returns status of all components.
Returns:
Dictionary containing component status information
"""
try:
from maverick_mcp.api.routers.health_enhanced import (
_get_detailed_health_status,
)
health_status = await _get_detailed_health_status()
components = health_status.get("components", {})
if component_name:
if component_name in components:
return {
"status": "success",
"component": component_name,
"data": components[component_name].__dict__,
"timestamp": datetime.now(UTC).isoformat(),
}
else:
return {
"status": "error",
"error": f"Component '{component_name}' not found",
"available_components": list(components.keys()),
"timestamp": datetime.now(UTC).isoformat(),
}
else:
return {
"status": "success",
"data": {name: comp.__dict__ for name, comp in components.items()},
"total_components": len(components),
"timestamp": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Failed to get component status: {e}")
return {
"status": "error",
"error": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
@mcp.tool()
async def get_circuit_breaker_status() -> dict[str, Any]:
"""
Get status of all circuit breakers.
Returns information about circuit breaker states, failure counts,
and performance metrics for all external API connections.
Returns:
Dictionary containing circuit breaker status information
"""
try:
from maverick_mcp.utils.circuit_breaker import (
get_all_circuit_breaker_status,
)
cb_status = get_all_circuit_breaker_status()
return {
"status": "success",
"data": cb_status,
"summary": {
"total_breakers": len(cb_status),
"states": {
"closed": sum(
1
for cb in cb_status.values()
if cb.get("state") == "closed"
),
"open": sum(
1 for cb in cb_status.values() if cb.get("state") == "open"
),
"half_open": sum(
1
for cb in cb_status.values()
if cb.get("state") == "half_open"
),
},
},
"timestamp": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Failed to get circuit breaker status: {e}")
return {
"status": "error",
"error": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
@mcp.tool()
async def get_resource_usage() -> dict[str, Any]:
"""
Get current system resource usage.
Returns information about CPU, memory, disk usage, and other
system resources being consumed by the backtesting system.
Returns:
Dictionary containing resource usage information
"""
try:
from maverick_mcp.api.routers.health_enhanced import _get_resource_usage
resource_usage = _get_resource_usage()
return {
"status": "success",
"data": resource_usage.__dict__,
"alerts": {
"high_cpu": resource_usage.cpu_percent > 80,
"high_memory": resource_usage.memory_percent > 85,
"high_disk": resource_usage.disk_percent > 90,
},
"timestamp": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Failed to get resource usage: {e}")
return {
"status": "error",
"error": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
@mcp.tool()
async def get_status_dashboard() -> dict[str, Any]:
"""
Get comprehensive status dashboard data.
Returns aggregated health status, performance metrics, alerts,
and historical trends for the entire backtesting system.
Returns:
Dictionary containing complete dashboard information
"""
try:
from maverick_mcp.monitoring.status_dashboard import get_dashboard_data
dashboard_data = await get_dashboard_data()
return {
"status": "success",
"data": dashboard_data,
"timestamp": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Failed to get status dashboard: {e}")
return {
"status": "error",
"error": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
@mcp.tool()
async def reset_circuit_breaker(breaker_name: str) -> dict[str, Any]:
"""
Reset a specific circuit breaker.
Args:
breaker_name: Name of the circuit breaker to reset
Returns:
Dictionary containing operation result
"""
try:
from maverick_mcp.utils.circuit_breaker import get_circuit_breaker_manager
manager = get_circuit_breaker_manager()
success = manager.reset_breaker(breaker_name)
if success:
return {
"status": "success",
"message": f"Circuit breaker '{breaker_name}' reset successfully",
"breaker_name": breaker_name,
"timestamp": datetime.now(UTC).isoformat(),
}
else:
return {
"status": "error",
"error": f"Circuit breaker '{breaker_name}' not found or could not be reset",
"breaker_name": breaker_name,
"timestamp": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Failed to reset circuit breaker {breaker_name}: {e}")
return {
"status": "error",
"error": str(e),
"breaker_name": breaker_name,
"timestamp": datetime.now(UTC).isoformat(),
}
@mcp.tool()
async def get_health_history() -> dict[str, Any]:
"""
Get historical health data for trend analysis.
Returns recent health check history including component status
changes, resource usage trends, and system performance over time.
Returns:
Dictionary containing historical health information
"""
try:
from maverick_mcp.monitoring.health_monitor import get_health_monitor
monitor = get_health_monitor()
monitoring_status = monitor.get_monitoring_status()
# Get historical data from dashboard
from maverick_mcp.monitoring.status_dashboard import get_status_dashboard
dashboard = get_status_dashboard()
dashboard_data = await dashboard.get_dashboard_data()
historical_data = dashboard_data.get("historical", {})
return {
"status": "success",
"data": {
"monitoring_status": monitoring_status,
"historical_data": historical_data,
"trends": {
"data_points": len(historical_data.get("data", [])),
"timespan_hours": historical_data.get("summary", {}).get(
"timespan_hours", 0
),
},
},
"timestamp": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Failed to get health history: {e}")
return {
"status": "error",
"error": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
@mcp.tool()
async def run_health_diagnostics() -> dict[str, Any]:
"""
Run comprehensive health diagnostics.
Performs a complete system health check including all components,
circuit breakers, resource usage, and generates a diagnostic report
with recommendations.
Returns:
Dictionary containing diagnostic results and recommendations
"""
try:
# Get all health information
from maverick_mcp.api.routers.health_enhanced import (
_get_detailed_health_status,
)
from maverick_mcp.monitoring.health_monitor import get_monitoring_status
from maverick_mcp.utils.circuit_breaker import (
get_all_circuit_breaker_status,
)
health_status = await _get_detailed_health_status()
cb_status = get_all_circuit_breaker_status()
monitoring_status = get_monitoring_status()
# Generate recommendations
recommendations = []
# Check component health
components = health_status.get("components", {})
unhealthy_components = [
name for name, comp in components.items() if comp.status == "unhealthy"
]
if unhealthy_components:
recommendations.append(
{
"type": "component_health",
"severity": "critical",
"message": f"Unhealthy components detected: {', '.join(unhealthy_components)}",
"action": "Check component logs and dependencies",
}
)
# Check circuit breakers
open_breakers = [
name for name, cb in cb_status.items() if cb.get("state") == "open"
]
if open_breakers:
recommendations.append(
{
"type": "circuit_breaker",
"severity": "warning",
"message": f"Open circuit breakers: {', '.join(open_breakers)}",
"action": "Check external service availability and consider resetting breakers",
}
)
# Check resource usage
resource_usage = health_status.get("resource_usage", {})
if resource_usage.get("memory_percent", 0) > 85:
recommendations.append(
{
"type": "resource_usage",
"severity": "warning",
"message": f"High memory usage: {resource_usage.get('memory_percent')}%",
"action": "Monitor memory usage trends and consider scaling",
}
)
if resource_usage.get("cpu_percent", 0) > 80:
recommendations.append(
{
"type": "resource_usage",
"severity": "warning",
"message": f"High CPU usage: {resource_usage.get('cpu_percent')}%",
"action": "Check for high-load operations and optimize if needed",
}
)
# Generate overall assessment
overall_health_score = 100
if unhealthy_components:
overall_health_score -= len(unhealthy_components) * 20
if open_breakers:
overall_health_score -= len(open_breakers) * 10
if resource_usage.get("memory_percent", 0) > 85:
overall_health_score -= 15
if resource_usage.get("cpu_percent", 0) > 80:
overall_health_score -= 10
overall_health_score = max(0, overall_health_score)
return {
"status": "success",
"data": {
"overall_health_score": overall_health_score,
"system_status": health_status.get("status", "unknown"),
"component_summary": {
"total": len(components),
"healthy": sum(
1 for c in components.values() if c.status == "healthy"
),
"degraded": sum(
1 for c in components.values() if c.status == "degraded"
),
"unhealthy": sum(
1 for c in components.values() if c.status == "unhealthy"
),
},
"circuit_breaker_summary": {
"total": len(cb_status),
"closed": sum(
1
for cb in cb_status.values()
if cb.get("state") == "closed"
),
"open": len(open_breakers),
"half_open": sum(
1
for cb in cb_status.values()
if cb.get("state") == "half_open"
),
},
"resource_summary": resource_usage,
"monitoring_summary": monitoring_status,
"recommendations": recommendations,
},
"timestamp": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Failed to run health diagnostics: {e}")
return {
"status": "error",
"error": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
logger.info("Health monitoring tools registered successfully")
```
--------------------------------------------------------------------------------
/maverick_mcp/tools/performance_monitoring.py:
--------------------------------------------------------------------------------
```python
"""
Performance monitoring tools for Maverick-MCP.
This module provides MCP tools for monitoring and analyzing system performance,
including Redis connection health, cache hit rates, query performance, and
database index usage.
"""
import logging
import time
from datetime import datetime
from typing import Any
from sqlalchemy import text
from maverick_mcp.data.performance import (
query_optimizer,
redis_manager,
request_cache,
)
from maverick_mcp.data.session_management import (
get_async_connection_pool_status,
get_async_db_session,
)
from maverick_mcp.providers.optimized_stock_data import optimized_stock_provider
logger = logging.getLogger(__name__)
async def get_redis_connection_health() -> dict[str, Any]:
"""
Get comprehensive Redis connection health metrics.
Returns:
Dictionary with Redis health information
"""
try:
metrics = redis_manager.get_metrics()
# Test basic Redis operations
test_key = f"health_check_{int(time.time())}"
test_value = "test_value"
client = await redis_manager.get_client()
if client:
# Test basic operations
start_time = time.time()
await client.set(test_key, test_value, ex=60) # 1 minute expiry
get_result = await client.get(test_key)
await client.delete(test_key)
operation_time = time.time() - start_time
redis_health = {
"status": "healthy",
"basic_operations_working": get_result == test_value,
"operation_latency_ms": round(operation_time * 1000, 2),
}
else:
redis_health = {
"status": "unhealthy",
"basic_operations_working": False,
"operation_latency_ms": None,
}
return {
"redis_health": redis_health,
"connection_metrics": metrics,
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"Error checking Redis health: {e}")
return {
"redis_health": {
"status": "error",
"error": str(e),
"basic_operations_working": False,
"operation_latency_ms": None,
},
"connection_metrics": {},
"timestamp": datetime.now().isoformat(),
}
async def get_cache_performance_metrics() -> dict[str, Any]:
"""
Get comprehensive cache performance metrics.
Returns:
Dictionary with cache performance information
"""
try:
# Get basic cache metrics
cache_metrics = request_cache.get_metrics()
# Test cache performance
test_key = f"cache_perf_test_{int(time.time())}"
test_data = {"test": "data", "timestamp": time.time()}
# Test cache operations
start_time = time.time()
set_success = await request_cache.set(test_key, test_data, ttl=60)
set_time = time.time() - start_time
start_time = time.time()
retrieved_data = await request_cache.get(test_key)
get_time = time.time() - start_time
# Cleanup
await request_cache.delete(test_key)
performance_test = {
"set_operation_ms": round(set_time * 1000, 2),
"get_operation_ms": round(get_time * 1000, 2),
"set_success": set_success,
"get_success": retrieved_data is not None,
"data_integrity": retrieved_data == test_data if retrieved_data else False,
}
# Get Redis-specific metrics if available
redis_metrics = redis_manager.get_metrics()
return {
"cache_performance": cache_metrics,
"performance_test": performance_test,
"redis_metrics": redis_metrics,
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"Error getting cache performance metrics: {e}")
return {
"error": str(e),
"timestamp": datetime.now().isoformat(),
}
async def get_query_performance_metrics() -> dict[str, Any]:
"""
Get database query performance metrics.
Returns:
Dictionary with query performance information
"""
try:
# Get query optimizer stats
query_stats = query_optimizer.get_query_stats()
# Get database connection pool stats
try:
async with get_async_db_session() as session:
pool_status = await get_async_connection_pool_status()
start_time = time.time()
result = await session.execute(text("SELECT 1 as test"))
result.fetchone()
db_latency = time.time() - start_time
db_health = {
"status": "healthy",
"latency_ms": round(db_latency * 1000, 2),
"pool_status": pool_status,
}
except Exception as e:
db_health = {
"status": "unhealthy",
"error": str(e),
"latency_ms": None,
"pool_status": {},
}
return {
"query_performance": query_stats,
"database_health": db_health,
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"Error getting query performance metrics: {e}")
return {
"error": str(e),
"timestamp": datetime.now().isoformat(),
}
async def analyze_database_indexes() -> dict[str, Any]:
"""
Analyze database index usage and provide recommendations.
Returns:
Dictionary with index analysis and recommendations
"""
try:
async with get_async_db_session() as session:
recommendations = await query_optimizer.analyze_missing_indexes(session)
# Get index usage statistics
index_usage_query = text(
"""
SELECT
schemaname,
tablename,
indexname,
idx_scan,
idx_tup_read,
idx_tup_fetch
FROM pg_stat_user_indexes
WHERE schemaname = 'public'
AND tablename LIKE 'stocks_%'
ORDER BY idx_scan DESC
"""
)
result = await session.execute(index_usage_query)
index_usage = [dict(row._mapping) for row in result.fetchall()] # type: ignore[attr-defined]
# Get table scan statistics
table_scan_query = text(
"""
SELECT
schemaname,
tablename,
seq_scan,
seq_tup_read,
idx_scan,
idx_tup_fetch,
CASE
WHEN seq_scan + idx_scan = 0 THEN 0
ELSE ROUND(100.0 * idx_scan / (seq_scan + idx_scan), 2)
END as index_usage_percent
FROM pg_stat_user_tables
WHERE schemaname = 'public'
AND tablename LIKE 'stocks_%'
ORDER BY seq_tup_read DESC
"""
)
result = await session.execute(table_scan_query)
table_stats = [dict(row._mapping) for row in result.fetchall()] # type: ignore[attr-defined]
# Identify tables with poor index usage
poor_index_usage = [
table
for table in table_stats
if table["index_usage_percent"] < 80 and table["seq_scan"] > 100
]
return {
"index_recommendations": recommendations,
"index_usage_stats": index_usage,
"table_scan_stats": table_stats,
"poor_index_usage": poor_index_usage,
"analysis_timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"Error analyzing database indexes: {e}")
return {
"error": str(e),
"timestamp": datetime.now().isoformat(),
}
async def get_comprehensive_performance_report() -> dict[str, Any]:
"""
Get a comprehensive performance report combining all metrics.
Returns:
Dictionary with complete performance analysis
"""
try:
# Gather all performance metrics
redis_health = await get_redis_connection_health()
cache_metrics = await get_cache_performance_metrics()
query_metrics = await get_query_performance_metrics()
index_analysis = await analyze_database_indexes()
provider_metrics = await optimized_stock_provider.get_performance_metrics()
# Calculate overall health scores
redis_score = 100 if redis_health["redis_health"]["status"] == "healthy" else 0
cache_hit_rate = cache_metrics.get("cache_performance", {}).get("hit_rate", 0)
cache_score = cache_hit_rate * 100
# Database performance score based on average query time
query_stats = query_metrics.get("query_performance", {}).get("query_stats", {})
avg_query_times = [stats.get("avg_time", 0) for stats in query_stats.values()]
avg_query_time = (
sum(avg_query_times) / len(avg_query_times) if avg_query_times else 0
)
db_score = max(0, 100 - (avg_query_time * 100)) # Penalty for slow queries
overall_score = (redis_score + cache_score + db_score) / 3
# Performance recommendations
recommendations = []
if redis_score < 100:
recommendations.append(
"Redis connection issues detected. Check Redis server status."
)
if cache_hit_rate < 0.8:
recommendations.append(
f"Cache hit rate is {cache_hit_rate:.1%}. Consider increasing TTL values or cache size."
)
if avg_query_time > 0.5:
recommendations.append(
f"Average query time is {avg_query_time:.2f}s. Consider adding database indexes."
)
poor_index_tables = index_analysis.get("poor_index_usage", [])
if poor_index_tables:
table_names = [table["tablename"] for table in poor_index_tables]
recommendations.append(
f"Poor index usage on tables: {', '.join(table_names)}"
)
if not recommendations:
recommendations.append("System performance is optimal.")
return {
"overall_health_score": round(overall_score, 1),
"component_scores": {
"redis": redis_score,
"cache": round(cache_score, 1),
"database": round(db_score, 1),
},
"recommendations": recommendations,
"detailed_metrics": {
"redis_health": redis_health,
"cache_performance": cache_metrics,
"query_performance": query_metrics,
"index_analysis": index_analysis,
"provider_metrics": provider_metrics,
},
"report_timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"Error generating comprehensive performance report: {e}")
return {
"error": str(e),
"timestamp": datetime.now().isoformat(),
}
async def optimize_cache_settings() -> dict[str, Any]:
"""
Analyze current cache usage and suggest optimal settings.
Returns:
Dictionary with cache optimization recommendations
"""
try:
# Get current cache metrics
cache_metrics = request_cache.get_metrics()
# Analyze cache performance patterns
hit_rate = cache_metrics.get("hit_rate", 0)
total_requests = cache_metrics.get("total_requests", 0)
# Get Redis memory usage if available
client = await redis_manager.get_client()
redis_info = {}
if client:
try:
redis_info = await client.info("memory")
except Exception as e:
logger.warning(f"Could not get Redis memory info: {e}")
# Generate recommendations
recommendations = []
optimal_settings = {}
if hit_rate < 0.7:
recommendations.append("Increase cache TTL values to improve hit rate")
optimal_settings["stock_data_ttl"] = 7200 # 2 hours instead of 1
optimal_settings["screening_ttl"] = 14400 # 4 hours instead of 2
elif hit_rate > 0.95:
recommendations.append(
"Consider reducing TTL values to ensure data freshness"
)
optimal_settings["stock_data_ttl"] = 1800 # 30 minutes
optimal_settings["screening_ttl"] = 3600 # 1 hour
if total_requests > 10000:
recommendations.append(
"High cache usage detected. Consider increasing Redis memory allocation"
)
# Memory usage recommendations
if redis_info.get("used_memory"):
used_memory_mb = int(redis_info["used_memory"]) / (1024 * 1024)
if used_memory_mb > 100:
recommendations.append(
f"Redis using {used_memory_mb:.1f}MB. Monitor memory usage."
)
return {
"current_performance": cache_metrics,
"redis_memory_info": redis_info,
"recommendations": recommendations,
"optimal_settings": optimal_settings,
"analysis_timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"Error optimizing cache settings: {e}")
return {
"error": str(e),
"timestamp": datetime.now().isoformat(),
}
async def clear_performance_caches(
cache_types: list[str] | None = None,
) -> dict[str, Any]:
"""
Clear specific performance caches for testing or maintenance.
Args:
cache_types: List of cache types to clear (stock_data, screening, market_data, all)
Returns:
Dictionary with cache clearing results
"""
if cache_types is None:
cache_types = ["all"]
results = {}
total_cleared = 0
try:
for cache_type in cache_types:
if cache_type == "all":
# Clear all caches
cleared = await request_cache.delete_pattern("cache:*")
results["all_caches"] = cleared
total_cleared += cleared
elif cache_type == "stock_data":
# Clear stock data caches
patterns = [
"cache:*get_stock_basic_info*",
"cache:*get_stock_price_data*",
"cache:*bulk_get_stock_data*",
]
cleared = 0
for pattern in patterns:
cleared += await request_cache.delete_pattern(pattern)
results["stock_data"] = cleared
total_cleared += cleared
elif cache_type == "screening":
# Clear screening caches
patterns = [
"cache:*get_maverick_recommendations*",
"cache:*get_trending_recommendations*",
]
cleared = 0
for pattern in patterns:
cleared += await request_cache.delete_pattern(pattern)
results["screening"] = cleared
total_cleared += cleared
elif cache_type == "market_data":
# Clear market data caches
patterns = [
"cache:*get_high_volume_stocks*",
"cache:*market_data*",
]
cleared = 0
for pattern in patterns:
cleared += await request_cache.delete_pattern(pattern)
results["market_data"] = cleared
total_cleared += cleared
return {
"success": True,
"total_entries_cleared": total_cleared,
"cleared_by_type": results,
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"Error clearing performance caches: {e}")
return {
"success": False,
"error": str(e),
"timestamp": datetime.now().isoformat(),
}
```
--------------------------------------------------------------------------------
/maverick_mcp/api/middleware/rate_limiting_enhanced.py:
--------------------------------------------------------------------------------
```python
"""Enhanced rate limiting middleware and utilities."""
from __future__ import annotations
import asyncio
import logging
import time
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from enum import Enum
from functools import wraps
from typing import Any
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, Response
from maverick_mcp.data.performance import redis_manager
from maverick_mcp.exceptions import RateLimitError
logger = logging.getLogger(__name__)
class RateLimitStrategy(str, Enum):
"""Supported rate limiting strategies."""
SLIDING_WINDOW = "sliding_window"
TOKEN_BUCKET = "token_bucket"
FIXED_WINDOW = "fixed_window"
class RateLimitTier(str, Enum):
"""Logical tiers used to classify API endpoints."""
PUBLIC = "public"
AUTHENTICATION = "authentication"
DATA_RETRIEVAL = "data_retrieval"
ANALYSIS = "analysis"
BULK_OPERATION = "bulk_operation"
ADMINISTRATIVE = "administrative"
class EndpointClassification:
"""Utility helpers for mapping endpoints to rate limit tiers."""
@staticmethod
def classify_endpoint(path: str) -> RateLimitTier:
normalized = path.lower()
if normalized in {
"/health",
"/metrics",
"/docs",
"/openapi.json",
"/api/docs",
"/api/openapi.json",
}:
return RateLimitTier.PUBLIC
if normalized.startswith("/api/auth"):
return RateLimitTier.AUTHENTICATION
if "admin" in normalized:
return RateLimitTier.ADMINISTRATIVE
if "bulk" in normalized or normalized.endswith("/batch"):
return RateLimitTier.BULK_OPERATION
if any(
segment in normalized
for segment in ("analysis", "technical", "screening", "portfolio")
):
return RateLimitTier.ANALYSIS
return RateLimitTier.DATA_RETRIEVAL
@dataclass(slots=True)
class RateLimitConfig:
"""Configuration options for rate limiting."""
public_limit: int = 100
auth_limit: int = 30
data_limit: int = 60
data_limit_anonymous: int = 15
analysis_limit: int = 30
analysis_limit_anonymous: int = 10
bulk_limit_per_hour: int = 10
admin_limit: int = 20
premium_multiplier: float = 3.0
enterprise_multiplier: float = 5.0
default_strategy: RateLimitStrategy = RateLimitStrategy.SLIDING_WINDOW
burst_multiplier: float = 1.5
window_size_seconds: int = 60
token_refill_rate: float = 1.0
max_tokens: int | None = None
log_violations: bool = True
alert_threshold: int = 5
def limit_for(
self, tier: RateLimitTier, *, authenticated: bool, role: str | None = None
) -> int:
limit = self.data_limit
if tier == RateLimitTier.PUBLIC:
limit = self.public_limit
elif tier == RateLimitTier.AUTHENTICATION:
limit = self.auth_limit
elif tier == RateLimitTier.DATA_RETRIEVAL:
limit = self.data_limit if authenticated else self.data_limit_anonymous
elif tier == RateLimitTier.ANALYSIS:
limit = (
self.analysis_limit if authenticated else self.analysis_limit_anonymous
)
elif tier == RateLimitTier.BULK_OPERATION:
limit = self.bulk_limit_per_hour
elif tier == RateLimitTier.ADMINISTRATIVE:
limit = self.admin_limit
normalized_role = (role or "").lower()
if normalized_role == "premium":
limit = int(limit * self.premium_multiplier)
elif normalized_role == "enterprise":
limit = int(limit * self.enterprise_multiplier)
return max(limit, 1)
class RateLimiter:
"""Core rate limiter that operates with Redis and an in-process fallback."""
def __init__(self, config: RateLimitConfig) -> None:
self.config = config
self._local_counters: dict[str, deque[float]] = defaultdict(deque)
self._violations: dict[str, int] = defaultdict(int)
@staticmethod
def _tiered_key(tier: RateLimitTier, identifier: str) -> str:
"""Compose a namespaced key for tracking tier-specific counters."""
return f"{tier.value}:{identifier}"
def _redis_key(self, prefix: str, *, tier: RateLimitTier, identifier: str) -> str:
"""Build a Redis key for the given strategy prefix and identifier."""
tiered_identifier = self._tiered_key(tier, identifier)
return f"rate_limit:{prefix}:{tiered_identifier}"
async def check_rate_limit(
self,
*,
key: str,
tier: RateLimitTier,
limit: int,
window_seconds: int,
strategy: RateLimitStrategy | None = None,
) -> tuple[bool, dict[str, Any]]:
strategy = strategy or self.config.default_strategy
client = await redis_manager.get_client()
tiered_key = self._tiered_key(tier, key)
if client is None:
allowed, info = self._check_local_rate_limit(
key=tiered_key,
limit=limit,
window_seconds=window_seconds,
)
info["strategy"] = strategy.value
info["tier"] = tier.value
info["fallback"] = True
return allowed, info
if strategy == RateLimitStrategy.SLIDING_WINDOW:
return await self._check_sliding_window(
client, key, tier, limit, window_seconds
)
if strategy == RateLimitStrategy.TOKEN_BUCKET:
return await self._check_token_bucket(
client, key, tier, limit, window_seconds
)
return await self._check_fixed_window(client, key, tier, limit, window_seconds)
def record_violation(self, key: str, *, tier: RateLimitTier | None = None) -> None:
namespaced_key = self._tiered_key(tier, key) if tier else key
self._violations[namespaced_key] += 1
def get_violation_count(
self, key: str, *, tier: RateLimitTier | None = None
) -> int:
namespaced_key = self._tiered_key(tier, key) if tier else key
return self._violations.get(namespaced_key, 0)
def _check_local_rate_limit(
self,
*,
key: str,
limit: int,
window_seconds: int,
) -> tuple[bool, dict[str, Any]]:
now = time.time()
window_start = now - window_seconds
bucket = self._local_counters[key]
while bucket and bucket[0] <= window_start:
bucket.popleft()
if len(bucket) >= limit:
retry_after = int(bucket[0] + window_seconds - now) + 1
return False, {
"limit": limit,
"remaining": 0,
"retry_after": max(retry_after, 1),
}
bucket.append(now)
remaining = max(limit - len(bucket), 0)
return True, {"limit": limit, "remaining": remaining}
async def _check_sliding_window(
self,
client: Any,
key: str,
tier: RateLimitTier,
limit: int,
window_seconds: int,
) -> tuple[bool, dict[str, Any]]:
redis_key = self._redis_key("sw", tier=tier, identifier=key)
now = time.time()
pipeline = client.pipeline()
pipeline.zremrangebyscore(redis_key, 0, now - window_seconds)
pipeline.zcard(redis_key)
pipeline.zadd(redis_key, {str(now): now})
pipeline.expire(redis_key, window_seconds)
results = await pipeline.execute()
current_count = int(results[1]) + 1
remaining = max(limit - current_count, 0)
info: dict[str, Any] = {
"limit": limit,
"remaining": remaining,
"burst_limit": int(limit * self.config.burst_multiplier),
"strategy": RateLimitStrategy.SLIDING_WINDOW.value,
"tier": tier.value,
}
if current_count > limit:
oldest = await client.zrange(redis_key, 0, 0, withscores=True)
retry_after = 1
if oldest:
oldest_ts = float(oldest[0][1])
retry_after = max(int(oldest_ts + window_seconds - now), 1)
info["remaining"] = 0
info["retry_after"] = retry_after
return False, info
return True, info
async def _check_token_bucket(
self,
client: Any,
key: str,
tier: RateLimitTier,
limit: int,
window_seconds: int,
) -> tuple[bool, dict[str, Any]]:
redis_key = self._redis_key("tb", tier=tier, identifier=key)
now = time.time()
state = await client.hgetall(redis_key)
def _decode_value(mapping: dict[Any, Any], key: str) -> str | None:
value = mapping.get(key)
if value is None:
value = mapping.get(key.encode("utf-8"))
if isinstance(value, bytes):
return value.decode("utf-8")
return value
if state:
tokens_value = _decode_value(state, "tokens")
last_refill_value = _decode_value(state, "last_refill")
else:
tokens_value = None
last_refill_value = None
tokens = float(tokens_value) if tokens_value is not None else float(limit)
last_refill = float(last_refill_value) if last_refill_value is not None else now
elapsed = max(now - last_refill, 0)
capacity = float(limit)
if self.config.max_tokens is not None:
capacity = min(capacity, float(self.config.max_tokens))
tokens = min(capacity, tokens + elapsed * self.config.token_refill_rate)
info: dict[str, Any] = {
"limit": limit,
"tokens": tokens,
"refill_rate": self.config.token_refill_rate,
"strategy": RateLimitStrategy.TOKEN_BUCKET.value,
"tier": tier.value,
}
if tokens < 1:
retry_after = int(max((1 - tokens) / self.config.token_refill_rate, 1))
info["remaining"] = 0
info["retry_after"] = retry_after
await client.hset(redis_key, mapping={"tokens": tokens, "last_refill": now})
await client.expire(redis_key, window_seconds)
return False, info
tokens -= 1
info["remaining"] = int(tokens)
await client.hset(redis_key, mapping={"tokens": tokens, "last_refill": now})
await client.expire(redis_key, window_seconds)
return True, info
async def _check_fixed_window(
self,
client: Any,
key: str,
tier: RateLimitTier,
limit: int,
window_seconds: int,
) -> tuple[bool, dict[str, Any]]:
redis_key = self._redis_key("fw", tier=tier, identifier=key)
pipeline = client.pipeline()
pipeline.incr(redis_key)
pipeline.expire(redis_key, window_seconds)
results = await pipeline.execute()
current = int(results[0])
remaining = max(limit - current, 0)
info = {
"limit": limit,
"current_count": current,
"remaining": remaining,
"strategy": RateLimitStrategy.FIXED_WINDOW.value,
"tier": tier.value,
}
if current > limit:
info["retry_after"] = window_seconds
info["remaining"] = 0
return False, info
return True, info
async def cleanup_old_data(self, *, older_than_hours: int = 24) -> None:
client = await redis_manager.get_client()
if client is None:
return
cutoff = time.time() - older_than_hours * 3600
cursor = 0
while True:
cursor, keys = await client.scan(
cursor=cursor, match="rate_limit:*", count=200
)
for raw_key in keys:
key = (
raw_key.decode()
if isinstance(raw_key, bytes | bytearray)
else raw_key
)
redis_type = await client.type(key)
if redis_type == "zset":
await client.zremrangebyscore(key, 0, cutoff)
if await client.zcard(key) == 0:
await client.delete(key)
elif redis_type == "string":
ttl = await client.ttl(key)
if ttl == -1:
await client.expire(key, int(older_than_hours * 3600))
if cursor == 0:
break
class EnhancedRateLimitMiddleware(BaseHTTPMiddleware):
"""FastAPI middleware that enforces rate limits for each request."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[override]
super().__init__(app)
self.config = config or RateLimitConfig()
self.rate_limiter = RateLimiter(self.config)
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response: # type: ignore[override]
path = request.url.path
tier = EndpointClassification.classify_endpoint(path)
user_id = getattr(request.state, "user_id", None)
user_context = getattr(request.state, "user_context", {}) or {}
role = user_context.get("role") if isinstance(user_context, dict) else None
authenticated = bool(user_id)
client = getattr(request, "client", None)
client_host = getattr(client, "host", None) if client else None
key = str(user_id or client_host or "anonymous")
limit = self.config.limit_for(tier, authenticated=authenticated, role=role)
allowed, info = await self.rate_limiter.check_rate_limit(
key=key,
tier=tier,
limit=limit,
window_seconds=self.config.window_size_seconds,
)
if not allowed:
retry_after = int(info.get("retry_after", 1))
self.rate_limiter.record_violation(key, tier=tier)
if self.config.log_violations:
logger.warning("Rate limit exceeded for %s (%s)", key, tier.value)
error = RateLimitError(retry_after=retry_after, context={"info": info})
headers = {"Retry-After": str(retry_after)}
body = {"error": error.message}
return JSONResponse(body, status_code=error.status_code, headers=headers)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit)
response.headers["X-RateLimit-Remaining"] = str(
max(info.get("remaining", limit), 0)
)
response.headers["X-RateLimit-Tier"] = tier.value
return response
_default_config = RateLimitConfig()
_default_limiter = RateLimiter(_default_config)
def _extract_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None:
for value in list(args) + list(kwargs.values()):
if isinstance(value, Request):
return value
if hasattr(value, "state") and hasattr(value, "url"):
return value # type: ignore[return-value]
return None
def rate_limit(
*,
requests_per_minute: int,
strategy: RateLimitStrategy | None = None,
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
"""Function decorator enforcing rate limits for arbitrary callables."""
window_seconds = 60
def decorator(func: Callable[..., Awaitable[Any]]):
if not asyncio.iscoroutinefunction(func):
raise TypeError(
"rate_limit decorator can only be applied to async callables"
)
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = _extract_request(args, kwargs)
if request is None:
return await func(*args, **kwargs)
tier = EndpointClassification.classify_endpoint(request.url.path)
key = str(getattr(request.state, "user_id", None) or request.url.path)
allowed, info = await _default_limiter.check_rate_limit(
key=key,
tier=tier,
limit=requests_per_minute,
window_seconds=window_seconds,
strategy=strategy,
)
if not allowed:
retry_after = int(info.get("retry_after", 1))
raise RateLimitError(retry_after=retry_after, context={"info": info})
return await func(*args, **kwargs)
return wrapper
return decorator
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/screening_ddd.py:
--------------------------------------------------------------------------------
```python
"""
DDD-based screening router for Maverick-MCP.
This module demonstrates Domain-Driven Design principles with clear
separation between layers and dependency injection.
"""
import logging
import time
from datetime import datetime
from decimal import Decimal
from typing import Any
from fastapi import Depends
from fastmcp import FastMCP
from pydantic import BaseModel, Field
# Application layer imports
from maverick_mcp.application.screening.queries import (
GetAllScreeningResultsQuery,
GetScreeningResultsQuery,
GetScreeningStatisticsQuery,
)
from maverick_mcp.domain.screening.services import IStockRepository
# Domain layer imports
from maverick_mcp.domain.screening.value_objects import (
ScreeningCriteria,
ScreeningStrategy,
)
# Infrastructure layer imports
from maverick_mcp.infrastructure.screening.repositories import (
CachedStockRepository,
PostgresStockRepository,
)
logger = logging.getLogger(__name__)
# Create the DDD screening router
screening_ddd_router: FastMCP = FastMCP("Stock_Screening_DDD")
# Dependency Injection Setup
def get_stock_repository() -> IStockRepository:
"""
Dependency injection for stock repository.
This function provides the concrete repository implementation
with caching capabilities.
"""
base_repository = PostgresStockRepository()
cached_repository = CachedStockRepository(base_repository, cache_ttl_seconds=300)
return cached_repository
# Request/Response Models for MCP Tools
class ScreeningRequest(BaseModel):
"""Request model for screening operations."""
strategy: str = Field(
description="Screening strategy (maverick_bullish, maverick_bearish, trending_stage2)"
)
limit: int = Field(
default=20, ge=1, le=100, description="Maximum number of results"
)
min_momentum_score: float | None = Field(
default=None, ge=0, le=100, description="Minimum momentum score"
)
min_volume: int | None = Field(
default=None, ge=0, description="Minimum daily volume"
)
max_price: float | None = Field(
default=None, gt=0, description="Maximum stock price"
)
min_price: float | None = Field(
default=None, gt=0, description="Minimum stock price"
)
require_above_sma50: bool = Field(
default=False, description="Require price above SMA 50"
)
require_pattern_detected: bool = Field(
default=False, description="Require pattern detection"
)
class AllScreeningRequest(BaseModel):
"""Request model for all screening strategies."""
limit_per_strategy: int = Field(
default=10, ge=1, le=50, description="Results per strategy"
)
min_momentum_score: float | None = Field(
default=None, ge=0, le=100, description="Minimum momentum score filter"
)
class StatisticsRequest(BaseModel):
"""Request model for screening statistics."""
strategy: str | None = Field(
default=None, description="Specific strategy to analyze (None for all)"
)
limit: int = Field(
default=100, ge=1, le=500, description="Maximum results to analyze"
)
# Helper Functions
def _create_screening_criteria_from_request(
request: ScreeningRequest,
) -> ScreeningCriteria:
"""Convert API request to domain value object."""
return ScreeningCriteria(
min_momentum_score=Decimal(str(request.min_momentum_score))
if request.min_momentum_score
else None,
min_volume=request.min_volume,
min_price=Decimal(str(request.min_price)) if request.min_price else None,
max_price=Decimal(str(request.max_price)) if request.max_price else None,
require_above_sma50=request.require_above_sma50,
require_pattern_detected=request.require_pattern_detected,
)
def _convert_collection_to_dto(
collection, execution_time_ms: float, applied_filters: dict[str, Any]
) -> dict[str, Any]:
"""Convert domain collection to API response DTO."""
results_dto = []
for result in collection.results:
result_dict = result.to_dict()
# Convert domain result to DTO format
result_dto = {
"stock_symbol": result_dict["stock_symbol"],
"screening_date": result_dict["screening_date"],
"close_price": result_dict["close_price"],
"volume": result.volume,
"momentum_score": result_dict["momentum_score"],
"adr_percentage": result_dict["adr_percentage"],
"ema_21": float(result.ema_21),
"sma_50": float(result.sma_50),
"sma_150": float(result.sma_150),
"sma_200": float(result.sma_200),
"avg_volume_30d": float(result.avg_volume_30d),
"atr": float(result.atr),
"pattern": result.pattern,
"squeeze": result.squeeze,
"consolidation": result.vcp,
"entry_signal": result.entry_signal,
"combined_score": result.combined_score,
"bear_score": result.bear_score,
"quality_score": result_dict["quality_score"],
"is_bullish": result_dict["is_bullish"],
"is_bearish": result_dict["is_bearish"],
"is_trending_stage2": result_dict["is_trending_stage2"],
"risk_reward_ratio": result_dict["risk_reward_ratio"],
# Bearish-specific fields
"rsi_14": float(result.rsi_14) if result.rsi_14 else None,
"macd": float(result.macd) if result.macd else None,
"macd_signal": float(result.macd_signal) if result.macd_signal else None,
"macd_histogram": float(result.macd_histogram)
if result.macd_histogram
else None,
"distribution_days_20": result.distribution_days_20,
"atr_contraction": result.atr_contraction,
"big_down_volume": result.big_down_volume,
}
results_dto.append(result_dto)
return {
"strategy_used": collection.strategy_used,
"screening_timestamp": collection.screening_timestamp.isoformat(),
"total_candidates_analyzed": collection.total_candidates_analyzed,
"results_returned": len(collection.results),
"results": results_dto,
"statistics": collection.get_statistics(),
"applied_filters": applied_filters,
"sorting_applied": {"field": "strategy_default", "descending": True},
"status": "success",
"execution_time_ms": execution_time_ms,
"warnings": [],
}
# MCP Tools
@screening_ddd_router.tool()
async def get_screening_results_ddd(
request: ScreeningRequest,
repository: IStockRepository = Depends(get_stock_repository),
) -> dict[str, Any]:
"""
Get stock screening results using Domain-Driven Design architecture.
This tool demonstrates DDD principles with clean separation of concerns:
- Domain layer: Pure business logic and rules
- Application layer: Orchestration and use cases
- Infrastructure layer: Data access and external services
- API layer: Request/response handling with dependency injection
Args:
request: Screening parameters including strategy and filters
repository: Injected repository dependency
Returns:
Comprehensive screening results with business intelligence
"""
start_time = time.time()
try:
# Validate strategy
try:
strategy = ScreeningStrategy(request.strategy)
except ValueError:
return {
"status": "error",
"error_code": "INVALID_STRATEGY",
"error_message": f"Invalid strategy: {request.strategy}",
"valid_strategies": [s.value for s in ScreeningStrategy],
"timestamp": datetime.utcnow().isoformat(),
}
# Convert request to domain value objects
criteria = _create_screening_criteria_from_request(request)
# Execute application query
query = GetScreeningResultsQuery(repository)
collection = await query.execute(
strategy=strategy, limit=request.limit, criteria=criteria
)
# Calculate execution time
execution_time_ms = (time.time() - start_time) * 1000
# Convert to API response
applied_filters = {
"strategy": request.strategy,
"limit": request.limit,
"min_momentum_score": request.min_momentum_score,
"min_volume": request.min_volume,
"min_price": request.min_price,
"max_price": request.max_price,
"require_above_sma50": request.require_above_sma50,
"require_pattern_detected": request.require_pattern_detected,
}
response = _convert_collection_to_dto(
collection, execution_time_ms, applied_filters
)
logger.info(
f"DDD screening completed: {strategy.value}, "
f"{len(collection.results)} results, {execution_time_ms:.1f}ms"
)
return response
except Exception as e:
execution_time_ms = (time.time() - start_time) * 1000
logger.error(f"Error in DDD screening: {e}")
return {
"status": "error",
"error_code": "SCREENING_FAILED",
"error_message": str(e),
"execution_time_ms": execution_time_ms,
"timestamp": datetime.utcnow().isoformat(),
}
@screening_ddd_router.tool()
async def get_all_screening_results_ddd(
request: AllScreeningRequest,
repository: IStockRepository = Depends(get_stock_repository),
) -> dict[str, Any]:
"""
Get screening results from all strategies using DDD architecture.
This tool executes all available screening strategies and provides
comprehensive cross-strategy analysis and insights.
Args:
request: Parameters for multi-strategy screening
repository: Injected repository dependency
Returns:
Results from all strategies with cross-strategy analysis
"""
start_time = time.time()
try:
# Create criteria if filters provided
criteria = None
if request.min_momentum_score:
criteria = ScreeningCriteria(
min_momentum_score=Decimal(str(request.min_momentum_score))
)
# Execute application query
query = GetAllScreeningResultsQuery(repository)
all_collections = await query.execute(
limit_per_strategy=request.limit_per_strategy, criteria=criteria
)
# Calculate execution time
execution_time_ms = (time.time() - start_time) * 1000
# Convert collections to DTOs
response = {
"screening_timestamp": datetime.utcnow().isoformat(),
"strategies_executed": list(all_collections.keys()),
"execution_time_ms": execution_time_ms,
"status": "success",
"errors": [],
}
# Add individual strategy results
for strategy_name, collection in all_collections.items():
applied_filters = {"limit": request.limit_per_strategy}
if request.min_momentum_score:
applied_filters["min_momentum_score"] = request.min_momentum_score
strategy_dto = _convert_collection_to_dto(
collection,
execution_time_ms
/ len(all_collections), # Approximate per-strategy time
applied_filters,
)
# Map strategy names to response fields
if strategy_name == ScreeningStrategy.MAVERICK_BULLISH.value:
response["maverick_bullish"] = strategy_dto
elif strategy_name == ScreeningStrategy.MAVERICK_BEARISH.value:
response["maverick_bearish"] = strategy_dto
elif strategy_name == ScreeningStrategy.TRENDING_STAGE2.value:
response["trending_stage2"] = strategy_dto
# Add cross-strategy analysis
statistics_query = GetScreeningStatisticsQuery(repository)
stats = await statistics_query.execute(limit=request.limit_per_strategy * 3)
response["cross_strategy_analysis"] = stats.get("cross_strategy_analysis", {})
response["overall_summary"] = stats.get("overall_summary", {})
logger.info(
f"DDD all screening completed: {len(all_collections)} strategies, "
f"{execution_time_ms:.1f}ms"
)
return response
except Exception as e:
execution_time_ms = (time.time() - start_time) * 1000
logger.error(f"Error in DDD all screening: {e}")
return {
"screening_timestamp": datetime.utcnow().isoformat(),
"strategies_executed": [],
"status": "error",
"error_code": "ALL_SCREENING_FAILED",
"error_message": str(e),
"execution_time_ms": execution_time_ms,
"errors": [str(e)],
}
@screening_ddd_router.tool()
async def get_screening_statistics_ddd(
request: StatisticsRequest,
repository: IStockRepository = Depends(get_stock_repository),
) -> dict[str, Any]:
"""
Get comprehensive screening statistics and analytics using DDD architecture.
This tool provides business intelligence and analytical insights
for screening operations, demonstrating how domain services can
calculate complex business metrics.
Args:
request: Statistics parameters
repository: Injected repository dependency
Returns:
Comprehensive statistics and business intelligence
"""
start_time = time.time()
try:
# Validate strategy if provided
strategy = None
if request.strategy:
try:
strategy = ScreeningStrategy(request.strategy)
except ValueError:
return {
"status": "error",
"error_code": "INVALID_STRATEGY",
"error_message": f"Invalid strategy: {request.strategy}",
"valid_strategies": [s.value for s in ScreeningStrategy],
"timestamp": datetime.utcnow().isoformat(),
}
# Execute statistics query
query = GetScreeningStatisticsQuery(repository)
stats = await query.execute(strategy=strategy, limit=request.limit)
# Calculate execution time
execution_time_ms = (time.time() - start_time) * 1000
# Enhance response with metadata
stats.update(
{
"execution_time_ms": execution_time_ms,
"status": "success",
"analysis_scope": "single" if strategy else "all",
"results_analyzed": request.limit,
}
)
logger.info(
f"DDD statistics completed: {strategy.value if strategy else 'all'}, "
f"{execution_time_ms:.1f}ms"
)
return stats
except Exception as e:
execution_time_ms = (time.time() - start_time) * 1000
logger.error(f"Error in DDD statistics: {e}")
return {
"status": "error",
"error_code": "STATISTICS_FAILED",
"error_message": str(e),
"execution_time_ms": execution_time_ms,
"timestamp": datetime.utcnow().isoformat(),
"analysis_scope": "failed",
"results_analyzed": 0,
}
@screening_ddd_router.tool()
async def get_repository_cache_stats(
repository: IStockRepository = Depends(get_stock_repository),
) -> dict[str, Any]:
"""
Get repository cache statistics for monitoring and optimization.
This tool demonstrates infrastructure layer monitoring capabilities
and provides insights into caching performance.
Args:
repository: Injected repository dependency
Returns:
Cache statistics and performance metrics
"""
try:
# Check if repository supports cache statistics
if hasattr(repository, "get_cache_stats"):
cache_stats = repository.get_cache_stats()
return {
"status": "success",
"cache_enabled": True,
"cache_statistics": cache_stats,
"timestamp": datetime.utcnow().isoformat(),
}
else:
return {
"status": "success",
"cache_enabled": False,
"message": "Repository does not support caching",
"timestamp": datetime.utcnow().isoformat(),
}
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return {
"status": "error",
"error_message": str(e),
"timestamp": datetime.utcnow().isoformat(),
}
```
--------------------------------------------------------------------------------
/maverick_mcp/workflows/agents/strategy_selector.py:
--------------------------------------------------------------------------------
```python
"""
Strategy Selector Agent for intelligent strategy recommendation.
This agent analyzes market regime and selects the most appropriate trading strategies
based on current market conditions, volatility, and trend characteristics.
"""
import logging
from datetime import datetime
from typing import Any
from maverick_mcp.backtesting.strategies.templates import (
get_strategy_info,
list_available_strategies,
)
from maverick_mcp.workflows.state import BacktestingWorkflowState
logger = logging.getLogger(__name__)
class StrategySelectorAgent:
"""Intelligent strategy selector based on market regime analysis."""
def __init__(self):
"""Initialize strategy selector agent."""
# Strategy fitness mapping for different market regimes
self.REGIME_STRATEGY_FITNESS = {
"trending": {
"sma_cross": 0.9,
"ema_cross": 0.85,
"macd": 0.8,
"breakout": 0.9,
"momentum": 0.85,
"rsi": 0.3, # Poor for trending markets
"bollinger": 0.4,
"mean_reversion": 0.2,
},
"ranging": {
"rsi": 0.9,
"bollinger": 0.85,
"mean_reversion": 0.9,
"sma_cross": 0.3, # Poor for ranging markets
"ema_cross": 0.3,
"breakout": 0.2,
"momentum": 0.25,
"macd": 0.5,
},
"volatile": {
"breakout": 0.8,
"momentum": 0.7,
"volatility_breakout": 0.9,
"bollinger": 0.7,
"sma_cross": 0.4,
"rsi": 0.6,
"mean_reversion": 0.5,
"macd": 0.5,
},
"volatile_trending": {
"breakout": 0.85,
"momentum": 0.8,
"volatility_breakout": 0.9,
"macd": 0.7,
"ema_cross": 0.6,
"sma_cross": 0.6,
"rsi": 0.4,
"bollinger": 0.6,
},
"low_volume": {
"sma_cross": 0.7,
"ema_cross": 0.7,
"rsi": 0.6,
"mean_reversion": 0.6,
"breakout": 0.3, # Poor for low volume
"momentum": 0.4,
"bollinger": 0.6,
"macd": 0.6,
},
"low_volume_ranging": {
"rsi": 0.8,
"mean_reversion": 0.8,
"bollinger": 0.7,
"sma_cross": 0.5,
"ema_cross": 0.5,
"breakout": 0.2,
"momentum": 0.3,
"macd": 0.4,
},
"unknown": {
# Balanced approach for unknown regimes
"sma_cross": 0.6,
"ema_cross": 0.6,
"rsi": 0.6,
"macd": 0.6,
"bollinger": 0.6,
"momentum": 0.5,
"breakout": 0.5,
"mean_reversion": 0.5,
},
}
# Additional fitness adjustments based on market conditions
self.CONDITION_ADJUSTMENTS = {
"high_volatility": {
"rsi": -0.1,
"breakout": 0.1,
"volatility_breakout": 0.15,
},
"low_volatility": {
"mean_reversion": 0.1,
"rsi": 0.1,
"breakout": -0.1,
},
"high_volume": {
"breakout": 0.1,
"momentum": 0.1,
"sma_cross": 0.05,
},
"low_volume": {
"breakout": -0.15,
"momentum": -0.1,
"mean_reversion": 0.05,
},
}
logger.info("StrategySelectorAgent initialized")
async def select_strategies(
self, state: BacktestingWorkflowState
) -> BacktestingWorkflowState:
"""Select optimal strategies based on market regime analysis.
Args:
state: Current workflow state with market regime analysis
Returns:
Updated state with strategy selection results
"""
try:
logger.info(
f"Selecting strategies for {state['symbol']} in {state['market_regime']} regime"
)
# Get available strategies
available_strategies = list_available_strategies()
# Calculate strategy fitness scores
strategy_rankings = self._calculate_strategy_fitness(
state["market_regime"],
state["market_conditions"],
available_strategies,
state["regime_confidence"],
)
# Select top strategies
selected_strategies = self._select_top_strategies(
strategy_rankings,
user_preference=state["requested_strategy"],
max_strategies=5, # Limit to top 5 for optimization efficiency
)
# Generate strategy candidates with metadata
candidates = self._generate_strategy_candidates(
selected_strategies, available_strategies
)
# Create selection reasoning
reasoning = self._generate_selection_reasoning(
state["market_regime"],
state["regime_confidence"],
selected_strategies,
state["market_conditions"],
)
# Calculate selection confidence
selection_confidence = self._calculate_selection_confidence(
strategy_rankings,
selected_strategies,
state["regime_confidence"],
)
# Update state
state["candidate_strategies"] = candidates
state["strategy_rankings"] = strategy_rankings
state["selected_strategies"] = selected_strategies
state["strategy_selection_reasoning"] = reasoning
state["strategy_selection_confidence"] = selection_confidence
# Update workflow status
state["workflow_status"] = "optimizing"
state["current_step"] = "strategy_selection_completed"
state["steps_completed"].append("strategy_selection")
logger.info(
f"Strategy selection completed for {state['symbol']}: "
f"Selected {len(selected_strategies)} strategies with confidence {selection_confidence:.2f}"
)
return state
except Exception as e:
error_info = {
"step": "strategy_selection",
"error": str(e),
"timestamp": datetime.now().isoformat(),
"symbol": state["symbol"],
}
state["errors_encountered"].append(error_info)
# Fallback to basic strategy set
fallback_strategies = ["sma_cross", "rsi", "macd"]
state["selected_strategies"] = fallback_strategies
state["strategy_selection_confidence"] = 0.3
state["fallback_strategies_used"].append("strategy_selection_fallback")
logger.error(f"Strategy selection failed for {state['symbol']}: {e}")
return state
def _calculate_strategy_fitness(
self,
regime: str,
market_conditions: dict[str, Any],
available_strategies: list[str],
regime_confidence: float,
) -> dict[str, float]:
"""Calculate fitness scores for all available strategies."""
fitness_scores = {}
# Base fitness from regime mapping
base_fitness = self.REGIME_STRATEGY_FITNESS.get(
regime, self.REGIME_STRATEGY_FITNESS["unknown"]
)
for strategy in available_strategies:
# Start with base fitness score
score = base_fitness.get(strategy, 0.5) # Default to neutral if not defined
# Apply condition-based adjustments
score = self._apply_condition_adjustments(
score, strategy, market_conditions
)
# Weight by regime confidence
# If low confidence, move scores toward neutral (0.5)
confidence_weight = regime_confidence
score = score * confidence_weight + 0.5 * (1 - confidence_weight)
# Ensure score is in valid range
fitness_scores[strategy] = max(0.0, min(1.0, score))
return fitness_scores
def _apply_condition_adjustments(
self, base_score: float, strategy: str, market_conditions: dict[str, Any]
) -> float:
"""Apply market condition adjustments to base fitness score."""
score = base_score
# Get relevant conditions
volatility_regime = market_conditions.get(
"volatility_regime", "medium_volatility"
)
volume_regime = market_conditions.get("volume_regime", "normal_volume")
# Apply volatility adjustments
if volatility_regime in self.CONDITION_ADJUSTMENTS:
adjustment = self.CONDITION_ADJUSTMENTS[volatility_regime].get(strategy, 0)
score += adjustment
# Apply volume adjustments
if volume_regime in self.CONDITION_ADJUSTMENTS:
adjustment = self.CONDITION_ADJUSTMENTS[volume_regime].get(strategy, 0)
score += adjustment
return score
def _select_top_strategies(
self,
strategy_rankings: dict[str, float],
user_preference: str | None = None,
max_strategies: int = 5,
) -> list[str]:
"""Select top strategies based on fitness scores and user preferences."""
# Sort strategies by fitness score
sorted_strategies = sorted(
strategy_rankings.items(), key=lambda x: x[1], reverse=True
)
selected = []
# Always include user preference if specified and available
if user_preference and user_preference in strategy_rankings:
selected.append(user_preference)
logger.info(f"Including user-requested strategy: {user_preference}")
# Add top strategies up to limit
for strategy, score in sorted_strategies:
if len(selected) >= max_strategies:
break
if strategy not in selected and score > 0.4: # Minimum threshold
selected.append(strategy)
# Ensure we have at least 2 strategies
if len(selected) < 2:
for strategy, _ in sorted_strategies:
if strategy not in selected:
selected.append(strategy)
if len(selected) >= 2:
break
return selected
def _generate_strategy_candidates(
self, selected_strategies: list[str], available_strategies: list[str]
) -> list[dict[str, Any]]:
"""Generate detailed candidate information for selected strategies."""
candidates = []
for strategy in selected_strategies:
if strategy in available_strategies:
strategy_info = get_strategy_info(strategy)
candidates.append(
{
"strategy": strategy,
"name": strategy_info.get("name", strategy.title()),
"description": strategy_info.get("description", ""),
"category": strategy_info.get("category", "unknown"),
"parameters": strategy_info.get("parameters", {}),
"risk_level": strategy_info.get("risk_level", "medium"),
"best_market_conditions": strategy_info.get(
"best_conditions", []
),
}
)
return candidates
def _generate_selection_reasoning(
self,
regime: str,
regime_confidence: float,
selected_strategies: list[str],
market_conditions: dict[str, Any],
) -> str:
"""Generate human-readable reasoning for strategy selection."""
reasoning_parts = []
# Market regime reasoning
reasoning_parts.append(
f"Market regime identified as '{regime}' with {regime_confidence:.1%} confidence."
)
# Strategy selection reasoning
if regime == "trending":
reasoning_parts.append(
"In trending markets, trend-following strategies like moving average crossovers "
"and momentum strategies typically perform well."
)
elif regime == "ranging":
reasoning_parts.append(
"In ranging markets, mean-reversion strategies like RSI and Bollinger Bands "
"are favored as they capitalize on price oscillations within a range."
)
elif regime == "volatile":
reasoning_parts.append(
"In volatile markets, breakout strategies and volatility-based approaches "
"can capture large price movements effectively."
)
# Condition-specific reasoning
volatility_regime = market_conditions.get("volatility_regime", "")
if volatility_regime == "high_volatility":
reasoning_parts.append(
"High volatility conditions favor strategies that can handle larger price swings."
)
elif volatility_regime == "low_volatility":
reasoning_parts.append(
"Low volatility conditions favor mean-reversion and range-bound strategies."
)
volume_regime = market_conditions.get("volume_regime", "")
if volume_regime == "low_volume":
reasoning_parts.append(
"Low volume conditions reduce reliability of breakout strategies and favor "
"trend-following approaches with longer timeframes."
)
# Selected strategies summary
reasoning_parts.append(
f"Selected strategies: {', '.join(selected_strategies)} "
f"based on their historical performance in similar market conditions."
)
return " ".join(reasoning_parts)
def _calculate_selection_confidence(
self,
strategy_rankings: dict[str, float],
selected_strategies: list[str],
regime_confidence: float,
) -> float:
"""Calculate confidence in strategy selection."""
if not selected_strategies or not strategy_rankings:
return 0.0
# Average fitness of selected strategies
selected_scores = [strategy_rankings.get(s, 0.5) for s in selected_strategies]
avg_selected_fitness = sum(selected_scores) / len(selected_scores)
# Score spread (higher spread = more confident in selection)
all_scores = list(strategy_rankings.values())
score_std = (
sum((s - sum(all_scores) / len(all_scores)) ** 2 for s in all_scores) ** 0.5
)
score_spread = (
score_std / (sum(all_scores) / len(all_scores)) if all_scores else 0
)
# Combine factors
fitness_confidence = avg_selected_fitness # 0-1
spread_confidence = min(score_spread, 1.0) # Normalize spread
# Weight by regime confidence
total_confidence = (
fitness_confidence * 0.5 + spread_confidence * 0.2 + regime_confidence * 0.3
)
return max(0.1, min(0.95, total_confidence))
def get_strategy_compatibility_matrix(self) -> dict[str, dict[str, float]]:
"""Get compatibility matrix showing strategy fitness for each regime."""
return self.REGIME_STRATEGY_FITNESS.copy()
def explain_strategy_selection(
self, regime: str, strategy: str, market_conditions: dict[str, Any]
) -> str:
"""Explain why a specific strategy is suitable for given conditions."""
base_fitness = self.REGIME_STRATEGY_FITNESS.get(regime, {}).get(strategy, 0.5)
explanations = {
"sma_cross": {
"trending": "SMA crossovers excel in trending markets by catching trend changes early.",
"ranging": "SMA crossovers produce many false signals in ranging markets.",
},
"rsi": {
"ranging": "RSI is ideal for ranging markets, buying oversold and selling overbought levels.",
"trending": "RSI can remain overbought/oversold for extended periods in strong trends.",
},
"breakout": {
"volatile": "Breakout strategies capitalize on high volatility and strong price moves.",
"ranging": "Breakout strategies struggle in ranging markets with frequent false breakouts.",
},
}
specific_explanation = explanations.get(strategy, {}).get(regime, "")
return f"Strategy '{strategy}' has {base_fitness:.1%} fitness for '{regime}' markets. {specific_explanation}"
```
--------------------------------------------------------------------------------
/scripts/seed_sp500.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
S&P 500 database seeding script for MaverickMCP.
This script populates the database with all S&P 500 stocks, including
company information, sector data, and comprehensive stock details.
"""
import logging
import os
import sys
import time
from pathlib import Path
# Add the project root to the Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# noqa: E402 - imports must come after sys.path modification
import pandas as pd # noqa: E402
import yfinance as yf # noqa: E402
from sqlalchemy import create_engine, text # noqa: E402
from sqlalchemy.orm import sessionmaker # noqa: E402
from maverick_mcp.data.models import ( # noqa: E402
MaverickBearStocks,
MaverickStocks,
PriceCache,
Stock,
SupplyDemandBreakoutStocks,
TechnicalCache,
bulk_insert_screening_data,
)
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("maverick_mcp.seed_sp500")
def get_database_url() -> str:
"""Get the database URL from environment or settings."""
return os.getenv("DATABASE_URL") or "sqlite:///maverick_mcp.db"
def fetch_sp500_list() -> pd.DataFrame:
"""Fetch the current S&P 500 stock list from Wikipedia."""
logger.info("Fetching S&P 500 stock list from Wikipedia...")
try:
# Read S&P 500 list from Wikipedia
tables = pd.read_html(
"https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
)
sp500_df = tables[0] # First table contains the stock list
# Clean up column names
sp500_df.columns = [
"symbol",
"company",
"gics_sector",
"gics_sub_industry",
"headquarters",
"date_added",
"cik",
"founded",
]
# Clean symbol column (remove any extra characters)
sp500_df["symbol"] = sp500_df["symbol"].str.replace(".", "-", regex=False)
logger.info(f"Fetched {len(sp500_df)} S&P 500 companies")
return sp500_df[
["symbol", "company", "gics_sector", "gics_sub_industry"]
].copy()
except Exception as e:
logger.error(f"Failed to fetch S&P 500 list from Wikipedia: {e}")
logger.info("Falling back to manually curated S&P 500 list...")
# Fallback to a curated list of major S&P 500 stocks
fallback_stocks = [
(
"AAPL",
"Apple Inc.",
"Information Technology",
"Technology Hardware, Storage & Peripherals",
),
("MSFT", "Microsoft Corporation", "Information Technology", "Software"),
(
"GOOGL",
"Alphabet Inc.",
"Communication Services",
"Interactive Media & Services",
),
(
"AMZN",
"Amazon.com Inc.",
"Consumer Discretionary",
"Internet & Direct Marketing Retail",
),
("TSLA", "Tesla Inc.", "Consumer Discretionary", "Automobiles"),
(
"NVDA",
"NVIDIA Corporation",
"Information Technology",
"Semiconductors & Semiconductor Equipment",
),
(
"META",
"Meta Platforms Inc.",
"Communication Services",
"Interactive Media & Services",
),
("BRK-B", "Berkshire Hathaway Inc.", "Financials", "Multi-Sector Holdings"),
("JNJ", "Johnson & Johnson", "Health Care", "Pharmaceuticals"),
(
"V",
"Visa Inc.",
"Information Technology",
"Data Processing & Outsourced Services",
),
# Add more major S&P 500 companies
("JPM", "JPMorgan Chase & Co.", "Financials", "Banks"),
("WMT", "Walmart Inc.", "Consumer Staples", "Food & Staples Retailing"),
("PG", "Procter & Gamble Co.", "Consumer Staples", "Household Products"),
(
"UNH",
"UnitedHealth Group Inc.",
"Health Care",
"Health Care Providers & Services",
),
(
"MA",
"Mastercard Inc.",
"Information Technology",
"Data Processing & Outsourced Services",
),
("HD", "Home Depot Inc.", "Consumer Discretionary", "Specialty Retail"),
("BAC", "Bank of America Corp.", "Financials", "Banks"),
("PFE", "Pfizer Inc.", "Health Care", "Pharmaceuticals"),
("KO", "Coca-Cola Co.", "Consumer Staples", "Beverages"),
("ABBV", "AbbVie Inc.", "Health Care", "Pharmaceuticals"),
]
fallback_df = pd.DataFrame(
fallback_stocks,
columns=["symbol", "company", "gics_sector", "gics_sub_industry"],
)
logger.info(
f"Using fallback list with {len(fallback_df)} major S&P 500 companies"
)
return fallback_df
def enrich_stock_data(symbol: str) -> dict:
"""Enrich stock data with additional information from yfinance."""
try:
ticker = yf.Ticker(symbol)
info = ticker.info
# Extract relevant information
enriched_data = {
"market_cap": info.get("marketCap"),
"shares_outstanding": info.get("sharesOutstanding"),
"description": info.get("longBusinessSummary", ""),
"country": info.get("country", "US"),
"currency": info.get("currency", "USD"),
"exchange": info.get("exchange", "NASDAQ"),
"industry": info.get("industry", ""),
"sector": info.get("sector", ""),
}
# Clean up description (limit length)
if enriched_data["description"] and len(enriched_data["description"]) > 500:
enriched_data["description"] = enriched_data["description"][:500] + "..."
return enriched_data
except Exception as e:
logger.warning(f"Failed to enrich data for {symbol}: {e}")
return {}
def create_sp500_stocks(session, sp500_df: pd.DataFrame) -> dict[str, Stock]:
"""Create S&P 500 stock records with comprehensive data."""
logger.info(f"Creating {len(sp500_df)} S&P 500 stocks...")
created_stocks = {}
batch_size = 10
for i, row in sp500_df.iterrows():
symbol = row["symbol"]
company = row["company"]
gics_sector = row["gics_sector"]
gics_sub_industry = row["gics_sub_industry"]
try:
logger.info(f"Processing {symbol} ({i + 1}/{len(sp500_df)})...")
# Rate limiting - pause every batch to be nice to APIs
if i > 0 and i % batch_size == 0:
logger.info(f"Processed {i} stocks, pausing for 2 seconds...")
time.sleep(2)
# Enrich with additional data from yfinance
enriched_data = enrich_stock_data(symbol)
# Create stock record
stock = Stock.get_or_create(
session,
ticker_symbol=symbol,
company_name=company,
sector=enriched_data.get("sector") or gics_sector or "Unknown",
industry=enriched_data.get("industry")
or gics_sub_industry
or "Unknown",
description=enriched_data.get("description")
or f"{company} - S&P 500 component",
exchange=enriched_data.get("exchange", "NASDAQ"),
country=enriched_data.get("country", "US"),
currency=enriched_data.get("currency", "USD"),
market_cap=enriched_data.get("market_cap"),
shares_outstanding=enriched_data.get("shares_outstanding"),
is_active=True,
)
created_stocks[symbol] = stock
logger.info(f"✓ Created {symbol}: {company}")
except Exception as e:
logger.error(f"✗ Error creating stock {symbol}: {e}")
continue
session.commit()
logger.info(f"Successfully created {len(created_stocks)} S&P 500 stocks")
return created_stocks
def create_sample_screening_for_sp500(session, stocks: dict[str, Stock]) -> None:
"""Create sample screening results for S&P 500 stocks."""
logger.info("Creating sample screening results for S&P 500 stocks...")
# Generate screening data based on stock symbols
screening_data = []
stock_items = list(stocks.items())
for _i, (ticker, _stock) in enumerate(stock_items):
# Use hash of ticker for consistent "random" values
ticker_hash = hash(ticker)
# Generate realistic screening metrics
base_price = 50 + (ticker_hash % 200) # Price between 50-250
momentum_score = 30 + (ticker_hash % 70) # Score 30-100
data = {
"ticker": ticker,
"close": round(base_price + (ticker_hash % 50), 2),
"volume": 500000 + (ticker_hash % 10000000), # 0.5M - 10.5M volume
"momentum_score": round(momentum_score, 2),
"combined_score": min(100, momentum_score + (ticker_hash % 20)),
"ema_21": round(base_price * 0.98, 2),
"sma_50": round(base_price * 0.96, 2),
"sma_150": round(base_price * 0.94, 2),
"sma_200": round(base_price * 0.92, 2),
"adr_pct": round(1.5 + (ticker_hash % 6), 2), # ADR 1.5-7.5%
"atr": round(2 + (ticker_hash % 8), 2),
"pattern_type": ["Breakout", "Continuation", "Reversal", "Base"][
ticker_hash % 4
],
"squeeze_status": ["No Squeeze", "Low", "Mid", "High"][ticker_hash % 4],
"consolidation_status": ["Base", "Flag", "Pennant", "Triangle"][
ticker_hash % 4
],
"entry_signal": ["Buy", "Hold", "Watch", "Caution"][ticker_hash % 4],
"compression_score": ticker_hash % 10,
"pattern_detected": 1 if ticker_hash % 3 == 0 else 0,
}
screening_data.append(data)
# Sort by combined score for different screening types
total_stocks = len(screening_data)
# Top 30% for Maverick (bullish momentum)
maverick_count = max(10, int(total_stocks * 0.3)) # At least 10 stocks
maverick_data = sorted(
screening_data, key=lambda x: x["combined_score"], reverse=True
)[:maverick_count]
maverick_count = bulk_insert_screening_data(session, MaverickStocks, maverick_data)
logger.info(f"Created {maverick_count} Maverick screening results")
# Bottom 20% for Bear stocks (weak momentum)
bear_count = max(5, int(total_stocks * 0.2)) # At least 5 stocks
bear_data = sorted(screening_data, key=lambda x: x["combined_score"])[:bear_count]
# Add bear-specific fields
for data in bear_data:
data["score"] = 100 - data["combined_score"] # Invert score
data["rsi_14"] = 70 + (hash(data["ticker"]) % 25) # Overbought RSI
data["macd"] = -0.1 - (hash(data["ticker"]) % 5) / 20 # Negative MACD
data["macd_signal"] = -0.05 - (hash(data["ticker"]) % 3) / 30
data["macd_histogram"] = data["macd"] - data["macd_signal"]
data["dist_days_20"] = hash(data["ticker"]) % 20
data["atr_contraction"] = hash(data["ticker"]) % 2 == 0
data["big_down_vol"] = hash(data["ticker"]) % 4 == 0
bear_inserted = bulk_insert_screening_data(session, MaverickBearStocks, bear_data)
logger.info(f"Created {bear_inserted} Bear screening results")
# Top 25% for Supply/Demand breakouts
breakout_count = max(8, int(total_stocks * 0.25)) # At least 8 stocks
breakout_data = sorted(
screening_data, key=lambda x: x["momentum_score"], reverse=True
)[:breakout_count]
# Add supply/demand specific fields
for data in breakout_data:
data["accumulation_rating"] = 2 + (hash(data["ticker"]) % 8) # 2-9
data["distribution_rating"] = 10 - data["accumulation_rating"]
data["breakout_strength"] = 3 + (hash(data["ticker"]) % 7) # 3-9
data["avg_volume_30d"] = data["volume"] * 1.3 # 30% above current
breakout_inserted = bulk_insert_screening_data(
session, SupplyDemandBreakoutStocks, breakout_data
)
logger.info(f"Created {breakout_inserted} Supply/Demand breakout results")
def verify_sp500_data(session) -> None:
"""Verify that S&P 500 data was seeded correctly."""
logger.info("Verifying S&P 500 seeded data...")
# Count records in each table
stock_count = session.query(Stock).count()
price_count = session.query(PriceCache).count()
maverick_count = session.query(MaverickStocks).count()
bear_count = session.query(MaverickBearStocks).count()
supply_demand_count = session.query(SupplyDemandBreakoutStocks).count()
technical_count = session.query(TechnicalCache).count()
logger.info("=== S&P 500 Data Seeding Summary ===")
logger.info(f"S&P 500 Stocks: {stock_count}")
logger.info(f"Price records: {price_count}")
logger.info(f"Maverick screening: {maverick_count}")
logger.info(f"Bear screening: {bear_count}")
logger.info(f"Supply/Demand screening: {supply_demand_count}")
logger.info(f"Technical indicators: {technical_count}")
logger.info("===================================")
# Show top stocks by sector
logger.info("\n📊 S&P 500 Stocks by Sector:")
sector_counts = session.execute(
text("""
SELECT sector, COUNT(*) as count
FROM mcp_stocks
WHERE sector IS NOT NULL
GROUP BY sector
ORDER BY count DESC
LIMIT 10
""")
).fetchall()
for sector, count in sector_counts:
logger.info(f" {sector}: {count} stocks")
# Test screening queries
if maverick_count > 0:
top_maverick = (
session.query(MaverickStocks)
.order_by(MaverickStocks.combined_score.desc())
.first()
)
if top_maverick and top_maverick.stock:
logger.info(
f"\n🚀 Top Maverick (Bullish): {top_maverick.stock.ticker_symbol} (Score: {top_maverick.combined_score})"
)
if bear_count > 0:
top_bear = (
session.query(MaverickBearStocks)
.order_by(MaverickBearStocks.score.desc())
.first()
)
if top_bear and top_bear.stock:
logger.info(
f"🐻 Top Bear: {top_bear.stock.ticker_symbol} (Score: {top_bear.score})"
)
if supply_demand_count > 0:
top_breakout = (
session.query(SupplyDemandBreakoutStocks)
.order_by(SupplyDemandBreakoutStocks.breakout_strength.desc())
.first()
)
if top_breakout and top_breakout.stock:
logger.info(
f"📈 Top Breakout: {top_breakout.stock.ticker_symbol} (Strength: {top_breakout.breakout_strength})"
)
def main():
"""Main S&P 500 seeding function."""
logger.info("🚀 Starting S&P 500 database seeding for MaverickMCP...")
# Set up database connection
database_url = get_database_url()
logger.info(f"Using database: {database_url}")
engine = create_engine(database_url, echo=False)
SessionLocal = sessionmaker(bind=engine)
with SessionLocal() as session:
try:
# Fetch S&P 500 stock list
sp500_df = fetch_sp500_list()
if sp500_df.empty:
logger.error("No S&P 500 stocks found. Exiting.")
return False
# Create S&P 500 stocks with comprehensive data
stocks = create_sp500_stocks(session, sp500_df)
if not stocks:
logger.error("No S&P 500 stocks created. Exiting.")
return False
# Create screening results for S&P 500 stocks
create_sample_screening_for_sp500(session, stocks)
# Verify data
verify_sp500_data(session)
logger.info("🎉 S&P 500 database seeding completed successfully!")
logger.info(f"📈 Added {len(stocks)} S&P 500 companies to the database")
logger.info("\n🔧 Next steps:")
logger.info("1. Run 'make dev' to start the MCP server")
logger.info("2. Connect with Claude Desktop using mcp-remote")
logger.info("3. Test with: 'Show me top S&P 500 momentum stocks'")
return True
except Exception as e:
logger.error(f"S&P 500 database seeding failed: {e}")
session.rollback()
raise
if __name__ == "__main__":
try:
success = main()
if not success:
sys.exit(1)
except KeyboardInterrupt:
logger.info("\n⏹️ Seeding interrupted by user")
sys.exit(1)
except Exception as e:
logger.error(f"❌ Fatal error: {e}")
sys.exit(1)
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/monitoring_middleware.py:
--------------------------------------------------------------------------------
```python
"""
Monitoring middleware for FastMCP and FastAPI applications.
This module provides comprehensive monitoring middleware that automatically:
- Tracks request metrics (count, duration, size)
- Creates distributed traces for all requests
- Monitors database and cache operations
- Tracks business metrics and user behavior
- Integrates with Prometheus and OpenTelemetry
"""
import time
from collections.abc import Callable
from typing import Any
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from maverick_mcp.utils.logging import get_logger, user_id_var
from maverick_mcp.utils.monitoring import (
active_connections,
concurrent_requests,
request_counter,
request_duration,
request_size_bytes,
response_size_bytes,
track_authentication,
track_rate_limit_hit,
track_security_violation,
update_performance_metrics,
)
from maverick_mcp.utils.tracing import get_tracing_service, trace_operation
logger = get_logger(__name__)
class MonitoringMiddleware(BaseHTTPMiddleware):
"""
Comprehensive monitoring middleware for FastAPI applications.
Automatically tracks:
- Request/response metrics
- Distributed tracing
- Performance monitoring
- Security events
- Business metrics
"""
def __init__(self, app, enable_detailed_logging: bool = True):
super().__init__(app)
self.enable_detailed_logging = enable_detailed_logging
self.tracing = get_tracing_service()
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request with comprehensive monitoring."""
# Start timing
start_time = time.time()
# Track active connections
active_connections.inc()
concurrent_requests.inc()
# Extract request information
method = request.method
path = request.url.path
endpoint = self._normalize_endpoint(path)
user_agent = request.headers.get("user-agent", "unknown")
# Calculate request size
content_length = request.headers.get("content-length")
req_size = int(content_length) if content_length else 0
# Extract user information for monitoring
user_id = self._extract_user_id(request)
user_type = self._determine_user_type(request, user_id)
# Set context variables for logging
if user_id:
user_id_var.set(str(user_id))
response = None
status_code = 500
error_type = None
# Create tracing span for the entire request
with trace_operation(
f"{method} {endpoint}",
attributes={
"http.method": method,
"http.route": endpoint,
"http.user_agent": user_agent[:100], # Truncate long user agents
"user.id": str(user_id) if user_id else "anonymous",
"user.type": user_type,
"http.request_size": req_size,
},
) as span:
try:
# Process the request
response = await call_next(request)
status_code = response.status_code
# Track successful request
if span:
span.set_attribute("http.status_code", status_code)
span.set_attribute(
"http.response_size", self._get_response_size(response)
)
# Track authentication events
if self._is_auth_endpoint(endpoint):
auth_status = "success" if 200 <= status_code < 300 else "failure"
track_authentication(
method="bearer_token",
status=auth_status,
user_agent=user_agent[:50],
)
# Track rate limiting
if status_code == 429:
track_rate_limit_hit(
user_id=str(user_id) if user_id else "anonymous",
endpoint=endpoint,
limit_type="request_rate",
)
except Exception as e:
error_type = type(e).__name__
status_code = 500
# Record exception in trace
if span:
span.record_exception(e)
span.set_attribute("error", True)
span.set_attribute("error.type", error_type)
# Track security violations for certain errors
if self._is_security_error(e):
track_security_violation(
violation_type=error_type,
severity="high" if status_code >= 400 else "medium",
)
# Re-raise the exception
raise
finally:
# Calculate duration
duration = time.time() - start_time
# Determine final status for metrics
final_status = "success" if 200 <= status_code < 400 else "error"
# Track request metrics
request_counter.labels(
method=method,
endpoint=endpoint,
status=final_status,
user_type=user_type,
).inc()
request_duration.labels(
method=method, endpoint=endpoint, user_type=user_type
).observe(duration)
# Track request/response sizes
if req_size > 0:
request_size_bytes.labels(method=method, endpoint=endpoint).observe(
req_size
)
if response:
resp_size = self._get_response_size(response)
if resp_size > 0:
response_size_bytes.labels(
method=method, endpoint=endpoint, status=str(status_code)
).observe(resp_size)
# Update performance metrics periodically
if int(time.time()) % 30 == 0: # Every 30 seconds
try:
update_performance_metrics()
except Exception as e:
logger.warning(f"Failed to update performance metrics: {e}")
# Log detailed request information
if self.enable_detailed_logging:
self._log_request_details(
method, endpoint, status_code, duration, user_id, error_type
)
# Update connection counters
active_connections.dec()
concurrent_requests.dec()
return response
def _normalize_endpoint(self, path: str) -> str:
"""Normalize endpoint path for metrics (replace IDs with placeholders)."""
# Replace UUIDs and IDs in paths
import re
# Replace UUID patterns
path = re.sub(r"/[a-f0-9-]{36}", "/{uuid}", path)
# Replace numeric IDs
path = re.sub(r"/\d+", "/{id}", path)
# Replace API keys or tokens
path = re.sub(r"/[a-zA-Z0-9]{20,}", "/{token}", path)
return path
def _extract_user_id(self, request: Request) -> str | None:
"""Extract user ID from request (from JWT, session, etc.)."""
# Check Authorization header
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
try:
# In a real implementation, you'd decode the JWT
# For now, we'll check if there's a user context
if hasattr(request.state, "user_id"):
return request.state.user_id
except Exception:
pass
# Check for user ID in path parameters
if hasattr(request, "path_params") and "user_id" in request.path_params:
return request.path_params["user_id"]
return None
def _determine_user_type(self, request: Request, user_id: str | None) -> str:
"""Determine user type for metrics."""
if not user_id:
return "anonymous"
# Check if it's an admin user (you'd implement your own logic)
if hasattr(request.state, "user_role"):
return request.state.user_role
# Check for API key usage
if request.headers.get("x-api-key"):
return "api_user"
return "authenticated"
def _is_auth_endpoint(self, endpoint: str) -> bool:
"""Check if endpoint is authentication-related."""
auth_endpoints = ["/login", "/auth", "/token", "/signup", "/register"]
return any(auth_ep in endpoint for auth_ep in auth_endpoints)
def _is_security_error(self, exception: Exception) -> bool:
"""Check if exception indicates a security issue."""
security_errors = [
"PermissionError",
"Unauthorized",
"Forbidden",
"ValidationError",
"SecurityError",
]
return any(error in str(type(exception)) for error in security_errors)
def _get_response_size(self, response: Response) -> int:
"""Calculate response size in bytes."""
content_length = response.headers.get("content-length")
if content_length:
return int(content_length)
# Estimate size if content-length is not set
if hasattr(response, "body") and response.body:
return len(response.body)
return 0
def _log_request_details(
self,
method: str,
endpoint: str,
status_code: int,
duration: float,
user_id: str | None,
error_type: str | None,
):
"""Log detailed request information."""
log_data = {
"http_method": method,
"endpoint": endpoint,
"status_code": status_code,
"duration_ms": int(duration * 1000),
"user_id": str(user_id) if user_id else None,
}
if error_type:
log_data["error_type"] = error_type
if status_code >= 400:
logger.warning(f"HTTP {status_code}: {method} {endpoint}", extra=log_data)
else:
logger.info(f"HTTP {status_code}: {method} {endpoint}", extra=log_data)
class MCPToolMonitoringWrapper:
"""
Wrapper for MCP tools to add monitoring capabilities.
This class wraps MCP tool execution to automatically:
- Track tool usage metrics
- Create distributed traces
- Monitor performance
"""
def __init__(self, enable_tracing: bool = True):
self.enable_tracing = enable_tracing
self.tracing = get_tracing_service()
def monitor_tool(self, tool_func: Callable) -> Callable:
"""
Decorator to add monitoring to MCP tools.
Args:
tool_func: The MCP tool function to monitor
Returns:
Wrapped function with monitoring
"""
from functools import wraps
@wraps(tool_func)
async def wrapper(*args, **kwargs):
tool_name = tool_func.__name__
start_time = time.time()
# Extract user context from args
user_id = None
for arg in args:
if hasattr(arg, "user_id"):
user_id = arg.user_id
break
# Check if it's an MCP context
if hasattr(arg, "__class__") and "Context" in arg.__class__.__name__:
# Extract user from context if available
if hasattr(arg, "user_id"):
user_id = arg.user_id
# Set context for logging
if user_id:
user_id_var.set(str(user_id))
# Create tracing span
with trace_operation(
f"tool.{tool_name}",
attributes={
"tool.name": tool_name,
"user.id": str(user_id) if user_id else "anonymous",
"tool.args_count": len(args),
"tool.kwargs_count": len(kwargs),
},
) as span:
try:
# Execute the tool
result = await tool_func(*args, **kwargs)
# Calculate execution time
duration = time.time() - start_time
# Track successful execution
from maverick_mcp.utils.monitoring import track_tool_usage
track_tool_usage(
tool_name=tool_name,
user_id=str(user_id) if user_id else "anonymous",
duration=duration,
status="success",
complexity=self._determine_complexity(tool_name, kwargs),
)
# Add attributes to span
if span:
span.set_attribute("tool.duration_seconds", duration)
span.set_attribute("tool.success", True)
span.set_attribute("tool.result_size", len(str(result)))
# Add usage information to result if it's a dict
if isinstance(result, dict):
result["_monitoring"] = {
"execution_time_ms": int(duration * 1000),
"tool_name": tool_name,
"timestamp": time.time(),
}
return result
except Exception as e:
# Calculate execution time
duration = time.time() - start_time
error_type = type(e).__name__
# Track failed execution
from maverick_mcp.utils.monitoring import track_tool_error
track_tool_error(
tool_name=tool_name,
error_type=error_type,
complexity=self._determine_complexity(tool_name, kwargs),
)
# Add error attributes to span
if span:
span.set_attribute("tool.duration_seconds", duration)
span.set_attribute("tool.success", False)
span.set_attribute("error.type", error_type)
span.record_exception(e)
logger.error(
f"Tool execution failed: {tool_name}",
extra={
"tool_name": tool_name,
"user_id": str(user_id) if user_id else None,
"duration_ms": int(duration * 1000),
"error_type": error_type,
},
exc_info=True,
)
# Re-raise the exception
raise
return wrapper
def _determine_complexity(self, tool_name: str, kwargs: dict[str, Any]) -> str:
"""Determine tool complexity based on parameters."""
# Simple heuristics for complexity
if "limit" in kwargs:
limit = kwargs.get("limit", 0)
if limit > 100:
return "high"
elif limit > 50:
return "medium"
if "symbols" in kwargs:
symbols = kwargs.get("symbols", [])
if isinstance(symbols, list) and len(symbols) > 10:
return "high"
elif isinstance(symbols, list) and len(symbols) > 5:
return "medium"
# Check for complex analysis tools
complex_tools = [
"get_portfolio_optimization",
"get_market_analysis",
"screen_stocks",
]
if any(complex_tool in tool_name for complex_tool in complex_tools):
return "high"
return "standard"
def create_monitoring_middleware(
enable_detailed_logging: bool = True,
) -> MonitoringMiddleware:
"""Create a monitoring middleware instance."""
return MonitoringMiddleware(enable_detailed_logging=enable_detailed_logging)
def create_tool_monitor(enable_tracing: bool = True) -> MCPToolMonitoringWrapper:
"""Create a tool monitoring wrapper instance."""
return MCPToolMonitoringWrapper(enable_tracing=enable_tracing)
# Global instances
_monitoring_middleware: MonitoringMiddleware | None = None
_tool_monitor: MCPToolMonitoringWrapper | None = None
def get_monitoring_middleware() -> MonitoringMiddleware:
"""Get or create the global monitoring middleware."""
global _monitoring_middleware
if _monitoring_middleware is None:
_monitoring_middleware = create_monitoring_middleware()
return _monitoring_middleware
def get_tool_monitor() -> MCPToolMonitoringWrapper:
"""Get or create the global tool monitor."""
global _tool_monitor
if _tool_monitor is None:
_tool_monitor = create_tool_monitor()
return _tool_monitor
```
--------------------------------------------------------------------------------
/examples/llm_speed_demo.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Focused LLM Speed Optimization Demonstration
This script demonstrates the core LLM optimization capabilities that provide
2-3x speed improvements, focusing on areas we can control directly.
Demonstrates:
- Adaptive model selection based on time constraints
- Fast model execution (Gemini 2.5 Flash)
- Token generation speed optimization
- Progressive timeout management
- Model performance comparison
"""
import asyncio
import os
import sys
import time
from datetime import datetime
from typing import Any
# Add the project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from maverick_mcp.providers.openrouter_provider import OpenRouterProvider, TaskType
from maverick_mcp.utils.llm_optimization import AdaptiveModelSelector
class LLMSpeedDemonstrator:
"""Focused demonstration of LLM speed optimizations."""
def __init__(self):
"""Initialize the demonstration."""
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise ValueError(
"OPENROUTER_API_KEY environment variable is required. "
"Please set it with your OpenRouter API key."
)
self.openrouter_provider = OpenRouterProvider(api_key=api_key)
self.model_selector = AdaptiveModelSelector(self.openrouter_provider)
# Test scenarios focused on different urgency levels
self.test_scenarios = [
{
"name": "Emergency Analysis (Critical Speed)",
"prompt": "Analyze NVIDIA's latest earnings impact on AI market sentiment. 2-3 key points only.",
"time_budget": 15.0,
"task_type": TaskType.QUICK_ANSWER,
"expected_speed": ">100 tok/s",
},
{
"name": "Technical Analysis (Fast Response)",
"prompt": "Provide technical analysis of Apple stock including RSI, MACD, and support levels. Be concise.",
"time_budget": 30.0,
"task_type": TaskType.TECHNICAL_ANALYSIS,
"expected_speed": ">80 tok/s",
},
{
"name": "Market Research (Moderate Speed)",
"prompt": "Analyze Federal Reserve interest rate policy impact on technology sector. Include risk assessment.",
"time_budget": 45.0,
"task_type": TaskType.MARKET_ANALYSIS,
"expected_speed": ">60 tok/s",
},
{
"name": "Complex Synthesis (Quality Balance)",
"prompt": "Synthesize renewable energy investment opportunities for 2025, considering policy changes, technology advances, and market trends.",
"time_budget": 60.0,
"task_type": TaskType.RESULT_SYNTHESIS,
"expected_speed": ">40 tok/s",
},
]
def print_header(self, title: str):
"""Print formatted header."""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def print_subheader(self, title: str):
"""Print formatted subheader."""
print(f"\n--- {title} ---")
async def validate_openrouter_connection(self) -> bool:
"""Validate OpenRouter API is accessible."""
self.print_header("🔧 API VALIDATION")
try:
test_llm = self.openrouter_provider.get_llm(TaskType.GENERAL)
from langchain_core.messages import HumanMessage
test_response = await asyncio.wait_for(
test_llm.ainvoke([HumanMessage(content="test connection")]),
timeout=10.0,
)
print("✅ OpenRouter API: Connected successfully")
print(f" Response length: {len(test_response.content)} chars")
return True
except Exception as e:
print(f"❌ OpenRouter API: Failed - {e}")
return False
async def demonstrate_model_selection(self):
"""Show intelligent model selection for different scenarios."""
self.print_header("🧠 ADAPTIVE MODEL SELECTION")
for scenario in self.test_scenarios:
print(f"\n📋 Scenario: {scenario['name']}")
print(f" Time Budget: {scenario['time_budget']}s")
print(f" Task Type: {scenario['task_type'].value}")
print(f" Expected Speed: {scenario['expected_speed']}")
# Calculate task complexity
complexity = self.model_selector.calculate_task_complexity(
content=scenario["prompt"],
task_type=scenario["task_type"],
focus_areas=["analysis"],
)
# Get optimal model for time budget
model_config = self.model_selector.select_model_for_time_budget(
task_type=scenario["task_type"],
time_remaining_seconds=scenario["time_budget"],
complexity_score=complexity,
content_size_tokens=len(scenario["prompt"]) // 4,
)
print(f" 📊 Complexity Score: {complexity:.2f}")
print(f" 🎯 Selected Model: {model_config.model_id}")
print(f" ⏱️ Max Timeout: {model_config.timeout_seconds}s")
print(f" 🌡️ Temperature: {model_config.temperature}")
print(f" 📝 Max Tokens: {model_config.max_tokens}")
# Check if speed-optimized
is_speed_model = any(
x in model_config.model_id.lower()
for x in ["flash", "haiku", "4o-mini", "deepseek"]
)
print(f" 🚀 Speed Optimized: {'✅' if is_speed_model else '❌'}")
async def run_speed_benchmarks(self):
"""Run actual speed benchmarks for each scenario."""
self.print_header("⚡ LIVE SPEED BENCHMARKS")
results = []
baseline_time = 60.0 # Historical baseline from timeout issues
for i, scenario in enumerate(self.test_scenarios, 1):
print(f"\n🔍 Benchmark {i}/{len(self.test_scenarios)}: {scenario['name']}")
print(f" Query: {scenario['prompt'][:60]}...")
try:
# Get optimal model configuration
complexity = self.model_selector.calculate_task_complexity(
content=scenario["prompt"],
task_type=scenario["task_type"],
)
model_config = self.model_selector.select_model_for_time_budget(
task_type=scenario["task_type"],
time_remaining_seconds=scenario["time_budget"],
complexity_score=complexity,
content_size_tokens=len(scenario["prompt"]) // 4,
)
# Execute with timing
llm = self.openrouter_provider.get_llm(
model_override=model_config.model_id,
temperature=model_config.temperature,
max_tokens=model_config.max_tokens,
)
start_time = time.time()
from langchain_core.messages import HumanMessage
response = await asyncio.wait_for(
llm.ainvoke([HumanMessage(content=scenario["prompt"])]),
timeout=model_config.timeout_seconds,
)
execution_time = time.time() - start_time
# Calculate metrics
response_length = len(response.content)
estimated_tokens = response_length // 4
tokens_per_second = (
estimated_tokens / execution_time if execution_time > 0 else 0
)
speed_improvement = (
baseline_time / execution_time if execution_time > 0 else 0
)
# Results
result = {
"scenario": scenario["name"],
"model_used": model_config.model_id,
"execution_time": execution_time,
"time_budget": scenario["time_budget"],
"budget_used_pct": (execution_time / scenario["time_budget"]) * 100,
"tokens_per_second": tokens_per_second,
"response_length": response_length,
"speed_improvement": speed_improvement,
"target_achieved": execution_time <= scenario["time_budget"],
"response_preview": response.content[:150] + "..."
if len(response.content) > 150
else response.content,
}
results.append(result)
# Print immediate results
status_icon = "✅" if result["target_achieved"] else "⚠️"
print(
f" {status_icon} Completed: {execution_time:.2f}s ({result['budget_used_pct']:.1f}% of budget)"
)
print(f" 🎯 Model: {model_config.model_id}")
print(f" 🚀 Speed: {tokens_per_second:.0f} tok/s")
print(
f" 📊 Improvement: {speed_improvement:.1f}x faster than baseline"
)
print(f" 💬 Preview: {result['response_preview']}")
# Brief pause between tests
await asyncio.sleep(1)
except Exception as e:
print(f" ❌ Failed: {str(e)}")
results.append(
{
"scenario": scenario["name"],
"error": str(e),
"target_achieved": False,
}
)
return results
def analyze_benchmark_results(self, results: list[dict[str, Any]]):
"""Analyze and report benchmark results."""
self.print_header("📊 SPEED OPTIMIZATION ANALYSIS")
successful_tests = [r for r in results if not r.get("error")]
failed_tests = [r for r in results if r.get("error")]
targets_achieved = [r for r in successful_tests if r.get("target_achieved")]
print("📈 Overall Performance:")
print(f" Total Tests: {len(results)}")
print(f" Successful: {len(successful_tests)}")
print(f" Failed: {len(failed_tests)}")
print(f" Targets Hit: {len(targets_achieved)}/{len(results)}")
print(f" Success Rate: {(len(targets_achieved) / len(results) * 100):.1f}%")
if successful_tests:
# Speed metrics
avg_execution_time = sum(
r["execution_time"] for r in successful_tests
) / len(successful_tests)
max_execution_time = max(r["execution_time"] for r in successful_tests)
avg_tokens_per_second = sum(
r["tokens_per_second"] for r in successful_tests
) / len(successful_tests)
avg_speed_improvement = sum(
r["speed_improvement"] for r in successful_tests
) / len(successful_tests)
print("\n⚡ Speed Metrics:")
print(f" Average Execution Time: {avg_execution_time:.2f}s")
print(f" Maximum Execution Time: {max_execution_time:.2f}s")
print(f" Average Token Generation: {avg_tokens_per_second:.0f} tok/s")
print(f" Average Speed Improvement: {avg_speed_improvement:.1f}x")
# Historical comparison
historical_baseline = 60.0 # Average timeout failure time
if max_execution_time > 0:
overall_improvement = historical_baseline / max_execution_time
print("\n🎯 Speed Validation:")
print(
f" Historical Baseline: {historical_baseline}s (timeout failures)"
)
print(f" Current Max Time: {max_execution_time:.2f}s")
print(f" Overall Improvement: {overall_improvement:.1f}x")
if overall_improvement >= 3.0:
print(
f" 🎉 EXCELLENT: {overall_improvement:.1f}x speed improvement!"
)
elif overall_improvement >= 2.0:
print(
f" ✅ SUCCESS: {overall_improvement:.1f}x speed improvement achieved!"
)
elif overall_improvement >= 1.5:
print(
f" 👍 GOOD: {overall_improvement:.1f}x improvement (target: 2x)"
)
else:
print(
f" ⚠️ NEEDS WORK: Only {overall_improvement:.1f}x improvement"
)
# Model performance breakdown
self.print_subheader("🧠 MODEL PERFORMANCE BREAKDOWN")
model_stats = {}
for result in successful_tests:
model = result["model_used"]
if model not in model_stats:
model_stats[model] = []
model_stats[model].append(result)
for model, model_results in model_stats.items():
avg_speed = sum(r["tokens_per_second"] for r in model_results) / len(
model_results
)
avg_time = sum(r["execution_time"] for r in model_results) / len(
model_results
)
success_rate = (
len([r for r in model_results if r["target_achieved"]])
/ len(model_results)
) * 100
print(f" {model}:")
print(f" Tests: {len(model_results)}")
print(f" Avg Speed: {avg_speed:.0f} tok/s")
print(f" Avg Time: {avg_time:.2f}s")
print(f" Success Rate: {success_rate:.0f}%")
async def run_comprehensive_demo(self):
"""Run the complete LLM speed demonstration."""
print("🚀 MaverickMCP LLM Speed Optimization Demonstration")
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("🎯 Goal: Demonstrate 2-3x LLM speed improvements")
# Step 1: Validate connection
if not await self.validate_openrouter_connection():
print("\n❌ Cannot proceed - API connection failed")
return False
# Step 2: Show model selection intelligence
await self.demonstrate_model_selection()
# Step 3: Run live speed benchmarks
results = await self.run_speed_benchmarks()
# Step 4: Analyze results
self.analyze_benchmark_results(results)
# Final summary
self.print_header("🎉 DEMONSTRATION SUMMARY")
successful_tests = [r for r in results if not r.get("error")]
targets_achieved = [r for r in successful_tests if r.get("target_achieved")]
print("✅ LLM Speed Optimization Results:")
print(f" Tests Executed: {len(results)}")
print(f" Successful: {len(successful_tests)}")
print(f" Targets Achieved: {len(targets_achieved)}")
print(f" Success Rate: {(len(targets_achieved) / len(results) * 100):.1f}%")
if successful_tests:
max_time = max(r["execution_time"] for r in successful_tests)
avg_speed = sum(r["tokens_per_second"] for r in successful_tests) / len(
successful_tests
)
speed_improvement = 60.0 / max_time if max_time > 0 else 0
print(
f" Fastest Response: {min(r['execution_time'] for r in successful_tests):.2f}s"
)
print(f" Average Token Speed: {avg_speed:.0f} tok/s")
print(f" Speed Improvement: {speed_improvement:.1f}x faster")
print("\n📊 Key Optimizations Demonstrated:")
print(" ✅ Adaptive Model Selection (context-aware)")
print(" ✅ Time-Budget Optimization")
print(" ✅ Fast Model Utilization (Gemini Flash, Claude Haiku)")
print(" ✅ Progressive Timeout Management")
print(" ✅ Token Generation Speed Optimization")
# Success criteria: at least 75% success rate and 2x improvement
success_criteria = len(targets_achieved) >= len(results) * 0.75 and (
successful_tests
and 60.0 / max(r["execution_time"] for r in successful_tests) >= 2.0
)
return success_criteria
async def main():
"""Main demonstration entry point."""
demo = LLMSpeedDemonstrator()
try:
success = await demo.run_comprehensive_demo()
if success:
print("\n🎉 LLM Speed Demonstration PASSED - Optimizations validated!")
return 0
else:
print("\n⚠️ Demonstration had mixed results - review analysis above")
return 1
except KeyboardInterrupt:
print("\n\n⏹️ Demonstration interrupted by user")
return 130
except Exception as e:
print(f"\n💥 Demonstration failed with error: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
# Check required environment variables
if not os.getenv("OPENROUTER_API_KEY"):
print("❌ Missing OPENROUTER_API_KEY environment variable")
print("Please check your .env file")
sys.exit(1)
# Run the demonstration
exit_code = asyncio.run(main())
sys.exit(exit_code)
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/resource_manager.py:
--------------------------------------------------------------------------------
```python
"""
Resource management utilities for the backtesting system.
Handles memory limits, resource cleanup, and system resource monitoring.
"""
import asyncio
import logging
import os
import resource
import signal
import threading
import time
from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Optional
import psutil
from maverick_mcp.utils.memory_profiler import (
check_memory_leak,
force_garbage_collection,
get_memory_stats,
)
logger = logging.getLogger(__name__)
# Resource limits (in bytes)
DEFAULT_MEMORY_LIMIT = 2 * 1024 * 1024 * 1024 # 2GB
DEFAULT_SWAP_LIMIT = 4 * 1024 * 1024 * 1024 # 4GB
CRITICAL_MEMORY_THRESHOLD = 0.9 # 90% of limit
# Global resource manager instance
_resource_manager: Optional["ResourceManager"] = None
@dataclass
class ResourceLimits:
"""Resource limit configuration."""
memory_limit_bytes: int = DEFAULT_MEMORY_LIMIT
swap_limit_bytes: int = DEFAULT_SWAP_LIMIT
cpu_time_limit_seconds: int = 3600 # 1 hour
file_descriptor_limit: int = 1024
enable_memory_monitoring: bool = True
enable_cpu_monitoring: bool = True
cleanup_interval_seconds: int = 60
@dataclass
class ResourceUsage:
"""Current resource usage snapshot."""
memory_rss_bytes: int
memory_vms_bytes: int
memory_percent: float
cpu_percent: float
open_files: int
threads: int
timestamp: float
class ResourceExhaustionError(Exception):
"""Raised when resource limits are exceeded."""
pass
class ResourceManager:
"""System resource manager with limits and cleanup."""
def __init__(self, limits: ResourceLimits = None):
"""Initialize resource manager.
Args:
limits: Resource limits configuration
"""
self.limits = limits or ResourceLimits()
self.process = psutil.Process()
self.monitoring_active = False
self.cleanup_callbacks: list[Callable[[], None]] = []
self.resource_history: list[ResourceUsage] = []
self.max_history_size = 100
# Setup signal handlers for graceful shutdown
self._setup_signal_handlers()
# Start monitoring if enabled
if self.limits.enable_memory_monitoring or self.limits.enable_cpu_monitoring:
self.start_monitoring()
def _setup_signal_handlers(self):
"""Setup signal handlers for resource cleanup."""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, performing cleanup")
self.cleanup_all()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
def start_monitoring(self):
"""Start resource monitoring in background thread."""
if self.monitoring_active:
return
self.monitoring_active = True
monitor_thread = threading.Thread(target=self._monitor_resources, daemon=True)
monitor_thread.start()
logger.info("Resource monitoring started")
def stop_monitoring(self):
"""Stop resource monitoring."""
self.monitoring_active = False
logger.info("Resource monitoring stopped")
def _monitor_resources(self):
"""Background resource monitoring loop."""
while self.monitoring_active:
try:
usage = self.get_current_usage()
self.resource_history.append(usage)
# Keep history size manageable
if len(self.resource_history) > self.max_history_size:
self.resource_history.pop(0)
# Check limits and trigger cleanup if needed
self._check_resource_limits(usage)
time.sleep(self.limits.cleanup_interval_seconds)
except Exception as e:
logger.error(f"Error in resource monitoring: {e}")
time.sleep(30) # Back off on errors
def get_current_usage(self) -> ResourceUsage:
"""Get current resource usage."""
try:
memory_info = self.process.memory_info()
cpu_percent = self.process.cpu_percent()
# Get open files count safely
try:
open_files = len(self.process.open_files())
except (psutil.AccessDenied, psutil.NoSuchProcess):
open_files = 0
# Get thread count safely
try:
threads = self.process.num_threads()
except (psutil.AccessDenied, psutil.NoSuchProcess):
threads = 0
return ResourceUsage(
memory_rss_bytes=memory_info.rss,
memory_vms_bytes=memory_info.vms,
memory_percent=self.process.memory_percent(),
cpu_percent=cpu_percent,
open_files=open_files,
threads=threads,
timestamp=time.time(),
)
except Exception as e:
logger.error(f"Error getting resource usage: {e}")
return ResourceUsage(0, 0, 0, 0, 0, 0, time.time())
def _check_resource_limits(self, usage: ResourceUsage):
"""Check if resource limits are exceeded and take action."""
# Memory limit check
if (
usage.memory_rss_bytes
> self.limits.memory_limit_bytes * CRITICAL_MEMORY_THRESHOLD
):
logger.warning(
f"Memory usage {usage.memory_rss_bytes / (1024**3):.2f}GB "
f"approaching limit {self.limits.memory_limit_bytes / (1024**3):.2f}GB"
)
self._trigger_emergency_cleanup()
if usage.memory_rss_bytes > self.limits.memory_limit_bytes:
logger.critical(
f"Memory limit exceeded: {usage.memory_rss_bytes / (1024**3):.2f}GB "
f"> {self.limits.memory_limit_bytes / (1024**3):.2f}GB"
)
raise ResourceExhaustionError("Memory limit exceeded")
# File descriptor check
if usage.open_files > self.limits.file_descriptor_limit * 0.9:
logger.warning(f"High file descriptor usage: {usage.open_files}")
self._close_unused_files()
def _trigger_emergency_cleanup(self):
"""Trigger emergency resource cleanup."""
logger.info("Triggering emergency resource cleanup")
# Force garbage collection
force_garbage_collection()
# Run cleanup callbacks
for callback in self.cleanup_callbacks:
try:
callback()
except Exception as e:
logger.error(f"Error in cleanup callback: {e}")
# Clear memory profiler snapshots
try:
from maverick_mcp.utils.memory_profiler import reset_memory_stats
reset_memory_stats()
except ImportError:
pass
# Clear cache if available
try:
from maverick_mcp.data.cache import clear_cache
clear_cache()
except ImportError:
pass
def _close_unused_files(self):
"""Close unused file descriptors."""
try:
# Get current open files
open_files = self.process.open_files()
logger.debug(f"Found {len(open_files)} open files")
# Note: We can't automatically close files as that might break the application
# This is mainly for monitoring and alerting
for file_info in open_files:
logger.debug(f"Open file: {file_info.path}")
except Exception as e:
logger.debug(f"Could not enumerate open files: {e}")
def add_cleanup_callback(self, callback: Callable[[], None]):
"""Add a cleanup callback function."""
self.cleanup_callbacks.append(callback)
def cleanup_all(self):
"""Run all cleanup callbacks and garbage collection."""
logger.info("Running comprehensive resource cleanup")
# Run cleanup callbacks
for callback in self.cleanup_callbacks:
try:
callback()
except Exception as e:
logger.error(f"Error in cleanup callback: {e}")
# Force garbage collection
force_garbage_collection()
# Log final resource usage
usage = self.get_current_usage()
logger.info(
f"Post-cleanup usage: {usage.memory_rss_bytes / (1024**2):.2f}MB memory, "
f"{usage.open_files} files, {usage.threads} threads"
)
def get_resource_report(self) -> dict[str, Any]:
"""Get comprehensive resource usage report."""
current = self.get_current_usage()
report = {
"current_usage": {
"memory_mb": current.memory_rss_bytes / (1024**2),
"memory_percent": current.memory_percent,
"cpu_percent": current.cpu_percent,
"open_files": current.open_files,
"threads": current.threads,
},
"limits": {
"memory_limit_mb": self.limits.memory_limit_bytes / (1024**2),
"memory_usage_ratio": current.memory_rss_bytes
/ self.limits.memory_limit_bytes,
"file_descriptor_limit": self.limits.file_descriptor_limit,
},
"monitoring": {
"active": self.monitoring_active,
"history_size": len(self.resource_history),
"cleanup_callbacks": len(self.cleanup_callbacks),
},
}
# Add memory profiler stats if available
try:
memory_stats = get_memory_stats()
report["memory_profiler"] = memory_stats
except Exception:
pass
return report
def set_memory_limit(self, limit_bytes: int):
"""Set memory limit for the process."""
try:
# Set soft and hard memory limits
resource.setrlimit(resource.RLIMIT_AS, (limit_bytes, limit_bytes))
self.limits.memory_limit_bytes = limit_bytes
logger.info(f"Memory limit set to {limit_bytes / (1024**3):.2f}GB")
except Exception as e:
logger.warning(f"Could not set memory limit: {e}")
def check_memory_health(self) -> dict[str, Any]:
"""Check memory health and detect potential issues."""
health_report = {
"status": "healthy",
"issues": [],
"recommendations": [],
}
current = self.get_current_usage()
# Check memory usage
usage_ratio = current.memory_rss_bytes / self.limits.memory_limit_bytes
if usage_ratio > 0.9:
health_report["status"] = "critical"
health_report["issues"].append(f"Memory usage at {usage_ratio:.1%}")
health_report["recommendations"].append("Trigger immediate cleanup")
elif usage_ratio > 0.7:
health_report["status"] = "warning"
health_report["issues"].append(f"High memory usage at {usage_ratio:.1%}")
health_report["recommendations"].append("Consider cleanup")
# Check for memory leaks
if check_memory_leak(threshold_mb=100.0):
health_report["status"] = "warning"
health_report["issues"].append("Potential memory leak detected")
health_report["recommendations"].append("Review memory profiler logs")
# Check file descriptor usage
fd_ratio = current.open_files / self.limits.file_descriptor_limit
if fd_ratio > 0.8:
health_report["status"] = "warning"
health_report["issues"].append(
f"High file descriptor usage: {current.open_files}"
)
health_report["recommendations"].append("Review open files")
return health_report
@contextmanager
def resource_limit_context(
memory_limit_mb: int = None,
cpu_limit_percent: float = None,
cleanup_on_exit: bool = True,
):
"""Context manager for resource-limited operations.
Args:
memory_limit_mb: Memory limit in MB
cpu_limit_percent: CPU limit as percentage
cleanup_on_exit: Whether to cleanup on exit
Yields:
ResourceManager instance
"""
limits = ResourceLimits()
if memory_limit_mb:
limits.memory_limit_bytes = memory_limit_mb * 1024 * 1024
manager = ResourceManager(limits)
try:
yield manager
finally:
if cleanup_on_exit:
manager.cleanup_all()
manager.stop_monitoring()
def get_resource_manager() -> ResourceManager:
"""Get or create global resource manager instance."""
global _resource_manager
if _resource_manager is None:
_resource_manager = ResourceManager()
return _resource_manager
def set_process_memory_limit(limit_gb: float):
"""Set memory limit for current process.
Args:
limit_gb: Memory limit in gigabytes
"""
limit_bytes = int(limit_gb * 1024 * 1024 * 1024)
manager = get_resource_manager()
manager.set_memory_limit(limit_bytes)
def monitor_async_task(task: asyncio.Task, name: str = "unknown"):
"""Monitor an async task for resource usage.
Args:
task: Asyncio task to monitor
name: Name of the task for logging
"""
def task_done_callback(finished_task):
if finished_task.exception():
logger.error(f"Task {name} failed: {finished_task.exception()}")
else:
logger.debug(f"Task {name} completed successfully")
# Trigger cleanup
manager = get_resource_manager()
manager._trigger_emergency_cleanup()
task.add_done_callback(task_done_callback)
class ResourceAwareExecutor:
"""Executor that respects resource limits."""
def __init__(self, max_workers: int = None, memory_limit_mb: int = None):
"""Initialize resource-aware executor.
Args:
max_workers: Maximum worker threads
memory_limit_mb: Memory limit in MB
"""
self.max_workers = max_workers or min(32, (os.cpu_count() or 1) + 4)
self.memory_limit_mb = memory_limit_mb or 500
self.active_tasks = 0
self.lock = threading.Lock()
def submit(self, fn: Callable, *args, **kwargs):
"""Submit a task for execution with resource monitoring."""
with self.lock:
if self.active_tasks >= self.max_workers:
raise ResourceExhaustionError("Too many active tasks")
# Check memory before starting
current_usage = get_resource_manager().get_current_usage()
if current_usage.memory_rss_bytes > self.memory_limit_mb * 1024 * 1024:
raise ResourceExhaustionError("Memory limit would be exceeded")
self.active_tasks += 1
try:
result = fn(*args, **kwargs)
return result
finally:
with self.lock:
self.active_tasks -= 1
# Utility functions
def cleanup_on_low_memory(threshold_mb: float = 500.0):
"""Decorator to trigger cleanup when memory is low.
Args:
threshold_mb: Memory threshold in MB
"""
def decorator(func):
def wrapper(*args, **kwargs):
get_resource_manager().get_current_usage()
available_mb = psutil.virtual_memory().available / (1024**2)
if available_mb < threshold_mb:
logger.warning(
f"Low memory detected ({available_mb:.1f}MB), triggering cleanup"
)
get_resource_manager()._trigger_emergency_cleanup()
return func(*args, **kwargs)
return wrapper
return decorator
def log_resource_usage(func: Callable = None, *, interval: int = 60):
"""Decorator to log resource usage periodically.
Args:
func: Function to decorate
interval: Logging interval in seconds
"""
def decorator(f):
def wrapper(*args, **kwargs):
start_time = time.time()
start_usage = get_resource_manager().get_current_usage()
try:
return f(*args, **kwargs)
finally:
end_usage = get_resource_manager().get_current_usage()
duration = time.time() - start_time
memory_delta = end_usage.memory_rss_bytes - start_usage.memory_rss_bytes
logger.info(
f"{f.__name__} completed in {duration:.2f}s, "
f"memory delta: {memory_delta / (1024**2):+.2f}MB"
)
return wrapper
if func is None:
return decorator
else:
return decorator(func)
# Initialize global resource manager
def initialize_resource_management(memory_limit_gb: float = 2.0):
"""Initialize global resource management.
Args:
memory_limit_gb: Memory limit in GB
"""
global _resource_manager
limits = ResourceLimits(
memory_limit_bytes=int(memory_limit_gb * 1024 * 1024 * 1024),
enable_memory_monitoring=True,
enable_cpu_monitoring=True,
)
_resource_manager = ResourceManager(limits)
logger.info(
f"Resource management initialized with {memory_limit_gb}GB memory limit"
)
# Cleanup function for graceful shutdown
def shutdown_resource_management():
"""Shutdown resource management gracefully."""
global _resource_manager
if _resource_manager:
_resource_manager.stop_monitoring()
_resource_manager.cleanup_all()
_resource_manager = None
logger.info("Resource management shut down")
```