This is page 28 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_backtest_persistence.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive tests for backtest persistence layer.
3 |
4 | Tests cover:
5 | - PostgreSQL persistence layer with comprehensive database operations
6 | - BacktestResult, BacktestTrade, OptimizationResult, and WalkForwardTest models
7 | - Database CRUD operations with proper error handling
8 | - Performance comparison and ranking functionality
9 | - Backtest result caching and retrieval optimization
10 | - Database constraint validation and data integrity
11 | - Concurrent access and transaction handling
12 | """
13 |
14 | from datetime import datetime, timedelta
15 | from decimal import Decimal
16 | from typing import Any
17 | from unittest.mock import Mock, patch
18 | from uuid import UUID, uuid4
19 |
20 | import numpy as np
21 | import pandas as pd
22 | import pytest
23 | from sqlalchemy.exc import SQLAlchemyError
24 | from sqlalchemy.orm import Session
25 |
26 | from maverick_mcp.backtesting.persistence import (
27 | BacktestPersistenceError,
28 | BacktestPersistenceManager,
29 | find_best_strategy_for_symbol,
30 | get_recent_backtests,
31 | save_vectorbt_results,
32 | )
33 | from maverick_mcp.data.models import (
34 | BacktestResult,
35 | BacktestTrade,
36 | OptimizationResult,
37 | WalkForwardTest,
38 | )
39 |
40 |
41 | class TestBacktestPersistenceManager:
42 | """Test suite for BacktestPersistenceManager class."""
43 |
44 | @pytest.fixture
45 | def sample_vectorbt_results(self) -> dict[str, Any]:
46 | """Create sample VectorBT results for testing."""
47 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
48 | equity_curve = np.cumsum(np.random.normal(0.001, 0.02, len(dates)))
49 | drawdown_series = np.minimum(
50 | 0, equity_curve - np.maximum.accumulate(equity_curve)
51 | )
52 |
53 | return {
54 | "symbol": "AAPL",
55 | "strategy": "momentum_crossover",
56 | "parameters": {
57 | "fast_window": 10,
58 | "slow_window": 20,
59 | "signal_threshold": 0.02,
60 | },
61 | "start_date": "2023-01-01",
62 | "end_date": "2023-12-31",
63 | "initial_capital": 10000.0,
64 | "metrics": {
65 | "total_return": 0.15,
66 | "annualized_return": 0.18,
67 | "sharpe_ratio": 1.25,
68 | "sortino_ratio": 1.45,
69 | "calmar_ratio": 1.10,
70 | "max_drawdown": -0.08,
71 | "max_drawdown_duration": 45,
72 | "volatility": 0.16,
73 | "downside_volatility": 0.12,
74 | "total_trades": 24,
75 | "winning_trades": 14,
76 | "losing_trades": 10,
77 | "win_rate": 0.583,
78 | "profit_factor": 1.35,
79 | "average_win": 0.045,
80 | "average_loss": -0.025,
81 | "largest_win": 0.12,
82 | "largest_loss": -0.08,
83 | "final_value": 11500.0,
84 | "peak_value": 12100.0,
85 | "beta": 1.05,
86 | "alpha": 0.03,
87 | },
88 | "equity_curve": equity_curve.tolist(),
89 | "drawdown_series": drawdown_series.tolist(),
90 | "trades": [
91 | {
92 | "entry_date": "2023-01-15",
93 | "entry_price": 150.0,
94 | "entry_time": "2023-01-15T09:30:00",
95 | "exit_date": "2023-01-25",
96 | "exit_price": 155.0,
97 | "exit_time": "2023-01-25T16:00:00",
98 | "position_size": 100,
99 | "direction": "long",
100 | "pnl": 500.0,
101 | "pnl_percent": 0.033,
102 | "mae": -150.0,
103 | "mfe": 600.0,
104 | "duration_days": 10,
105 | "duration_hours": 6.5,
106 | "exit_reason": "take_profit",
107 | "fees_paid": 2.0,
108 | "slippage_cost": 1.0,
109 | },
110 | {
111 | "entry_date": "2023-02-01",
112 | "entry_price": 160.0,
113 | "entry_time": "2023-02-01T10:00:00",
114 | "exit_date": "2023-02-10",
115 | "exit_price": 156.0,
116 | "exit_time": "2023-02-10T15:30:00",
117 | "position_size": 100,
118 | "direction": "long",
119 | "pnl": -400.0,
120 | "pnl_percent": -0.025,
121 | "mae": -500.0,
122 | "mfe": 200.0,
123 | "duration_days": 9,
124 | "duration_hours": 5.5,
125 | "exit_reason": "stop_loss",
126 | "fees_paid": 2.0,
127 | "slippage_cost": 1.0,
128 | },
129 | ],
130 | }
131 |
132 | @pytest.fixture
133 | def persistence_manager(self, db_session: Session):
134 | """Create a persistence manager with test database session."""
135 | return BacktestPersistenceManager(session=db_session)
136 |
137 | def test_persistence_manager_context_manager(self, db_session: Session):
138 | """Test persistence manager as context manager."""
139 | with BacktestPersistenceManager(session=db_session) as manager:
140 | assert manager.session == db_session
141 | assert not manager._owns_session
142 |
143 | # Test with auto-session creation (mocked)
144 | with patch(
145 | "maverick_mcp.backtesting.persistence.SessionLocal"
146 | ) as mock_session_local:
147 | mock_session = Mock(spec=Session)
148 | mock_session_local.return_value = mock_session
149 |
150 | with BacktestPersistenceManager() as manager:
151 | assert manager.session == mock_session
152 | assert manager._owns_session
153 | mock_session.commit.assert_called_once()
154 | mock_session.close.assert_called_once()
155 |
156 | def test_save_backtest_result_success(
157 | self, persistence_manager, sample_vectorbt_results
158 | ):
159 | """Test successful backtest result saving."""
160 | backtest_id = persistence_manager.save_backtest_result(
161 | vectorbt_results=sample_vectorbt_results,
162 | execution_time=2.5,
163 | notes="Test backtest run",
164 | )
165 |
166 | # Test return value
167 | assert isinstance(backtest_id, str)
168 | assert UUID(backtest_id) # Valid UUID
169 |
170 | # Test database record
171 | result = (
172 | persistence_manager.session.query(BacktestResult)
173 | .filter(BacktestResult.backtest_id == UUID(backtest_id))
174 | .first()
175 | )
176 |
177 | assert result is not None
178 | assert result.symbol == "AAPL"
179 | assert result.strategy_type == "momentum_crossover"
180 | assert result.total_return == Decimal("0.15")
181 | assert result.sharpe_ratio == Decimal("1.25")
182 | assert result.total_trades == 24
183 | assert result.execution_time_seconds == Decimal("2.5")
184 | assert result.notes == "Test backtest run"
185 |
186 | # Test trades were saved
187 | trades = (
188 | persistence_manager.session.query(BacktestTrade)
189 | .filter(BacktestTrade.backtest_id == UUID(backtest_id))
190 | .all()
191 | )
192 |
193 | assert len(trades) == 2
194 | assert trades[0].symbol == "AAPL"
195 | assert trades[0].pnl == Decimal("500.0")
196 | assert trades[1].pnl == Decimal("-400.0")
197 |
198 | def test_save_backtest_result_validation_error(self, persistence_manager):
199 | """Test backtest saving with validation errors."""
200 | # Missing required fields
201 | invalid_results = {"symbol": "", "strategy": ""}
202 |
203 | with pytest.raises(BacktestPersistenceError) as exc_info:
204 | persistence_manager.save_backtest_result(invalid_results)
205 |
206 | assert "Symbol and strategy type are required" in str(exc_info.value)
207 |
208 | def test_save_backtest_result_database_error(
209 | self, persistence_manager, sample_vectorbt_results
210 | ):
211 | """Test backtest saving with database errors."""
212 | with patch.object(
213 | persistence_manager.session, "add", side_effect=SQLAlchemyError("DB Error")
214 | ):
215 | with pytest.raises(BacktestPersistenceError) as exc_info:
216 | persistence_manager.save_backtest_result(sample_vectorbt_results)
217 |
218 | assert "Failed to save backtest" in str(exc_info.value)
219 |
220 | def test_get_backtest_by_id(self, persistence_manager, sample_vectorbt_results):
221 | """Test retrieval of backtest by ID."""
222 | # Save a backtest first
223 | backtest_id = persistence_manager.save_backtest_result(sample_vectorbt_results)
224 |
225 | # Retrieve it
226 | result = persistence_manager.get_backtest_by_id(backtest_id)
227 |
228 | assert result is not None
229 | assert str(result.backtest_id) == backtest_id
230 | assert result.symbol == "AAPL"
231 | assert result.strategy_type == "momentum_crossover"
232 |
233 | # Test non-existent ID
234 | fake_id = str(uuid4())
235 | result = persistence_manager.get_backtest_by_id(fake_id)
236 | assert result is None
237 |
238 | # Test invalid UUID format
239 | result = persistence_manager.get_backtest_by_id("invalid-uuid")
240 | assert result is None
241 |
242 | def test_get_backtests_by_symbol(
243 | self, persistence_manager, sample_vectorbt_results
244 | ):
245 | """Test retrieval of backtests by symbol."""
246 | # Save multiple backtests for same symbol
247 | sample_vectorbt_results["strategy"] = "momentum_v1"
248 | backtest_id1 = persistence_manager.save_backtest_result(sample_vectorbt_results)
249 |
250 | sample_vectorbt_results["strategy"] = "momentum_v2"
251 | backtest_id2 = persistence_manager.save_backtest_result(sample_vectorbt_results)
252 |
253 | # Save backtest for different symbol
254 | sample_vectorbt_results["symbol"] = "GOOGL"
255 | sample_vectorbt_results["strategy"] = "momentum_v1"
256 | backtest_id3 = persistence_manager.save_backtest_result(sample_vectorbt_results)
257 |
258 | # Test retrieval by symbol
259 | aapl_results = persistence_manager.get_backtests_by_symbol("AAPL")
260 | assert len(aapl_results) == 2
261 | assert all(result.symbol == "AAPL" for result in aapl_results)
262 | assert backtest_id1 != backtest_id2
263 | assert backtest_id3 not in {backtest_id1, backtest_id2}
264 | retrieved_ids = {str(result.backtest_id) for result in aapl_results}
265 | assert {backtest_id1, backtest_id2}.issubset(retrieved_ids)
266 |
267 | # Test with strategy filter
268 | aapl_v1_results = persistence_manager.get_backtests_by_symbol(
269 | "AAPL", "momentum_v1"
270 | )
271 | assert len(aapl_v1_results) == 1
272 | assert aapl_v1_results[0].strategy_type == "momentum_v1"
273 |
274 | # Test with limit
275 | limited_results = persistence_manager.get_backtests_by_symbol("AAPL", limit=1)
276 | assert len(limited_results) == 1
277 |
278 | # Test non-existent symbol
279 | empty_results = persistence_manager.get_backtests_by_symbol("NONEXISTENT")
280 | assert len(empty_results) == 0
281 |
282 | def test_get_best_performing_strategies(
283 | self, persistence_manager, sample_vectorbt_results
284 | ):
285 | """Test retrieval of best performing strategies."""
286 | # Create multiple backtests with different performance
287 | strategies_performance = [
288 | (
289 | "momentum",
290 | {"sharpe_ratio": 1.5, "total_return": 0.2, "total_trades": 15},
291 | ),
292 | (
293 | "mean_reversion",
294 | {"sharpe_ratio": 1.8, "total_return": 0.15, "total_trades": 20},
295 | ),
296 | (
297 | "breakout",
298 | {"sharpe_ratio": 0.8, "total_return": 0.25, "total_trades": 10},
299 | ),
300 | (
301 | "momentum_v2",
302 | {"sharpe_ratio": 2.0, "total_return": 0.3, "total_trades": 25},
303 | ),
304 | ]
305 |
306 | backtest_ids = []
307 | for strategy, metrics in strategies_performance:
308 | sample_vectorbt_results["strategy"] = strategy
309 | sample_vectorbt_results["metrics"].update(metrics)
310 | backtest_id = persistence_manager.save_backtest_result(
311 | sample_vectorbt_results
312 | )
313 | backtest_ids.append(backtest_id)
314 |
315 | # Test best by Sharpe ratio (default)
316 | best_sharpe = persistence_manager.get_best_performing_strategies(
317 | "sharpe_ratio", limit=3
318 | )
319 | assert len(best_sharpe) == 3
320 | assert best_sharpe[0].strategy_type == "momentum_v2" # Highest Sharpe
321 | assert best_sharpe[1].strategy_type == "mean_reversion" # Second highest
322 | assert best_sharpe[0].sharpe_ratio > best_sharpe[1].sharpe_ratio
323 |
324 | # Test best by total return
325 | best_return = persistence_manager.get_best_performing_strategies(
326 | "total_return", limit=2
327 | )
328 | assert len(best_return) == 2
329 | assert best_return[0].strategy_type == "momentum_v2" # Highest return
330 |
331 | # Test minimum trades filter
332 | high_volume = persistence_manager.get_best_performing_strategies(
333 | "sharpe_ratio", min_trades=20
334 | )
335 | assert len(high_volume) == 2 # Only momentum_v2 and mean_reversion
336 | assert all(result.total_trades >= 20 for result in high_volume)
337 |
338 | def test_compare_strategies(self, persistence_manager, sample_vectorbt_results):
339 | """Test strategy comparison functionality."""
340 | # Create backtests to compare
341 | strategies = ["momentum", "mean_reversion", "breakout"]
342 | backtest_ids = []
343 |
344 | for i, strategy in enumerate(strategies):
345 | sample_vectorbt_results["strategy"] = strategy
346 | sample_vectorbt_results["metrics"]["sharpe_ratio"] = 1.0 + i * 0.5
347 | sample_vectorbt_results["metrics"]["total_return"] = 0.1 + i * 0.05
348 | sample_vectorbt_results["metrics"]["max_drawdown"] = -0.05 - i * 0.02
349 | backtest_id = persistence_manager.save_backtest_result(
350 | sample_vectorbt_results
351 | )
352 | backtest_ids.append(backtest_id)
353 |
354 | # Test comparison
355 | comparison = persistence_manager.compare_strategies(backtest_ids)
356 |
357 | assert "backtests" in comparison
358 | assert "rankings" in comparison
359 | assert "summary" in comparison
360 | assert len(comparison["backtests"]) == 3
361 |
362 | # Test rankings
363 | assert "sharpe_ratio" in comparison["rankings"]
364 | sharpe_rankings = comparison["rankings"]["sharpe_ratio"]
365 | assert len(sharpe_rankings) == 3
366 | assert sharpe_rankings[0]["rank"] == 1 # Best rank
367 | assert sharpe_rankings[0]["value"] > sharpe_rankings[1]["value"]
368 |
369 | # Test max_drawdown ranking (lower is better)
370 | assert "max_drawdown" in comparison["rankings"]
371 | dd_rankings = comparison["rankings"]["max_drawdown"]
372 | assert (
373 | dd_rankings[0]["value"] > dd_rankings[-1]["value"]
374 | ) # Less negative is better
375 |
376 | # Test summary
377 | summary = comparison["summary"]
378 | assert summary["total_backtests"] == 3
379 | assert "date_range" in summary
380 |
381 | def test_save_optimization_results(
382 | self, persistence_manager, sample_vectorbt_results
383 | ):
384 | """Test saving parameter optimization results."""
385 | # Save parent backtest first
386 | backtest_id = persistence_manager.save_backtest_result(sample_vectorbt_results)
387 |
388 | # Create optimization results
389 | optimization_results = [
390 | {
391 | "parameters": {"window": 10, "threshold": 0.01},
392 | "objective_value": 1.2,
393 | "total_return": 0.15,
394 | "sharpe_ratio": 1.2,
395 | "max_drawdown": -0.08,
396 | "win_rate": 0.6,
397 | "profit_factor": 1.3,
398 | "total_trades": 20,
399 | "rank": 1,
400 | },
401 | {
402 | "parameters": {"window": 20, "threshold": 0.02},
403 | "objective_value": 1.5,
404 | "total_return": 0.18,
405 | "sharpe_ratio": 1.5,
406 | "max_drawdown": -0.06,
407 | "win_rate": 0.65,
408 | "profit_factor": 1.4,
409 | "total_trades": 18,
410 | "rank": 2,
411 | },
412 | ]
413 |
414 | # Save optimization results
415 | count = persistence_manager.save_optimization_results(
416 | backtest_id=backtest_id,
417 | optimization_results=optimization_results,
418 | objective_function="sharpe_ratio",
419 | )
420 |
421 | assert count == 2
422 |
423 | # Verify saved results
424 | opt_results = (
425 | persistence_manager.session.query(OptimizationResult)
426 | .filter(OptimizationResult.backtest_id == UUID(backtest_id))
427 | .all()
428 | )
429 |
430 | assert len(opt_results) == 2
431 | assert opt_results[0].objective_function == "sharpe_ratio"
432 | assert opt_results[0].parameters == {"window": 10, "threshold": 0.01}
433 | assert opt_results[0].objective_value == Decimal("1.2")
434 |
435 | def test_save_walk_forward_test(self, persistence_manager, sample_vectorbt_results):
436 | """Test saving walk-forward validation results."""
437 | # Save parent backtest first
438 | backtest_id = persistence_manager.save_backtest_result(sample_vectorbt_results)
439 |
440 | # Create walk-forward test data
441 | walk_forward_data = {
442 | "window_size_months": 6,
443 | "step_size_months": 1,
444 | "training_start": "2023-01-01",
445 | "training_end": "2023-06-30",
446 | "test_period_start": "2023-07-01",
447 | "test_period_end": "2023-07-31",
448 | "optimal_parameters": {"window": 15, "threshold": 0.015},
449 | "training_performance": 1.3,
450 | "out_of_sample_return": 0.12,
451 | "out_of_sample_sharpe": 1.1,
452 | "out_of_sample_drawdown": -0.05,
453 | "out_of_sample_trades": 8,
454 | "performance_ratio": 0.85,
455 | "degradation_factor": 0.15,
456 | "is_profitable": True,
457 | "is_statistically_significant": True,
458 | }
459 |
460 | # Save walk-forward test
461 | wf_id = persistence_manager.save_walk_forward_test(
462 | backtest_id, walk_forward_data
463 | )
464 |
465 | assert isinstance(wf_id, str)
466 | assert UUID(wf_id)
467 |
468 | # Verify saved result
469 | wf_test = (
470 | persistence_manager.session.query(WalkForwardTest)
471 | .filter(WalkForwardTest.walk_forward_id == UUID(wf_id))
472 | .first()
473 | )
474 |
475 | assert wf_test is not None
476 | assert wf_test.parent_backtest_id == UUID(backtest_id)
477 | assert wf_test.window_size_months == 6
478 | assert wf_test.out_of_sample_sharpe == Decimal("1.1")
479 | assert wf_test.is_profitable is True
480 |
481 | def test_get_backtest_performance_summary(
482 | self, persistence_manager, sample_vectorbt_results
483 | ):
484 | """Test performance summary generation."""
485 | # Create backtests with different dates and performance
486 | base_date = datetime.utcnow()
487 |
488 | # Recent backtests (within 30 days)
489 | for i in range(3):
490 | sample_vectorbt_results["strategy"] = f"momentum_v{i + 1}"
491 | sample_vectorbt_results["metrics"]["total_return"] = 0.1 + i * 0.05
492 | sample_vectorbt_results["metrics"]["sharpe_ratio"] = 1.0 + i * 0.3
493 | sample_vectorbt_results["metrics"]["win_rate"] = 0.5 + i * 0.1
494 |
495 | with patch(
496 | "maverick_mcp.data.models.BacktestResult.backtest_date",
497 | base_date - timedelta(days=i * 10),
498 | ):
499 | persistence_manager.save_backtest_result(sample_vectorbt_results)
500 |
501 | # Old backtest (outside 30 days)
502 | sample_vectorbt_results["strategy"] = "old_strategy"
503 | with patch(
504 | "maverick_mcp.data.models.BacktestResult.backtest_date",
505 | base_date - timedelta(days=45),
506 | ):
507 | persistence_manager.save_backtest_result(sample_vectorbt_results)
508 |
509 | # Get summary
510 | summary = persistence_manager.get_backtest_performance_summary(days_back=30)
511 |
512 | assert "period" in summary
513 | assert summary["total_backtests"] == 3 # Only recent ones
514 | assert "performance_metrics" in summary
515 |
516 | metrics = summary["performance_metrics"]
517 | assert "average_return" in metrics
518 | assert "best_return" in metrics
519 | assert "worst_return" in metrics
520 | assert "average_sharpe" in metrics
521 |
522 | # Test strategy and symbol breakdowns
523 | assert "strategy_breakdown" in summary
524 | assert len(summary["strategy_breakdown"]) == 3
525 | assert "symbol_breakdown" in summary
526 | assert "AAPL" in summary["symbol_breakdown"]
527 |
528 | def test_delete_backtest(self, persistence_manager, sample_vectorbt_results):
529 | """Test backtest deletion with cascading."""
530 | # Save backtest with trades
531 | backtest_id = persistence_manager.save_backtest_result(sample_vectorbt_results)
532 |
533 | # Verify it exists
534 | result = persistence_manager.get_backtest_by_id(backtest_id)
535 | assert result is not None
536 |
537 | trades = (
538 | persistence_manager.session.query(BacktestTrade)
539 | .filter(BacktestTrade.backtest_id == UUID(backtest_id))
540 | .all()
541 | )
542 | assert len(trades) > 0
543 |
544 | # Delete backtest
545 | deleted = persistence_manager.delete_backtest(backtest_id)
546 | assert deleted is True
547 |
548 | # Verify deletion
549 | result = persistence_manager.get_backtest_by_id(backtest_id)
550 | assert result is None
551 |
552 | # Test non-existent deletion
553 | fake_id = str(uuid4())
554 | deleted = persistence_manager.delete_backtest(fake_id)
555 | assert deleted is False
556 |
557 | def test_safe_decimal_conversion(self):
558 | """Test safe decimal conversion utility."""
559 | from maverick_mcp.backtesting.persistence import BacktestPersistenceManager
560 |
561 | # Test valid conversions
562 | assert BacktestPersistenceManager._safe_decimal(123) == Decimal("123")
563 | assert BacktestPersistenceManager._safe_decimal(123.45) == Decimal("123.45")
564 | assert BacktestPersistenceManager._safe_decimal("456.78") == Decimal("456.78")
565 | assert BacktestPersistenceManager._safe_decimal(Decimal("789.01")) == Decimal(
566 | "789.01"
567 | )
568 |
569 | # Test None and invalid values
570 | assert BacktestPersistenceManager._safe_decimal(None) is None
571 | assert BacktestPersistenceManager._safe_decimal("invalid") is None
572 | assert BacktestPersistenceManager._safe_decimal([1, 2, 3]) is None
573 |
574 |
575 | class TestConvenienceFunctions:
576 | """Test suite for convenience functions."""
577 |
578 | def test_save_vectorbt_results_function(
579 | self, db_session: Session, sample_vectorbt_results
580 | ):
581 | """Test save_vectorbt_results convenience function."""
582 | with patch(
583 | "maverick_mcp.backtesting.persistence.get_persistence_manager"
584 | ) as mock_factory:
585 | mock_manager = Mock(spec=BacktestPersistenceManager)
586 | mock_manager.save_backtest_result.return_value = "test-uuid-123"
587 | mock_manager.__enter__ = Mock(return_value=mock_manager)
588 | mock_manager.__exit__ = Mock(return_value=None)
589 | mock_factory.return_value = mock_manager
590 |
591 | result = save_vectorbt_results(
592 | vectorbt_results=sample_vectorbt_results,
593 | execution_time=2.5,
594 | notes="Test run",
595 | )
596 |
597 | assert result == "test-uuid-123"
598 | mock_manager.save_backtest_result.assert_called_once_with(
599 | sample_vectorbt_results, 2.5, "Test run"
600 | )
601 |
602 | def test_get_recent_backtests_function(self, db_session: Session):
603 | """Test get_recent_backtests convenience function."""
604 | with patch(
605 | "maverick_mcp.backtesting.persistence.get_persistence_manager"
606 | ) as mock_factory:
607 | mock_manager = Mock(spec=BacktestPersistenceManager)
608 | mock_session = Mock(spec=Session)
609 | mock_query = Mock()
610 |
611 | mock_manager.session = mock_session
612 | mock_session.query.return_value = mock_query
613 | mock_query.filter.return_value = mock_query
614 | mock_query.order_by.return_value = mock_query
615 | mock_query.all.return_value = ["result1", "result2"]
616 |
617 | mock_manager.__enter__ = Mock(return_value=mock_manager)
618 | mock_manager.__exit__ = Mock(return_value=None)
619 | mock_factory.return_value = mock_manager
620 |
621 | results = get_recent_backtests("AAPL", days=7)
622 |
623 | assert results == ["result1", "result2"]
624 | mock_session.query.assert_called_once_with(BacktestResult)
625 |
626 | def test_find_best_strategy_for_symbol_function(self, db_session: Session):
627 | """Test find_best_strategy_for_symbol convenience function."""
628 | with patch(
629 | "maverick_mcp.backtesting.persistence.get_persistence_manager"
630 | ) as mock_factory:
631 | mock_manager = Mock(spec=BacktestPersistenceManager)
632 | mock_best_result = Mock(spec=BacktestResult)
633 |
634 | mock_manager.get_best_performing_strategies.return_value = [
635 | mock_best_result
636 | ]
637 | mock_manager.get_backtests_by_symbol.return_value = [mock_best_result]
638 | mock_manager.__enter__ = Mock(return_value=mock_manager)
639 | mock_manager.__exit__ = Mock(return_value=None)
640 | mock_factory.return_value = mock_manager
641 |
642 | result = find_best_strategy_for_symbol("AAPL", "sharpe_ratio")
643 |
644 | assert result == mock_best_result
645 | mock_manager.get_backtests_by_symbol.assert_called_once_with(
646 | "AAPL", limit=1000
647 | )
648 |
649 |
650 | class TestPersistenceStressTests:
651 | """Stress tests for persistence layer performance and reliability."""
652 |
653 | def test_bulk_insert_performance(
654 | self, persistence_manager, sample_vectorbt_results, benchmark_timer
655 | ):
656 | """Test bulk insert performance with many backtests."""
657 | backtest_count = 50
658 |
659 | with benchmark_timer() as timer:
660 | for i in range(backtest_count):
661 | sample_vectorbt_results["symbol"] = f"STOCK{i:03d}"
662 | sample_vectorbt_results["strategy"] = (
663 | f"strategy_{i % 5}" # 5 different strategies
664 | )
665 | persistence_manager.save_backtest_result(sample_vectorbt_results)
666 |
667 | # Should complete within reasonable time
668 | assert timer.elapsed < 30.0 # < 30 seconds for 50 backtests
669 |
670 | # Verify all were saved
671 | all_results = persistence_manager.session.query(BacktestResult).count()
672 | assert all_results == backtest_count
673 |
674 | def test_concurrent_access_handling(
675 | self, db_session: Session, sample_vectorbt_results
676 | ):
677 | """Test handling of concurrent database access."""
678 | import queue
679 | import threading
680 |
681 | results_queue = queue.Queue()
682 | error_queue = queue.Queue()
683 |
684 | def save_backtest(thread_id):
685 | try:
686 | # Each thread gets its own session
687 | with BacktestPersistenceManager() as manager:
688 | modified_results = sample_vectorbt_results.copy()
689 | modified_results["symbol"] = f"THREAD{thread_id}"
690 | backtest_id = manager.save_backtest_result(modified_results)
691 | results_queue.put(backtest_id)
692 | except Exception as e:
693 | error_queue.put(f"Thread {thread_id}: {e}")
694 |
695 | # Create multiple threads
696 | threads = []
697 | thread_count = 5
698 |
699 | for i in range(thread_count):
700 | thread = threading.Thread(target=save_backtest, args=(i,))
701 | threads.append(thread)
702 | thread.start()
703 |
704 | # Wait for all threads to complete
705 | for thread in threads:
706 | thread.join(timeout=10) # 10 second timeout per thread
707 |
708 | # Check results
709 | assert error_queue.empty(), f"Errors occurred: {list(error_queue.queue)}"
710 | assert results_queue.qsize() == thread_count
711 |
712 | # Verify all backtests were saved with unique IDs
713 | saved_ids = []
714 | while not results_queue.empty():
715 | saved_ids.append(results_queue.get())
716 |
717 | assert len(saved_ids) == thread_count
718 | assert len(set(saved_ids)) == thread_count # All unique
719 |
720 | def test_large_result_handling(self, persistence_manager, sample_vectorbt_results):
721 | """Test handling of large backtest results."""
722 | # Create large equity curve and drawdown series (1 year of minute data)
723 | large_data_size = 365 * 24 * 60 # ~525k data points
724 |
725 | sample_vectorbt_results["equity_curve"] = list(range(large_data_size))
726 | sample_vectorbt_results["drawdown_series"] = [
727 | -i / 1000 for i in range(large_data_size)
728 | ]
729 |
730 | # Also add many trades
731 | sample_vectorbt_results["trades"] = []
732 | for i in range(1000): # 1000 trades
733 | trade = {
734 | "entry_date": f"2023-{(i % 12) + 1:02d}-{(i % 28) + 1:02d}",
735 | "entry_price": 100 + (i % 100),
736 | "exit_date": f"2023-{(i % 12) + 1:02d}-{(i % 28) + 1:02d}",
737 | "exit_price": 101 + (i % 100),
738 | "position_size": 100,
739 | "direction": "long",
740 | "pnl": i % 100 - 50,
741 | "pnl_percent": (i % 100 - 50) / 1000,
742 | "duration_days": i % 30 + 1,
743 | "exit_reason": "time_exit",
744 | }
745 | sample_vectorbt_results["trades"].append(trade)
746 |
747 | # Should handle large data without issues
748 | backtest_id = persistence_manager.save_backtest_result(sample_vectorbt_results)
749 |
750 | assert backtest_id is not None
751 |
752 | # Verify retrieval works
753 | result = persistence_manager.get_backtest_by_id(backtest_id)
754 | assert result is not None
755 | assert result.data_points == large_data_size
756 |
757 | # Verify trades were saved
758 | trades = (
759 | persistence_manager.session.query(BacktestTrade)
760 | .filter(BacktestTrade.backtest_id == UUID(backtest_id))
761 | .count()
762 | )
763 | assert trades == 1000
764 |
765 | def test_database_constraint_validation(
766 | self, persistence_manager, sample_vectorbt_results
767 | ):
768 | """Test database constraint validation and error handling."""
769 | # Save first backtest
770 | backtest_id1 = persistence_manager.save_backtest_result(sample_vectorbt_results)
771 |
772 | # Try to save with same UUID (should be prevented by unique constraint)
773 | with patch("uuid.uuid4", return_value=UUID(backtest_id1)):
774 | # This should handle the constraint violation gracefully
775 | try:
776 | backtest_id2 = persistence_manager.save_backtest_result(
777 | sample_vectorbt_results
778 | )
779 | # If it succeeds, it should have generated a different UUID
780 | assert backtest_id2 != backtest_id1
781 | except BacktestPersistenceError:
782 | # Or it should raise a proper persistence error
783 | pass
784 |
785 | def test_memory_usage_with_large_datasets(
786 | self, persistence_manager, sample_vectorbt_results
787 | ):
788 | """Test memory usage doesn't grow excessively with large datasets."""
789 | import os
790 |
791 | import psutil
792 |
793 | process = psutil.Process(os.getpid())
794 | initial_memory = process.memory_info().rss
795 |
796 | # Create and save multiple large backtests
797 | for i in range(10):
798 | large_results = sample_vectorbt_results.copy()
799 | large_results["symbol"] = f"LARGE{i}"
800 | large_results["equity_curve"] = list(range(10000)) # 10k data points each
801 | large_results["drawdown_series"] = [-j / 1000 for j in range(10000)]
802 |
803 | persistence_manager.save_backtest_result(large_results)
804 |
805 | final_memory = process.memory_info().rss
806 | memory_growth = (final_memory - initial_memory) / 1024 / 1024 # MB
807 |
808 | # Memory growth should be reasonable (< 100MB for 10 large backtests)
809 | assert memory_growth < 100
810 |
811 |
812 | if __name__ == "__main__":
813 | # Run tests with detailed output
814 | pytest.main([__file__, "-v", "--tb=short", "-x"])
815 |
```
--------------------------------------------------------------------------------
/docs/api/backtesting.md:
--------------------------------------------------------------------------------
```markdown
1 | # Backtesting API Documentation
2 |
3 | ## Overview
4 |
5 | The MaverickMCP backtesting system provides comprehensive strategy backtesting capabilities powered by VectorBT. It offers both traditional technical analysis strategies and advanced ML-enhanced approaches, with extensive optimization, validation, and analysis tools.
6 |
7 | ### Key Features
8 |
9 | - **35+ Pre-built Strategies**: From simple moving averages to advanced ML ensembles
10 | - **Strategy Optimization**: Grid search with coarse/medium/fine granularity
11 | - **Walk-Forward Analysis**: Out-of-sample validation for strategy robustness
12 | - **Monte Carlo Simulation**: Risk assessment with confidence intervals
13 | - **Portfolio Backtesting**: Multi-symbol strategy application
14 | - **Market Regime Analysis**: Intelligent strategy selection based on market conditions
15 | - **ML-Enhanced Strategies**: Adaptive, ensemble, and regime-aware approaches
16 | - **Comprehensive Visualization**: Charts, heatmaps, and performance dashboards
17 |
18 | ## Core Backtesting Tools
19 |
20 | ### run_backtest
21 |
22 | Run a comprehensive backtest with specified strategy and parameters.
23 |
24 | **Function**: `run_backtest`
25 |
26 | **Parameters**:
27 | - `symbol` (str, required): Stock symbol to backtest (e.g., "AAPL", "TSLA")
28 | - `strategy` (str, default: "sma_cross"): Strategy type to use
29 | - `start_date` (str, optional): Start date (YYYY-MM-DD), defaults to 1 year ago
30 | - `end_date` (str, optional): End date (YYYY-MM-DD), defaults to today
31 | - `initial_capital` (float, default: 10000.0): Starting capital for backtest
32 |
33 | **Strategy-Specific Parameters**:
34 | - `fast_period` (int, optional): Fast moving average period
35 | - `slow_period` (int, optional): Slow moving average period
36 | - `period` (int, optional): General period parameter (RSI, etc.)
37 | - `oversold` (float, optional): RSI oversold threshold (default: 30)
38 | - `overbought` (float, optional): RSI overbought threshold (default: 70)
39 | - `signal_period` (int, optional): MACD signal line period
40 | - `std_dev` (float, optional): Bollinger Bands standard deviation
41 | - `lookback` (int, optional): Lookback period for momentum/breakout
42 | - `threshold` (float, optional): Threshold for momentum strategies
43 | - `z_score_threshold` (float, optional): Z-score threshold for mean reversion
44 | - `breakout_factor` (float, optional): Breakout factor for channel strategies
45 |
46 | **Returns**:
47 | ```json
48 | {
49 | "symbol": "AAPL",
50 | "strategy": "sma_cross",
51 | "period": "2023-01-01 to 2024-01-01",
52 | "metrics": {
53 | "total_return": 0.15,
54 | "sharpe_ratio": 1.2,
55 | "max_drawdown": -0.08,
56 | "total_trades": 24,
57 | "win_rate": 0.58,
58 | "profit_factor": 1.45,
59 | "calmar_ratio": 1.85,
60 | "volatility": 0.18
61 | },
62 | "trades": [
63 | {
64 | "entry_date": "2023-01-15",
65 | "exit_date": "2023-02-10",
66 | "entry_price": 150.0,
67 | "exit_price": 158.5,
68 | "return": 0.057,
69 | "holding_period": 26
70 | }
71 | ],
72 | "equity_curve": [10000, 10150, 10200, ...],
73 | "drawdown_series": [0, -0.01, -0.02, ...],
74 | "analysis": {
75 | "risk_metrics": {...},
76 | "performance_analysis": {...}
77 | }
78 | }
79 | ```
80 |
81 | **Examples**:
82 | ```python
83 | # Simple SMA crossover
84 | run_backtest("AAPL", "sma_cross", fast_period=10, slow_period=20)
85 |
86 | # RSI mean reversion
87 | run_backtest("TSLA", "rsi", period=14, oversold=30, overbought=70)
88 |
89 | # MACD strategy with custom parameters
90 | run_backtest("MSFT", "macd", fast_period=12, slow_period=26, signal_period=9)
91 |
92 | # Bollinger Bands strategy
93 | run_backtest("GOOGL", "bollinger", period=20, std_dev=2.0)
94 | ```
95 |
96 | ### optimize_strategy
97 |
98 | Optimize strategy parameters using grid search to find the best-performing configuration.
99 |
100 | **Function**: `optimize_strategy`
101 |
102 | **Parameters**:
103 | - `symbol` (str, required): Stock symbol to optimize
104 | - `strategy` (str, default: "sma_cross"): Strategy type to optimize
105 | - `start_date` (str, optional): Start date (YYYY-MM-DD)
106 | - `end_date` (str, optional): End date (YYYY-MM-DD)
107 | - `optimization_metric` (str, default: "sharpe_ratio"): Metric to optimize ("sharpe_ratio", "total_return", "win_rate", "calmar_ratio")
108 | - `optimization_level` (str, default: "medium"): Level of optimization ("coarse", "medium", "fine")
109 | - `top_n` (int, default: 10): Number of top results to return
110 |
111 | **Returns**:
112 | ```json
113 | {
114 | "symbol": "AAPL",
115 | "strategy": "sma_cross",
116 | "optimization_metric": "sharpe_ratio",
117 | "optimization_level": "medium",
118 | "total_combinations": 64,
119 | "execution_time": 45.2,
120 | "best_parameters": {
121 | "fast_period": 8,
122 | "slow_period": 21,
123 | "sharpe_ratio": 1.85,
124 | "total_return": 0.28,
125 | "max_drawdown": -0.06
126 | },
127 | "top_results": [
128 | {
129 | "parameters": {"fast_period": 8, "slow_period": 21},
130 | "sharpe_ratio": 1.85,
131 | "total_return": 0.28,
132 | "max_drawdown": -0.06,
133 | "total_trades": 18
134 | }
135 | ],
136 | "parameter_sensitivity": {
137 | "fast_period": {"min": 5, "max": 20, "best": 8},
138 | "slow_period": {"min": 20, "max": 50, "best": 21}
139 | }
140 | }
141 | ```
142 |
143 | **Examples**:
144 | ```python
145 | # Optimize SMA crossover for Sharpe ratio
146 | optimize_strategy("AAPL", "sma_cross", optimization_metric="sharpe_ratio")
147 |
148 | # Fine-tune RSI parameters for total return
149 | optimize_strategy("TSLA", "rsi", optimization_metric="total_return", optimization_level="fine")
150 |
151 | # Quick coarse optimization for multiple strategies
152 | optimize_strategy("MSFT", "macd", optimization_level="coarse", top_n=5)
153 | ```
154 |
155 | ### walk_forward_analysis
156 |
157 | Perform walk-forward analysis to test strategy robustness and out-of-sample performance.
158 |
159 | **Function**: `walk_forward_analysis`
160 |
161 | **Parameters**:
162 | - `symbol` (str, required): Stock symbol to analyze
163 | - `strategy` (str, default: "sma_cross"): Strategy type
164 | - `start_date` (str, optional): Start date (YYYY-MM-DD)
165 | - `end_date` (str, optional): End date (YYYY-MM-DD)
166 | - `window_size` (int, default: 252): Test window size in trading days (default: 1 year)
167 | - `step_size` (int, default: 63): Step size for rolling window (default: 1 quarter)
168 |
169 | **Returns**:
170 | ```json
171 | {
172 | "symbol": "AAPL",
173 | "strategy": "sma_cross",
174 | "total_windows": 8,
175 | "window_size": 252,
176 | "step_size": 63,
177 | "out_of_sample_performance": {
178 | "average_return": 0.12,
179 | "average_sharpe": 0.95,
180 | "consistency_score": 0.75,
181 | "best_window": {"period": "2023-Q2", "return": 0.28},
182 | "worst_window": {"period": "2023-Q4", "return": -0.05}
183 | },
184 | "window_results": [
185 | {
186 | "window_id": 1,
187 | "optimization_period": "2022-01-01 to 2022-12-31",
188 | "test_period": "2023-01-01 to 2023-03-31",
189 | "best_parameters": {"fast_period": 10, "slow_period": 25},
190 | "out_of_sample_return": 0.08,
191 | "out_of_sample_sharpe": 1.1
192 | }
193 | ],
194 | "stability_metrics": {
195 | "parameter_stability": 0.85,
196 | "performance_stability": 0.72,
197 | "overfitting_risk": "low"
198 | }
199 | }
200 | ```
201 |
202 | ### monte_carlo_simulation
203 |
204 | Run Monte Carlo simulation on backtest results to assess risk and confidence intervals.
205 |
206 | **Function**: `monte_carlo_simulation`
207 |
208 | **Parameters**:
209 | - `symbol` (str, required): Stock symbol
210 | - `strategy` (str, default: "sma_cross"): Strategy type
211 | - `start_date` (str, optional): Start date (YYYY-MM-DD)
212 | - `end_date` (str, optional): End date (YYYY-MM-DD)
213 | - `num_simulations` (int, default: 1000): Number of Monte Carlo simulations
214 | - Strategy-specific parameters (same as `run_backtest`)
215 |
216 | **Returns**:
217 | ```json
218 | {
219 | "symbol": "AAPL",
220 | "strategy": "sma_cross",
221 | "num_simulations": 1000,
222 | "confidence_intervals": {
223 | "95%": {"lower": 0.05, "upper": 0.32},
224 | "90%": {"lower": 0.08, "upper": 0.28},
225 | "68%": {"lower": 0.12, "upper": 0.22}
226 | },
227 | "risk_metrics": {
228 | "probability_of_loss": 0.15,
229 | "expected_return": 0.17,
230 | "value_at_risk_5%": -0.12,
231 | "expected_shortfall": -0.18,
232 | "maximum_drawdown_95%": -0.15
233 | },
234 | "simulation_statistics": {
235 | "mean_return": 0.168,
236 | "std_return": 0.089,
237 | "skewness": -0.23,
238 | "kurtosis": 2.85,
239 | "best_simulation": 0.45,
240 | "worst_simulation": -0.28
241 | }
242 | }
243 | ```
244 |
245 | ### compare_strategies
246 |
247 | Compare multiple strategies on the same symbol to identify the best performer.
248 |
249 | **Function**: `compare_strategies`
250 |
251 | **Parameters**:
252 | - `symbol` (str, required): Stock symbol
253 | - `strategies` (list[str], optional): List of strategy types to compare (defaults to top 5)
254 | - `start_date` (str, optional): Start date (YYYY-MM-DD)
255 | - `end_date` (str, optional): End date (YYYY-MM-DD)
256 |
257 | **Returns**:
258 | ```json
259 | {
260 | "symbol": "AAPL",
261 | "comparison_period": "2023-01-01 to 2024-01-01",
262 | "strategies_compared": ["sma_cross", "rsi", "macd", "bollinger", "momentum"],
263 | "rankings": {
264 | "by_sharpe_ratio": [
265 | {"strategy": "macd", "sharpe_ratio": 1.45},
266 | {"strategy": "sma_cross", "sharpe_ratio": 1.22},
267 | {"strategy": "momentum", "sharpe_ratio": 0.98}
268 | ],
269 | "by_total_return": [
270 | {"strategy": "momentum", "total_return": 0.32},
271 | {"strategy": "macd", "total_return": 0.28},
272 | {"strategy": "sma_cross", "total_return": 0.18}
273 | ]
274 | },
275 | "detailed_comparison": {
276 | "sma_cross": {
277 | "total_return": 0.18,
278 | "sharpe_ratio": 1.22,
279 | "max_drawdown": -0.08,
280 | "total_trades": 24,
281 | "win_rate": 0.58
282 | }
283 | },
284 | "best_overall": "macd",
285 | "recommendation": "MACD strategy provides best risk-adjusted returns"
286 | }
287 | ```
288 |
289 | ### backtest_portfolio
290 |
291 | Backtest a strategy across multiple symbols to create a diversified portfolio.
292 |
293 | **Function**: `backtest_portfolio`
294 |
295 | **Parameters**:
296 | - `symbols` (list[str], required): List of stock symbols
297 | - `strategy` (str, default: "sma_cross"): Strategy type to apply
298 | - `start_date` (str, optional): Start date (YYYY-MM-DD)
299 | - `end_date` (str, optional): End date (YYYY-MM-DD)
300 | - `initial_capital` (float, default: 10000.0): Starting capital
301 | - `position_size` (float, default: 0.1): Position size per symbol (0.1 = 10%)
302 | - Strategy-specific parameters (same as `run_backtest`)
303 |
304 | **Returns**:
305 | ```json
306 | {
307 | "portfolio_metrics": {
308 | "symbols_tested": 5,
309 | "total_return": 0.22,
310 | "average_sharpe": 1.15,
311 | "max_drawdown": -0.12,
312 | "total_trades": 120,
313 | "diversification_benefit": 0.85
314 | },
315 | "individual_results": [
316 | {
317 | "symbol": "AAPL",
318 | "total_return": 0.18,
319 | "sharpe_ratio": 1.22,
320 | "max_drawdown": -0.08,
321 | "contribution_to_portfolio": 0.24
322 | }
323 | ],
324 | "correlation_matrix": {
325 | "AAPL": {"MSFT": 0.72, "GOOGL": 0.68},
326 | "MSFT": {"GOOGL": 0.75}
327 | },
328 | "summary": "Portfolio backtest of 5 symbols with sma_cross strategy"
329 | }
330 | ```
331 |
332 | ## Strategy Management
333 |
334 | ### list_strategies
335 |
336 | List all available backtesting strategies with descriptions and parameters.
337 |
338 | **Function**: `list_strategies`
339 |
340 | **Parameters**: None
341 |
342 | **Returns**:
343 | ```json
344 | {
345 | "available_strategies": {
346 | "sma_cross": {
347 | "type": "sma_cross",
348 | "name": "SMA Crossover",
349 | "description": "Buy when fast SMA crosses above slow SMA, sell when it crosses below",
350 | "default_parameters": {"fast_period": 10, "slow_period": 20},
351 | "optimization_ranges": {
352 | "fast_period": [5, 10, 15, 20],
353 | "slow_period": [20, 30, 50, 100]
354 | }
355 | }
356 | },
357 | "total_count": 9,
358 | "categories": {
359 | "trend_following": ["sma_cross", "ema_cross", "macd", "breakout"],
360 | "mean_reversion": ["rsi", "bollinger", "mean_reversion"],
361 | "momentum": ["momentum", "volume_momentum"]
362 | }
363 | }
364 | ```
365 |
366 | ### parse_strategy
367 |
368 | Parse natural language strategy description into VectorBT parameters.
369 |
370 | **Function**: `parse_strategy`
371 |
372 | **Parameters**:
373 | - `description` (str, required): Natural language description of trading strategy
374 |
375 | **Returns**:
376 | ```json
377 | {
378 | "success": true,
379 | "strategy": {
380 | "strategy_type": "rsi",
381 | "parameters": {
382 | "period": 14,
383 | "oversold": 30,
384 | "overbought": 70
385 | }
386 | },
387 | "message": "Successfully parsed as rsi strategy"
388 | }
389 | ```
390 |
391 | **Examples**:
392 | ```python
393 | # Parse natural language descriptions
394 | parse_strategy("Buy when RSI is below 30 and sell when above 70")
395 | parse_strategy("Use 10-day and 20-day moving average crossover")
396 | parse_strategy("MACD strategy with standard parameters")
397 | ```
398 |
399 | ## Visualization Tools
400 |
401 | ### generate_backtest_charts
402 |
403 | Generate comprehensive charts for a backtest including equity curve, trades, and performance dashboard.
404 |
405 | **Function**: `generate_backtest_charts`
406 |
407 | **Parameters**:
408 | - `symbol` (str, required): Stock symbol
409 | - `strategy` (str, default: "sma_cross"): Strategy type
410 | - `start_date` (str, optional): Start date (YYYY-MM-DD)
411 | - `end_date` (str, optional): End date (YYYY-MM-DD)
412 | - `theme` (str, default: "light"): Chart theme ("light" or "dark")
413 |
414 | **Returns**:
415 | ```json
416 | {
417 | "equity_curve": "...",
418 | "trade_scatter": "...",
419 | "performance_dashboard": "..."
420 | }
421 | ```
422 |
423 | ### generate_optimization_charts
424 |
425 | Generate heatmap charts for strategy parameter optimization results.
426 |
427 | **Function**: `generate_optimization_charts`
428 |
429 | **Parameters**:
430 | - `symbol` (str, required): Stock symbol
431 | - `strategy` (str, default: "sma_cross"): Strategy type
432 | - `start_date` (str, optional): Start date (YYYY-MM-DD)
433 | - `end_date` (str, optional): End date (YYYY-MM-DD)
434 | - `theme` (str, default: "light"): Chart theme ("light" or "dark")
435 |
436 | **Returns**:
437 | ```json
438 | {
439 | "optimization_heatmap": "..."
440 | }
441 | ```
442 |
443 | ## ML-Enhanced Strategies
444 |
445 | ### run_ml_strategy_backtest
446 |
447 | Run backtest using machine learning-enhanced strategies with adaptive capabilities.
448 |
449 | **Function**: `run_ml_strategy_backtest`
450 |
451 | **Parameters**:
452 | - `symbol` (str, required): Stock symbol to backtest
453 | - `strategy_type` (str, default: "ml_predictor"): ML strategy type ("ml_predictor", "adaptive", "ensemble", "regime_aware")
454 | - `start_date` (str, optional): Start date (YYYY-MM-DD)
455 | - `end_date` (str, optional): End date (YYYY-MM-DD)
456 | - `initial_capital` (float, default: 10000.0): Initial capital amount
457 | - `train_ratio` (float, default: 0.8): Ratio of data for training (0.0-1.0)
458 | - `model_type` (str, default: "random_forest"): ML model type
459 | - `n_estimators` (int, default: 100): Number of estimators for ensemble models
460 | - `max_depth` (int, optional): Maximum tree depth
461 | - `learning_rate` (float, default: 0.01): Learning rate for adaptive strategies
462 | - `adaptation_method` (str, default: "gradient"): Adaptation method ("gradient", "momentum")
463 |
464 | **Returns**:
465 | ```json
466 | {
467 | "symbol": "AAPL",
468 | "strategy_type": "ml_predictor",
469 | "metrics": {
470 | "total_return": 0.24,
471 | "sharpe_ratio": 1.35,
472 | "max_drawdown": -0.09
473 | },
474 | "ml_metrics": {
475 | "training_period": 400,
476 | "testing_period": 100,
477 | "train_test_split": 0.8,
478 | "feature_importance": {
479 | "rsi": 0.25,
480 | "macd": 0.22,
481 | "volume_ratio": 0.18,
482 | "price_momentum": 0.16
483 | },
484 | "model_accuracy": 0.68,
485 | "prediction_confidence": 0.72
486 | }
487 | }
488 | ```
489 |
490 | ### train_ml_predictor
491 |
492 | Train a machine learning predictor model for generating trading signals.
493 |
494 | **Function**: `train_ml_predictor`
495 |
496 | **Parameters**:
497 | - `symbol` (str, required): Stock symbol to train on
498 | - `start_date` (str, optional): Start date for training data
499 | - `end_date` (str, optional): End date for training data
500 | - `model_type` (str, default: "random_forest"): ML model type
501 | - `target_periods` (int, default: 5): Forward periods for target variable
502 | - `return_threshold` (float, default: 0.02): Return threshold for signal classification
503 | - `n_estimators` (int, default: 100): Number of estimators
504 | - `max_depth` (int, optional): Maximum tree depth
505 | - `min_samples_split` (int, default: 2): Minimum samples to split
506 |
507 | **Returns**:
508 | ```json
509 | {
510 | "symbol": "AAPL",
511 | "model_type": "random_forest",
512 | "training_period": "2022-01-01 to 2024-01-01",
513 | "data_points": 500,
514 | "target_periods": 5,
515 | "return_threshold": 0.02,
516 | "model_parameters": {
517 | "n_estimators": 100,
518 | "max_depth": 10,
519 | "min_samples_split": 2
520 | },
521 | "training_metrics": {
522 | "accuracy": 0.68,
523 | "precision": 0.72,
524 | "recall": 0.65,
525 | "f1_score": 0.68,
526 | "feature_importance": {
527 | "rsi_14": 0.25,
528 | "macd_signal": 0.22,
529 | "volume_sma_ratio": 0.18
530 | }
531 | }
532 | }
533 | ```
534 |
535 | ### analyze_market_regimes
536 |
537 | Analyze market regimes using machine learning to identify different market conditions.
538 |
539 | **Function**: `analyze_market_regimes`
540 |
541 | **Parameters**:
542 | - `symbol` (str, required): Stock symbol to analyze
543 | - `start_date` (str, optional): Start date for analysis
544 | - `end_date` (str, optional): End date for analysis
545 | - `method` (str, default: "hmm"): Detection method ("hmm", "kmeans", "threshold")
546 | - `n_regimes` (int, default: 3): Number of regimes to detect
547 | - `lookback_period` (int, default: 50): Lookback period for regime detection
548 |
549 | **Returns**:
550 | ```json
551 | {
552 | "symbol": "AAPL",
553 | "analysis_period": "2023-01-01 to 2024-01-01",
554 | "method": "hmm",
555 | "n_regimes": 3,
556 | "regime_names": {
557 | "0": "Bear/Declining",
558 | "1": "Sideways/Uncertain",
559 | "2": "Bull/Trending"
560 | },
561 | "current_regime": 2,
562 | "regime_counts": {"0": 45, "1": 89, "2": 118},
563 | "regime_percentages": {"0": 17.9, "1": 35.3, "2": 46.8},
564 | "average_regime_durations": {"0": 15.2, "1": 22.3, "2": 28.7},
565 | "recent_regime_history": [
566 | {
567 | "date": "2024-01-15",
568 | "regime": 2,
569 | "probabilities": [0.05, 0.15, 0.80]
570 | }
571 | ],
572 | "total_regime_switches": 18
573 | }
574 | ```
575 |
576 | ### create_strategy_ensemble
577 |
578 | Create and backtest a strategy ensemble that combines multiple base strategies.
579 |
580 | **Function**: `create_strategy_ensemble`
581 |
582 | **Parameters**:
583 | - `symbols` (list[str], required): List of stock symbols
584 | - `base_strategies` (list[str], optional): List of base strategy names (defaults to ["sma_cross", "rsi", "macd"])
585 | - `weighting_method` (str, default: "performance"): Weighting method ("performance", "equal", "volatility")
586 | - `start_date` (str, optional): Start date for backtesting
587 | - `end_date` (str, optional): End date for backtesting
588 | - `initial_capital` (float, default: 10000.0): Initial capital per symbol
589 |
590 | **Returns**:
591 | ```json
592 | {
593 | "ensemble_summary": {
594 | "symbols_tested": 5,
595 | "base_strategies": ["sma_cross", "rsi", "macd"],
596 | "weighting_method": "performance",
597 | "average_return": 0.19,
598 | "total_trades": 87,
599 | "average_trades_per_symbol": 17.4
600 | },
601 | "individual_results": [
602 | {
603 | "symbol": "AAPL",
604 | "results": {
605 | "total_return": 0.21,
606 | "sharpe_ratio": 1.18
607 | },
608 | "ensemble_metrics": {
609 | "strategy_weights": {"sma_cross": 0.4, "rsi": 0.3, "macd": 0.3},
610 | "strategy_performance": {"sma_cross": 0.15, "rsi": 0.12, "macd": 0.18}
611 | }
612 | }
613 | ],
614 | "final_strategy_weights": {"sma_cross": 0.42, "rsi": 0.28, "macd": 0.30}
615 | }
616 | ```
617 |
618 | ## Intelligent Backtesting Workflow
619 |
620 | ### run_intelligent_backtest
621 |
622 | Run comprehensive intelligent backtesting workflow with automatic market regime analysis and strategy optimization.
623 |
624 | **Function**: `run_intelligent_backtest`
625 |
626 | **Parameters**:
627 | - `symbol` (str, required): Stock symbol to analyze (e.g., 'AAPL', 'TSLA')
628 | - `start_date` (str, optional): Start date (YYYY-MM-DD), defaults to 1 year ago
629 | - `end_date` (str, optional): End date (YYYY-MM-DD), defaults to today
630 | - `initial_capital` (float, default: 10000.0): Starting capital for backtest
631 | - `requested_strategy` (str, optional): User-preferred strategy (e.g., 'sma_cross', 'rsi', 'macd')
632 |
633 | **Returns**:
634 | ```json
635 | {
636 | "symbol": "AAPL",
637 | "analysis_period": "2023-01-01 to 2024-01-01",
638 | "execution_metadata": {
639 | "total_execution_time": 45.2,
640 | "steps_completed": 6,
641 | "confidence_score": 0.87
642 | },
643 | "market_regime_analysis": {
644 | "current_regime": "trending",
645 | "regime_confidence": 0.85,
646 | "market_characteristics": {
647 | "volatility_percentile": 35,
648 | "trend_strength": 0.72,
649 | "volume_profile": "above_average"
650 | }
651 | },
652 | "strategy_recommendations": [
653 | {
654 | "strategy": "macd",
655 | "fitness_score": 0.92,
656 | "recommended_parameters": {"fast_period": 12, "slow_period": 26, "signal_period": 9},
657 | "expected_performance": {"sharpe_ratio": 1.45, "total_return": 0.28}
658 | },
659 | {
660 | "strategy": "sma_cross",
661 | "fitness_score": 0.88,
662 | "recommended_parameters": {"fast_period": 8, "slow_period": 21},
663 | "expected_performance": {"sharpe_ratio": 1.32, "total_return": 0.24}
664 | }
665 | ],
666 | "optimization_results": {
667 | "best_strategy": "macd",
668 | "optimized_parameters": {"fast_period": 12, "slow_period": 26, "signal_period": 9},
669 | "optimization_method": "grid_search",
670 | "combinations_tested": 48
671 | },
672 | "validation_results": {
673 | "walk_forward_analysis": {
674 | "out_of_sample_sharpe": 1.28,
675 | "consistency_score": 0.82,
676 | "overfitting_risk": "low"
677 | },
678 | "monte_carlo_simulation": {
679 | "probability_of_loss": 0.12,
680 | "95_percent_confidence_interval": {"lower": 0.08, "upper": 0.35}
681 | }
682 | },
683 | "final_recommendation": {
684 | "recommended_strategy": "macd",
685 | "confidence_level": "high",
686 | "expected_annual_return": 0.28,
687 | "expected_sharpe_ratio": 1.45,
688 | "maximum_expected_drawdown": -0.09,
689 | "risk_assessment": "moderate",
690 | "implementation_notes": [
691 | "Strategy performs well in trending markets",
692 | "Consider position sizing based on volatility",
693 | "Monitor for regime changes"
694 | ]
695 | }
696 | }
697 | ```
698 |
699 | ### quick_market_regime_analysis
700 |
701 | Perform fast market regime analysis and basic strategy recommendations without full optimization.
702 |
703 | **Function**: `quick_market_regime_analysis`
704 |
705 | **Parameters**:
706 | - `symbol` (str, required): Stock symbol to analyze
707 | - `start_date` (str, optional): Start date (YYYY-MM-DD), defaults to 1 year ago
708 | - `end_date` (str, optional): End date (YYYY-MM-DD), defaults to today
709 |
710 | **Returns**:
711 | ```json
712 | {
713 | "symbol": "AAPL",
714 | "analysis_type": "quick_analysis",
715 | "execution_time": 8.5,
716 | "market_regime": {
717 | "classification": "trending",
718 | "confidence": 0.78,
719 | "characteristics": {
720 | "trend_direction": "bullish",
721 | "volatility_level": "moderate",
722 | "volume_profile": "above_average"
723 | }
724 | },
725 | "strategy_recommendations": [
726 | {
727 | "strategy": "sma_cross",
728 | "fitness_score": 0.85,
729 | "reasoning": "Strong trend favors moving average strategies"
730 | },
731 | {
732 | "strategy": "macd",
733 | "fitness_score": 0.82,
734 | "reasoning": "MACD works well in trending environments"
735 | },
736 | {
737 | "strategy": "momentum",
738 | "fitness_score": 0.79,
739 | "reasoning": "Momentum strategies benefit from clear trends"
740 | }
741 | ],
742 | "market_conditions_summary": {
743 | "overall_assessment": "favorable_for_trend_following",
744 | "risk_level": "moderate",
745 | "recommended_position_sizing": "standard"
746 | }
747 | }
748 | ```
749 |
750 | ### explain_market_regime
751 |
752 | Get detailed explanation of market regime characteristics and suitable strategies.
753 |
754 | **Function**: `explain_market_regime`
755 |
756 | **Parameters**:
757 | - `regime` (str, required): Market regime to explain ("trending", "ranging", "volatile", "volatile_trending", "low_volume")
758 |
759 | **Returns**:
760 | ```json
761 | {
762 | "regime": "trending",
763 | "explanation": {
764 | "description": "A market in a clear directional movement (up or down trend)",
765 | "characteristics": [
766 | "Strong directional price movement",
767 | "Higher highs and higher lows (uptrend) or lower highs and lower lows (downtrend)",
768 | "Good momentum indicators",
769 | "Volume supporting the trend direction"
770 | ],
771 | "best_strategies": ["sma_cross", "ema_cross", "macd", "breakout", "momentum"],
772 | "avoid_strategies": ["rsi", "mean_reversion", "bollinger"],
773 | "risk_factors": [
774 | "Trend reversals can be sudden",
775 | "False breakouts in weak trends",
776 | "Momentum strategies can give late signals"
777 | ]
778 | },
779 | "trading_tips": [
780 | "Focus on sma_cross, ema_cross, macd, breakout, momentum strategies",
781 | "Avoid rsi, mean_reversion, bollinger strategies",
782 | "Always use proper risk management",
783 | "Consider the broader market context"
784 | ]
785 | }
786 | ```
787 |
788 | ## Available Strategies
789 |
790 | ### Traditional Technical Analysis Strategies
791 |
792 | #### 1. SMA Crossover (`sma_cross`)
793 | - **Description**: Buy when fast SMA crosses above slow SMA, sell when crosses below
794 | - **Default Parameters**: `fast_period=10, slow_period=20`
795 | - **Best For**: Trending markets
796 | - **Optimization Ranges**: fast_period [5-20], slow_period [20-100]
797 |
798 | #### 2. EMA Crossover (`ema_cross`)
799 | - **Description**: Exponential moving average crossover with faster response than SMA
800 | - **Default Parameters**: `fast_period=12, slow_period=26`
801 | - **Best For**: Trending markets with more responsiveness
802 | - **Optimization Ranges**: fast_period [8-20], slow_period [20-50]
803 |
804 | #### 3. RSI Mean Reversion (`rsi`)
805 | - **Description**: Buy oversold (RSI < 30), sell overbought (RSI > 70)
806 | - **Default Parameters**: `period=14, oversold=30, overbought=70`
807 | - **Best For**: Ranging/sideways markets
808 | - **Optimization Ranges**: period [7-21], oversold [20-35], overbought [65-80]
809 |
810 | #### 4. MACD Signal (`macd`)
811 | - **Description**: Buy when MACD crosses above signal line, sell when crosses below
812 | - **Default Parameters**: `fast_period=12, slow_period=26, signal_period=9`
813 | - **Best For**: Trending markets with momentum confirmation
814 | - **Optimization Ranges**: fast_period [8-14], slow_period [21-30], signal_period [7-11]
815 |
816 | #### 5. Bollinger Bands (`bollinger`)
817 | - **Description**: Buy at lower band (oversold), sell at upper band (overbought)
818 | - **Default Parameters**: `period=20, std_dev=2.0`
819 | - **Best For**: Mean-reverting/ranging markets
820 | - **Optimization Ranges**: period [10-25], std_dev [1.5-3.0]
821 |
822 | #### 6. Momentum (`momentum`)
823 | - **Description**: Buy strong momentum, sell weak momentum based on returns threshold
824 | - **Default Parameters**: `lookback=20, threshold=0.05`
825 | - **Best For**: Trending markets with clear momentum
826 | - **Optimization Ranges**: lookback [10-30], threshold [0.02-0.10]
827 |
828 | #### 7. Mean Reversion (`mean_reversion`)
829 | - **Description**: Buy when price is below moving average by threshold
830 | - **Default Parameters**: `ma_period=20, entry_threshold=0.02, exit_threshold=0.01`
831 | - **Best For**: Sideways/ranging markets
832 | - **Optimization Ranges**: ma_period [15-50], entry_threshold [0.01-0.05]
833 |
834 | #### 8. Channel Breakout (`breakout`)
835 | - **Description**: Buy on breakout above rolling high, sell on breakdown below rolling low
836 | - **Default Parameters**: `lookback=20, exit_lookback=10`
837 | - **Best For**: Volatile trending markets
838 | - **Optimization Ranges**: lookback [10-50], exit_lookback [5-20]
839 |
840 | #### 9. Volume-Weighted Momentum (`volume_momentum`)
841 | - **Description**: Momentum strategy filtered by volume surge
842 | - **Default Parameters**: `momentum_period=20, volume_period=20, momentum_threshold=0.05, volume_multiplier=1.5`
843 | - **Best For**: Markets with significant volume participation
844 | - **Optimization Ranges**: momentum_period [10-30], volume_multiplier [1.2-2.0]
845 |
846 | ### ML-Enhanced Strategies
847 |
848 | #### 1. ML Predictor (`ml_predictor`)
849 | - Uses machine learning models (Random Forest, etc.) to predict future price movements
850 | - Features: Technical indicators, price patterns, volume analysis
851 | - Training/testing split with out-of-sample validation
852 |
853 | #### 2. Adaptive Strategy (`adaptive`)
854 | - Adapts base strategy parameters based on recent performance
855 | - Uses gradient-based or momentum-based adaptation methods
856 | - Continuously learns from market feedback
857 |
858 | #### 3. Strategy Ensemble (`ensemble`)
859 | - Combines multiple base strategies with dynamic weighting
860 | - Weighting methods: performance-based, equal-weight, volatility-adjusted
861 | - Provides diversification benefits
862 |
863 | #### 4. Regime-Aware Strategy (`regime_aware`)
864 | - Automatically switches between different strategies based on detected market regime
865 | - Uses Hidden Markov Models or clustering for regime detection
866 | - Optimizes strategy selection for current market conditions
867 |
868 | ## Performance Considerations
869 |
870 | ### Execution Times
871 | - **Simple Backtest**: 2-5 seconds
872 | - **Strategy Optimization**: 30-120 seconds (depending on level)
873 | - **Walk-Forward Analysis**: 60-300 seconds
874 | - **Monte Carlo Simulation**: 45-90 seconds
875 | - **ML Strategy Training**: 60-180 seconds
876 | - **Intelligent Backtest**: 120-300 seconds (full workflow)
877 |
878 | ### Memory Usage
879 | - **Single Symbol**: 50-200 MB
880 | - **Portfolio (5 symbols)**: 200-500 MB
881 | - **ML Training**: 100-1000 MB (depending on data size)
882 |
883 | ### Optimization Levels
884 | - **Coarse**: 16-36 parameter combinations, fastest
885 | - **Medium**: 36-100 combinations, balanced speed/accuracy
886 | - **Fine**: 100-500+ combinations, most thorough
887 |
888 | ## Error Handling
889 |
890 | ### Common Errors
891 |
892 | #### Insufficient Data
893 | ```json
894 | {
895 | "error": "Insufficient data for backtest (minimum 100 data points)",
896 | "symbol": "PENNY_STOCK",
897 | "message": "Please use a longer time period or different symbol"
898 | }
899 | ```
900 |
901 | #### Invalid Strategy
902 | ```json
903 | {
904 | "error": "Unknown strategy type: invalid_strategy",
905 | "available_strategies": ["sma_cross", "rsi", "macd", ...],
906 | "message": "Please use one of the available strategy types"
907 | }
908 | ```
909 |
910 | #### Parameter Validation
911 | ```json
912 | {
913 | "error": "Invalid parameter value",
914 | "parameter": "fast_period",
915 | "value": -5,
916 | "message": "fast_period must be positive integer"
917 | }
918 | ```
919 |
920 | #### ML Training Errors
921 | ```json
922 | {
923 | "error": "ML training failed: Insufficient data for training (minimum 200 data points)",
924 | "symbol": "LOW_DATA_STOCK",
925 | "message": "ML strategies require more historical data"
926 | }
927 | ```
928 |
929 | ### Troubleshooting
930 |
931 | 1. **Data Issues**: Ensure sufficient historical data (minimum 100 points, 200+ for ML)
932 | 2. **Parameter Validation**: Check parameter types and ranges
933 | 3. **Memory Issues**: Reduce number of symbols in portfolio backtests
934 | 4. **Timeout Issues**: Use coarse optimization for faster results
935 | 5. **Strategy Parsing**: Use exact strategy names from `list_strategies`
936 |
937 | ## Integration Examples
938 |
939 | ### Claude Desktop Usage
940 |
941 | ```
942 | # Basic backtest
943 | "Run a backtest for AAPL using RSI strategy with 14-day period"
944 |
945 | # Strategy comparison
946 | "Compare SMA crossover, RSI, and MACD strategies on Tesla stock"
947 |
948 | # Intelligent analysis
949 | "Run intelligent backtest on Microsoft stock and recommend the best strategy"
950 |
951 | # Portfolio backtest
952 | "Backtest momentum strategy on AAPL, MSFT, GOOGL, AMZN, and TSLA"
953 |
954 | # Optimization
955 | "Optimize MACD parameters for Netflix stock over the last 2 years"
956 |
957 | # ML strategies
958 | "Train an ML predictor on Amazon stock and test its performance"
959 | ```
960 |
961 | ### API Integration
962 |
963 | ```python
964 | # Using MCP client
965 | import mcp
966 |
967 | client = mcp.Client("maverick-mcp")
968 |
969 | # Run backtest
970 | result = await client.call_tool("run_backtest", {
971 | "symbol": "AAPL",
972 | "strategy": "sma_cross",
973 | "fast_period": 10,
974 | "slow_period": 20,
975 | "initial_capital": 50000
976 | })
977 |
978 | # Optimize strategy
979 | optimization = await client.call_tool("optimize_strategy", {
980 | "symbol": "TSLA",
981 | "strategy": "rsi",
982 | "optimization_level": "medium",
983 | "optimization_metric": "sharpe_ratio"
984 | })
985 |
986 | # Intelligent backtest
987 | intelligent_result = await client.call_tool("run_intelligent_backtest", {
988 | "symbol": "MSFT",
989 | "start_date": "2022-01-01",
990 | "end_date": "2023-12-31"
991 | })
992 | ```
993 |
994 | ### Workflow Integration
995 |
996 | ```python
997 | # Complete backtesting workflow
998 | symbols = ["AAPL", "MSFT", "GOOGL"]
999 | strategies = ["sma_cross", "rsi", "macd"]
1000 |
1001 | for symbol in symbols:
1002 | # 1. Quick regime analysis
1003 | regime = await client.call_tool("quick_market_regime_analysis", {
1004 | "symbol": symbol
1005 | })
1006 |
1007 | # 2. Strategy comparison
1008 | comparison = await client.call_tool("compare_strategies", {
1009 | "symbol": symbol,
1010 | "strategies": strategies
1011 | })
1012 |
1013 | # 3. Optimize best strategy
1014 | best_strategy = comparison["best_overall"]
1015 | optimization = await client.call_tool("optimize_strategy", {
1016 | "symbol": symbol,
1017 | "strategy": best_strategy
1018 | })
1019 |
1020 | # 4. Validate with walk-forward
1021 | validation = await client.call_tool("walk_forward_analysis", {
1022 | "symbol": symbol,
1023 | "strategy": best_strategy
1024 | })
1025 | ```
1026 |
1027 | ## Best Practices
1028 |
1029 | ### Strategy Selection
1030 | 1. **Trending Markets**: Use sma_cross, ema_cross, macd, breakout, momentum
1031 | 2. **Ranging Markets**: Use rsi, bollinger, mean_reversion
1032 | 3. **Volatile Markets**: Use breakout, volatility_breakout with wider stops
1033 | 4. **Unknown Conditions**: Use intelligent_backtest for automatic selection
1034 |
1035 | ### Parameter Optimization
1036 | 1. **Start with Default**: Test default parameters first
1037 | 2. **Use Medium Level**: Good balance of thoroughness and speed
1038 | 3. **Validate Results**: Always use walk-forward analysis for final validation
1039 | 4. **Avoid Overfitting**: Check for consistent out-of-sample performance
1040 |
1041 | ### Risk Management
1042 | 1. **Position Sizing**: Never risk more than 1-2% per trade
1043 | 2. **Diversification**: Test strategies across multiple symbols
1044 | 3. **Regime Awareness**: Monitor market regime changes
1045 | 4. **Drawdown Limits**: Set maximum acceptable drawdown levels
1046 |
1047 | ### Performance Optimization
1048 | 1. **Parallel Processing**: Use portfolio backtests for batch analysis
1049 | 2. **Caching**: Results are cached for faster repeated analysis
1050 | 3. **Data Efficiency**: Use appropriate date ranges to balance data needs and speed
1051 | 4. **ML Considerations**: Ensure sufficient training data for ML strategies
1052 |
1053 | This comprehensive API documentation provides everything needed to effectively use the MaverickMCP backtesting system. Each tool is designed to work independently or as part of larger workflows, with extensive error handling and performance optimization built-in.
```
--------------------------------------------------------------------------------
/maverick_mcp/agents/optimized_research.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Optimized Deep Research Agent with LLM-side optimizations to prevent timeouts.
3 |
4 | This module integrates the comprehensive LLM optimization strategies including:
5 | - Adaptive model selection based on time constraints
6 | - Progressive token budgeting with confidence tracking
7 | - Parallel LLM processing with intelligent load balancing
8 | - Optimized prompt engineering for speed
9 | - Early termination based on confidence thresholds
10 | - Content filtering to reduce processing overhead
11 | """
12 |
13 | import asyncio
14 | import logging
15 | import time
16 | from datetime import datetime
17 | from typing import Any
18 |
19 | from langchain_core.messages import HumanMessage, SystemMessage
20 | from langgraph.checkpoint.memory import MemorySaver
21 |
22 | from maverick_mcp.agents.deep_research import (
23 | PERSONA_RESEARCH_FOCUS,
24 | RESEARCH_DEPTH_LEVELS,
25 | ContentAnalyzer,
26 | DeepResearchAgent,
27 | )
28 | from maverick_mcp.providers.openrouter_provider import OpenRouterProvider, TaskType
29 | from maverick_mcp.utils.llm_optimization import (
30 | AdaptiveModelSelector,
31 | ConfidenceTracker,
32 | IntelligentContentFilter,
33 | OptimizedPromptEngine,
34 | ParallelLLMProcessor,
35 | ProgressiveTokenBudgeter,
36 | )
37 | from maverick_mcp.utils.orchestration_logging import (
38 | get_orchestration_logger,
39 | log_method_call,
40 | log_performance_metrics,
41 | )
42 |
43 | # Import moved to avoid circular dependency
44 |
45 | logger = logging.getLogger(__name__)
46 |
47 |
48 | class OptimizedContentAnalyzer(ContentAnalyzer):
49 | """Enhanced ContentAnalyzer with LLM optimizations."""
50 |
51 | def __init__(self, openrouter_provider: OpenRouterProvider):
52 | # Initialize with OpenRouter provider instead of single LLM
53 | self.openrouter_provider = openrouter_provider
54 | self.model_selector = AdaptiveModelSelector(openrouter_provider)
55 | self.prompt_engine = OptimizedPromptEngine()
56 | self.parallel_processor = ParallelLLMProcessor(openrouter_provider)
57 |
58 | async def analyze_content_optimized(
59 | self,
60 | content: str,
61 | persona: str,
62 | analysis_focus: str = "general",
63 | time_budget_seconds: float = 30.0,
64 | current_confidence: float = 0.0,
65 | ) -> dict[str, Any]:
66 | """Analyze content with time-optimized LLM selection and prompting."""
67 |
68 | if not content or not content.strip():
69 | return self._create_empty_analysis()
70 |
71 | # Calculate content complexity
72 | complexity_score = self.model_selector.calculate_task_complexity(
73 | content, TaskType.SENTIMENT_ANALYSIS, [analysis_focus]
74 | )
75 |
76 | # Select optimal model for time budget
77 | model_config = self.model_selector.select_model_for_time_budget(
78 | task_type=TaskType.SENTIMENT_ANALYSIS,
79 | time_remaining_seconds=time_budget_seconds,
80 | complexity_score=complexity_score,
81 | content_size_tokens=len(content) // 4, # Rough token estimate
82 | current_confidence=current_confidence,
83 | )
84 |
85 | # Create optimized prompt
86 | optimized_prompt = self.prompt_engine.get_optimized_prompt(
87 | prompt_type="content_analysis",
88 | time_remaining=time_budget_seconds,
89 | confidence_level=current_confidence,
90 | content=content[: model_config.max_tokens * 3], # Limit content size
91 | persona=persona,
92 | focus_areas=analysis_focus,
93 | )
94 |
95 | # Execute with optimized LLM
96 | try:
97 | llm = self.openrouter_provider.get_llm(
98 | model_override=model_config.model_id,
99 | temperature=model_config.temperature,
100 | max_tokens=model_config.max_tokens,
101 | )
102 |
103 | start_time = time.time()
104 | response = await asyncio.wait_for(
105 | llm.ainvoke(
106 | [
107 | SystemMessage(
108 | content="You are a financial content analyst. Return structured JSON analysis."
109 | ),
110 | HumanMessage(content=optimized_prompt),
111 | ]
112 | ),
113 | timeout=model_config.timeout_seconds,
114 | )
115 |
116 | execution_time = time.time() - start_time
117 |
118 | # Parse response
119 | analysis = self._parse_optimized_response(response.content, persona)
120 | analysis["execution_time"] = execution_time
121 | analysis["model_used"] = model_config.model_id
122 | analysis["optimization_applied"] = True
123 |
124 | return analysis
125 |
126 | except TimeoutError:
127 | logger.warning(
128 | f"Content analysis timed out after {model_config.timeout_seconds}s"
129 | )
130 | return self._fallback_analysis(content, persona)
131 | except Exception as e:
132 | logger.warning(f"Optimized content analysis failed: {e}")
133 | return self._fallback_analysis(content, persona)
134 |
135 | async def batch_analyze_content(
136 | self,
137 | sources: list[dict],
138 | persona: str,
139 | analysis_type: str,
140 | time_budget_seconds: float,
141 | current_confidence: float = 0.0,
142 | ) -> list[dict]:
143 | """Analyze multiple sources using parallel processing."""
144 |
145 | return await self.parallel_processor.parallel_content_analysis(
146 | sources=sources,
147 | analysis_type=analysis_type,
148 | persona=persona,
149 | time_budget_seconds=time_budget_seconds,
150 | current_confidence=current_confidence,
151 | )
152 |
153 | def _parse_optimized_response(
154 | self, response_content: str, persona: str
155 | ) -> dict[str, Any]:
156 | """Parse LLM response with fallback handling."""
157 |
158 | try:
159 | # Try to parse as JSON first
160 | import json
161 |
162 | if response_content.strip().startswith("{"):
163 | return json.loads(response_content.strip())
164 | except Exception:
165 | pass
166 |
167 | # Try structured text parsing
168 | try:
169 | return self._parse_structured_response(response_content, persona)
170 | except Exception:
171 | # Final fallback
172 | return self._fallback_analysis(response_content, persona)
173 |
174 | def _parse_structured_response(self, response: str, persona: str) -> dict[str, Any]:
175 | """Parse structured text response."""
176 |
177 | import re
178 |
179 | # Extract sentiment
180 | sentiment_match = re.search(
181 | r"sentiment:?\s*(\w+)[,\s]*(?:confidence:?\s*([\d.]+))?", response.lower()
182 | )
183 | if sentiment_match:
184 | direction = sentiment_match.group(1).lower()
185 | confidence = float(sentiment_match.group(2) or 0.6)
186 |
187 | # Normalize sentiment terms
188 | if direction in ["bull", "bullish", "positive", "buy"]:
189 | direction = "bullish"
190 | elif direction in ["bear", "bearish", "negative", "sell"]:
191 | direction = "bearish"
192 | else:
193 | direction = "neutral"
194 | else:
195 | direction = "neutral"
196 | confidence = 0.5
197 |
198 | # Extract insights
199 | insights = []
200 | insight_patterns = [
201 | r"insight:?\s*([^\n.]+)",
202 | r"key point:?\s*([^\n.]+)",
203 | r"finding:?\s*([^\n.]+)",
204 | ]
205 | for pattern in insight_patterns:
206 | matches = re.findall(pattern, response, re.IGNORECASE)
207 | insights.extend([m.strip() for m in matches if m.strip()])
208 |
209 | # Extract risks and opportunities
210 | risks = re.findall(r"risk:?\s*([^\n.]+)", response, re.IGNORECASE)
211 | opportunities = re.findall(
212 | r"opportunit(?:y|ies):?\s*([^\n.]+)", response, re.IGNORECASE
213 | )
214 |
215 | # Extract scores
216 | relevance_match = re.search(r"relevance:?\s*([\d.]+)", response.lower())
217 | relevance_score = float(relevance_match.group(1)) if relevance_match else 0.6
218 |
219 | credibility_match = re.search(r"credibility:?\s*([\d.]+)", response.lower())
220 | credibility_score = (
221 | float(credibility_match.group(1)) if credibility_match else 0.7
222 | )
223 |
224 | return {
225 | "insights": insights[:5],
226 | "sentiment": {"direction": direction, "confidence": confidence},
227 | "risk_factors": [r.strip() for r in risks[:3]],
228 | "opportunities": [o.strip() for o in opportunities[:3]],
229 | "credibility_score": credibility_score,
230 | "relevance_score": relevance_score,
231 | "summary": f"Analysis for {persona} investor using optimized processing",
232 | "analysis_timestamp": datetime.now(),
233 | "structured_parsing": True,
234 | }
235 |
236 | def _create_empty_analysis(self) -> dict[str, Any]:
237 | """Create empty analysis for invalid content."""
238 | return {
239 | "insights": [],
240 | "sentiment": {"direction": "neutral", "confidence": 0.0},
241 | "risk_factors": [],
242 | "opportunities": [],
243 | "credibility_score": 0.0,
244 | "relevance_score": 0.0,
245 | "summary": "No content to analyze",
246 | "analysis_timestamp": datetime.now(),
247 | "empty_content": True,
248 | }
249 |
250 |
251 | class OptimizedDeepResearchAgent(DeepResearchAgent):
252 | """
253 | Deep research agent with comprehensive LLM-side optimizations to prevent timeouts.
254 |
255 | Integrates all optimization strategies:
256 | - Adaptive model selection
257 | - Progressive token budgeting
258 | - Parallel LLM processing
259 | - Optimized prompting
260 | - Early termination
261 | - Content filtering
262 | """
263 |
264 | def __init__(
265 | self,
266 | openrouter_provider: OpenRouterProvider,
267 | persona: str = "moderate",
268 | checkpointer: MemorySaver | None = None,
269 | ttl_hours: int = 24,
270 | exa_api_key: str | None = None,
271 | default_depth: str = "standard",
272 | max_sources: int | None = None,
273 | research_depth: str | None = None,
274 | enable_parallel_execution: bool = True,
275 | parallel_config=None, # Type: ParallelResearchConfig | None
276 | optimization_enabled: bool = True,
277 | ):
278 | """Initialize optimized deep research agent."""
279 |
280 | # Import here to avoid circular dependency
281 |
282 | self.openrouter_provider = openrouter_provider
283 | self.optimization_enabled = optimization_enabled
284 |
285 | # Initialize optimization components
286 | if optimization_enabled:
287 | self.model_selector = AdaptiveModelSelector(openrouter_provider)
288 | self.token_budgeter = None # Will be created per request
289 | self.prompt_engine = OptimizedPromptEngine()
290 | self.confidence_tracker = None # Will be created per request
291 | self.content_filter = IntelligentContentFilter()
292 | self.parallel_processor = ParallelLLMProcessor(openrouter_provider)
293 |
294 | # Replace content analyzer with optimized version
295 | self.optimized_analyzer = OptimizedContentAnalyzer(openrouter_provider)
296 |
297 | # Initialize base class with dummy LLM (we'll use OpenRouter provider instead)
298 | dummy_llm = openrouter_provider.get_llm(TaskType.GENERAL)
299 |
300 | super().__init__(
301 | llm=dummy_llm,
302 | persona=persona,
303 | checkpointer=checkpointer,
304 | ttl_hours=ttl_hours,
305 | exa_api_key=exa_api_key,
306 | default_depth=default_depth,
307 | max_sources=max_sources,
308 | research_depth=research_depth,
309 | enable_parallel_execution=enable_parallel_execution,
310 | parallel_config=parallel_config,
311 | )
312 |
313 | logger.info("OptimizedDeepResearchAgent initialized")
314 |
315 | @log_method_call(component="OptimizedDeepResearchAgent", include_timing=True)
316 | async def research_comprehensive(
317 | self,
318 | topic: str,
319 | session_id: str,
320 | depth: str | None = None,
321 | focus_areas: list[str] | None = None,
322 | timeframe: str = "30d",
323 | time_budget_seconds: float = 120.0, # Default 2 minutes
324 | target_confidence: float = 0.75,
325 | **kwargs,
326 | ) -> dict[str, Any]:
327 | """
328 | Comprehensive research with LLM optimizations to prevent timeouts.
329 |
330 | Args:
331 | topic: Research topic or company/symbol
332 | session_id: Session identifier
333 | depth: Research depth (basic/standard/comprehensive/exhaustive)
334 | focus_areas: Specific areas to focus on
335 | timeframe: Time range for research
336 | time_budget_seconds: Maximum time allowed for research
337 | target_confidence: Target confidence level for early termination
338 | **kwargs: Additional parameters
339 |
340 | Returns:
341 | Comprehensive research results with optimization metrics
342 | """
343 |
344 | if not self.optimization_enabled:
345 | # Fall back to parent implementation
346 | return await super().research_comprehensive(
347 | topic, session_id, depth, focus_areas, timeframe, **kwargs
348 | )
349 |
350 | # Check if search providers are available
351 | if not self.search_providers:
352 | return {
353 | "error": "Research functionality unavailable - no search providers configured",
354 | "details": "Please configure EXA_API_KEY or TAVILY_API_KEY environment variables",
355 | "topic": topic,
356 | "optimization_enabled": self.optimization_enabled,
357 | }
358 |
359 | start_time = time.time()
360 | depth = depth or self.default_depth
361 |
362 | # Initialize optimization components for this request
363 | self.token_budgeter = ProgressiveTokenBudgeter(
364 | total_time_budget_seconds=time_budget_seconds,
365 | confidence_target=target_confidence,
366 | )
367 | self.confidence_tracker = ConfidenceTracker(
368 | target_confidence=target_confidence,
369 | min_sources=3,
370 | max_sources=RESEARCH_DEPTH_LEVELS[depth]["max_sources"],
371 | )
372 |
373 | orchestration_logger = get_orchestration_logger("OptimizedDeepResearchAgent")
374 | orchestration_logger.set_request_context(
375 | session_id=session_id,
376 | topic=topic[:50],
377 | time_budget=time_budget_seconds,
378 | target_confidence=target_confidence,
379 | )
380 |
381 | orchestration_logger.info(
382 | "🚀 OPTIMIZED_RESEARCH_START",
383 | depth=depth,
384 | focus_areas=focus_areas,
385 | )
386 |
387 | try:
388 | # Phase 1: Search and Content Filtering
389 | orchestration_logger.info("📋 PHASE_1_SEARCH_START")
390 | search_time_budget = min(
391 | time_budget_seconds * 0.2, 30
392 | ) # 20% of budget, max 30s
393 |
394 | search_results = await self._optimized_search_phase(
395 | topic, depth, focus_areas, search_time_budget
396 | )
397 |
398 | orchestration_logger.info(
399 | "✅ PHASE_1_COMPLETE",
400 | sources_found=len(search_results.get("filtered_sources", [])),
401 | )
402 |
403 | # Phase 2: Content Analysis with Parallel Processing
404 | remaining_time = time_budget_seconds - (time.time() - start_time)
405 | if remaining_time < 10:
406 | orchestration_logger.warning(
407 | "⚠️ TIME_CONSTRAINT_CRITICAL", remaining=f"{remaining_time:.1f}s"
408 | )
409 | return self._create_emergency_response(
410 | topic, search_results, start_time
411 | )
412 |
413 | orchestration_logger.info("🔬 PHASE_2_ANALYSIS_START")
414 | analysis_time_budget = remaining_time * 0.7 # 70% of remaining time
415 |
416 | analysis_results = await self._optimized_analysis_phase(
417 | search_results["filtered_sources"],
418 | topic,
419 | focus_areas,
420 | analysis_time_budget,
421 | )
422 |
423 | orchestration_logger.info(
424 | "✅ PHASE_2_COMPLETE",
425 | sources_analyzed=len(analysis_results["analyzed_sources"]),
426 | confidence=f"{analysis_results['final_confidence']:.2f}",
427 | )
428 |
429 | # Phase 3: Synthesis with Remaining Time
430 | remaining_time = time_budget_seconds - (time.time() - start_time)
431 | if remaining_time < 5:
432 | # Skip synthesis if very little time left
433 | synthesis_results = {
434 | "synthesis": "Time constraints prevented full synthesis"
435 | }
436 | else:
437 | orchestration_logger.info("🧠 PHASE_3_SYNTHESIS_START")
438 | synthesis_results = await self._optimized_synthesis_phase(
439 | analysis_results["analyzed_sources"], topic, remaining_time
440 | )
441 | orchestration_logger.info("✅ PHASE_3_COMPLETE")
442 |
443 | # Compile final results
444 | execution_time = time.time() - start_time
445 | final_results = self._compile_optimized_results(
446 | topic=topic,
447 | session_id=session_id,
448 | depth=depth,
449 | search_results=search_results,
450 | analysis_results=analysis_results,
451 | synthesis_results=synthesis_results,
452 | execution_time=execution_time,
453 | time_budget=time_budget_seconds,
454 | )
455 |
456 | # Log performance metrics
457 | log_performance_metrics(
458 | "OptimizedDeepResearchAgent",
459 | {
460 | "total_execution_time": execution_time,
461 | "time_budget_used_pct": (execution_time / time_budget_seconds)
462 | * 100,
463 | "sources_processed": len(analysis_results["analyzed_sources"]),
464 | "final_confidence": analysis_results["final_confidence"],
465 | "optimization_enabled": True,
466 | "phases_completed": 3,
467 | },
468 | )
469 |
470 | orchestration_logger.info(
471 | "🎉 OPTIMIZED_RESEARCH_COMPLETE",
472 | duration=f"{execution_time:.2f}s",
473 | confidence=f"{analysis_results['final_confidence']:.2f}",
474 | )
475 |
476 | return final_results
477 |
478 | except Exception as e:
479 | execution_time = time.time() - start_time
480 | orchestration_logger.error(
481 | "💥 OPTIMIZED_RESEARCH_FAILED",
482 | error=str(e),
483 | execution_time=f"{execution_time:.2f}s",
484 | )
485 |
486 | return {
487 | "status": "error",
488 | "error": str(e),
489 | "execution_time_ms": execution_time * 1000,
490 | "agent_type": "optimized_deep_research",
491 | "optimization_enabled": True,
492 | "topic": topic,
493 | }
494 |
495 | async def _optimized_search_phase(
496 | self, topic: str, depth: str, focus_areas: list[str], time_budget_seconds: float
497 | ) -> dict[str, Any]:
498 | """Execute search phase with content filtering."""
499 |
500 | # Generate search queries (reuse parent logic)
501 | persona_focus = PERSONA_RESEARCH_FOCUS[self.persona.name.lower()]
502 | search_queries = await self._generate_search_queries(
503 | topic, persona_focus, RESEARCH_DEPTH_LEVELS[depth]
504 | )
505 |
506 | # Execute searches (reuse parent logic but with time limits)
507 | all_results = []
508 | max_searches = min(len(search_queries), 4) # Limit searches for speed
509 |
510 | search_tasks = []
511 | for query in search_queries[:max_searches]:
512 | for provider in self.search_providers[
513 | :1
514 | ]: # Use only first provider for speed
515 | task = self._search_with_timeout(
516 | provider, query, time_budget_seconds / max_searches
517 | )
518 | search_tasks.append(task)
519 |
520 | search_results = await asyncio.gather(*search_tasks, return_exceptions=True)
521 |
522 | # Collect valid results
523 | for result in search_results:
524 | if isinstance(result, list):
525 | all_results.extend(result)
526 |
527 | # Apply intelligent content filtering
528 | current_confidence = 0.0 # Starting confidence
529 | research_focus = focus_areas[0] if focus_areas else "fundamental"
530 |
531 | filtered_sources = await self.content_filter.filter_and_prioritize_sources(
532 | sources=all_results,
533 | research_focus=research_focus,
534 | time_budget=time_budget_seconds,
535 | current_confidence=current_confidence,
536 | )
537 |
538 | return {
539 | "raw_results": all_results,
540 | "filtered_sources": filtered_sources,
541 | "search_queries": search_queries[:max_searches],
542 | "filtering_applied": True,
543 | }
544 |
545 | async def _search_with_timeout(
546 | self, provider, query: str, timeout: float
547 | ) -> list[dict]:
548 | """Execute search with timeout."""
549 | try:
550 | return await asyncio.wait_for(
551 | provider.search(query, num_results=5), timeout=timeout
552 | )
553 | except TimeoutError:
554 | logger.warning(f"Search timeout for query: {query}")
555 | return []
556 | except Exception as e:
557 | logger.warning(f"Search failed for {query}: {e}")
558 | return []
559 |
560 | async def _optimized_analysis_phase(
561 | self,
562 | sources: list[dict],
563 | topic: str,
564 | focus_areas: list[str],
565 | time_budget_seconds: float,
566 | ) -> dict[str, Any]:
567 | """Execute content analysis with optimizations and early termination."""
568 |
569 | if not sources:
570 | return {
571 | "analyzed_sources": [],
572 | "final_confidence": 0.0,
573 | "early_terminated": False,
574 | "termination_reason": "no_sources",
575 | }
576 |
577 | analyzed_sources = []
578 | current_confidence = 0.0
579 | sources_to_process = sources.copy()
580 |
581 | # Calculate time per source
582 | time_per_source = time_budget_seconds / len(sources_to_process)
583 |
584 | # Use batch processing if time allows
585 | if len(sources_to_process) > 3 and time_per_source < 8:
586 | # Use parallel batch processing
587 |
588 | analyzed_sources = await self.optimized_analyzer.batch_analyze_content(
589 | sources=sources_to_process,
590 | persona=self.persona.name.lower(),
591 | analysis_type=focus_areas[0] if focus_areas else "general",
592 | time_budget_seconds=time_budget_seconds,
593 | current_confidence=current_confidence,
594 | )
595 |
596 | # Calculate final confidence from batch results
597 | confidence_sum = 0
598 | for source in analyzed_sources:
599 | analysis = source.get("analysis", {})
600 | sentiment = analysis.get("sentiment", {})
601 | source_confidence = sentiment.get("confidence", 0.5)
602 | credibility = analysis.get("credibility_score", 0.5)
603 | confidence_sum += source_confidence * credibility
604 |
605 | final_confidence = (
606 | confidence_sum / len(analyzed_sources) if analyzed_sources else 0.0
607 | )
608 |
609 | return {
610 | "analyzed_sources": analyzed_sources,
611 | "final_confidence": final_confidence,
612 | "early_terminated": False,
613 | "termination_reason": "batch_processing_complete",
614 | "processing_mode": "parallel_batch",
615 | }
616 |
617 | else:
618 | # Use sequential processing with early termination
619 |
620 | for _, source in enumerate(sources_to_process):
621 | remaining_time = time_budget_seconds - (
622 | len(analyzed_sources) * time_per_source
623 | )
624 |
625 | if remaining_time < 5: # Reserve minimum time
626 | break
627 |
628 | # Analyze source with optimizations
629 | analysis_result = (
630 | await self.optimized_analyzer.analyze_content_optimized(
631 | content=source.get("content", ""),
632 | persona=self.persona.name.lower(),
633 | analysis_focus=focus_areas[0] if focus_areas else "general",
634 | time_budget_seconds=min(
635 | remaining_time / 2, 15
636 | ), # Max 15s per source
637 | current_confidence=current_confidence,
638 | )
639 | )
640 |
641 | # Add analysis to source
642 | source["analysis"] = analysis_result
643 | analyzed_sources.append(source)
644 |
645 | # Update confidence tracker
646 | credibility_score = analysis_result.get("credibility_score", 0.5)
647 | confidence_update = self.confidence_tracker.update_confidence(
648 | analysis_result, credibility_score
649 | )
650 |
651 | current_confidence = confidence_update["current_confidence"]
652 |
653 | # Check for early termination
654 | if not confidence_update["should_continue"]:
655 | logger.info(
656 | f"Early termination after {len(analyzed_sources)} sources: {confidence_update['early_termination_reason']}"
657 | )
658 | return {
659 | "analyzed_sources": analyzed_sources,
660 | "final_confidence": current_confidence,
661 | "early_terminated": True,
662 | "termination_reason": confidence_update[
663 | "early_termination_reason"
664 | ],
665 | "processing_mode": "sequential_early_termination",
666 | }
667 |
668 | return {
669 | "analyzed_sources": analyzed_sources,
670 | "final_confidence": current_confidence,
671 | "early_terminated": False,
672 | "termination_reason": "all_sources_processed",
673 | "processing_mode": "sequential_complete",
674 | }
675 |
676 | async def _optimized_synthesis_phase(
677 | self, analyzed_sources: list[dict], topic: str, time_budget_seconds: float
678 | ) -> dict[str, Any]:
679 | """Execute synthesis with optimized model selection."""
680 |
681 | if not analyzed_sources:
682 | return {"synthesis": "No sources available for synthesis"}
683 |
684 | # Select optimal model for synthesis
685 | combined_content = "\n".join(
686 | [str(source.get("analysis", {})) for source in analyzed_sources[:5]]
687 | )
688 |
689 | complexity_score = self.model_selector.calculate_task_complexity(
690 | combined_content, TaskType.RESULT_SYNTHESIS
691 | )
692 |
693 | model_config = self.model_selector.select_model_for_time_budget(
694 | task_type=TaskType.RESULT_SYNTHESIS,
695 | time_remaining_seconds=time_budget_seconds,
696 | complexity_score=complexity_score,
697 | content_size_tokens=len(combined_content) // 4,
698 | )
699 |
700 | # Create optimized synthesis prompt
701 | synthesis_prompt = self.prompt_engine.create_time_optimized_synthesis_prompt(
702 | sources=analyzed_sources,
703 | persona=self.persona.name,
704 | time_remaining=time_budget_seconds,
705 | current_confidence=0.8, # Assume good confidence at synthesis stage
706 | )
707 |
708 | # Execute synthesis
709 | try:
710 | llm = self.openrouter_provider.get_llm(
711 | model_override=model_config.model_id,
712 | temperature=model_config.temperature,
713 | max_tokens=model_config.max_tokens,
714 | )
715 |
716 | response = await asyncio.wait_for(
717 | llm.ainvoke(
718 | [
719 | SystemMessage(
720 | content="You are a financial research synthesizer."
721 | ),
722 | HumanMessage(content=synthesis_prompt),
723 | ]
724 | ),
725 | timeout=model_config.timeout_seconds,
726 | )
727 |
728 | return {
729 | "synthesis": response.content,
730 | "model_used": model_config.model_id,
731 | "synthesis_optimized": True,
732 | }
733 |
734 | except Exception as e:
735 | logger.warning(f"Optimized synthesis failed: {e}")
736 | return {
737 | "synthesis": f"Synthesis of {len(analyzed_sources)} sources completed with basic processing due to constraints.",
738 | "fallback_used": True,
739 | }
740 |
741 | def _create_emergency_response(
742 | self, topic: str, search_results: dict, start_time: float
743 | ) -> dict[str, Any]:
744 | """Create emergency response when time is critically low."""
745 |
746 | execution_time = time.time() - start_time
747 | source_count = len(search_results.get("filtered_sources", []))
748 |
749 | return {
750 | "status": "partial_success",
751 | "agent_type": "optimized_deep_research",
752 | "emergency_mode": True,
753 | "topic": topic,
754 | "sources_found": source_count,
755 | "execution_time_ms": execution_time * 1000,
756 | "findings": {
757 | "synthesis": f"Emergency mode: Found {source_count} relevant sources for {topic}. "
758 | "Full analysis was prevented by time constraints.",
759 | "confidence_score": 0.3,
760 | "sources_analyzed": source_count,
761 | },
762 | "optimization_metrics": {
763 | "time_budget_exceeded": True,
764 | "phases_completed": 1,
765 | "emergency_fallback": True,
766 | },
767 | }
768 |
769 | def _compile_optimized_results(
770 | self,
771 | topic: str,
772 | session_id: str,
773 | depth: str,
774 | search_results: dict,
775 | analysis_results: dict,
776 | synthesis_results: dict,
777 | execution_time: float,
778 | time_budget: float,
779 | ) -> dict[str, Any]:
780 | """Compile final optimized research results."""
781 |
782 | analyzed_sources = analysis_results["analyzed_sources"]
783 |
784 | # Create citations
785 | citations = []
786 | for i, source in enumerate(analyzed_sources, 1):
787 | analysis = source.get("analysis", {})
788 | citation = {
789 | "id": i,
790 | "title": source.get("title", f"Source {i}"),
791 | "url": source.get("url", ""),
792 | "published_date": source.get("published_date"),
793 | "credibility_score": analysis.get("credibility_score", 0.5),
794 | "relevance_score": analysis.get("relevance_score", 0.5),
795 | "optimized_analysis": analysis.get("optimization_applied", False),
796 | }
797 | citations.append(citation)
798 |
799 | return {
800 | "status": "success",
801 | "agent_type": "optimized_deep_research",
802 | "optimization_enabled": True,
803 | "persona": self.persona.name,
804 | "research_topic": topic,
805 | "research_depth": depth,
806 | "findings": {
807 | "synthesis": synthesis_results.get(
808 | "synthesis", "No synthesis available"
809 | ),
810 | "confidence_score": analysis_results["final_confidence"],
811 | "early_terminated": analysis_results.get("early_terminated", False),
812 | "termination_reason": analysis_results.get("termination_reason"),
813 | "processing_mode": analysis_results.get("processing_mode", "unknown"),
814 | },
815 | "sources_analyzed": len(analyzed_sources),
816 | "citations": citations,
817 | "execution_time_ms": execution_time * 1000,
818 | "optimization_metrics": {
819 | "time_budget_seconds": time_budget,
820 | "time_used_seconds": execution_time,
821 | "time_utilization_pct": (execution_time / time_budget) * 100,
822 | "sources_found": len(search_results.get("raw_results", [])),
823 | "sources_filtered": len(search_results.get("filtered_sources", [])),
824 | "sources_processed": len(analyzed_sources),
825 | "content_filtering_applied": search_results.get(
826 | "filtering_applied", False
827 | ),
828 | "parallel_processing_used": "batch"
829 | in analysis_results.get("processing_mode", ""),
830 | "synthesis_optimized": synthesis_results.get(
831 | "synthesis_optimized", False
832 | ),
833 | "optimization_features_used": [
834 | "adaptive_model_selection",
835 | "progressive_token_budgeting",
836 | "content_filtering",
837 | "optimized_prompts",
838 | ]
839 | + (
840 | ["parallel_processing"]
841 | if "batch" in analysis_results.get("processing_mode", "")
842 | else []
843 | )
844 | + (
845 | ["early_termination"]
846 | if analysis_results.get("early_terminated")
847 | else []
848 | ),
849 | },
850 | "search_queries_used": search_results.get("search_queries", []),
851 | "session_id": session_id,
852 | }
853 |
854 |
855 | # Factory function for easy integration
856 | def create_optimized_research_agent(
857 | openrouter_api_key: str,
858 | persona: str = "moderate",
859 | time_budget_seconds: float = 120.0,
860 | target_confidence: float = 0.75,
861 | **kwargs,
862 | ) -> OptimizedDeepResearchAgent:
863 | """Create an optimized deep research agent with recommended settings."""
864 |
865 | openrouter_provider = OpenRouterProvider(openrouter_api_key)
866 |
867 | return OptimizedDeepResearchAgent(
868 | openrouter_provider=openrouter_provider,
869 | persona=persona,
870 | optimization_enabled=True,
871 | **kwargs,
872 | )
873 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/circuit_breaker.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive circuit breaker implementation for all external API calls.
3 |
4 | This module provides circuit breakers for:
5 | - yfinance (Yahoo Finance)
6 | - Tiingo API
7 | - FRED API
8 | - OpenRouter AI API
9 | - Exa Search API
10 | - Any other external services
11 |
12 | Circuit breakers help prevent cascade failures and provide graceful degradation.
13 | """
14 |
15 | import asyncio
16 | import functools
17 | import logging
18 | import threading
19 | import time
20 | from collections import deque
21 | from collections.abc import Callable
22 | from enum import Enum
23 | from typing import Any, ParamSpec, TypeVar, cast
24 |
25 | from maverick_mcp.config.settings import get_settings
26 | from maverick_mcp.exceptions import CircuitBreakerError, ExternalServiceError
27 |
28 | logger = logging.getLogger(__name__)
29 | settings = get_settings()
30 |
31 | P = ParamSpec("P")
32 | T = TypeVar("T")
33 | F = TypeVar("F", bound=Callable[..., Any])
34 |
35 |
36 | class CircuitState(Enum):
37 | """Circuit breaker states."""
38 |
39 | CLOSED = "closed" # Normal operation
40 | OPEN = "open" # Failing, reject calls
41 | HALF_OPEN = "half_open" # Testing if service recovered
42 |
43 |
44 | class FailureDetectionStrategy(Enum):
45 | """Strategies for detecting failures."""
46 |
47 | CONSECUTIVE_FAILURES = "consecutive" # N failures in a row
48 | FAILURE_RATE = "failure_rate" # % of failures in time window
49 | TIMEOUT_RATE = "timeout_rate" # % of timeouts in time window
50 | COMBINED = "combined" # Any of the above
51 |
52 |
53 | class CircuitBreakerConfig:
54 | """Configuration for a circuit breaker."""
55 |
56 | def __init__(
57 | self,
58 | name: str,
59 | failure_threshold: int = 5,
60 | failure_rate_threshold: float = 0.5,
61 | timeout_threshold: float = 10.0,
62 | recovery_timeout: int = 60,
63 | success_threshold: int = 3,
64 | window_size: int = 60,
65 | detection_strategy: FailureDetectionStrategy = FailureDetectionStrategy.COMBINED,
66 | expected_exceptions: tuple[type[Exception], ...] = (Exception,),
67 | ):
68 | """
69 | Initialize circuit breaker configuration.
70 |
71 | Args:
72 | name: Name of the circuit breaker
73 | failure_threshold: Number of consecutive failures before opening
74 | failure_rate_threshold: Failure rate (0-1) before opening
75 | timeout_threshold: Timeout in seconds for calls
76 | recovery_timeout: Seconds to wait before testing recovery
77 | success_threshold: Successes needed in half-open to close
78 | window_size: Time window in seconds for rate calculations
79 | detection_strategy: Strategy for detecting failures
80 | expected_exceptions: Exceptions to catch and count as failures
81 | """
82 | self.name = name
83 | self.failure_threshold = failure_threshold
84 | self.failure_rate_threshold = failure_rate_threshold
85 | self.timeout_threshold = timeout_threshold
86 | self.recovery_timeout = recovery_timeout
87 | self.success_threshold = success_threshold
88 | self.window_size = window_size
89 | self.detection_strategy = detection_strategy
90 | self.expected_exceptions = expected_exceptions
91 |
92 |
93 | class CircuitBreakerMetrics:
94 | """Metrics collection for circuit breakers."""
95 |
96 | def __init__(self, window_size: int = 300):
97 | """Initialize metrics with a time window."""
98 | self.window_size = window_size
99 | self.calls: deque[tuple[float, bool, float]] = (
100 | deque()
101 | ) # (timestamp, success, duration)
102 | self.state_changes: deque[tuple[float, CircuitState]] = deque()
103 | self._lock = threading.RLock()
104 |
105 | def record_call(self, success: bool, duration: float):
106 | """Record a call result."""
107 | with self._lock:
108 | now = time.time()
109 | self.calls.append((now, success, duration))
110 | self._cleanup_old_data(now)
111 |
112 | def record_state_change(self, new_state: CircuitState):
113 | """Record a state change."""
114 | with self._lock:
115 | now = time.time()
116 | self.state_changes.append((now, new_state))
117 | self._cleanup_old_data(now)
118 |
119 | def get_stats(self) -> dict[str, Any]:
120 | """Get current statistics."""
121 | with self._lock:
122 | now = time.time()
123 | self._cleanup_old_data(now)
124 |
125 | if not self.calls:
126 | return {
127 | "total_calls": 0,
128 | "success_rate": 1.0,
129 | "failure_rate": 0.0,
130 | "avg_duration": 0.0,
131 | "timeout_rate": 0.0,
132 | }
133 |
134 | total = len(self.calls)
135 | successes = sum(1 for _, success, _ in self.calls if success)
136 | failures = total - successes
137 | durations = [duration for _, _, duration in self.calls]
138 | timeouts = sum(
139 | 1
140 | for _, success, duration in self.calls
141 | if not success and duration >= 10.0
142 | )
143 |
144 | return {
145 | "total_calls": total,
146 | "success_rate": successes / total if total > 0 else 1.0,
147 | "failure_rate": failures / total if total > 0 else 0.0,
148 | "avg_duration": sum(durations) / len(durations) if durations else 0.0,
149 | "timeout_rate": timeouts / total if total > 0 else 0.0,
150 | "min_duration": min(durations) if durations else 0.0,
151 | "max_duration": max(durations) if durations else 0.0,
152 | }
153 |
154 | def get_total_calls(self) -> int:
155 | """Get total number of calls in the window."""
156 | with self._lock:
157 | now = time.time()
158 | self._cleanup_old_data(now)
159 | return len(self.calls)
160 |
161 | def get_success_rate(self) -> float:
162 | """Get success rate in the window."""
163 | stats = self.get_stats()
164 | return stats["success_rate"]
165 |
166 | def get_failure_rate(self) -> float:
167 | """Get failure rate in the window."""
168 | stats = self.get_stats()
169 | return stats["failure_rate"]
170 |
171 | def get_average_response_time(self) -> float:
172 | """Get average response time in the window."""
173 | stats = self.get_stats()
174 | return stats["avg_duration"]
175 |
176 | def get_last_failure_time(self) -> float | None:
177 | """Get timestamp of last failure."""
178 | with self._lock:
179 | for timestamp, success, _ in reversed(self.calls):
180 | if not success:
181 | return timestamp
182 | return None
183 |
184 | def get_uptime_percentage(self) -> float:
185 | """Get uptime percentage based on state changes."""
186 | with self._lock:
187 | if not self.state_changes:
188 | return 100.0
189 |
190 | now = time.time()
191 | window_start = now - self.window_size
192 | uptime = 0.0
193 | last_time = window_start
194 | last_state = CircuitState.CLOSED
195 |
196 | for timestamp, state in self.state_changes:
197 | if timestamp < window_start:
198 | last_state = state
199 | continue
200 |
201 | if last_state == CircuitState.CLOSED:
202 | uptime += timestamp - last_time
203 |
204 | last_time = timestamp
205 | last_state = state
206 |
207 | if last_state == CircuitState.CLOSED:
208 | uptime += now - last_time
209 |
210 | total_time = now - window_start
211 | return (uptime / total_time * 100) if total_time > 0 else 100.0
212 |
213 | def _cleanup_old_data(self, now: float):
214 | """Remove data outside the window."""
215 | cutoff = now - self.window_size
216 |
217 | # Clean up calls
218 | while self.calls and self.calls[0][0] < cutoff:
219 | self.calls.popleft()
220 |
221 | # Clean up state changes (keep longer history)
222 | state_cutoff = now - (self.window_size * 10)
223 | while self.state_changes and self.state_changes[0][0] < state_cutoff:
224 | self.state_changes.popleft()
225 |
226 |
227 | class EnhancedCircuitBreaker:
228 | """
229 | Enhanced circuit breaker with failure rate tracking, timeouts, and metrics.
230 | Thread-safe and supports both sync and async operations.
231 | """
232 |
233 | def __init__(self, config: CircuitBreakerConfig):
234 | """Initialize enhanced circuit breaker."""
235 | self.config = config
236 | self._state = CircuitState.CLOSED
237 | self._consecutive_failures = 0
238 | self._half_open_successes = 0
239 | self._last_failure_time: float | None = None
240 | self._metrics = CircuitBreakerMetrics(config.window_size)
241 |
242 | # Thread-safe locks
243 | self._lock = threading.RLock()
244 | self._async_lock = asyncio.Lock()
245 |
246 | @property
247 | def state(self) -> CircuitState:
248 | """Get current circuit state."""
249 | with self._lock:
250 | return self._state
251 |
252 | @property
253 | def consecutive_failures(self) -> int:
254 | """Get consecutive failures count."""
255 | with self._lock:
256 | return self._consecutive_failures
257 |
258 | @property
259 | def is_open(self) -> bool:
260 | """Check if circuit is open."""
261 | return self.state == CircuitState.OPEN
262 |
263 | @property
264 | def is_closed(self) -> bool:
265 | """Check if circuit is closed."""
266 | return self.state == CircuitState.CLOSED
267 |
268 | def get_metrics(self) -> CircuitBreakerMetrics:
269 | """Get circuit breaker metrics."""
270 | return self._metrics
271 |
272 | def time_until_retry(self) -> float | None:
273 | """Get time until next retry attempt."""
274 | with self._lock:
275 | if self._state == CircuitState.OPEN and self._last_failure_time:
276 | return max(
277 | 0,
278 | self.config.recovery_timeout
279 | - (time.time() - self._last_failure_time),
280 | )
281 | return None
282 |
283 | def _should_open(self) -> bool:
284 | """Determine if circuit should open based on detection strategy."""
285 | stats = self._metrics.get_stats()
286 |
287 | if (
288 | self.config.detection_strategy
289 | == FailureDetectionStrategy.CONSECUTIVE_FAILURES
290 | ):
291 | return self._consecutive_failures >= self.config.failure_threshold
292 |
293 | elif self.config.detection_strategy == FailureDetectionStrategy.FAILURE_RATE:
294 | return (
295 | stats["total_calls"] >= 5 # Minimum calls for rate calculation
296 | and stats["failure_rate"] >= self.config.failure_rate_threshold
297 | )
298 |
299 | elif self.config.detection_strategy == FailureDetectionStrategy.TIMEOUT_RATE:
300 | return (
301 | stats["total_calls"] >= 5
302 | and stats["timeout_rate"] >= self.config.failure_rate_threshold
303 | )
304 |
305 | else: # COMBINED
306 | return (
307 | self._consecutive_failures >= self.config.failure_threshold
308 | or (
309 | stats["total_calls"] >= 5
310 | and stats["failure_rate"] >= self.config.failure_rate_threshold
311 | )
312 | or (
313 | stats["total_calls"] >= 5
314 | and stats["timeout_rate"] >= self.config.failure_rate_threshold
315 | )
316 | )
317 |
318 | def _should_attempt_reset(self) -> bool:
319 | """Check if enough time has passed to attempt reset."""
320 | if self._last_failure_time is None:
321 | return True
322 | return (time.time() - self._last_failure_time) >= self.config.recovery_timeout
323 |
324 | def _transition_state(self, new_state: CircuitState):
325 | """Transition to a new state."""
326 | if self._state != new_state:
327 | logger.info(
328 | f"Circuit breaker '{self.config.name}' transitioning from {self._state.value} to {new_state.value}"
329 | )
330 | self._state = new_state
331 | self._metrics.record_state_change(new_state)
332 |
333 | def _on_success(self, duration: float):
334 | """Handle successful call."""
335 | with self._lock:
336 | self._metrics.record_call(True, duration)
337 | self._consecutive_failures = 0
338 |
339 | if self._state == CircuitState.HALF_OPEN:
340 | self._half_open_successes += 1
341 | if self._half_open_successes >= self.config.success_threshold:
342 | self._transition_state(CircuitState.CLOSED)
343 | self._half_open_successes = 0
344 |
345 | def _on_failure(self, duration: float):
346 | """Handle failed call."""
347 | with self._lock:
348 | self._metrics.record_call(False, duration)
349 | self._consecutive_failures += 1
350 | self._last_failure_time = time.time()
351 |
352 | if self._state == CircuitState.HALF_OPEN:
353 | self._transition_state(CircuitState.OPEN)
354 | self._half_open_successes = 0
355 | elif self._state == CircuitState.CLOSED and self._should_open():
356 | self._transition_state(CircuitState.OPEN)
357 |
358 | def call(self, func: Callable[..., Any], *args, **kwargs) -> Any:
359 | """Call function through circuit breaker (sync version)."""
360 | return self.call_sync(func, *args, **kwargs)
361 |
362 | async def call_async(self, func: Callable[..., Any], *args, **kwargs) -> Any:
363 | """
364 | Call async function through circuit breaker with timeout support.
365 |
366 | Args:
367 | func: Async function to call
368 | *args: Function arguments
369 | **kwargs: Function keyword arguments
370 |
371 | Returns:
372 | Function result
373 |
374 | Raises:
375 | CircuitBreakerError: If circuit is open
376 | Exception: If function fails
377 | """
378 | # Check if we should attempt reset
379 | async with self._async_lock:
380 | if self._state == CircuitState.OPEN:
381 | if self._should_attempt_reset():
382 | self._transition_state(CircuitState.HALF_OPEN)
383 | self._half_open_successes = 0
384 | else:
385 | time_until_retry = self.config.recovery_timeout
386 | if self._last_failure_time:
387 | time_until_retry = max(
388 | 0,
389 | self.config.recovery_timeout
390 | - (time.time() - self._last_failure_time),
391 | )
392 | raise CircuitBreakerError(
393 | service=self.config.name,
394 | failure_count=self._consecutive_failures,
395 | threshold=self.config.failure_threshold,
396 | context={
397 | "state": self._state.value,
398 | "time_until_retry": round(time_until_retry, 1),
399 | },
400 | )
401 |
402 | start_time = time.time()
403 | try:
404 | # Execute with timeout
405 | result = await asyncio.wait_for(
406 | func(*args, **kwargs), timeout=self.config.timeout_threshold
407 | )
408 | duration = time.time() - start_time
409 | self._on_success(duration)
410 | return result
411 |
412 | except TimeoutError as e:
413 | duration = time.time() - start_time
414 | self._on_failure(duration)
415 | logger.warning(
416 | f"Circuit breaker '{self.config.name}' timeout after {duration:.2f}s"
417 | )
418 | raise ExternalServiceError(
419 | service=self.config.name,
420 | message=f"Service timed out after {self.config.timeout_threshold}s",
421 | context={
422 | "timeout": self.config.timeout_threshold,
423 | },
424 | ) from e
425 |
426 | except self.config.expected_exceptions:
427 | duration = time.time() - start_time
428 | self._on_failure(duration)
429 | raise
430 |
431 | def call_sync(self, func: Callable[..., Any], *args, **kwargs) -> Any:
432 | """
433 | Call sync function through circuit breaker.
434 |
435 | For sync functions, timeout is enforced differently depending on the function type.
436 | HTTP requests should use their own timeout parameters.
437 | """
438 | # Check if we should attempt reset
439 | with self._lock:
440 | if self._state == CircuitState.OPEN:
441 | if self._should_attempt_reset():
442 | self._transition_state(CircuitState.HALF_OPEN)
443 | self._half_open_successes = 0
444 | else:
445 | time_until_retry = self.config.recovery_timeout
446 | if self._last_failure_time:
447 | time_until_retry = max(
448 | 0,
449 | self.config.recovery_timeout
450 | - (time.time() - self._last_failure_time),
451 | )
452 | raise CircuitBreakerError(
453 | service=self.config.name,
454 | failure_count=self._consecutive_failures,
455 | threshold=self.config.failure_threshold,
456 | context={
457 | "state": self._state.value,
458 | "time_until_retry": round(time_until_retry, 1),
459 | },
460 | )
461 |
462 | start_time = time.time()
463 | try:
464 | result = func(*args, **kwargs)
465 | duration = time.time() - start_time
466 | self._on_success(duration)
467 | return result
468 |
469 | except self.config.expected_exceptions:
470 | duration = time.time() - start_time
471 | self._on_failure(duration)
472 | raise
473 |
474 | def reset(self):
475 | """Manually reset the circuit breaker."""
476 | with self._lock:
477 | self._transition_state(CircuitState.CLOSED)
478 | self._consecutive_failures = 0
479 | self._half_open_successes = 0
480 | self._last_failure_time = None
481 | logger.info(f"Circuit breaker '{self.config.name}' manually reset")
482 |
483 | def get_status(self) -> dict[str, Any]:
484 | """Get detailed circuit breaker status."""
485 | with self._lock:
486 | stats = self._metrics.get_stats()
487 | time_until_retry = None
488 |
489 | if self._state == CircuitState.OPEN and self._last_failure_time:
490 | time_until_retry = max(
491 | 0,
492 | self.config.recovery_timeout
493 | - (time.time() - self._last_failure_time),
494 | )
495 |
496 | return {
497 | "name": self.config.name,
498 | "state": self._state.value,
499 | "consecutive_failures": self._consecutive_failures,
500 | "time_until_retry": round(time_until_retry, 1)
501 | if time_until_retry
502 | else None,
503 | "metrics": stats,
504 | "config": {
505 | "failure_threshold": self.config.failure_threshold,
506 | "failure_rate_threshold": self.config.failure_rate_threshold,
507 | "timeout_threshold": self.config.timeout_threshold,
508 | "recovery_timeout": self.config.recovery_timeout,
509 | "detection_strategy": self.config.detection_strategy.value,
510 | },
511 | }
512 |
513 |
514 | # Global registry of circuit breakers
515 | _breakers: dict[str, EnhancedCircuitBreaker] = {}
516 | _breakers_lock = threading.Lock()
517 |
518 |
519 | def _get_or_create_breaker(config: CircuitBreakerConfig) -> EnhancedCircuitBreaker:
520 | """Get or create a circuit breaker."""
521 | with _breakers_lock:
522 | if config.name not in _breakers:
523 | _breakers[config.name] = EnhancedCircuitBreaker(config)
524 | return _breakers[config.name]
525 |
526 |
527 | def register_circuit_breaker(name: str, breaker: EnhancedCircuitBreaker):
528 | """Register a circuit breaker in the global registry."""
529 | with _breakers_lock:
530 | _breakers[name] = breaker
531 | logger.debug(f"Registered circuit breaker: {name}")
532 |
533 |
534 | def get_circuit_breaker(name: str) -> EnhancedCircuitBreaker | None:
535 | """Get a circuit breaker by name."""
536 | return _breakers.get(name)
537 |
538 |
539 | def get_all_circuit_breakers() -> dict[str, EnhancedCircuitBreaker]:
540 | """Get all circuit breakers."""
541 | return _breakers.copy()
542 |
543 |
544 | def reset_all_circuit_breakers():
545 | """Reset all circuit breakers."""
546 | for breaker in _breakers.values():
547 | breaker.reset()
548 |
549 |
550 | def get_circuit_breaker_status() -> dict[str, dict[str, Any]]:
551 | """Get status of all circuit breakers."""
552 | return {name: breaker.get_status() for name, breaker in _breakers.items()}
553 |
554 |
555 | def circuit_breaker(
556 | name: str | None = None,
557 | failure_threshold: int | None = None,
558 | failure_rate_threshold: float | None = None,
559 | timeout_threshold: float | None = None,
560 | recovery_timeout: int | None = None,
561 | expected_exceptions: tuple[type[Exception], ...] | None = None,
562 | ) -> Callable:
563 | """
564 | Decorator to apply circuit breaker to a function.
565 |
566 | Args:
567 | name: Circuit breaker name (defaults to function name)
568 | failure_threshold: Override default failure threshold
569 | failure_rate_threshold: Override default failure rate threshold
570 | timeout_threshold: Override default timeout threshold
571 | recovery_timeout: Override default recovery timeout
572 | expected_exceptions: Exceptions to catch (defaults to Exception)
573 | """
574 |
575 | def decorator(func: Callable[P, T]) -> Callable[P, T]:
576 | # Create config with overrides
577 | cb_name = name or f"{func.__module__}.{getattr(func, '__name__', 'unknown')}"
578 | config = CircuitBreakerConfig(
579 | name=cb_name,
580 | failure_threshold=failure_threshold
581 | or settings.agent.circuit_breaker_failure_threshold,
582 | failure_rate_threshold=failure_rate_threshold or 0.5,
583 | timeout_threshold=timeout_threshold or 30.0,
584 | recovery_timeout=recovery_timeout
585 | or settings.agent.circuit_breaker_recovery_timeout,
586 | expected_exceptions=expected_exceptions or (Exception,),
587 | )
588 |
589 | # Get or create circuit breaker for this function
590 | breaker = _get_or_create_breaker(config)
591 |
592 | if asyncio.iscoroutinefunction(func):
593 |
594 | @functools.wraps(func)
595 | async def async_wrapper(*args, **kwargs):
596 | return await breaker.call_async(func, *args, **kwargs)
597 |
598 | return cast(Callable[..., T], async_wrapper)
599 | else:
600 |
601 | @functools.wraps(func)
602 | def sync_wrapper(*args, **kwargs):
603 | return breaker.call_sync(func, *args, **kwargs)
604 |
605 | return cast(Callable[..., T], sync_wrapper)
606 |
607 | return decorator
608 |
609 |
610 | # Circuit breaker configurations for different services
611 | CIRCUIT_BREAKER_CONFIGS = {
612 | "yfinance": CircuitBreakerConfig(
613 | name="yfinance",
614 | failure_threshold=3,
615 | failure_rate_threshold=0.6,
616 | timeout_threshold=30.0,
617 | recovery_timeout=120,
618 | success_threshold=2,
619 | window_size=300,
620 | detection_strategy=FailureDetectionStrategy.COMBINED,
621 | expected_exceptions=(Exception,),
622 | ),
623 | "tiingo": CircuitBreakerConfig(
624 | name="tiingo",
625 | failure_threshold=5,
626 | failure_rate_threshold=0.7,
627 | timeout_threshold=15.0,
628 | recovery_timeout=60,
629 | success_threshold=3,
630 | window_size=300,
631 | detection_strategy=FailureDetectionStrategy.COMBINED,
632 | expected_exceptions=(Exception,),
633 | ),
634 | "fred_api": CircuitBreakerConfig(
635 | name="fred_api",
636 | failure_threshold=3,
637 | failure_rate_threshold=0.5,
638 | timeout_threshold=20.0,
639 | recovery_timeout=180,
640 | success_threshold=2,
641 | window_size=600,
642 | detection_strategy=FailureDetectionStrategy.COMBINED,
643 | expected_exceptions=(Exception,),
644 | ),
645 | "openrouter": CircuitBreakerConfig(
646 | name="openrouter",
647 | failure_threshold=5,
648 | failure_rate_threshold=0.6,
649 | timeout_threshold=60.0, # AI APIs can be slower
650 | recovery_timeout=120,
651 | success_threshold=2,
652 | window_size=300,
653 | detection_strategy=FailureDetectionStrategy.COMBINED,
654 | expected_exceptions=(Exception,),
655 | ),
656 | "exa": CircuitBreakerConfig(
657 | name="exa",
658 | failure_threshold=4,
659 | failure_rate_threshold=0.6,
660 | timeout_threshold=30.0,
661 | recovery_timeout=90,
662 | success_threshold=2,
663 | window_size=300,
664 | detection_strategy=FailureDetectionStrategy.COMBINED,
665 | expected_exceptions=(Exception,),
666 | ),
667 | "news_api": CircuitBreakerConfig(
668 | name="news_api",
669 | failure_threshold=3,
670 | failure_rate_threshold=0.5,
671 | timeout_threshold=25.0,
672 | recovery_timeout=120,
673 | success_threshold=2,
674 | window_size=300,
675 | detection_strategy=FailureDetectionStrategy.COMBINED,
676 | expected_exceptions=(Exception,),
677 | ),
678 | "finviz": CircuitBreakerConfig(
679 | name="finviz",
680 | failure_threshold=3,
681 | failure_rate_threshold=0.6,
682 | timeout_threshold=20.0,
683 | recovery_timeout=150,
684 | success_threshold=2,
685 | window_size=300,
686 | detection_strategy=FailureDetectionStrategy.COMBINED,
687 | expected_exceptions=(Exception,),
688 | ),
689 | "external_api": CircuitBreakerConfig(
690 | name="external_api",
691 | failure_threshold=4,
692 | failure_rate_threshold=0.6,
693 | timeout_threshold=25.0,
694 | recovery_timeout=120,
695 | success_threshold=2,
696 | window_size=300,
697 | detection_strategy=FailureDetectionStrategy.COMBINED,
698 | expected_exceptions=(Exception,),
699 | ),
700 | }
701 |
702 |
703 | def initialize_circuit_breakers() -> dict[str, EnhancedCircuitBreaker]:
704 | """Initialize all circuit breakers for external services."""
705 | circuit_breakers = {}
706 |
707 | for service_name, config in CIRCUIT_BREAKER_CONFIGS.items():
708 | try:
709 | breaker = EnhancedCircuitBreaker(config)
710 | register_circuit_breaker(service_name, breaker)
711 | circuit_breakers[service_name] = breaker
712 | logger.info(f"Initialized circuit breaker for {service_name}")
713 | except Exception as e:
714 | logger.error(
715 | f"Failed to initialize circuit breaker for {service_name}: {e}"
716 | )
717 |
718 | logger.info(f"Initialized {len(circuit_breakers)} circuit breakers")
719 | return circuit_breakers
720 |
721 |
722 | def with_circuit_breaker(service_name: str):
723 | """Decorator to wrap functions with a circuit breaker.
724 |
725 | Args:
726 | service_name: Name of the service/circuit breaker to use
727 |
728 | Usage:
729 | @with_circuit_breaker("yfinance")
730 | def fetch_stock_data(symbol: str):
731 | # API call code here
732 | pass
733 | """
734 |
735 | def decorator(func: Callable[P, T]) -> Callable[P, T]:
736 | @functools.wraps(func)
737 | def wrapper(*args, **kwargs) -> T:
738 | breaker = get_circuit_breaker(service_name)
739 | if not breaker:
740 | logger.warning(
741 | f"Circuit breaker '{service_name}' not found, executing without protection"
742 | )
743 | return func(*args, **kwargs)
744 |
745 | return breaker.call(func, *args, **kwargs)
746 |
747 | return wrapper
748 |
749 | return decorator
750 |
751 |
752 | def with_async_circuit_breaker(service_name: str):
753 | """Decorator to wrap async functions with a circuit breaker.
754 |
755 | Args:
756 | service_name: Name of the service/circuit breaker to use
757 |
758 | Usage:
759 | @with_async_circuit_breaker("tiingo")
760 | async def fetch_real_time_data(symbol: str):
761 | # Async API call code here
762 | pass
763 | """
764 |
765 | def decorator(func: Callable[..., T]) -> Callable[..., T]:
766 | @functools.wraps(func)
767 | async def wrapper(*args, **kwargs) -> T:
768 | breaker = get_circuit_breaker(service_name)
769 | if not breaker:
770 | logger.warning(
771 | f"Circuit breaker '{service_name}' not found, executing without protection"
772 | )
773 | return await func(*args, **kwargs)
774 |
775 | return await breaker.call_async(func, *args, **kwargs)
776 |
777 | return wrapper
778 |
779 | return decorator
780 |
781 |
782 | class CircuitBreakerManager:
783 | """Manager for all circuit breakers in the application."""
784 |
785 | def __init__(self):
786 | self._breakers = {}
787 | self._initialized = False
788 |
789 | def initialize(self) -> bool:
790 | """Initialize all circuit breakers."""
791 | if self._initialized:
792 | return True
793 |
794 | try:
795 | self._breakers = initialize_circuit_breakers()
796 | self._initialized = True
797 | logger.info("Circuit breaker manager initialized successfully")
798 | return True
799 | except Exception as e:
800 | logger.error(f"Failed to initialize circuit breaker manager: {e}")
801 | return False
802 |
803 | def get_breaker(self, service_name: str) -> EnhancedCircuitBreaker | None:
804 | """Get a circuit breaker by service name."""
805 | if not self._initialized:
806 | self.initialize()
807 |
808 | return self._breakers.get(service_name)
809 |
810 | def get_all_breakers(self) -> dict[str, EnhancedCircuitBreaker]:
811 | """Get all circuit breakers."""
812 | if not self._initialized:
813 | self.initialize()
814 |
815 | return self._breakers.copy()
816 |
817 | def reset_breaker(self, service_name: str) -> bool:
818 | """Reset a specific circuit breaker."""
819 | breaker = self.get_breaker(service_name)
820 | if breaker:
821 | breaker.reset()
822 | logger.info(f"Reset circuit breaker for {service_name}")
823 | return True
824 | return False
825 |
826 | def reset_all_breakers(self) -> int:
827 | """Reset all circuit breakers."""
828 | reset_count = 0
829 | for service_name, breaker in self._breakers.items():
830 | try:
831 | breaker.reset()
832 | reset_count += 1
833 | logger.info(f"Reset circuit breaker for {service_name}")
834 | except Exception as e:
835 | logger.error(f"Failed to reset circuit breaker for {service_name}: {e}")
836 |
837 | logger.info(f"Reset {reset_count} circuit breakers")
838 | return reset_count
839 |
840 | def get_health_status(self) -> dict[str, dict[str, Any]]:
841 | """Get health status of all circuit breakers."""
842 | if not self._initialized:
843 | self.initialize()
844 |
845 | status = {}
846 | for service_name, breaker in self._breakers.items():
847 | try:
848 | metrics = breaker.get_metrics()
849 | status[service_name] = {
850 | "name": service_name,
851 | "state": breaker.state.value,
852 | "consecutive_failures": breaker.consecutive_failures,
853 | "time_until_retry": breaker.time_until_retry(),
854 | "metrics": {
855 | "total_calls": metrics.get_total_calls(),
856 | "success_rate": metrics.get_success_rate(),
857 | "failure_rate": metrics.get_failure_rate(),
858 | "avg_response_time": metrics.get_average_response_time(),
859 | "last_failure_time": metrics.get_last_failure_time(),
860 | "uptime_percentage": metrics.get_uptime_percentage(),
861 | },
862 | }
863 | except Exception as e:
864 | status[service_name] = {
865 | "name": service_name,
866 | "state": "error",
867 | "error": str(e),
868 | }
869 |
870 | return status
871 |
872 |
873 | # Global circuit breaker manager instance
874 | _circuit_breaker_manager = CircuitBreakerManager()
875 |
876 |
877 | def get_circuit_breaker_manager() -> CircuitBreakerManager:
878 | """Get the global circuit breaker manager."""
879 | return _circuit_breaker_manager
880 |
881 |
882 | def initialize_all_circuit_breakers() -> bool:
883 | """Initialize all circuit breakers (convenience function)."""
884 | return _circuit_breaker_manager.initialize()
885 |
886 |
887 | def get_all_circuit_breaker_status() -> dict[str, dict[str, Any]]:
888 | """Get status of all circuit breakers (convenience function)."""
889 | return _circuit_breaker_manager.get_health_status()
890 |
891 |
892 | # Specific circuit breaker decorators for common services
893 |
894 |
895 | def with_yfinance_circuit_breaker(func: F) -> F: # noqa: UP047
896 | """Decorator for yfinance API calls."""
897 | return cast(F, with_circuit_breaker("yfinance")(func))
898 |
899 |
900 | def with_tiingo_circuit_breaker(func: F) -> F: # noqa: UP047
901 | """Decorator for Tiingo API calls."""
902 | return cast(F, with_circuit_breaker("tiingo")(func))
903 |
904 |
905 | def with_fred_circuit_breaker(func: F) -> F: # noqa: UP047
906 | """Decorator for FRED API calls."""
907 | return cast(F, with_circuit_breaker("fred_api")(func))
908 |
909 |
910 | def with_openrouter_circuit_breaker(func: F) -> F: # noqa: UP047
911 | """Decorator for OpenRouter API calls."""
912 | return cast(F, with_circuit_breaker("openrouter")(func))
913 |
914 |
915 | def with_exa_circuit_breaker(func: F) -> F: # noqa: UP047
916 | """Decorator for Exa API calls."""
917 | return cast(F, with_circuit_breaker("exa")(func))
918 |
919 |
920 | # Async versions
921 |
922 |
923 | def with_async_yfinance_circuit_breaker(func: F) -> F: # noqa: UP047
924 | """Async decorator for yfinance API calls."""
925 | return cast(F, with_async_circuit_breaker("yfinance")(func))
926 |
927 |
928 | def with_async_tiingo_circuit_breaker(func: F) -> F: # noqa: UP047
929 | """Async decorator for Tiingo API calls."""
930 | return cast(F, with_async_circuit_breaker("tiingo")(func))
931 |
932 |
933 | def with_async_fred_circuit_breaker(func: F) -> F: # noqa: UP047
934 | """Async decorator for FRED API calls."""
935 | return cast(F, with_async_circuit_breaker("fred_api")(func))
936 |
937 |
938 | def with_async_openrouter_circuit_breaker(func: F) -> F: # noqa: UP047
939 | """Async decorator for OpenRouter API calls."""
940 | return cast(F, with_async_circuit_breaker("openrouter")(func))
941 |
942 |
943 | def with_async_exa_circuit_breaker(func: F) -> F: # noqa: UP047
944 | """Async decorator for Exa API calls."""
945 | return cast(F, with_async_circuit_breaker("exa")(func))
946 |
```