This is page 35 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/providers/stock_data.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Enhanced stock data provider with SQLAlchemy integration and screening recommendations.
3 | Provides comprehensive stock data retrieval with database caching and maverick screening.
4 | """
5 |
6 | # Suppress specific pyright warnings for pandas operations
7 | # pyright: reportOperatorIssue=false
8 |
9 | import logging
10 | from datetime import UTC, datetime, timedelta
11 |
12 | import pandas as pd
13 | import pandas_market_calendars as mcal
14 | import pytz
15 | import yfinance as yf
16 | from dotenv import load_dotenv
17 | from sqlalchemy import text
18 | from sqlalchemy.orm import Session
19 |
20 | from maverick_mcp.data.models import (
21 | MaverickBearStocks,
22 | MaverickStocks,
23 | PriceCache,
24 | SessionLocal,
25 | Stock,
26 | SupplyDemandBreakoutStocks,
27 | bulk_insert_price_data,
28 | get_latest_maverick_screening,
29 | )
30 | from maverick_mcp.data.session_management import get_db_session_read_only
31 | from maverick_mcp.utils.circuit_breaker_decorators import (
32 | with_stock_data_circuit_breaker,
33 | )
34 | from maverick_mcp.utils.yfinance_pool import get_yfinance_pool
35 |
36 | # Load environment variables
37 | load_dotenv()
38 |
39 | # Configure logging
40 | logging.basicConfig(
41 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
42 | )
43 | logger = logging.getLogger("maverick_mcp.stock_data")
44 |
45 |
46 | class EnhancedStockDataProvider:
47 | """
48 | Enhanced provider for stock data with database caching and screening recommendations.
49 | """
50 |
51 | def __init__(self, db_session: Session | None = None):
52 | """
53 | Initialize the stock data provider.
54 |
55 | Args:
56 | db_session: Optional database session for dependency injection.
57 | If not provided, will get sessions as needed.
58 | """
59 | self.timeout = 30
60 | self.max_retries = 3
61 | self.cache_days = 1 # Cache data for 1 day by default
62 | # Initialize NYSE calendar for US stock market
63 | self.market_calendar = mcal.get_calendar("NYSE")
64 | self._db_session = db_session
65 | # Initialize yfinance connection pool
66 | self._yf_pool = get_yfinance_pool()
67 | if db_session:
68 | # Test the provided session
69 | self._test_db_connection_with_session(db_session)
70 | else:
71 | # Test creating a new session
72 | self._test_db_connection()
73 |
74 | def _test_db_connection(self):
75 | """Test database connection on initialization."""
76 | try:
77 | # Use read-only context manager for automatic session management
78 | with get_db_session_read_only() as session:
79 | # Try a simple query
80 | result = session.execute(text("SELECT 1"))
81 | result.fetchone()
82 | logger.info("Database connection successful")
83 | except Exception as e:
84 | logger.warning(
85 | f"Database connection test failed: {e}. Caching will be disabled."
86 | )
87 |
88 | def _test_db_connection_with_session(self, session: Session):
89 | """Test provided database session."""
90 | try:
91 | # Try a simple query
92 | result = session.execute(text("SELECT 1"))
93 | result.fetchone()
94 | logger.info("Database session test successful")
95 | except Exception as e:
96 | logger.warning(
97 | f"Database session test failed: {e}. Caching may not work properly."
98 | )
99 |
100 | def _get_data_with_smart_cache(
101 | self, symbol: str, start_date: str, end_date: str, interval: str
102 | ) -> pd.DataFrame:
103 | """
104 | Get stock data using smart caching strategy.
105 |
106 | This method:
107 | 1. Gets all available data from cache
108 | 2. Identifies missing date ranges
109 | 3. Fetches only missing data from yfinance
110 | 4. Combines and returns the complete dataset
111 |
112 | Args:
113 | symbol: Stock ticker symbol
114 | start_date: Start date in YYYY-MM-DD format
115 | end_date: End date in YYYY-MM-DD format
116 | interval: Data interval (only '1d' is cached)
117 |
118 | Returns:
119 | DataFrame with complete stock data
120 | """
121 | symbol = symbol.upper()
122 | session, should_close = self._get_db_session()
123 |
124 | try:
125 | # Step 1: Get ALL available cached data for the date range
126 | logger.info(f"Checking cache for {symbol} from {start_date} to {end_date}")
127 | cached_df = self._get_cached_data_flexible(
128 | session, symbol, start_date, end_date
129 | )
130 |
131 | # Convert dates for comparison - ensure timezone-naive for consistency
132 | start_dt = pd.to_datetime(start_date).tz_localize(None)
133 | end_dt = pd.to_datetime(end_date).tz_localize(None)
134 |
135 | # Step 2: Determine what data we need
136 | if cached_df is not None and not cached_df.empty:
137 | logger.info(f"Found {len(cached_df)} cached records for {symbol}")
138 |
139 | # Check if we have all the data we need - ensure timezone-naive for comparison
140 | cached_start = pd.to_datetime(cached_df.index.min()).tz_localize(None)
141 | cached_end = pd.to_datetime(cached_df.index.max()).tz_localize(None)
142 |
143 | # Identify missing ranges
144 | missing_ranges = []
145 |
146 | # Missing data at the beginning?
147 | if start_dt < cached_start:
148 | # Get trading days in the missing range
149 | missing_start_trading = self._get_trading_days(
150 | start_dt, cached_start - timedelta(days=1)
151 | )
152 | if len(missing_start_trading) > 0:
153 | # Only request data if there are trading days
154 | missing_ranges.append(
155 | (
156 | missing_start_trading[0].strftime("%Y-%m-%d"),
157 | missing_start_trading[-1].strftime("%Y-%m-%d"),
158 | )
159 | )
160 |
161 | # Missing recent data?
162 | if end_dt > cached_end:
163 | # Check if there are any trading days after our cached data
164 | if self._is_trading_day_between(cached_end, end_dt):
165 | # Get the actual trading days we need
166 | missing_end_trading = self._get_trading_days(
167 | cached_end + timedelta(days=1), end_dt
168 | )
169 | if len(missing_end_trading) > 0:
170 | missing_ranges.append(
171 | (
172 | missing_end_trading[0].strftime("%Y-%m-%d"),
173 | missing_end_trading[-1].strftime("%Y-%m-%d"),
174 | )
175 | )
176 |
177 | # If no missing data, return cached data
178 | if not missing_ranges:
179 | logger.info(
180 | f"Cache hit! Returning {len(cached_df)} cached records for {symbol}"
181 | )
182 | # Filter to requested range - ensure index is timezone-naive
183 | cached_df.index = pd.to_datetime(cached_df.index).tz_localize(None)
184 | mask = (cached_df.index >= start_dt) & (cached_df.index <= end_dt)
185 | return cached_df.loc[mask]
186 |
187 | # Step 3: Fetch only missing data
188 | logger.info(f"Cache partial hit. Missing ranges: {missing_ranges}")
189 | all_dfs = [cached_df]
190 |
191 | for miss_start, miss_end in missing_ranges:
192 | logger.info(
193 | f"Fetching missing data for {symbol} from {miss_start} to {miss_end}"
194 | )
195 | missing_df = self._fetch_stock_data_from_yfinance(
196 | symbol, miss_start, miss_end, None, interval
197 | )
198 | if not missing_df.empty:
199 | all_dfs.append(missing_df)
200 | # Cache the new data
201 | self._cache_price_data(session, symbol, missing_df)
202 |
203 | # Combine all data
204 | combined_df = pd.concat(all_dfs).sort_index()
205 | # Remove any duplicates (keep first)
206 | combined_df = combined_df[~combined_df.index.duplicated(keep="first")]
207 |
208 | # Filter to requested range - ensure index is timezone-naive
209 | combined_df.index = pd.to_datetime(combined_df.index).tz_localize(None)
210 | mask = (combined_df.index >= start_dt) & (combined_df.index <= end_dt)
211 | return combined_df.loc[mask]
212 |
213 | else:
214 | # No cached data, fetch everything but only for trading days
215 | logger.info(
216 | f"No cached data found for {symbol}, fetching from yfinance"
217 | )
218 |
219 | # Adjust dates to trading days
220 | trading_days = self._get_trading_days(start_date, end_date)
221 | if len(trading_days) == 0:
222 | logger.warning(
223 | f"No trading days found between {start_date} and {end_date}"
224 | )
225 | return pd.DataFrame(
226 | columns=[ # type: ignore[arg-type]
227 | "Open",
228 | "High",
229 | "Low",
230 | "Close",
231 | "Volume",
232 | "Dividends",
233 | "Stock Splits",
234 | ]
235 | )
236 |
237 | # Fetch data only for the trading day range
238 | fetch_start = trading_days[0].strftime("%Y-%m-%d")
239 | fetch_end = trading_days[-1].strftime("%Y-%m-%d")
240 |
241 | logger.info(
242 | f"Fetching data for trading days: {fetch_start} to {fetch_end}"
243 | )
244 | df = self._fetch_stock_data_from_yfinance(
245 | symbol, fetch_start, fetch_end, None, interval
246 | )
247 | if not df.empty:
248 | # Ensure stock exists and cache the data
249 | self._get_or_create_stock(session, symbol)
250 | self._cache_price_data(session, symbol, df)
251 | return df
252 |
253 | finally:
254 | if should_close:
255 | session.close()
256 |
257 | def _get_cached_data_flexible(
258 | self, session: Session, symbol: str, start_date: str, end_date: str
259 | ) -> pd.DataFrame | None:
260 | """
261 | Get cached data with flexible date range.
262 |
263 | Unlike the strict version, this returns whatever cached data exists
264 | within the requested range, even if incomplete.
265 |
266 | Args:
267 | session: Database session
268 | symbol: Stock ticker symbol (will be uppercased)
269 | start_date: Start date in YYYY-MM-DD format
270 | end_date: End date in YYYY-MM-DD format
271 |
272 | Returns:
273 | DataFrame with available cached data or None
274 | """
275 | try:
276 | # Get whatever data exists in the range
277 | df = PriceCache.get_price_data(session, symbol, start_date, end_date)
278 |
279 | if df.empty:
280 | return None
281 |
282 | # Add expected columns for compatibility
283 | for col in ["Dividends", "Stock Splits"]:
284 | if col not in df.columns:
285 | df[col] = 0.0
286 |
287 | # Ensure column names match yfinance format
288 | column_mapping = {
289 | "open": "Open",
290 | "high": "High",
291 | "low": "Low",
292 | "close": "Close",
293 | "volume": "Volume",
294 | }
295 | df.rename(columns=column_mapping, inplace=True)
296 |
297 | # Ensure proper data types to match yfinance
298 | # Convert Decimal to float for price columns
299 | for col in ["Open", "High", "Low", "Close"]:
300 | if col in df.columns:
301 | df[col] = pd.to_numeric(df[col], errors="coerce").astype("float64")
302 |
303 | # Convert volume to int
304 | if "Volume" in df.columns:
305 | df["Volume"] = (
306 | pd.to_numeric(df["Volume"], errors="coerce")
307 | .fillna(0)
308 | .astype("int64")
309 | )
310 |
311 | # Ensure index is timezone-naive for consistency
312 | df.index = pd.to_datetime(df.index).tz_localize(None)
313 |
314 | return df
315 |
316 | except Exception as e:
317 | logger.error(f"Error getting flexible cached data: {e}")
318 | return None
319 |
320 | def _is_trading_day_between(
321 | self, start_date: pd.Timestamp, end_date: pd.Timestamp
322 | ) -> bool:
323 | """
324 | Check if there's a trading day between two dates using market calendar.
325 |
326 | Args:
327 | start_date: Start date
328 | end_date: End date
329 |
330 | Returns:
331 | True if there's a trading day between the dates
332 | """
333 | # Add one day to start since we're checking "between"
334 | check_start = start_date + timedelta(days=1)
335 |
336 | if check_start > end_date:
337 | return False
338 |
339 | # Get trading days in the range
340 | trading_days = self._get_trading_days(check_start, end_date)
341 | return len(trading_days) > 0
342 |
343 | def _get_trading_days(self, start_date, end_date) -> pd.DatetimeIndex:
344 | """
345 | Get all trading days between start and end dates.
346 |
347 | Args:
348 | start_date: Start date (can be string or datetime)
349 | end_date: End date (can be string or datetime)
350 |
351 | Returns:
352 | DatetimeIndex of trading days (timezone-naive)
353 | """
354 | # Ensure dates are datetime objects (timezone-naive)
355 | if isinstance(start_date, str):
356 | start_date = pd.to_datetime(start_date).tz_localize(None)
357 | else:
358 | start_date = pd.to_datetime(start_date).tz_localize(None)
359 | if isinstance(end_date, str):
360 | end_date = pd.to_datetime(end_date).tz_localize(None)
361 | else:
362 | end_date = pd.to_datetime(end_date).tz_localize(None)
363 |
364 | # Get valid trading days from market calendar
365 | schedule = self.market_calendar.schedule(
366 | start_date=start_date, end_date=end_date
367 | )
368 | # Return timezone-naive index
369 | return schedule.index.tz_localize(None)
370 |
371 | def _get_last_trading_day(self, date) -> pd.Timestamp:
372 | """
373 | Get the last trading day on or before the given date.
374 |
375 | Args:
376 | date: Date to check (can be string or datetime)
377 |
378 | Returns:
379 | Last trading day as pd.Timestamp
380 | """
381 | if isinstance(date, str):
382 | date = pd.to_datetime(date)
383 |
384 | # Check if the date itself is a trading day
385 | if self._is_trading_day(date):
386 | return date
387 |
388 | # Otherwise, find the previous trading day
389 | for i in range(1, 10): # Look back up to 10 days
390 | check_date = date - timedelta(days=i)
391 | if self._is_trading_day(check_date):
392 | return check_date
393 |
394 | # Fallback to the date itself if no trading day found
395 | return date
396 |
397 | def _is_trading_day(self, date) -> bool:
398 | """
399 | Check if a specific date is a trading day.
400 |
401 | Args:
402 | date: Date to check
403 |
404 | Returns:
405 | True if it's a trading day
406 | """
407 | if isinstance(date, str):
408 | date = pd.to_datetime(date)
409 |
410 | schedule = self.market_calendar.schedule(start_date=date, end_date=date)
411 | return len(schedule) > 0
412 |
413 | def _get_db_session(self) -> tuple[Session, bool]:
414 | """
415 | Get a database session.
416 |
417 | Returns:
418 | Tuple of (session, should_close) where should_close indicates
419 | whether the caller should close the session.
420 | """
421 | # Use injected session if available - should NOT be closed
422 | if self._db_session:
423 | return self._db_session, False
424 |
425 | # Otherwise, create a new session using session factory - should be closed
426 | try:
427 | session = SessionLocal()
428 | return session, True
429 | except Exception as e:
430 | logger.error(f"Failed to get database session: {e}", exc_info=True)
431 | raise
432 |
433 | def _get_or_create_stock(self, session: Session, symbol: str) -> Stock:
434 | """
435 | Get or create a stock in the database.
436 |
437 | Args:
438 | session: Database session
439 | symbol: Stock ticker symbol
440 |
441 | Returns:
442 | Stock object
443 | """
444 | stock = Stock.get_or_create(session, symbol)
445 |
446 | # Try to update stock info if it's missing
447 | company_name = getattr(stock, "company_name", None)
448 | if company_name is None or company_name == "":
449 | try:
450 | # Use connection pool for info retrieval
451 | info = self._yf_pool.get_info(symbol)
452 |
453 | stock.company_name = info.get("longName", info.get("shortName"))
454 | stock.sector = info.get("sector")
455 | stock.industry = info.get("industry")
456 | stock.exchange = info.get("exchange")
457 | stock.currency = info.get("currency", "USD")
458 | stock.country = info.get("country")
459 |
460 | session.commit()
461 | except Exception as e:
462 | logger.warning(f"Could not update stock info for {symbol}: {e}")
463 | session.rollback()
464 |
465 | return stock
466 |
467 | def _get_cached_price_data(
468 | self, session: Session, symbol: str, start_date: str, end_date: str
469 | ) -> pd.DataFrame | None:
470 | """
471 | DEPRECATED: Use _get_data_with_smart_cache instead.
472 |
473 | This method is kept for backward compatibility but is no longer used
474 | in the main flow. The new smart caching approach provides better
475 | database prioritization.
476 | """
477 | logger.warning("Using deprecated _get_cached_price_data method")
478 | return self._get_cached_data_flexible(
479 | session, symbol.upper(), start_date, end_date
480 | )
481 |
482 | def _cache_price_data(
483 | self, session: Session, symbol: str, df: pd.DataFrame
484 | ) -> None:
485 | """
486 | Cache price data in the database.
487 |
488 | Args:
489 | session: Database session
490 | symbol: Stock ticker symbol
491 | df: DataFrame with price data
492 | """
493 | try:
494 | if df.empty:
495 | return
496 |
497 | # Ensure symbol is uppercase to match database
498 | symbol = symbol.upper()
499 |
500 | # Prepare DataFrame for caching
501 | cache_df = df.copy()
502 |
503 | # Ensure proper column names
504 | column_mapping = {
505 | "Open": "open",
506 | "High": "high",
507 | "Low": "low",
508 | "Close": "close",
509 | "Volume": "volume",
510 | }
511 | cache_df.rename(columns=column_mapping, inplace=True)
512 |
513 | # Log DataFrame info for debugging
514 | logger.debug(
515 | f"DataFrame columns before caching: {cache_df.columns.tolist()}"
516 | )
517 | logger.debug(f"DataFrame shape: {cache_df.shape}")
518 | logger.debug(f"DataFrame index type: {type(cache_df.index)}")
519 | if not cache_df.empty:
520 | logger.debug(f"Sample row: {cache_df.iloc[0].to_dict()}")
521 |
522 | # Insert data
523 | count = bulk_insert_price_data(session, symbol, cache_df)
524 | if count == 0:
525 | logger.info(
526 | f"No new records cached for {symbol} (data may already exist)"
527 | )
528 | else:
529 | logger.info(f"Cached {count} new price records for {symbol}")
530 |
531 | except Exception as e:
532 | logger.error(f"Error caching price data for {symbol}: {e}", exc_info=True)
533 | session.rollback()
534 |
535 | def get_stock_data(
536 | self,
537 | symbol: str,
538 | start_date: str | None = None,
539 | end_date: str | None = None,
540 | period: str | None = None,
541 | interval: str = "1d",
542 | use_cache: bool = True,
543 | ) -> pd.DataFrame:
544 | """
545 | Fetch stock data with database caching support.
546 |
547 | This method prioritizes cached data from the database and only fetches
548 | missing data from yfinance when necessary.
549 |
550 | Args:
551 | symbol: Stock ticker symbol
552 | start_date: Start date in YYYY-MM-DD format
553 | end_date: End date in YYYY-MM-DD format
554 | period: Alternative to start/end dates (e.g., '1d', '5d', '1mo', '3mo', '1y', etc.)
555 | interval: Data interval ('1d', '1wk', '1mo', '1m', '5m', etc.)
556 | use_cache: Whether to use cached data if available
557 |
558 | Returns:
559 | DataFrame with stock data
560 | """
561 | # For non-daily intervals or periods, always fetch fresh data
562 | if interval != "1d" or period:
563 | return self._fetch_stock_data_from_yfinance(
564 | symbol, start_date, end_date, period, interval
565 | )
566 |
567 | # Set default dates if not provided
568 | if start_date is None:
569 | start_date = (datetime.now(UTC) - timedelta(days=365)).strftime("%Y-%m-%d")
570 | if end_date is None:
571 | end_date = datetime.now(UTC).strftime("%Y-%m-%d")
572 |
573 | # For daily data, adjust end date to last trading day if it's not a trading day
574 | # This prevents unnecessary cache misses on weekends/holidays
575 | if interval == "1d" and use_cache:
576 | end_dt = pd.to_datetime(end_date)
577 | if not self._is_trading_day(end_dt):
578 | last_trading = self._get_last_trading_day(end_dt)
579 | logger.debug(
580 | f"Adjusting end date from {end_date} to last trading day {last_trading.strftime('%Y-%m-%d')}"
581 | )
582 | end_date = last_trading.strftime("%Y-%m-%d")
583 |
584 | # If cache is disabled, fetch directly from yfinance
585 | if not use_cache:
586 | logger.info(f"Cache disabled, fetching from yfinance for {symbol}")
587 | return self._fetch_stock_data_from_yfinance(
588 | symbol, start_date, end_date, period, interval
589 | )
590 |
591 | # Try a smarter caching approach
592 | try:
593 | return self._get_data_with_smart_cache(
594 | symbol, start_date, end_date, interval
595 | )
596 | except Exception as e:
597 | logger.warning(f"Smart cache failed, falling back to yfinance: {e}")
598 | return self._fetch_stock_data_from_yfinance(
599 | symbol, start_date, end_date, period, interval
600 | )
601 |
602 | async def get_stock_data_async(
603 | self,
604 | symbol: str,
605 | start_date: str | None = None,
606 | end_date: str | None = None,
607 | period: str | None = None,
608 | interval: str = "1d",
609 | use_cache: bool = True,
610 | ) -> pd.DataFrame:
611 | """
612 | Async version of get_stock_data for parallel processing.
613 |
614 | This method wraps the synchronous get_stock_data method to provide
615 | an async interface for use in parallel backtesting operations.
616 |
617 | Args:
618 | symbol: Stock ticker symbol
619 | start_date: Start date in YYYY-MM-DD format
620 | end_date: End date in YYYY-MM-DD format
621 | period: Alternative to start/end dates (e.g., '1d', '5d', '1mo', '3mo', '1y', etc.)
622 | interval: Data interval ('1d', '1wk', '1mo', '1m', '5m', etc.)
623 | use_cache: Whether to use cached data if available
624 |
625 | Returns:
626 | DataFrame with stock data
627 | """
628 | import asyncio
629 | import functools
630 |
631 | # Run the synchronous method in a thread pool to avoid blocking
632 | loop = asyncio.get_event_loop()
633 |
634 | # Use functools.partial to create a callable with all arguments
635 | sync_method = functools.partial(
636 | self.get_stock_data,
637 | symbol=symbol,
638 | start_date=start_date,
639 | end_date=end_date,
640 | period=period,
641 | interval=interval,
642 | use_cache=use_cache,
643 | )
644 |
645 | # Execute in thread pool to avoid blocking the event loop
646 | return await loop.run_in_executor(None, sync_method)
647 |
648 | @with_stock_data_circuit_breaker(
649 | use_fallback=False
650 | ) # Fallback handled at higher level
651 | def _fetch_stock_data_from_yfinance(
652 | self,
653 | symbol: str,
654 | start_date: str | None = None,
655 | end_date: str | None = None,
656 | period: str | None = None,
657 | interval: str = "1d",
658 | ) -> pd.DataFrame:
659 | """
660 | Fetch stock data from yfinance with circuit breaker protection.
661 |
662 | Note: Circuit breaker is applied with use_fallback=False because
663 | fallback strategies are handled at the get_stock_data level.
664 | """
665 | logger.info(
666 | f"Fetching data from yfinance for {symbol} - Start: {start_date}, End: {end_date}, Period: {period}, Interval: {interval}"
667 | )
668 | # Use connection pool for better performance
669 | # The pool handles session management and retries internally
670 |
671 | # Use the optimized connection pool
672 | df = self._yf_pool.get_history(
673 | symbol=symbol,
674 | start=start_date,
675 | end=end_date,
676 | period=period,
677 | interval=interval,
678 | )
679 |
680 | # Check if dataframe is empty or if required columns are missing
681 | if df.empty:
682 | logger.warning(f"Empty dataframe returned for {symbol}")
683 | return pd.DataFrame(
684 | columns=["Open", "High", "Low", "Close", "Volume"] # type: ignore[arg-type]
685 | )
686 |
687 | # Ensure all expected columns exist
688 | for col in ["Open", "High", "Low", "Close", "Volume"]:
689 | if col not in df.columns:
690 | logger.warning(
691 | f"Column {col} missing from data for {symbol}, adding empty column"
692 | )
693 | # Use appropriate default values
694 | if col == "Volume":
695 | df[col] = 0
696 | else:
697 | df[col] = 0.0
698 |
699 | df.index.name = "Date"
700 | return df
701 |
702 | def get_maverick_recommendations(
703 | self, limit: int = 20, min_score: int | None = None
704 | ) -> list[dict]:
705 | """
706 | Get top Maverick stock recommendations from the database.
707 |
708 | Args:
709 | limit: Maximum number of recommendations
710 | min_score: Minimum combined score filter
711 |
712 | Returns:
713 | List of stock recommendations with details
714 | """
715 | session, should_close = self._get_db_session()
716 | try:
717 | # Build query with filtering at database level
718 | query = session.query(MaverickStocks)
719 |
720 | # Apply min_score filter in the query if specified
721 | if min_score:
722 | query = query.filter(MaverickStocks.combined_score >= min_score)
723 |
724 | # Order by score and limit results
725 | stocks = (
726 | query.order_by(MaverickStocks.combined_score.desc()).limit(limit).all()
727 | )
728 |
729 | # Process results with list comprehension for better performance
730 | recommendations = [
731 | {
732 | **stock.to_dict(),
733 | "recommendation_type": "maverick_bullish",
734 | "reason": self._generate_maverick_reason(stock),
735 | }
736 | for stock in stocks
737 | ]
738 |
739 | return recommendations
740 | except Exception as e:
741 | logger.error(f"Error getting maverick recommendations: {e}")
742 | return []
743 | finally:
744 | if should_close:
745 | session.close()
746 |
747 | def get_maverick_bear_recommendations(
748 | self, limit: int = 20, min_score: int | None = None
749 | ) -> list[dict]:
750 | """
751 | Get top Maverick bear stock recommendations from the database.
752 |
753 | Args:
754 | limit: Maximum number of recommendations
755 | min_score: Minimum score filter
756 |
757 | Returns:
758 | List of bear stock recommendations with details
759 | """
760 | session, should_close = self._get_db_session()
761 | try:
762 | # Build query with filtering at database level
763 | query = session.query(MaverickBearStocks)
764 |
765 | # Apply min_score filter in the query if specified
766 | if min_score:
767 | query = query.filter(MaverickBearStocks.score >= min_score)
768 |
769 | # Order by score and limit results
770 | stocks = query.order_by(MaverickBearStocks.score.desc()).limit(limit).all()
771 |
772 | # Process results with list comprehension for better performance
773 | recommendations = [
774 | {
775 | **stock.to_dict(),
776 | "recommendation_type": "maverick_bearish",
777 | "reason": self._generate_bear_reason(stock),
778 | }
779 | for stock in stocks
780 | ]
781 |
782 | return recommendations
783 | except Exception as e:
784 | logger.error(f"Error getting bear recommendations: {e}")
785 | return []
786 | finally:
787 | if should_close:
788 | session.close()
789 |
790 | def get_supply_demand_breakout_recommendations(
791 | self, limit: int = 20, min_momentum_score: float | None = None
792 | ) -> list[dict]:
793 | """
794 | Get stocks showing supply/demand breakout patterns from accumulation phases.
795 |
796 | Args:
797 | limit: Maximum number of recommendations
798 | min_momentum_score: Minimum momentum score filter
799 |
800 | Returns:
801 | List of supply/demand breakout recommendations with market structure analysis
802 | """
803 | session, should_close = self._get_db_session()
804 | try:
805 | # Build query with all filters at database level
806 | query = session.query(SupplyDemandBreakoutStocks).filter(
807 | # Supply/demand breakout criteria: price above all moving averages (demand zone)
808 | SupplyDemandBreakoutStocks.close_price
809 | > SupplyDemandBreakoutStocks.sma_50,
810 | SupplyDemandBreakoutStocks.close_price
811 | > SupplyDemandBreakoutStocks.sma_150,
812 | SupplyDemandBreakoutStocks.close_price
813 | > SupplyDemandBreakoutStocks.sma_200,
814 | # Moving average alignment indicates accumulation structure
815 | SupplyDemandBreakoutStocks.sma_50 > SupplyDemandBreakoutStocks.sma_150,
816 | SupplyDemandBreakoutStocks.sma_150 > SupplyDemandBreakoutStocks.sma_200,
817 | )
818 |
819 | # Apply min_momentum_score filter if specified
820 | if min_momentum_score:
821 | query = query.filter(
822 | SupplyDemandBreakoutStocks.momentum_score >= min_momentum_score
823 | )
824 |
825 | # Order by momentum score and limit results
826 | stocks = (
827 | query.order_by(SupplyDemandBreakoutStocks.momentum_score.desc())
828 | .limit(limit)
829 | .all()
830 | )
831 |
832 | # Process results with list comprehension for better performance
833 | recommendations = [
834 | {
835 | **stock.to_dict(),
836 | "recommendation_type": "supply_demand_breakout",
837 | "reason": self._generate_supply_demand_reason(stock),
838 | }
839 | for stock in stocks
840 | ]
841 |
842 | return recommendations
843 | except Exception as e:
844 | logger.error(f"Error getting trending recommendations: {e}")
845 | return []
846 | finally:
847 | if should_close:
848 | session.close()
849 |
850 | def get_all_screening_recommendations(self) -> dict[str, list[dict]]:
851 | """
852 | Get all screening recommendations in one call.
853 |
854 | Returns:
855 | Dictionary with all screening types and their recommendations
856 | """
857 | try:
858 | results = get_latest_maverick_screening()
859 |
860 | # Add recommendation reasons
861 | for stock in results.get("maverick_stocks", []):
862 | stock["recommendation_type"] = "maverick_bullish"
863 | stock["reason"] = self._generate_maverick_reason_from_dict(stock)
864 |
865 | for stock in results.get("maverick_bear_stocks", []):
866 | stock["recommendation_type"] = "maverick_bearish"
867 | stock["reason"] = self._generate_bear_reason_from_dict(stock)
868 |
869 | for stock in results.get("supply_demand_breakouts", []):
870 | stock["recommendation_type"] = "supply_demand_breakout"
871 | stock["reason"] = self._generate_supply_demand_reason_from_dict(stock)
872 |
873 | return results
874 | except Exception as e:
875 | logger.error(f"Error getting all screening recommendations: {e}")
876 | return {
877 | "maverick_stocks": [],
878 | "maverick_bear_stocks": [],
879 | "supply_demand_breakouts": [],
880 | }
881 |
882 | def _generate_maverick_reason(self, stock: MaverickStocks) -> str:
883 | """Generate recommendation reason for Maverick stock."""
884 | reasons = []
885 |
886 | combined_score = getattr(stock, "combined_score", None)
887 | if combined_score is not None and combined_score >= 90:
888 | reasons.append("Exceptional combined score")
889 | elif combined_score is not None and combined_score >= 80:
890 | reasons.append("Strong combined score")
891 |
892 | momentum_score = getattr(stock, "momentum_score", None)
893 | if momentum_score is not None and momentum_score >= 90:
894 | reasons.append("outstanding relative strength")
895 | elif momentum_score is not None and momentum_score >= 80:
896 | reasons.append("strong relative strength")
897 |
898 | pat = getattr(stock, "pat", None)
899 | if pat is not None and pat != "":
900 | reasons.append(f"{pat} pattern detected")
901 |
902 | consolidation = getattr(stock, "consolidation", None)
903 | if consolidation is not None and consolidation == "yes":
904 | reasons.append("consolidation characteristics")
905 |
906 | sqz = getattr(stock, "sqz", None)
907 | if sqz is not None and sqz != "":
908 | reasons.append(f"squeeze indicator: {sqz}")
909 |
910 | return (
911 | "Bullish setup with " + ", ".join(reasons)
912 | if reasons
913 | else "Strong technical setup"
914 | )
915 |
916 | def _generate_bear_reason(self, stock: MaverickBearStocks) -> str:
917 | """Generate recommendation reason for bear stock."""
918 | reasons = []
919 |
920 | score = getattr(stock, "score", None)
921 | if score is not None and score >= 90:
922 | reasons.append("Exceptional bear score")
923 | elif score is not None and score >= 80:
924 | reasons.append("Strong bear score")
925 |
926 | momentum_score = getattr(stock, "momentum_score", None)
927 | if momentum_score is not None and momentum_score <= 30:
928 | reasons.append("weak relative strength")
929 |
930 | rsi_14 = getattr(stock, "rsi_14", None)
931 | if rsi_14 is not None and rsi_14 <= 30:
932 | reasons.append("oversold RSI")
933 |
934 | atr_contraction = getattr(stock, "atr_contraction", False)
935 | if atr_contraction is True:
936 | reasons.append("ATR contraction")
937 |
938 | big_down_vol = getattr(stock, "big_down_vol", False)
939 | if big_down_vol is True:
940 | reasons.append("heavy selling volume")
941 |
942 | return (
943 | "Bearish setup with " + ", ".join(reasons)
944 | if reasons
945 | else "Weak technical setup"
946 | )
947 |
948 | def _generate_supply_demand_reason(self, stock: SupplyDemandBreakoutStocks) -> str:
949 | """Generate recommendation reason for supply/demand breakout stock."""
950 | reasons = ["Supply/demand breakout from accumulation"]
951 |
952 | momentum_score = getattr(stock, "momentum_score", None)
953 | if momentum_score is not None and momentum_score >= 90:
954 | reasons.append("exceptional relative strength")
955 | elif momentum_score is not None and momentum_score >= 80:
956 | reasons.append("strong relative strength")
957 |
958 | reasons.append("price above all major moving averages")
959 | reasons.append("moving averages in proper alignment")
960 |
961 | pat = getattr(stock, "pat", None)
962 | if pat is not None and pat != "":
963 | reasons.append(f"{pat} pattern")
964 |
965 | return " with ".join(reasons)
966 |
967 | def _generate_maverick_reason_from_dict(self, stock: dict) -> str:
968 | """Generate recommendation reason for Maverick stock from dict."""
969 | reasons = []
970 |
971 | score = stock.get("combined_score", 0)
972 | if score >= 90:
973 | reasons.append("Exceptional combined score")
974 | elif score >= 80:
975 | reasons.append("Strong combined score")
976 |
977 | momentum = stock.get("momentum_score", 0)
978 | if momentum >= 90:
979 | reasons.append("outstanding relative strength")
980 | elif momentum >= 80:
981 | reasons.append("strong relative strength")
982 |
983 | if stock.get("pattern"):
984 | reasons.append(f"{stock['pattern']} pattern detected")
985 |
986 | if stock.get("consolidation") == "yes":
987 | reasons.append("consolidation characteristics")
988 |
989 | if stock.get("squeeze"):
990 | reasons.append(f"squeeze indicator: {stock['squeeze']}")
991 |
992 | return (
993 | "Bullish setup with " + ", ".join(reasons)
994 | if reasons
995 | else "Strong technical setup"
996 | )
997 |
998 | def _generate_bear_reason_from_dict(self, stock: dict) -> str:
999 | """Generate recommendation reason for bear stock from dict."""
1000 | reasons = []
1001 |
1002 | score = stock.get("score", 0)
1003 | if score >= 90:
1004 | reasons.append("Exceptional bear score")
1005 | elif score >= 80:
1006 | reasons.append("Strong bear score")
1007 |
1008 | momentum = stock.get("momentum_score", 100)
1009 | if momentum <= 30:
1010 | reasons.append("weak relative strength")
1011 |
1012 | rsi = stock.get("rsi_14")
1013 | if rsi and rsi <= 30:
1014 | reasons.append("oversold RSI")
1015 |
1016 | if stock.get("atr_contraction"):
1017 | reasons.append("ATR contraction")
1018 |
1019 | if stock.get("big_down_vol"):
1020 | reasons.append("heavy selling volume")
1021 |
1022 | return (
1023 | "Bearish setup with " + ", ".join(reasons)
1024 | if reasons
1025 | else "Weak technical setup"
1026 | )
1027 |
1028 | def _generate_supply_demand_reason_from_dict(self, stock: dict) -> str:
1029 | """Generate recommendation reason for supply/demand breakout stock from dict."""
1030 | reasons = ["Supply/demand breakout from accumulation"]
1031 |
1032 | momentum = stock.get("momentum_score", 0)
1033 | if momentum >= 90:
1034 | reasons.append("exceptional relative strength")
1035 | elif momentum >= 80:
1036 | reasons.append("strong relative strength")
1037 |
1038 | reasons.append("price above all major moving averages")
1039 | reasons.append("moving averages in proper alignment")
1040 |
1041 | if stock.get("pattern"):
1042 | reasons.append(f"{stock['pattern']} pattern")
1043 |
1044 | return " with ".join(reasons)
1045 |
1046 | # Keep all original methods for backward compatibility
1047 | @with_stock_data_circuit_breaker(use_fallback=False)
1048 | def get_stock_info(self, symbol: str) -> dict:
1049 | """Get detailed stock information from yfinance with circuit breaker protection."""
1050 | # Use connection pool for better performance
1051 | return self._yf_pool.get_info(symbol)
1052 |
1053 | def get_realtime_data(self, symbol):
1054 | """Get the latest real-time data for a symbol using yfinance."""
1055 | try:
1056 | # Use connection pool for real-time data
1057 | data = self._yf_pool.get_history(symbol, period="1d")
1058 |
1059 | if data.empty:
1060 | return None
1061 |
1062 | latest = data.iloc[-1]
1063 |
1064 | # Get previous close for change calculation
1065 | info = self._yf_pool.get_info(symbol)
1066 | prev_close = info.get("previousClose", None)
1067 | if prev_close is None:
1068 | # Try to get from 2-day history
1069 | data_2d = self._yf_pool.get_history(symbol, period="2d")
1070 | if len(data_2d) > 1:
1071 | prev_close = data_2d.iloc[0]["Close"]
1072 | else:
1073 | prev_close = latest["Close"]
1074 |
1075 | # Calculate change
1076 | price = latest["Close"]
1077 | change = price - prev_close
1078 | change_percent = (change / prev_close) * 100 if prev_close != 0 else 0
1079 |
1080 | return {
1081 | "symbol": symbol,
1082 | "price": round(price, 2),
1083 | "change": round(change, 2),
1084 | "change_percent": round(change_percent, 2),
1085 | "volume": int(latest["Volume"]),
1086 | "timestamp": data.index[-1],
1087 | "timestamp_display": data.index[-1].strftime("%Y-%m-%d %H:%M:%S"),
1088 | "is_real_time": False, # yfinance data has some delay
1089 | }
1090 | except Exception as e:
1091 | logger.error(f"Error fetching realtime data for {symbol}: {str(e)}")
1092 | return None
1093 |
1094 | def get_all_realtime_data(self, symbols):
1095 | """Get all latest real-time data for multiple symbols."""
1096 | results = {}
1097 | for symbol in symbols:
1098 | data = self.get_realtime_data(symbol)
1099 | if data:
1100 | results[symbol] = data
1101 | return results
1102 |
1103 | def is_market_open(self) -> bool:
1104 | """Check if the US stock market is currently open."""
1105 | now = datetime.now(pytz.timezone("US/Eastern"))
1106 |
1107 | # Check if it's a weekday
1108 | if now.weekday() >= 5: # 5 and 6 are Saturday and Sunday
1109 | return False
1110 |
1111 | # Check if it's between 9:30 AM and 4:00 PM Eastern Time
1112 | market_open = now.replace(hour=9, minute=30, second=0, microsecond=0)
1113 | market_close = now.replace(hour=16, minute=0, second=0, microsecond=0)
1114 |
1115 | return market_open <= now <= market_close
1116 |
1117 | def get_news(self, symbol: str, limit: int = 10) -> pd.DataFrame:
1118 | """Get news for a stock from yfinance."""
1119 | try:
1120 | ticker = yf.Ticker(symbol)
1121 | news = ticker.news
1122 |
1123 | if not news:
1124 | return pd.DataFrame(
1125 | columns=[ # type: ignore[arg-type]
1126 | "title",
1127 | "publisher",
1128 | "link",
1129 | "providerPublishTime",
1130 | "type",
1131 | ]
1132 | )
1133 |
1134 | df = pd.DataFrame(news[:limit])
1135 |
1136 | # Convert timestamp to datetime
1137 | if "providerPublishTime" in df.columns:
1138 | df["providerPublishTime"] = pd.to_datetime(
1139 | df["providerPublishTime"], unit="s"
1140 | )
1141 |
1142 | return df
1143 | except Exception as e:
1144 | logger.error(f"Error fetching news for {symbol}: {str(e)}")
1145 | return pd.DataFrame(
1146 | columns=["title", "publisher", "link", "providerPublishTime", "type"] # type: ignore[arg-type]
1147 | )
1148 |
1149 | def get_earnings(self, symbol: str) -> dict:
1150 | """Get earnings information for a stock."""
1151 | try:
1152 | ticker = yf.Ticker(symbol)
1153 | return {
1154 | "earnings": ticker.earnings.to_dict()
1155 | if hasattr(ticker, "earnings") and not ticker.earnings.empty
1156 | else {},
1157 | "earnings_dates": ticker.earnings_dates.to_dict()
1158 | if hasattr(ticker, "earnings_dates") and not ticker.earnings_dates.empty
1159 | else {},
1160 | "earnings_trend": ticker.earnings_trend
1161 | if hasattr(ticker, "earnings_trend")
1162 | else {},
1163 | }
1164 | except Exception as e:
1165 | logger.error(f"Error fetching earnings for {symbol}: {str(e)}")
1166 | return {"earnings": {}, "earnings_dates": {}, "earnings_trend": {}}
1167 |
1168 | def get_recommendations(self, symbol: str) -> pd.DataFrame:
1169 | """Get analyst recommendations for a stock."""
1170 | try:
1171 | ticker = yf.Ticker(symbol)
1172 | recommendations = ticker.recommendations
1173 |
1174 | if recommendations is None or recommendations.empty:
1175 | return pd.DataFrame(columns=["firm", "toGrade", "fromGrade", "action"]) # type: ignore[arg-type]
1176 |
1177 | return recommendations
1178 | except Exception as e:
1179 | logger.error(f"Error fetching recommendations for {symbol}: {str(e)}")
1180 | return pd.DataFrame(columns=["firm", "toGrade", "fromGrade", "action"]) # type: ignore[arg-type]
1181 |
1182 | def is_etf(self, symbol: str) -> bool:
1183 | """Check if a given symbol is an ETF."""
1184 | try:
1185 | stock = yf.Ticker(symbol)
1186 | # Check if quoteType exists and is ETF
1187 | if "quoteType" in stock.info:
1188 | return stock.info["quoteType"].upper() == "ETF" # type: ignore[no-any-return]
1189 | # Fallback check for common ETF identifiers
1190 | return any(
1191 | [
1192 | symbol.endswith(("ETF", "FUND")),
1193 | symbol
1194 | in [
1195 | "SPY",
1196 | "QQQ",
1197 | "IWM",
1198 | "DIA",
1199 | "XLB",
1200 | "XLE",
1201 | "XLF",
1202 | "XLI",
1203 | "XLK",
1204 | "XLP",
1205 | "XLU",
1206 | "XLV",
1207 | "XLY",
1208 | "XLC",
1209 | "XLRE",
1210 | "XME",
1211 | ],
1212 | "ETF" in stock.info.get("longName", "").upper(),
1213 | ]
1214 | )
1215 | except Exception as e:
1216 | logger.error(f"Error checking if {symbol} is ETF: {e}")
1217 | return False
1218 |
1219 |
1220 | # Maintain backward compatibility
1221 | StockDataProvider = EnhancedStockDataProvider
1222 |
```
--------------------------------------------------------------------------------
/tests/test_deep_research_functional.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive functional tests for DeepResearchAgent.
3 |
4 | This test suite focuses on testing the actual research functionality including:
5 |
6 | ## Web Search Integration Tests (TestWebSearchIntegration):
7 | - Exa and Tavily search provider query formatting and result processing
8 | - Provider fallback behavior when APIs fail
9 | - Search result deduplication from multiple providers
10 | - Social media filtering and content processing
11 |
12 | ## Research Synthesis Tests (TestResearchSynthesis):
13 | - Persona-aware content analysis with different investment styles
14 | - Complete research synthesis workflow from query to findings
15 | - Iterative research refinement based on initial results
16 | - Fact validation and source credibility scoring
17 |
18 | ## Persona-Based Research Tests (TestPersonaBasedResearch):
19 | - Conservative persona focus on stability, dividends, and risk factors
20 | - Aggressive persona exploration of growth opportunities and innovation
21 | - Day trader persona emphasis on short-term catalysts and volatility
22 | - Research depth differences between conservative and aggressive approaches
23 |
24 | ## Multi-Step Research Workflow Tests (TestMultiStepResearchWorkflow):
25 | - End-to-end research workflow from initial query to final report
26 | - Handling of insufficient or conflicting information scenarios
27 | - Research focusing and refinement based on discovered gaps
28 | - Citation generation and source attribution
29 |
30 | ## Research Method Specialization Tests (TestResearchMethodSpecialization):
31 | - Sentiment analysis specialization with news and social signals
32 | - Fundamental analysis focusing on financials and company data
33 | - Competitive analysis examining market position and rivals
34 | - Proper routing to specialized analysis based on focus areas
35 |
36 | ## Error Handling and Resilience Tests (TestErrorHandlingAndResilience):
37 | - Graceful degradation when search providers are unavailable
38 | - Content analysis fallback when LLM services fail
39 | - Partial search failure handling with provider redundancy
40 | - Circuit breaker behavior and timeout handling
41 |
42 | ## Research Quality and Validation Tests (TestResearchQualityAndValidation):
43 | - Research confidence calculation based on source quality and diversity
44 | - Source credibility scoring (government, financial sites vs. blogs)
45 | - Source diversity assessment for balanced research
46 | - Investment recommendation logic based on persona and findings
47 |
48 | ## Key Features Tested:
49 | - **Realistic Mock Data**: Uses comprehensive financial article samples
50 | - **Provider Integration**: Tests both Exa and Tavily search providers
51 | - **LangGraph Workflows**: Tests complete research state machine
52 | - **Persona Adaptation**: Validates different investor behavior patterns
53 | - **Error Resilience**: Ensures system continues operating with degraded capabilities
54 | - **Research Logic**: Tests actual synthesis and analysis rather than just API calls
55 |
56 | All tests use realistic mock data and test the research logic rather than just API connectivity.
57 | 26 test cases cover the complete research pipeline from initial search to final recommendations.
58 | """
59 |
60 | import json
61 | from unittest.mock import AsyncMock, MagicMock, Mock, patch
62 |
63 | import pytest
64 |
65 | from maverick_mcp.agents.deep_research import (
66 | PERSONA_RESEARCH_FOCUS,
67 | RESEARCH_DEPTH_LEVELS,
68 | ContentAnalyzer,
69 | DeepResearchAgent,
70 | ExaSearchProvider,
71 | TavilySearchProvider,
72 | )
73 | from maverick_mcp.exceptions import WebSearchError
74 |
75 |
76 | # Mock Data Fixtures
77 | @pytest.fixture
78 | def mock_llm():
79 | """Mock LLM with realistic responses for content analysis."""
80 | llm = MagicMock()
81 | llm.ainvoke = AsyncMock()
82 | llm.bind_tools = MagicMock(return_value=llm)
83 |
84 | # Default response for content analysis
85 | def mock_response(messages):
86 | response = Mock()
87 | response.content = json.dumps(
88 | {
89 | "KEY_INSIGHTS": [
90 | "Strong revenue growth in cloud services",
91 | "Market expansion in international segments",
92 | "Increasing competitive pressure from rivals",
93 | ],
94 | "SENTIMENT": {"direction": "bullish", "confidence": 0.75},
95 | "RISK_FACTORS": [
96 | "Regulatory scrutiny in international markets",
97 | "Supply chain disruptions affecting hardware",
98 | ],
99 | "OPPORTUNITIES": [
100 | "AI integration driving new revenue streams",
101 | "Subscription model improving recurring revenue",
102 | ],
103 | "CREDIBILITY": 0.8,
104 | "RELEVANCE": 0.9,
105 | "SUMMARY": "Analysis shows strong fundamentals with growth opportunities despite some regulatory risks.",
106 | }
107 | )
108 | return response
109 |
110 | llm.ainvoke.side_effect = mock_response
111 | return llm
112 |
113 |
114 | @pytest.fixture
115 | def comprehensive_search_results():
116 | """Comprehensive mock search results from multiple providers."""
117 | return [
118 | {
119 | "url": "https://finance.yahoo.com/news/apple-earnings-q4-2024",
120 | "title": "Apple Reports Strong Q4 2024 Earnings",
121 | "content": """Apple Inc. reported quarterly earnings that beat Wall Street expectations,
122 | driven by strong iPhone sales and growing services revenue. The company posted
123 | revenue of $94.9 billion, up 6% year-over-year. CEO Tim Cook highlighted the
124 | success of the iPhone 15 lineup and expressed optimism about AI integration
125 | in future products. Services revenue reached $22.3 billion, representing
126 | a 16% increase. The company also announced a 4% increase in quarterly dividend.""",
127 | "published_date": "2024-01-25T10:30:00Z",
128 | "score": 0.92,
129 | "provider": "exa",
130 | "author": "Financial Times Staff",
131 | },
132 | {
133 | "url": "https://seekingalpha.com/article/apple-technical-analysis-2024",
134 | "title": "Apple Stock Technical Analysis: Bullish Momentum Building",
135 | "content": """Technical analysis of Apple stock shows bullish momentum building
136 | with the stock breaking above key resistance at $190. Volume has been
137 | increasing on up days, suggesting institutional accumulation. The RSI
138 | is at 58, indicating room for further upside. Key support levels are
139 | at $185 and $180. Price target for the next quarter is $210-$220 based
140 | on chart patterns and momentum indicators.""",
141 | "published_date": "2024-01-24T14:45:00Z",
142 | "score": 0.85,
143 | "provider": "exa",
144 | "author": "Tech Analyst Pro",
145 | },
146 | {
147 | "url": "https://reuters.com/apple-supply-chain-concerns",
148 | "title": "Apple Faces Supply Chain Headwinds in 2024",
149 | "content": """Apple is encountering supply chain challenges that could impact
150 | production timelines for its upcoming product launches. Manufacturing
151 | partners in Asia report delays due to component shortages, particularly
152 | for advanced semiconductors. The company is working to diversify its
153 | supplier base to reduce risks. Despite these challenges, analysts
154 | remain optimistic about Apple's ability to meet demand through
155 | strategic inventory management.""",
156 | "published_date": "2024-01-23T08:15:00Z",
157 | "score": 0.78,
158 | "provider": "tavily",
159 | "author": "Reuters Technology Team",
160 | },
161 | {
162 | "url": "https://fool.com/apple-ai-strategy-competitive-advantage",
163 | "title": "Apple's AI Strategy Could Be Its Next Competitive Moat",
164 | "content": """Apple's approach to artificial intelligence differs significantly
165 | from competitors, focusing on on-device processing and privacy protection.
166 | The company's investment in AI chips and machine learning capabilities
167 | positions it well for the next phase of mobile computing. Industry
168 | experts predict Apple's AI integration will drive hardware upgrade
169 | cycles and create new revenue opportunities in services. The privacy-first
170 | approach could become a key differentiator in the market.""",
171 | "published_date": "2024-01-22T16:20:00Z",
172 | "score": 0.88,
173 | "provider": "exa",
174 | "author": "Investment Strategy Team",
175 | },
176 | {
177 | "url": "https://barrons.com/apple-dividend-growth-analysis",
178 | "title": "Apple's Dividend Growth Story Continues",
179 | "content": """Apple has increased its dividend for the 12th consecutive year,
180 | demonstrating strong cash flow generation and commitment to returning
181 | capital to shareholders. The company's dividend yield of 0.5% may seem
182 | modest, but the consistent growth rate of 7% annually makes it attractive
183 | for income-focused investors. With over $162 billion in cash and
184 | marketable securities, Apple has the financial flexibility to continue
185 | rewarding shareholders while investing in growth initiatives.""",
186 | "published_date": "2024-01-21T11:00:00Z",
187 | "score": 0.82,
188 | "provider": "tavily",
189 | "author": "Dividend Analysis Team",
190 | },
191 | ]
192 |
193 |
194 | @pytest.fixture
195 | def mock_research_agent(mock_llm):
196 | """Create a DeepResearchAgent with mocked dependencies."""
197 | with (
198 | patch("maverick_mcp.agents.deep_research.ExaSearchProvider") as mock_exa,
199 | patch("maverick_mcp.agents.deep_research.TavilySearchProvider") as mock_tavily,
200 | ):
201 | # Mock search providers
202 | mock_exa_instance = Mock()
203 | mock_tavily_instance = Mock()
204 | mock_exa.return_value = mock_exa_instance
205 | mock_tavily.return_value = mock_tavily_instance
206 |
207 | agent = DeepResearchAgent(
208 | llm=mock_llm,
209 | persona="moderate",
210 | exa_api_key="mock-key",
211 | tavily_api_key="mock-key",
212 | )
213 |
214 | # Add mock providers to the agent for testing
215 | agent.search_providers = [mock_exa_instance, mock_tavily_instance]
216 |
217 | return agent
218 |
219 |
220 | class TestWebSearchIntegration:
221 | """Test web search integration and result processing."""
222 |
223 | @pytest.mark.asyncio
224 | async def test_exa_search_provider_query_formatting(self):
225 | """Test that Exa search queries are properly formatted and sent."""
226 | with patch("maverick_mcp.agents.deep_research.circuit_manager") as mock_circuit:
227 | mock_circuit.get_or_create = AsyncMock()
228 | mock_circuit_instance = AsyncMock()
229 | mock_circuit.get_or_create.return_value = mock_circuit_instance
230 |
231 | # Mock the Exa client response
232 | mock_exa_response = Mock()
233 | mock_exa_response.results = [
234 | Mock(
235 | url="https://example.com/test",
236 | title="Test Article",
237 | text="Test content for search",
238 | summary="Test summary",
239 | highlights=["key highlight"],
240 | published_date="2024-01-25",
241 | author="Test Author",
242 | score=0.9,
243 | )
244 | ]
245 |
246 | with patch("exa_py.Exa") as mock_exa_client:
247 | mock_client_instance = Mock()
248 | mock_client_instance.search_and_contents.return_value = (
249 | mock_exa_response
250 | )
251 | mock_exa_client.return_value = mock_client_instance
252 |
253 | # Create actual provider (not mocked)
254 | provider = ExaSearchProvider("test-api-key")
255 | mock_circuit_instance.call.return_value = [
256 | {
257 | "url": "https://example.com/test",
258 | "title": "Test Article",
259 | "content": "Test content for search",
260 | "summary": "Test summary",
261 | "highlights": ["key highlight"],
262 | "published_date": "2024-01-25",
263 | "author": "Test Author",
264 | "score": 0.9,
265 | "provider": "exa",
266 | }
267 | ]
268 |
269 | # Test the search
270 | results = await provider.search("AAPL stock analysis", num_results=5)
271 |
272 | # Verify query was properly formatted
273 | assert len(results) == 1
274 | assert results[0]["url"] == "https://example.com/test"
275 | assert results[0]["provider"] == "exa"
276 | assert results[0]["score"] == 0.9
277 |
278 | @pytest.mark.asyncio
279 | async def test_tavily_search_result_processing(self):
280 | """Test Tavily search result processing and filtering."""
281 | with patch("maverick_mcp.agents.deep_research.circuit_manager") as mock_circuit:
282 | mock_circuit.get_or_create = AsyncMock()
283 | mock_circuit_instance = AsyncMock()
284 | mock_circuit.get_or_create.return_value = mock_circuit_instance
285 |
286 | mock_tavily_response = {
287 | "results": [
288 | {
289 | "url": "https://news.example.com/tech-news",
290 | "title": "Tech News Article",
291 | "content": "Content about technology trends",
292 | "raw_content": "Extended raw content with more details",
293 | "published_date": "2024-01-25",
294 | "score": 0.85,
295 | },
296 | {
297 | "url": "https://facebook.com/social-post", # Should be filtered out
298 | "title": "Social Media Post",
299 | "content": "Social media content",
300 | "score": 0.7,
301 | },
302 | ]
303 | }
304 |
305 | with patch("tavily.TavilyClient") as mock_tavily_client:
306 | mock_client_instance = Mock()
307 | mock_client_instance.search.return_value = mock_tavily_response
308 | mock_tavily_client.return_value = mock_client_instance
309 |
310 | provider = TavilySearchProvider("test-api-key")
311 | mock_circuit_instance.call.return_value = [
312 | {
313 | "url": "https://news.example.com/tech-news",
314 | "title": "Tech News Article",
315 | "content": "Content about technology trends",
316 | "raw_content": "Extended raw content with more details",
317 | "published_date": "2024-01-25",
318 | "score": 0.85,
319 | "provider": "tavily",
320 | }
321 | ]
322 |
323 | results = await provider.search("tech trends analysis")
324 |
325 | # Verify results are properly processed and social media filtered
326 | assert len(results) == 1
327 | assert results[0]["provider"] == "tavily"
328 | assert "facebook.com" not in results[0]["url"]
329 |
330 | @pytest.mark.asyncio
331 | async def test_search_provider_fallback_behavior(self, mock_research_agent):
332 | """Test fallback behavior when search providers fail."""
333 | # Mock the execute searches workflow step directly
334 | with patch.object(mock_research_agent, "_execute_searches") as mock_execute:
335 | # Mock first provider to fail, second to succeed
336 | mock_research_agent.search_providers[0].search = AsyncMock(
337 | side_effect=WebSearchError("Exa API rate limit exceeded")
338 | )
339 |
340 | mock_research_agent.search_providers[1].search = AsyncMock(
341 | return_value=[
342 | {
343 | "url": "https://backup-source.com/article",
344 | "title": "Backup Article",
345 | "content": "Fallback content from secondary provider",
346 | "provider": "tavily",
347 | "score": 0.75,
348 | }
349 | ]
350 | )
351 |
352 | # Mock successful execution with fallback results
353 | mock_result = Mock()
354 | mock_result.goto = "analyze_content"
355 | mock_result.update = {
356 | "search_results": [
357 | {
358 | "url": "https://backup-source.com/article",
359 | "title": "Backup Article",
360 | "content": "Fallback content from secondary provider",
361 | "provider": "tavily",
362 | "score": 0.75,
363 | }
364 | ],
365 | "research_status": "analyzing",
366 | }
367 | mock_execute.return_value = mock_result
368 |
369 | # Test state for search execution
370 | state = {"search_queries": ["AAPL analysis"], "research_depth": "standard"}
371 |
372 | # Execute the search step
373 | result = await mock_research_agent._execute_searches(state)
374 |
375 | # Should handle provider failure gracefully
376 | assert result.goto == "analyze_content"
377 | assert len(result.update["search_results"]) > 0
378 |
379 | @pytest.mark.asyncio
380 | async def test_search_result_deduplication(self, comprehensive_search_results):
381 | """Test deduplication of search results from multiple providers."""
382 | # Create search results with duplicates
383 | duplicate_results = (
384 | comprehensive_search_results
385 | + [
386 | {
387 | "url": "https://finance.yahoo.com/news/apple-earnings-q4-2024", # Duplicate URL
388 | "title": "Apple Q4 Results (Duplicate)",
389 | "content": "Duplicate content with different title",
390 | "provider": "tavily",
391 | "score": 0.7,
392 | }
393 | ]
394 | )
395 |
396 | with patch.object(DeepResearchAgent, "_execute_searches") as mock_execute:
397 | mock_execute.return_value = Mock()
398 |
399 | DeepResearchAgent(llm=MagicMock(), persona="moderate")
400 |
401 | # Test the deduplication logic directly
402 |
403 | # Simulate search execution with duplicates
404 | all_results = duplicate_results
405 | unique_results = []
406 | seen_urls = set()
407 | depth_config = RESEARCH_DEPTH_LEVELS["standard"]
408 |
409 | for result in all_results:
410 | if (
411 | result["url"] not in seen_urls
412 | and len(unique_results) < depth_config["max_sources"]
413 | ):
414 | unique_results.append(result)
415 | seen_urls.add(result["url"])
416 |
417 | # Verify deduplication worked
418 | assert len(unique_results) == 5 # Should remove 1 duplicate
419 | urls = [r["url"] for r in unique_results]
420 | assert len(set(urls)) == len(urls) # All URLs should be unique
421 |
422 |
423 | class TestResearchSynthesis:
424 | """Test research synthesis and iterative querying functionality."""
425 |
426 | @pytest.mark.asyncio
427 | async def test_content_analysis_with_persona_focus(
428 | self, comprehensive_search_results
429 | ):
430 | """Test that content analysis adapts to persona focus areas."""
431 | # Mock LLM with persona-specific responses
432 | mock_llm = MagicMock()
433 |
434 | def persona_aware_response(messages):
435 | response = Mock()
436 | # Check if content is about dividends for conservative persona
437 | content = messages[1].content if len(messages) > 1 else ""
438 | if "conservative" in content and "dividend" in content:
439 | response.content = json.dumps(
440 | {
441 | "KEY_INSIGHTS": [
442 | "Strong dividend yield provides stable income"
443 | ],
444 | "SENTIMENT": {"direction": "bullish", "confidence": 0.7},
445 | "RISK_FACTORS": ["Interest rate sensitivity"],
446 | "OPPORTUNITIES": ["Consistent dividend growth"],
447 | "CREDIBILITY": 0.85,
448 | "RELEVANCE": 0.9,
449 | "SUMMARY": "Dividend analysis shows strong income potential for conservative investors.",
450 | }
451 | )
452 | else:
453 | response.content = json.dumps(
454 | {
455 | "KEY_INSIGHTS": ["Growth opportunity in AI sector"],
456 | "SENTIMENT": {"direction": "bullish", "confidence": 0.8},
457 | "RISK_FACTORS": ["Market competition"],
458 | "OPPORTUNITIES": ["Innovation leadership"],
459 | "CREDIBILITY": 0.8,
460 | "RELEVANCE": 0.85,
461 | "SUMMARY": "Analysis shows strong growth opportunities through innovation.",
462 | }
463 | )
464 | return response
465 |
466 | mock_llm.ainvoke = AsyncMock(side_effect=persona_aware_response)
467 | analyzer = ContentAnalyzer(mock_llm)
468 |
469 | # Test conservative persona analysis with dividend content
470 | conservative_result = await analyzer.analyze_content(
471 | content=comprehensive_search_results[4]["content"], # Dividend article
472 | persona="conservative",
473 | )
474 |
475 | # Verify conservative-focused analysis
476 | assert conservative_result["relevance_score"] > 0.8
477 | assert (
478 | "dividend" in conservative_result["summary"].lower()
479 | or "income" in conservative_result["summary"].lower()
480 | )
481 |
482 | # Test aggressive persona analysis with growth content
483 | aggressive_result = await analyzer.analyze_content(
484 | content=comprehensive_search_results[3]["content"], # AI strategy article
485 | persona="aggressive",
486 | )
487 |
488 | # Verify aggressive-focused analysis
489 | assert aggressive_result["relevance_score"] > 0.7
490 | assert any(
491 | keyword in aggressive_result["summary"].lower()
492 | for keyword in ["growth", "opportunity", "innovation"]
493 | )
494 |
495 | @pytest.mark.asyncio
496 | async def test_research_synthesis_workflow(
497 | self, mock_research_agent, comprehensive_search_results
498 | ):
499 | """Test the complete research synthesis workflow."""
500 | # Mock the workflow components using the actual graph structure
501 | with patch.object(mock_research_agent, "graph") as mock_graph:
502 | # Mock successful workflow execution with all required fields
503 | mock_result = {
504 | "research_topic": "AAPL",
505 | "research_depth": "standard",
506 | "search_queries": ["AAPL financial analysis", "Apple earnings 2024"],
507 | "search_results": comprehensive_search_results,
508 | "analyzed_content": [
509 | {
510 | **result,
511 | "analysis": {
512 | "insights": [
513 | "Strong revenue growth",
514 | "AI integration opportunity",
515 | ],
516 | "sentiment": {"direction": "bullish", "confidence": 0.8},
517 | "risk_factors": [
518 | "Supply chain risks",
519 | "Regulatory concerns",
520 | ],
521 | "opportunities": ["AI monetization", "Services expansion"],
522 | "credibility_score": 0.85,
523 | "relevance_score": 0.9,
524 | "summary": "Strong fundamentals with growth catalysts",
525 | },
526 | }
527 | for result in comprehensive_search_results[:3]
528 | ],
529 | "validated_sources": comprehensive_search_results[:3],
530 | "research_findings": {
531 | "synthesis": "Apple shows strong fundamentals with growth opportunities",
532 | "key_insights": ["Revenue growth", "AI opportunities"],
533 | "overall_sentiment": {"direction": "bullish", "confidence": 0.8},
534 | "confidence_score": 0.82,
535 | },
536 | "citations": [
537 | {"id": 1, "title": "Apple Earnings", "url": "https://example.com/1"}
538 | ],
539 | "research_status": "completed",
540 | "research_confidence": 0.82,
541 | "execution_time_ms": 1500.0,
542 | "persona": "moderate",
543 | }
544 |
545 | mock_graph.ainvoke = AsyncMock(return_value=mock_result)
546 |
547 | # Execute research
548 | result = await mock_research_agent.research_comprehensive(
549 | topic="AAPL", session_id="test_synthesis", depth="standard"
550 | )
551 |
552 | # Verify synthesis was performed
553 | assert result["status"] == "success"
554 | assert "findings" in result
555 | assert result["sources_analyzed"] > 0
556 |
557 | @pytest.mark.asyncio
558 | async def test_iterative_research_refinement(self, mock_research_agent):
559 | """Test iterative research with follow-up queries based on initial findings."""
560 | # Mock initial research finding gaps
561 |
562 | with patch.object(
563 | mock_research_agent, "_generate_search_queries"
564 | ) as mock_queries:
565 | # First iteration - general queries
566 | mock_queries.return_value = [
567 | "NVDA competitive analysis",
568 | "NVIDIA market position 2024",
569 | ]
570 |
571 | queries_first = await mock_research_agent._generate_search_queries(
572 | topic="NVDA competitive position",
573 | persona_focus=PERSONA_RESEARCH_FOCUS["moderate"],
574 | depth_config=RESEARCH_DEPTH_LEVELS["standard"],
575 | )
576 |
577 | # Verify initial queries are broad
578 | assert any("competitive" in q.lower() for q in queries_first)
579 | assert any("NVDA" in q or "NVIDIA" in q for q in queries_first)
580 |
581 | @pytest.mark.asyncio
582 | async def test_fact_validation_and_source_credibility(self, mock_research_agent):
583 | """Test fact validation and source credibility scoring."""
584 | # Test source credibility calculation
585 | test_sources = [
586 | {
587 | "url": "https://sec.gov/filing/aapl-10k-2024",
588 | "title": "Apple 10-K Filing",
589 | "content": "Official SEC filing content",
590 | "published_date": "2024-01-20T00:00:00Z",
591 | "analysis": {"credibility_score": 0.9},
592 | },
593 | {
594 | "url": "https://random-blog.com/apple-speculation",
595 | "title": "Random Blog Post",
596 | "content": "Speculative content with no sources",
597 | "published_date": "2023-06-01T00:00:00Z", # Old content
598 | "analysis": {"credibility_score": 0.3},
599 | },
600 | ]
601 |
602 | # Test credibility scoring
603 | for source in test_sources:
604 | credibility = mock_research_agent._calculate_source_credibility(source)
605 |
606 | if "sec.gov" in source["url"]:
607 | assert (
608 | credibility >= 0.8
609 | ) # Government sources should be highly credible
610 | elif "random-blog" in source["url"]:
611 | assert credibility <= 0.6 # Random blogs should have lower credibility
612 |
613 |
614 | class TestPersonaBasedResearch:
615 | """Test persona-based research behavior and adaptation."""
616 |
617 | @pytest.mark.asyncio
618 | async def test_conservative_persona_research_focus(self, mock_llm):
619 | """Test that conservative persona focuses on stability and risk factors."""
620 | agent = DeepResearchAgent(llm=mock_llm, persona="conservative")
621 |
622 | # Test search query generation for conservative persona
623 | persona_focus = PERSONA_RESEARCH_FOCUS["conservative"]
624 | depth_config = RESEARCH_DEPTH_LEVELS["standard"]
625 |
626 | queries = await agent._generate_search_queries(
627 | topic="AAPL", persona_focus=persona_focus, depth_config=depth_config
628 | )
629 |
630 | # Verify conservative-focused queries
631 | query_text = " ".join(queries).lower()
632 | assert any(
633 | keyword in query_text for keyword in ["dividend", "stability", "risk"]
634 | )
635 |
636 | # Test that conservative persona performs more thorough fact-checking
637 | assert persona_focus["risk_focus"] == "downside protection"
638 | assert persona_focus["time_horizon"] == "long-term"
639 |
640 | @pytest.mark.asyncio
641 | async def test_aggressive_persona_research_behavior(self, mock_llm):
642 | """Test aggressive persona explores speculative opportunities."""
643 | agent = DeepResearchAgent(llm=mock_llm, persona="aggressive")
644 |
645 | persona_focus = PERSONA_RESEARCH_FOCUS["aggressive"]
646 |
647 | # Test query generation for aggressive persona
648 | queries = await agent._generate_search_queries(
649 | topic="TSLA",
650 | persona_focus=persona_focus,
651 | depth_config=RESEARCH_DEPTH_LEVELS["standard"],
652 | )
653 |
654 | # Verify aggressive-focused queries
655 | query_text = " ".join(queries).lower()
656 | assert any(
657 | keyword in query_text for keyword in ["growth", "momentum", "opportunity"]
658 | )
659 |
660 | # Verify aggressive characteristics
661 | assert persona_focus["risk_focus"] == "upside potential"
662 | assert "innovation" in persona_focus["keywords"]
663 |
664 | @pytest.mark.asyncio
665 | async def test_day_trader_persona_short_term_focus(self, mock_llm):
666 | """Test day trader persona focuses on short-term catalysts and volatility."""
667 | DeepResearchAgent(llm=mock_llm, persona="day_trader")
668 |
669 | persona_focus = PERSONA_RESEARCH_FOCUS["day_trader"]
670 |
671 | # Test characteristics specific to day trader persona
672 | assert persona_focus["time_horizon"] == "intraday to weekly"
673 | assert "catalysts" in persona_focus["keywords"]
674 | assert "volatility" in persona_focus["keywords"]
675 | assert "earnings" in persona_focus["keywords"]
676 |
677 | # Test sources preference
678 | assert "breaking news" in persona_focus["sources"]
679 | assert "social sentiment" in persona_focus["sources"]
680 |
681 | @pytest.mark.asyncio
682 | async def test_research_depth_differences_by_persona(self, mock_llm):
683 | """Test that conservative personas do more thorough research."""
684 | conservative_agent = DeepResearchAgent(
685 | llm=mock_llm, persona="conservative", default_depth="comprehensive"
686 | )
687 |
688 | aggressive_agent = DeepResearchAgent(
689 | llm=mock_llm, persona="aggressive", default_depth="standard"
690 | )
691 |
692 | # Conservative should use more comprehensive depth by default
693 | assert conservative_agent.default_depth == "comprehensive"
694 |
695 | # Aggressive can use standard depth for faster decisions
696 | assert aggressive_agent.default_depth == "standard"
697 |
698 | # Test depth level configurations
699 | comprehensive_config = RESEARCH_DEPTH_LEVELS["comprehensive"]
700 | standard_config = RESEARCH_DEPTH_LEVELS["standard"]
701 |
702 | assert comprehensive_config["max_sources"] > standard_config["max_sources"]
703 | assert comprehensive_config["validation_required"]
704 |
705 |
706 | class TestMultiStepResearchWorkflow:
707 | """Test complete multi-step research workflows."""
708 |
709 | @pytest.mark.asyncio
710 | async def test_complete_research_workflow_success(
711 | self, mock_research_agent, comprehensive_search_results
712 | ):
713 | """Test complete research workflow from query to final report."""
714 | # Mock all workflow steps
715 | with patch.object(mock_research_agent, "graph") as mock_graph:
716 | # Mock successful workflow execution
717 | mock_result = {
718 | "research_topic": "AAPL",
719 | "research_depth": "standard",
720 | "search_queries": ["AAPL analysis", "Apple earnings"],
721 | "search_results": comprehensive_search_results,
722 | "analyzed_content": [
723 | {
724 | **result,
725 | "analysis": {
726 | "insights": ["Strong performance"],
727 | "sentiment": {"direction": "bullish", "confidence": 0.8},
728 | "credibility_score": 0.85,
729 | },
730 | }
731 | for result in comprehensive_search_results
732 | ],
733 | "validated_sources": comprehensive_search_results[:3],
734 | "research_findings": {
735 | "synthesis": "Apple shows strong fundamentals with growth opportunities",
736 | "key_insights": [
737 | "Revenue growth",
738 | "AI opportunities",
739 | "Strong cash flow",
740 | ],
741 | "overall_sentiment": {"direction": "bullish", "confidence": 0.8},
742 | "confidence_score": 0.82,
743 | },
744 | "citations": [
745 | {
746 | "id": 1,
747 | "title": "Apple Earnings",
748 | "url": "https://example.com/1",
749 | },
750 | {
751 | "id": 2,
752 | "title": "Technical Analysis",
753 | "url": "https://example.com/2",
754 | },
755 | ],
756 | "research_status": "completed",
757 | "research_confidence": 0.82,
758 | "execution_time_ms": 1500.0,
759 | }
760 |
761 | mock_graph.ainvoke = AsyncMock(return_value=mock_result)
762 |
763 | # Execute complete research
764 | result = await mock_research_agent.research_comprehensive(
765 | topic="AAPL", session_id="workflow_test", depth="standard"
766 | )
767 |
768 | # Verify complete workflow
769 | assert result["status"] == "success"
770 | assert result["agent_type"] == "deep_research"
771 | assert result["research_topic"] == "AAPL"
772 | assert result["sources_analyzed"] == 3
773 | assert result["confidence_score"] == 0.82
774 | assert len(result["citations"]) == 2
775 |
776 | @pytest.mark.asyncio
777 | async def test_research_workflow_with_insufficient_information(
778 | self, mock_research_agent
779 | ):
780 | """Test workflow handling when insufficient information is found."""
781 | # Mock scenario with limited/poor quality results
782 | with patch.object(mock_research_agent, "graph") as mock_graph:
783 | mock_result = {
784 | "research_topic": "OBSCURE_STOCK",
785 | "research_depth": "standard",
786 | "search_results": [], # No results found
787 | "validated_sources": [],
788 | "research_findings": {},
789 | "research_confidence": 0.1, # Very low confidence
790 | "research_status": "completed",
791 | "execution_time_ms": 800.0,
792 | }
793 |
794 | mock_graph.ainvoke = AsyncMock(return_value=mock_result)
795 |
796 | result = await mock_research_agent.research_comprehensive(
797 | topic="OBSCURE_STOCK", session_id="insufficient_test"
798 | )
799 |
800 | # Should handle insufficient information gracefully
801 | assert result["status"] == "success"
802 | assert result["confidence_score"] == 0.1
803 | assert result["sources_analyzed"] == 0
804 |
805 | @pytest.mark.asyncio
806 | async def test_research_with_conflicting_information(self, mock_research_agent):
807 | """Test handling of conflicting information from different sources."""
808 | conflicting_sources = [
809 | {
810 | "url": "https://bull-analyst.com/buy-rating",
811 | "title": "Strong Buy Rating for AAPL",
812 | "analysis": {
813 | "sentiment": {"direction": "bullish", "confidence": 0.9},
814 | "credibility_score": 0.8,
815 | },
816 | },
817 | {
818 | "url": "https://bear-analyst.com/sell-rating",
819 | "title": "Sell Rating for AAPL Due to Overvaluation",
820 | "analysis": {
821 | "sentiment": {"direction": "bearish", "confidence": 0.8},
822 | "credibility_score": 0.7,
823 | },
824 | },
825 | ]
826 |
827 | # Test overall sentiment calculation with conflicting sources
828 | overall_sentiment = mock_research_agent._calculate_overall_sentiment(
829 | conflicting_sources
830 | )
831 |
832 | # Should handle conflicts by providing consensus information
833 | assert overall_sentiment["direction"] in ["bullish", "bearish", "neutral"]
834 | assert "consensus" in overall_sentiment
835 | assert overall_sentiment["source_count"] == 2
836 |
837 | @pytest.mark.asyncio
838 | async def test_research_focus_and_refinement(self, mock_research_agent):
839 | """Test research focusing and refinement based on initial findings."""
840 | # Test different research focus areas
841 | focus_areas = ["sentiment", "fundamental", "competitive"]
842 |
843 | for focus in focus_areas:
844 | route = mock_research_agent._route_specialized_analysis(
845 | {"focus_areas": [focus]}
846 | )
847 |
848 | if focus == "sentiment":
849 | assert route == "sentiment"
850 | elif focus == "fundamental":
851 | assert route == "fundamental"
852 | elif focus == "competitive":
853 | assert route == "competitive"
854 |
855 |
856 | class TestResearchMethodSpecialization:
857 | """Test specialized research methods: sentiment, fundamental, competitive analysis."""
858 |
859 | @pytest.mark.asyncio
860 | async def test_sentiment_analysis_specialization(self, mock_research_agent):
861 | """Test sentiment analysis research method."""
862 | test_state = {
863 | "focus_areas": [
864 | "sentiment",
865 | "news",
866 | ], # Use keywords that match routing logic
867 | "analyzed_content": [],
868 | }
869 |
870 | # Test sentiment analysis routing
871 | route = mock_research_agent._route_specialized_analysis(test_state)
872 | assert route == "sentiment"
873 |
874 | # Test sentiment analysis execution (mocked)
875 | with patch.object(mock_research_agent, "_analyze_content") as mock_analyze:
876 | mock_analyze.return_value = Mock()
877 |
878 | await mock_research_agent._sentiment_analysis(test_state)
879 | mock_analyze.assert_called_once()
880 |
881 | @pytest.mark.asyncio
882 | async def test_fundamental_analysis_specialization(self, mock_research_agent):
883 | """Test fundamental analysis research method."""
884 | test_state = {
885 | "focus_areas": [
886 | "fundamental",
887 | "financial",
888 | ], # Use exact keywords from routing logic
889 | "analyzed_content": [],
890 | }
891 |
892 | # Test fundamental analysis routing
893 | route = mock_research_agent._route_specialized_analysis(test_state)
894 | assert route == "fundamental"
895 |
896 | # Test fundamental analysis execution
897 | with patch.object(mock_research_agent, "_analyze_content") as mock_analyze:
898 | mock_analyze.return_value = Mock()
899 |
900 | await mock_research_agent._fundamental_analysis(test_state)
901 | mock_analyze.assert_called_once()
902 |
903 | @pytest.mark.asyncio
904 | async def test_competitive_analysis_specialization(self, mock_research_agent):
905 | """Test competitive analysis research method."""
906 | test_state = {
907 | "focus_areas": [
908 | "competitive",
909 | "market",
910 | ], # Use exact keywords from routing logic
911 | "analyzed_content": [],
912 | }
913 |
914 | # Test competitive analysis routing
915 | route = mock_research_agent._route_specialized_analysis(test_state)
916 | assert route == "competitive"
917 |
918 | # Test competitive analysis execution
919 | with patch.object(mock_research_agent, "_analyze_content") as mock_analyze:
920 | mock_analyze.return_value = Mock()
921 |
922 | await mock_research_agent._competitive_analysis(test_state)
923 | mock_analyze.assert_called_once()
924 |
925 |
926 | class TestErrorHandlingAndResilience:
927 | """Test error handling and system resilience."""
928 |
929 | @pytest.mark.asyncio
930 | async def test_research_agent_with_no_search_providers(self, mock_llm):
931 | """Test research agent behavior with no available search providers."""
932 | # Create agent without search providers
933 | agent = DeepResearchAgent(llm=mock_llm, persona="moderate")
934 |
935 | # Should initialize successfully but with limited capabilities
936 | assert len(agent.search_providers) == 0
937 |
938 | # Research should still attempt to work but with limited results
939 | result = await agent.research_comprehensive(
940 | topic="TEST", session_id="no_providers_test"
941 | )
942 |
943 | # Should not crash, may return limited results
944 | assert "status" in result
945 |
946 | @pytest.mark.asyncio
947 | async def test_content_analysis_fallback_on_llm_failure(
948 | self, comprehensive_search_results
949 | ):
950 | """Test content analysis fallback when LLM fails."""
951 | # Mock LLM that fails
952 | failing_llm = MagicMock()
953 | failing_llm.ainvoke = AsyncMock(
954 | side_effect=Exception("LLM service unavailable")
955 | )
956 |
957 | analyzer = ContentAnalyzer(failing_llm)
958 |
959 | # Should use fallback analysis
960 | result = await analyzer.analyze_content(
961 | content=comprehensive_search_results[0]["content"], persona="conservative"
962 | )
963 |
964 | # Verify fallback was used
965 | assert result["fallback_used"]
966 | assert result["sentiment"]["direction"] in ["bullish", "bearish", "neutral"]
967 | assert 0 <= result["credibility_score"] <= 1
968 | assert 0 <= result["relevance_score"] <= 1
969 |
970 | @pytest.mark.asyncio
971 | async def test_partial_search_failure_handling(self, mock_research_agent):
972 | """Test handling when some but not all search providers fail."""
973 | # Test the actual search execution logic directly
974 | mock_research_agent.search_providers[0].search = AsyncMock(
975 | side_effect=WebSearchError("Provider 1 failed")
976 | )
977 |
978 | mock_research_agent.search_providers[1].search = AsyncMock(
979 | return_value=[
980 | {
981 | "url": "https://working-provider.com/article",
982 | "title": "Working Provider Article",
983 | "content": "Content from working provider",
984 | "provider": "working_provider",
985 | "score": 0.8,
986 | }
987 | ]
988 | )
989 |
990 | # Test the search execution directly
991 | state = {"search_queries": ["test query"], "research_depth": "standard"}
992 |
993 | result = await mock_research_agent._execute_searches(state)
994 |
995 | # Should continue with working providers and return results
996 | assert hasattr(result, "update")
997 | assert "search_results" in result.update
998 | # Should have at least the working provider results
999 | assert (
1000 | len(result.update["search_results"]) >= 0
1001 | ) # May be 0 if all fail, but should not crash
1002 |
1003 | @pytest.mark.asyncio
1004 | async def test_research_timeout_and_circuit_breaker(self, mock_research_agent):
1005 | """Test research timeout handling and circuit breaker behavior."""
1006 | # Test would require actual circuit breaker implementation
1007 | # This is a placeholder for circuit breaker testing
1008 |
1009 | with patch(
1010 | "maverick_mcp.agents.circuit_breaker.circuit_manager"
1011 | ) as mock_circuit:
1012 | mock_circuit.get_or_create = AsyncMock()
1013 | circuit_instance = AsyncMock()
1014 | mock_circuit.get_or_create.return_value = circuit_instance
1015 |
1016 | # Mock circuit breaker open state
1017 | circuit_instance.call = AsyncMock(
1018 | side_effect=Exception("Circuit breaker open")
1019 | )
1020 |
1021 | # Research should handle circuit breaker gracefully
1022 | # Implementation depends on actual circuit breaker behavior
1023 | pass
1024 |
1025 |
1026 | class TestResearchQualityAndValidation:
1027 | """Test research quality assurance and validation mechanisms."""
1028 |
1029 | def test_research_confidence_calculation(self, mock_research_agent):
1030 | """Test research confidence calculation based on multiple factors."""
1031 | # Test with high-quality sources
1032 | high_quality_sources = [
1033 | {
1034 | "url": "https://sec.gov/filing1",
1035 | "credibility_score": 0.95,
1036 | "analysis": {"relevance_score": 0.9},
1037 | },
1038 | {
1039 | "url": "https://bloomberg.com/article1",
1040 | "credibility_score": 0.85,
1041 | "analysis": {"relevance_score": 0.8},
1042 | },
1043 | {
1044 | "url": "https://reuters.com/article2",
1045 | "credibility_score": 0.8,
1046 | "analysis": {"relevance_score": 0.85},
1047 | },
1048 | ]
1049 |
1050 | confidence = mock_research_agent._calculate_research_confidence(
1051 | high_quality_sources
1052 | )
1053 | assert confidence >= 0.65 # Should be reasonably high confidence
1054 |
1055 | # Test with low-quality sources
1056 | low_quality_sources = [
1057 | {
1058 | "url": "https://random-blog.com/post1",
1059 | "credibility_score": 0.3,
1060 | "analysis": {"relevance_score": 0.4},
1061 | }
1062 | ]
1063 |
1064 | low_confidence = mock_research_agent._calculate_research_confidence(
1065 | low_quality_sources
1066 | )
1067 | assert low_confidence < 0.5 # Should be low confidence
1068 |
1069 | def test_source_diversity_scoring(self, mock_research_agent):
1070 | """Test source diversity calculation."""
1071 | diverse_sources = [
1072 | {"url": "https://sec.gov/filing"},
1073 | {"url": "https://bloomberg.com/news"},
1074 | {"url": "https://reuters.com/article"},
1075 | {"url": "https://wsj.com/story"},
1076 | {"url": "https://ft.com/content"},
1077 | ]
1078 |
1079 | confidence = mock_research_agent._calculate_research_confidence(diverse_sources)
1080 |
1081 | # More diverse sources should contribute to higher confidence
1082 | assert confidence > 0.6
1083 |
1084 | def test_investment_recommendation_logic(self, mock_research_agent):
1085 | """Test investment recommendation based on research findings."""
1086 | # Test bullish scenario
1087 | bullish_sources = [
1088 | {
1089 | "analysis": {
1090 | "sentiment": {"direction": "bullish", "confidence": 0.9},
1091 | "credibility_score": 0.8,
1092 | }
1093 | }
1094 | ]
1095 |
1096 | recommendation = mock_research_agent._recommend_action(bullish_sources)
1097 |
1098 | # Conservative persona should be more cautious
1099 | if mock_research_agent.persona.name.lower() == "conservative":
1100 | assert (
1101 | "gradual" in recommendation.lower()
1102 | or "risk management" in recommendation.lower()
1103 | )
1104 | else:
1105 | assert (
1106 | "consider" in recommendation.lower()
1107 | and "position" in recommendation.lower()
1108 | )
1109 |
1110 |
1111 | if __name__ == "__main__":
1112 | pytest.main([__file__, "-v", "--tb=short"])
1113 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/backtesting.py:
--------------------------------------------------------------------------------
```python
1 | """MCP router for VectorBT backtesting tools with structured logging."""
2 |
3 | from typing import Any
4 |
5 | import numpy as np
6 | from fastmcp import Context
7 |
8 | from maverick_mcp.backtesting import (
9 | BacktestAnalyzer,
10 | StrategyOptimizer,
11 | VectorBTEngine,
12 | )
13 | from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES, StrategyParser
14 | from maverick_mcp.backtesting.strategies.templates import (
15 | get_strategy_info,
16 | list_available_strategies,
17 | )
18 | from maverick_mcp.backtesting.visualization import (
19 | generate_equity_curve,
20 | generate_optimization_heatmap,
21 | generate_performance_dashboard,
22 | generate_trade_scatter,
23 | )
24 | from maverick_mcp.utils.debug_utils import debug_operation
25 | from maverick_mcp.utils.logging import get_logger
26 | from maverick_mcp.utils.structured_logger import (
27 | CorrelationIDGenerator,
28 | get_performance_logger,
29 | with_structured_logging,
30 | )
31 |
32 | # Initialize performance logger for backtesting router
33 | performance_logger = get_performance_logger("backtesting_router")
34 | logger = get_logger(__name__)
35 |
36 |
37 | def convert_numpy_types(obj: Any) -> Any:
38 | """Recursively convert numpy types to Python native types for JSON serialization.
39 |
40 | Args:
41 | obj: Any object that might contain numpy types
42 |
43 | Returns:
44 | Object with all numpy types converted to Python native types
45 | """
46 | import pandas as pd
47 |
48 | # Check for numpy integer types (more robust using numpy's type hierarchy)
49 | if isinstance(obj, np.integer):
50 | return int(obj)
51 | # Check for numpy floating point types
52 | elif isinstance(obj, np.floating):
53 | return float(obj)
54 | # Check for numpy boolean type
55 | elif isinstance(obj, np.bool_ | bool) and hasattr(obj, "item"):
56 | return bool(obj)
57 | # Check for numpy complex types
58 | elif isinstance(obj, np.complexfloating):
59 | return complex(obj)
60 | # Handle numpy arrays
61 | elif isinstance(obj, np.ndarray):
62 | return obj.tolist()
63 | # Handle pandas Series
64 | elif isinstance(obj, pd.Series):
65 | return obj.tolist()
66 | # Handle pandas DataFrame
67 | elif isinstance(obj, pd.DataFrame):
68 | return obj.to_dict("records")
69 | # Handle NaN/None values
70 | elif pd.isna(obj):
71 | return None
72 | # Handle other numpy scalars with .item() method
73 | elif hasattr(obj, "item") and hasattr(obj, "dtype"):
74 | try:
75 | return obj.item()
76 | except Exception:
77 | return str(obj)
78 | # Recursively handle dictionaries
79 | elif isinstance(obj, dict):
80 | return {key: convert_numpy_types(value) for key, value in obj.items()}
81 | # Recursively handle lists and tuples
82 | elif isinstance(obj, list | tuple):
83 | return [convert_numpy_types(item) for item in obj]
84 | # Try to handle custom objects with __dict__
85 | elif hasattr(obj, "__dict__") and not isinstance(obj, type):
86 | try:
87 | return convert_numpy_types(obj.__dict__)
88 | except Exception:
89 | return str(obj)
90 | else:
91 | # Return as-is for regular Python types
92 | return obj
93 |
94 |
95 | def setup_backtesting_tools(mcp):
96 | """Set up VectorBT backtesting tools for MCP.
97 |
98 | Args:
99 | mcp: FastMCP instance
100 | """
101 |
102 | @mcp.tool()
103 | @with_structured_logging("run_backtest", include_performance=True, log_params=True)
104 | @debug_operation("run_backtest", enable_profiling=True, symbol="backtest_symbol")
105 | async def run_backtest(
106 | ctx: Context,
107 | symbol: str,
108 | strategy: str = "sma_cross",
109 | start_date: str | None = None,
110 | end_date: str | None = None,
111 | initial_capital: float = 10000.0,
112 | fast_period: str | int | None = None,
113 | slow_period: str | int | None = None,
114 | period: str | int | None = None,
115 | oversold: str | float | None = None,
116 | overbought: str | float | None = None,
117 | signal_period: str | int | None = None,
118 | std_dev: str | float | None = None,
119 | lookback: str | int | None = None,
120 | threshold: str | float | None = None,
121 | z_score_threshold: str | float | None = None,
122 | breakout_factor: str | float | None = None,
123 | ) -> dict[str, Any]:
124 | """Run a VectorBT backtest with specified strategy and parameters.
125 |
126 | Args:
127 | symbol: Stock symbol to backtest
128 | strategy: Strategy type (sma_cross, rsi, macd, bollinger, momentum, etc.)
129 | start_date: Start date (YYYY-MM-DD), defaults to 1 year ago
130 | end_date: End date (YYYY-MM-DD), defaults to today
131 | initial_capital: Starting capital for backtest
132 | Strategy-specific parameters passed as individual arguments (e.g., fast_period=10, slow_period=20)
133 |
134 | Returns:
135 | Comprehensive backtest results including metrics, trades, and analysis
136 |
137 | Examples:
138 | run_backtest("AAPL", "sma_cross", fast_period=10, slow_period=20)
139 | run_backtest("TSLA", "rsi", period=14, oversold=30, overbought=70)
140 | """
141 | from datetime import datetime, timedelta
142 |
143 | # Default date range
144 | if not end_date:
145 | end_date = datetime.now().strftime("%Y-%m-%d")
146 | if not start_date:
147 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
148 |
149 | # Convert string parameters to appropriate types
150 | def convert_param(value, param_type):
151 | """Convert string parameter to appropriate type."""
152 | if value is None:
153 | return None
154 | if isinstance(value, str):
155 | try:
156 | if param_type is int:
157 | return int(value)
158 | elif param_type is float:
159 | return float(value)
160 | except (ValueError, TypeError) as e:
161 | raise ValueError(
162 | f"Invalid {param_type.__name__} value: {value}"
163 | ) from e
164 | return value
165 |
166 | # Build parameters dict from provided arguments with type conversion
167 | param_map = {
168 | "fast_period": convert_param(fast_period, int),
169 | "slow_period": convert_param(slow_period, int),
170 | "period": convert_param(period, int),
171 | "oversold": convert_param(oversold, float),
172 | "overbought": convert_param(overbought, float),
173 | "signal_period": convert_param(signal_period, int),
174 | "std_dev": convert_param(std_dev, float),
175 | "lookback": convert_param(lookback, int),
176 | "threshold": convert_param(threshold, float),
177 | "z_score_threshold": convert_param(z_score_threshold, float),
178 | "breakout_factor": convert_param(breakout_factor, float),
179 | }
180 |
181 | # Get default parameters for strategy
182 | if strategy in STRATEGY_TEMPLATES:
183 | parameters = dict(STRATEGY_TEMPLATES[strategy]["parameters"])
184 | # Override with provided non-None parameters
185 | for param_name, param_value in param_map.items():
186 | if param_value is not None:
187 | parameters[param_name] = param_value
188 | else:
189 | # Use only provided parameters for unknown strategies
190 | parameters = {k: v for k, v in param_map.items() if v is not None}
191 |
192 | # Initialize engine
193 | engine = VectorBTEngine()
194 |
195 | # Run backtest
196 | results = await engine.run_backtest(
197 | symbol=symbol,
198 | strategy_type=strategy,
199 | parameters=parameters,
200 | start_date=start_date,
201 | end_date=end_date,
202 | initial_capital=initial_capital,
203 | )
204 |
205 | # Analyze results
206 | analyzer = BacktestAnalyzer()
207 | analysis = analyzer.analyze(results)
208 |
209 | # Combine results and analysis
210 | results["analysis"] = analysis
211 |
212 | # Log business metrics
213 | if results.get("metrics"):
214 | metrics = results["metrics"]
215 | performance_logger.log_business_metric(
216 | "backtest_completion",
217 | 1,
218 | symbol=symbol,
219 | strategy=strategy,
220 | total_return=metrics.get("total_return", 0),
221 | sharpe_ratio=metrics.get("sharpe_ratio", 0),
222 | max_drawdown=metrics.get("max_drawdown", 0),
223 | total_trades=metrics.get("total_trades", 0),
224 | )
225 |
226 | # Set correlation context for downstream operations
227 | CorrelationIDGenerator.set_correlation_id()
228 |
229 | return results
230 |
231 | @mcp.tool()
232 | @with_structured_logging(
233 | "optimize_strategy", include_performance=True, log_params=True
234 | )
235 | @debug_operation(
236 | "optimize_strategy", enable_profiling=True, strategy="optimization_strategy"
237 | )
238 | async def optimize_strategy(
239 | ctx: Context,
240 | symbol: str,
241 | strategy: str = "sma_cross",
242 | start_date: str | None = None,
243 | end_date: str | None = None,
244 | optimization_metric: str = "sharpe_ratio",
245 | optimization_level: str = "medium",
246 | top_n: int = 10,
247 | ) -> dict[str, Any]:
248 | """Optimize strategy parameters using VectorBT grid search.
249 |
250 | Args:
251 | symbol: Stock symbol to optimize
252 | strategy: Strategy type to optimize
253 | start_date: Start date (YYYY-MM-DD)
254 | end_date: End date (YYYY-MM-DD)
255 | optimization_metric: Metric to optimize (sharpe_ratio, total_return, win_rate, etc.)
256 | optimization_level: Level of optimization (coarse, medium, fine)
257 | top_n: Number of top results to return
258 |
259 | Returns:
260 | Optimization results with best parameters and performance metrics
261 | """
262 | from datetime import datetime, timedelta
263 |
264 | # Default date range
265 | if not end_date:
266 | end_date = datetime.now().strftime("%Y-%m-%d")
267 | if not start_date:
268 | start_date = (datetime.now() - timedelta(days=365 * 2)).strftime("%Y-%m-%d")
269 |
270 | # Initialize engine and optimizer
271 | engine = VectorBTEngine()
272 | optimizer = StrategyOptimizer(engine)
273 |
274 | # Generate parameter grid
275 | param_grid = optimizer.generate_param_grid(strategy, optimization_level)
276 |
277 | # Run optimization
278 | results = await engine.optimize_parameters(
279 | symbol=symbol,
280 | strategy_type=strategy,
281 | param_grid=param_grid,
282 | start_date=start_date,
283 | end_date=end_date,
284 | optimization_metric=optimization_metric,
285 | top_n=top_n,
286 | )
287 |
288 | return results
289 |
290 | @mcp.tool()
291 | async def walk_forward_analysis(
292 | ctx: Context,
293 | symbol: str,
294 | strategy: str = "sma_cross",
295 | start_date: str | None = None,
296 | end_date: str | None = None,
297 | window_size: int = 252,
298 | step_size: int = 63,
299 | ) -> dict[str, Any]:
300 | """Perform walk-forward analysis to test strategy robustness.
301 |
302 | Args:
303 | symbol: Stock symbol to analyze
304 | strategy: Strategy type
305 | start_date: Start date (YYYY-MM-DD)
306 | end_date: End date (YYYY-MM-DD)
307 | window_size: Test window size in trading days (default: 1 year)
308 | step_size: Step size for rolling window (default: 1 quarter)
309 |
310 | Returns:
311 | Walk-forward analysis results with out-of-sample performance
312 | """
313 | from datetime import datetime, timedelta
314 |
315 | # Default date range (3 years for walk-forward)
316 | if not end_date:
317 | end_date = datetime.now().strftime("%Y-%m-%d")
318 | if not start_date:
319 | start_date = (datetime.now() - timedelta(days=365 * 3)).strftime("%Y-%m-%d")
320 |
321 | # Initialize engine and optimizer
322 | engine = VectorBTEngine()
323 | optimizer = StrategyOptimizer(engine)
324 |
325 | # Get default parameters
326 | parameters = STRATEGY_TEMPLATES.get(strategy, {}).get("parameters", {})
327 |
328 | # Run walk-forward analysis
329 | results = await optimizer.walk_forward_analysis(
330 | symbol=symbol,
331 | strategy_type=strategy,
332 | parameters=parameters,
333 | start_date=start_date,
334 | end_date=end_date,
335 | window_size=window_size,
336 | step_size=step_size,
337 | )
338 |
339 | return results
340 |
341 | @mcp.tool()
342 | async def monte_carlo_simulation(
343 | ctx: Context,
344 | symbol: str,
345 | strategy: str = "sma_cross",
346 | start_date: str | None = None,
347 | end_date: str | None = None,
348 | num_simulations: int = 1000,
349 | fast_period: str | int | None = None,
350 | slow_period: str | int | None = None,
351 | period: str | int | None = None,
352 | ) -> dict[str, Any]:
353 | """Run Monte Carlo simulation on backtest results.
354 |
355 | Args:
356 | symbol: Stock symbol
357 | strategy: Strategy type
358 | start_date: Start date (YYYY-MM-DD)
359 | end_date: End date (YYYY-MM-DD)
360 | num_simulations: Number of Monte Carlo simulations
361 | Strategy-specific parameters as individual arguments
362 |
363 | Returns:
364 | Monte Carlo simulation results with confidence intervals
365 | """
366 | from datetime import datetime, timedelta
367 |
368 | # Default date range
369 | if not end_date:
370 | end_date = datetime.now().strftime("%Y-%m-%d")
371 | if not start_date:
372 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
373 |
374 | # Convert string parameters to appropriate types
375 | def convert_param(value, param_type):
376 | """Convert string parameter to appropriate type."""
377 | if value is None:
378 | return None
379 | if isinstance(value, str):
380 | try:
381 | if param_type is int:
382 | return int(value)
383 | elif param_type is float:
384 | return float(value)
385 | except (ValueError, TypeError) as e:
386 | raise ValueError(
387 | f"Invalid {param_type.__name__} value: {value}"
388 | ) from e
389 | return value
390 |
391 | # Build parameters dict from provided arguments with type conversion
392 | param_map = {
393 | "fast_period": convert_param(fast_period, int),
394 | "slow_period": convert_param(slow_period, int),
395 | "period": convert_param(period, int),
396 | }
397 |
398 | # Get parameters
399 | if strategy in STRATEGY_TEMPLATES:
400 | parameters = dict(STRATEGY_TEMPLATES[strategy]["parameters"])
401 | # Override with provided non-None parameters
402 | for param_name, param_value in param_map.items():
403 | if param_value is not None:
404 | parameters[param_name] = param_value
405 | else:
406 | # Use only provided parameters for unknown strategies
407 | parameters = {k: v for k, v in param_map.items() if v is not None}
408 |
409 | # Run backtest first
410 | engine = VectorBTEngine()
411 | backtest_results = await engine.run_backtest(
412 | symbol=symbol,
413 | strategy_type=strategy,
414 | parameters=parameters,
415 | start_date=start_date,
416 | end_date=end_date,
417 | )
418 |
419 | # Run Monte Carlo simulation
420 | optimizer = StrategyOptimizer(engine)
421 | mc_results = await optimizer.monte_carlo_simulation(
422 | backtest_results=backtest_results,
423 | num_simulations=num_simulations,
424 | )
425 |
426 | return mc_results
427 |
428 | @mcp.tool()
429 | async def compare_strategies(
430 | ctx: Context,
431 | symbol: str,
432 | strategies: list[str] | str | None = None,
433 | start_date: str | None = None,
434 | end_date: str | None = None,
435 | ) -> dict[str, Any]:
436 | """Compare multiple strategies on the same symbol.
437 |
438 | Args:
439 | symbol: Stock symbol
440 | strategies: List of strategy types to compare (defaults to all)
441 | start_date: Start date (YYYY-MM-DD)
442 | end_date: End date (YYYY-MM-DD)
443 |
444 | Returns:
445 | Comparison results with rankings and analysis
446 | """
447 | from datetime import datetime, timedelta
448 |
449 | # Default date range
450 | if not end_date:
451 | end_date = datetime.now().strftime("%Y-%m-%d")
452 | if not start_date:
453 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
454 |
455 | # Handle strategies as JSON string from some clients
456 | if isinstance(strategies, str):
457 | import json
458 |
459 | try:
460 | strategies = json.loads(strategies)
461 | except json.JSONDecodeError:
462 | # If it's not JSON, treat it as a single strategy
463 | strategies = [strategies]
464 |
465 | # Default to comparing top strategies
466 | if not strategies:
467 | strategies = ["sma_cross", "rsi", "macd", "bollinger", "momentum"]
468 |
469 | # Run backtests for each strategy
470 | engine = VectorBTEngine()
471 | results_list = []
472 |
473 | for strategy in strategies:
474 | try:
475 | # Get default parameters
476 | parameters = STRATEGY_TEMPLATES.get(strategy, {}).get("parameters", {})
477 |
478 | # Run backtest
479 | results = await engine.run_backtest(
480 | symbol=symbol,
481 | strategy_type=strategy,
482 | parameters=parameters,
483 | start_date=start_date,
484 | end_date=end_date,
485 | )
486 | results_list.append(results)
487 | except Exception:
488 | # Skip failed strategies
489 | continue
490 |
491 | # Compare results
492 | analyzer = BacktestAnalyzer()
493 | comparison = analyzer.compare_strategies(results_list)
494 |
495 | return comparison
496 |
497 | @mcp.tool()
498 | async def list_strategies(ctx: Context) -> dict[str, Any]:
499 | """List all available VectorBT strategies with descriptions.
500 |
501 | Returns:
502 | Dictionary of available strategies and their information
503 | """
504 | strategies = {}
505 |
506 | for strategy_type in list_available_strategies():
507 | strategies[strategy_type] = get_strategy_info(strategy_type)
508 |
509 | return {
510 | "available_strategies": strategies,
511 | "total_count": len(strategies),
512 | "categories": {
513 | "trend_following": ["sma_cross", "ema_cross", "macd", "breakout"],
514 | "mean_reversion": ["rsi", "bollinger", "mean_reversion"],
515 | "momentum": ["momentum", "volume_momentum"],
516 | },
517 | }
518 |
519 | @mcp.tool()
520 | async def parse_strategy(ctx: Context, description: str) -> dict[str, Any]:
521 | """Parse natural language strategy description into VectorBT parameters.
522 |
523 | Args:
524 | description: Natural language description of trading strategy
525 |
526 | Returns:
527 | Parsed strategy configuration with type and parameters
528 |
529 | Examples:
530 | "Buy when RSI is below 30 and sell when above 70"
531 | "Use 10-day and 20-day moving average crossover"
532 | "MACD strategy with standard parameters"
533 | """
534 | parser = StrategyParser()
535 | config = parser.parse_simple(description)
536 |
537 | # Validate the parsed strategy
538 | if parser.validate_strategy(config):
539 | return {
540 | "success": True,
541 | "strategy": config,
542 | "message": f"Successfully parsed as {config['strategy_type']} strategy",
543 | }
544 | else:
545 | return {
546 | "success": False,
547 | "strategy": config,
548 | "message": "Could not fully parse strategy, using defaults",
549 | }
550 |
551 | @mcp.tool()
552 | async def backtest_portfolio(
553 | ctx: Context,
554 | symbols: list[str],
555 | strategy: str = "sma_cross",
556 | start_date: str | None = None,
557 | end_date: str | None = None,
558 | initial_capital: float = 10000.0,
559 | position_size: float = 0.1,
560 | fast_period: str | int | None = None,
561 | slow_period: str | int | None = None,
562 | period: str | int | None = None,
563 | ) -> dict[str, Any]:
564 | """Backtest a strategy across multiple symbols (portfolio).
565 |
566 | Args:
567 | symbols: List of stock symbols
568 | strategy: Strategy type to apply
569 | start_date: Start date (YYYY-MM-DD)
570 | end_date: End date (YYYY-MM-DD)
571 | initial_capital: Starting capital
572 | position_size: Position size per symbol (0.1 = 10%)
573 | Strategy-specific parameters as individual arguments
574 |
575 | Returns:
576 | Portfolio backtest results with aggregate metrics
577 | """
578 | from datetime import datetime, timedelta
579 |
580 | # Default date range
581 | if not end_date:
582 | end_date = datetime.now().strftime("%Y-%m-%d")
583 | if not start_date:
584 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
585 |
586 | # Convert string parameters to appropriate types
587 | def convert_param(value, param_type):
588 | """Convert string parameter to appropriate type."""
589 | if value is None:
590 | return None
591 | if isinstance(value, str):
592 | try:
593 | if param_type is int:
594 | return int(value)
595 | elif param_type is float:
596 | return float(value)
597 | except (ValueError, TypeError) as e:
598 | raise ValueError(
599 | f"Invalid {param_type.__name__} value: {value}"
600 | ) from e
601 | return value
602 |
603 | # Build parameters dict from provided arguments with type conversion
604 | param_map = {
605 | "fast_period": convert_param(fast_period, int),
606 | "slow_period": convert_param(slow_period, int),
607 | "period": convert_param(period, int),
608 | }
609 |
610 | # Get parameters
611 | if strategy in STRATEGY_TEMPLATES:
612 | parameters = dict(STRATEGY_TEMPLATES[strategy]["parameters"])
613 | # Override with provided non-None parameters
614 | for param_name, param_value in param_map.items():
615 | if param_value is not None:
616 | parameters[param_name] = param_value
617 | else:
618 | # Use only provided parameters for unknown strategies
619 | parameters = {k: v for k, v in param_map.items() if v is not None}
620 |
621 | # Run backtests for each symbol
622 | engine = VectorBTEngine()
623 | portfolio_results = []
624 | capital_per_symbol = initial_capital * position_size
625 |
626 | for symbol in symbols:
627 | try:
628 | results = await engine.run_backtest(
629 | symbol=symbol,
630 | strategy_type=strategy,
631 | parameters=parameters,
632 | start_date=start_date,
633 | end_date=end_date,
634 | initial_capital=capital_per_symbol,
635 | )
636 | portfolio_results.append(results)
637 | except Exception:
638 | # Skip failed symbols
639 | continue
640 |
641 | if not portfolio_results:
642 | return {"error": "No symbols could be backtested"}
643 |
644 | # Aggregate portfolio metrics
645 | total_return = sum(
646 | r["metrics"]["total_return"] for r in portfolio_results
647 | ) / len(portfolio_results)
648 | avg_sharpe = sum(r["metrics"]["sharpe_ratio"] for r in portfolio_results) / len(
649 | portfolio_results
650 | )
651 | max_drawdown = max(r["metrics"]["max_drawdown"] for r in portfolio_results)
652 | total_trades = sum(r["metrics"]["total_trades"] for r in portfolio_results)
653 |
654 | return {
655 | "portfolio_metrics": {
656 | "symbols_tested": len(portfolio_results),
657 | "total_return": total_return,
658 | "average_sharpe": avg_sharpe,
659 | "max_drawdown": max_drawdown,
660 | "total_trades": total_trades,
661 | },
662 | "individual_results": portfolio_results,
663 | "summary": f"Portfolio backtest of {len(portfolio_results)} symbols with {strategy} strategy",
664 | }
665 |
666 | @mcp.tool()
667 | async def generate_backtest_charts(
668 | ctx: Context,
669 | symbol: str,
670 | strategy: str = "sma_cross",
671 | start_date: str | None = None,
672 | end_date: str | None = None,
673 | theme: str = "light",
674 | ) -> dict[str, str]:
675 | """Generate comprehensive charts for a backtest.
676 |
677 | Args:
678 | symbol: Stock symbol
679 | strategy: Strategy type
680 | start_date: Start date (YYYY-MM-DD)
681 | end_date: End date (YYYY-MM-DD)
682 | theme: Chart theme (light or dark)
683 |
684 | Returns:
685 | Dictionary of base64-encoded chart images
686 | """
687 | from datetime import datetime, timedelta
688 |
689 | import pandas as pd
690 |
691 | # Default date range
692 | if not end_date:
693 | end_date = datetime.now().strftime("%Y-%m-%d")
694 | if not start_date:
695 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
696 |
697 | # Run backtest
698 | engine = VectorBTEngine()
699 |
700 | # Get default parameters for the strategy
701 | from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
702 |
703 | parameters = STRATEGY_TEMPLATES.get(strategy, {}).get("parameters", {})
704 |
705 | results = await engine.run_backtest(
706 | symbol=symbol,
707 | strategy_type=strategy,
708 | parameters=parameters,
709 | start_date=start_date,
710 | end_date=end_date,
711 | )
712 |
713 | # Prepare data for charts
714 | equity_curve_data = results["equity_curve"]
715 | drawdown_data = results["drawdown_series"]
716 |
717 | # Convert to pandas Series for charting
718 | returns = pd.Series(equity_curve_data)
719 | drawdown = pd.Series(drawdown_data)
720 | trades = pd.DataFrame(results["trades"])
721 |
722 | # Generate charts
723 | charts = {
724 | "equity_curve": generate_equity_curve(
725 | returns, drawdown, f"{symbol} {strategy} Equity Curve", theme
726 | ),
727 | "trade_scatter": generate_trade_scatter(
728 | returns, trades, f"{symbol} {strategy} Trades", theme
729 | ),
730 | "performance_dashboard": generate_performance_dashboard(
731 | results["metrics"], f"{symbol} {strategy} Performance", theme
732 | ),
733 | }
734 |
735 | return charts
736 |
737 | @mcp.tool()
738 | async def generate_optimization_charts(
739 | ctx: Context,
740 | symbol: str,
741 | strategy: str = "sma_cross",
742 | start_date: str | None = None,
743 | end_date: str | None = None,
744 | theme: str = "light",
745 | ) -> dict[str, str]:
746 | """Generate chart for strategy parameter optimization.
747 |
748 | Args:
749 | symbol: Stock symbol
750 | strategy: Strategy type
751 | start_date: Start date (YYYY-MM-DD)
752 | end_date: End date (YYYY-MM-DD)
753 | theme: Chart theme (light or dark)
754 |
755 | Returns:
756 | Dictionary of base64-encoded chart images
757 | """
758 | from datetime import datetime, timedelta
759 |
760 | # Default date range
761 | if not end_date:
762 | end_date = datetime.now().strftime("%Y-%m-%d")
763 | if not start_date:
764 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
765 |
766 | # Run optimization
767 | engine = VectorBTEngine()
768 | optimizer = StrategyOptimizer(engine)
769 | param_grid = optimizer.generate_param_grid(strategy, "medium")
770 |
771 | # Create optimization results dictionary for heatmap
772 | optimization_results = {}
773 | for param_set, results in param_grid.items():
774 | optimization_results[str(param_set)] = {
775 | "performance": results.get("total_return", 0)
776 | }
777 |
778 | # Generate optimization heatmap
779 | heatmap = generate_optimization_heatmap(
780 | optimization_results, f"{symbol} {strategy} Parameter Optimization", theme
781 | )
782 |
783 | return {"optimization_heatmap": heatmap}
784 |
785 | # ============ ML-ENHANCED STRATEGY TOOLS ============
786 |
787 | @mcp.tool()
788 | async def run_ml_strategy_backtest(
789 | ctx: Context,
790 | symbol: str,
791 | strategy_type: str = "ml_predictor",
792 | start_date: str | None = None,
793 | end_date: str | None = None,
794 | initial_capital: float = 10000.0,
795 | train_ratio: float = 0.8,
796 | model_type: str = "random_forest",
797 | n_estimators: int = 100,
798 | max_depth: int | None = None,
799 | learning_rate: float = 0.01,
800 | adaptation_method: str = "gradient",
801 | ) -> dict[str, Any]:
802 | """Run backtest using ML-enhanced strategies.
803 |
804 | Args:
805 | symbol: Stock symbol to backtest
806 | strategy_type: ML strategy type (ml_predictor, adaptive, ensemble, regime_aware)
807 | start_date: Start date (YYYY-MM-DD)
808 | end_date: End date (YYYY-MM-DD)
809 | initial_capital: Initial capital amount
810 | train_ratio: Ratio of data for training (0.0-1.0)
811 | Strategy-specific parameters passed as individual arguments
812 |
813 | Returns:
814 | Backtest results with ML-specific metrics
815 | """
816 | from datetime import datetime, timedelta
817 |
818 | from maverick_mcp.backtesting.strategies.ml import (
819 | AdaptiveStrategy,
820 | MLPredictor,
821 | RegimeAwareStrategy,
822 | StrategyEnsemble,
823 | )
824 | from maverick_mcp.backtesting.strategies.templates import (
825 | SimpleMovingAverageStrategy,
826 | )
827 |
828 | # Default date range
829 | if not end_date:
830 | end_date = datetime.now().strftime("%Y-%m-%d")
831 | if not start_date:
832 | start_date = (datetime.now() - timedelta(days=730)).strftime(
833 | "%Y-%m-%d"
834 | ) # 2 years for ML
835 |
836 | # Get historical data
837 | engine = VectorBTEngine()
838 | data = await engine.get_historical_data(symbol, start_date, end_date)
839 |
840 | # Enhanced data validation for ML strategies
841 | min_total_data = 200 # Minimum total data points for ML strategies
842 | if len(data) < min_total_data:
843 | return {
844 | "error": f"Insufficient data for ML strategy: {len(data)} < {min_total_data} required"
845 | }
846 |
847 | # Split data for training/testing
848 | split_idx = int(len(data) * train_ratio)
849 | train_data = data.iloc[:split_idx]
850 | test_data = data.iloc[split_idx:]
851 |
852 | # Validate split data sizes
853 | min_train_data = 100
854 | min_test_data = 50
855 |
856 | if len(train_data) < min_train_data:
857 | return {
858 | "error": f"Insufficient training data: {len(train_data)} < {min_train_data} required"
859 | }
860 |
861 | if len(test_data) < min_test_data:
862 | return {
863 | "error": f"Insufficient test data: {len(test_data)} < {min_test_data} required"
864 | }
865 |
866 | logger.info(
867 | f"ML backtest data split: {len(train_data)} training, {len(test_data)} testing samples"
868 | )
869 |
870 | try:
871 | # Create ML strategy based on type
872 | if strategy_type == "ml_predictor":
873 | ml_strategy = MLPredictor(
874 | model_type=model_type,
875 | n_estimators=n_estimators,
876 | max_depth=max_depth,
877 | )
878 | # Train the model
879 | training_metrics = ml_strategy.train(train_data)
880 |
881 | elif strategy_type == "adaptive" or strategy_type == "online_learning":
882 | # online_learning is an alias for adaptive strategy
883 | base_strategy = SimpleMovingAverageStrategy()
884 | ml_strategy = AdaptiveStrategy(
885 | base_strategy,
886 | learning_rate=learning_rate,
887 | adaptation_method=adaptation_method,
888 | )
889 | training_metrics = {
890 | "adaptation_method": adaptation_method,
891 | "strategy_alias": strategy_type,
892 | }
893 |
894 | elif strategy_type == "ensemble":
895 | # Create ensemble with basic strategies
896 | base_strategies = [
897 | SimpleMovingAverageStrategy({"fast_period": 10, "slow_period": 20}),
898 | SimpleMovingAverageStrategy({"fast_period": 5, "slow_period": 15}),
899 | ]
900 | ml_strategy = StrategyEnsemble(base_strategies)
901 | training_metrics = {"ensemble_size": len(base_strategies)}
902 |
903 | elif strategy_type == "regime_aware":
904 | base_strategies = {
905 | 0: SimpleMovingAverageStrategy(
906 | {"fast_period": 5, "slow_period": 20}
907 | ), # Bear
908 | 1: SimpleMovingAverageStrategy(
909 | {"fast_period": 10, "slow_period": 30}
910 | ), # Sideways
911 | 2: SimpleMovingAverageStrategy(
912 | {"fast_period": 20, "slow_period": 50}
913 | ), # Bull
914 | }
915 | ml_strategy = RegimeAwareStrategy(base_strategies)
916 | # Fit regime detector
917 | ml_strategy.fit_regime_detector(train_data)
918 | training_metrics = {"n_regimes": len(base_strategies)}
919 |
920 | else:
921 | return {"error": f"Unsupported ML strategy type: {strategy_type}"}
922 |
923 | # Generate signals on test data
924 | entry_signals, exit_signals = ml_strategy.generate_signals(test_data)
925 |
926 | # Run backtest analysis on test period
927 | analyzer = BacktestAnalyzer()
928 | backtest_results = await analyzer.run_vectorbt_backtest(
929 | data=test_data,
930 | entry_signals=entry_signals,
931 | exit_signals=exit_signals,
932 | initial_capital=initial_capital,
933 | )
934 |
935 | # Add ML-specific metrics
936 | ml_metrics = {
937 | "strategy_type": strategy_type,
938 | "training_period": len(train_data),
939 | "testing_period": len(test_data),
940 | "train_test_split": train_ratio,
941 | "training_metrics": training_metrics,
942 | }
943 |
944 | # Add strategy-specific analysis
945 | if hasattr(ml_strategy, "get_feature_importance"):
946 | ml_metrics["feature_importance"] = ml_strategy.get_feature_importance()
947 |
948 | if hasattr(ml_strategy, "get_regime_analysis"):
949 | ml_metrics["regime_analysis"] = ml_strategy.get_regime_analysis()
950 |
951 | if hasattr(ml_strategy, "get_strategy_weights"):
952 | ml_metrics["strategy_weights"] = ml_strategy.get_strategy_weights()
953 |
954 | backtest_results["ml_metrics"] = ml_metrics
955 |
956 | # Convert all numpy types before returning
957 | return convert_numpy_types(backtest_results)
958 |
959 | except Exception as e:
960 | return {"error": f"ML backtest failed: {str(e)}"}
961 |
962 | @mcp.tool()
963 | async def train_ml_predictor(
964 | ctx: Context,
965 | symbol: str,
966 | start_date: str | None = None,
967 | end_date: str | None = None,
968 | model_type: str = "random_forest",
969 | target_periods: int = 5,
970 | return_threshold: float = 0.02,
971 | n_estimators: int = 100,
972 | max_depth: int | None = None,
973 | min_samples_split: int = 2,
974 | ) -> dict[str, Any]:
975 | """Train an ML predictor model for trading signals.
976 |
977 | Args:
978 | symbol: Stock symbol to train on
979 | start_date: Start date for training data
980 | end_date: End date for training data
981 | model_type: ML model type (random_forest)
982 | target_periods: Forward periods for target variable
983 | return_threshold: Return threshold for signal classification
984 | n_estimators, max_depth, min_samples_split: Model-specific parameters
985 |
986 | Returns:
987 | Training results and model metrics
988 | """
989 | from datetime import datetime, timedelta
990 |
991 | from maverick_mcp.backtesting.strategies.ml import MLPredictor
992 |
993 | # Default date range (2 years for good ML training)
994 | if not end_date:
995 | end_date = datetime.now().strftime("%Y-%m-%d")
996 | if not start_date:
997 | start_date = (datetime.now() - timedelta(days=730)).strftime("%Y-%m-%d")
998 |
999 | try:
1000 | # Get training data
1001 | engine = VectorBTEngine()
1002 | data = await engine.get_historical_data(symbol, start_date, end_date)
1003 |
1004 | if len(data) < 200:
1005 | return {
1006 | "error": "Insufficient data for ML training (minimum 200 data points)"
1007 | }
1008 |
1009 | # Create and train ML predictor
1010 | ml_predictor = MLPredictor(
1011 | model_type=model_type,
1012 | n_estimators=n_estimators,
1013 | max_depth=max_depth,
1014 | min_samples_split=min_samples_split,
1015 | )
1016 | training_metrics = ml_predictor.train(
1017 | data=data,
1018 | target_periods=target_periods,
1019 | return_threshold=return_threshold,
1020 | )
1021 |
1022 | # Create model parameters dictionary
1023 | model_params = {
1024 | "n_estimators": n_estimators,
1025 | "max_depth": max_depth,
1026 | "min_samples_split": min_samples_split,
1027 | }
1028 | # Add training details
1029 | training_results = {
1030 | "symbol": symbol,
1031 | "model_type": model_type,
1032 | "training_period": f"{start_date} to {end_date}",
1033 | "data_points": len(data),
1034 | "target_periods": target_periods,
1035 | "return_threshold": return_threshold,
1036 | "model_parameters": model_params,
1037 | "training_metrics": training_metrics,
1038 | }
1039 |
1040 | # Convert all numpy types before returning
1041 | return convert_numpy_types(training_results)
1042 |
1043 | except Exception as e:
1044 | return {"error": f"ML training failed: {str(e)}"}
1045 |
1046 | @mcp.tool()
1047 | async def analyze_market_regimes(
1048 | ctx: Context,
1049 | symbol: str,
1050 | start_date: str | None = None,
1051 | end_date: str | None = None,
1052 | method: str = "hmm",
1053 | n_regimes: int = 3,
1054 | lookback_period: int = 50,
1055 | ) -> dict[str, Any]:
1056 | """Analyze market regimes for a stock using ML methods.
1057 |
1058 | Args:
1059 | symbol: Stock symbol to analyze
1060 | start_date: Start date for analysis
1061 | end_date: End date for analysis
1062 | method: Detection method (hmm, kmeans, threshold)
1063 | n_regimes: Number of regimes to detect
1064 | lookback_period: Lookback period for regime detection
1065 |
1066 | Returns:
1067 | Market regime analysis results
1068 | """
1069 | from datetime import datetime, timedelta
1070 |
1071 | from maverick_mcp.backtesting.strategies.ml.regime_aware import (
1072 | MarketRegimeDetector,
1073 | )
1074 |
1075 | # Default date range
1076 | if not end_date:
1077 | end_date = datetime.now().strftime("%Y-%m-%d")
1078 | if not start_date:
1079 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
1080 |
1081 | try:
1082 | # Get historical data
1083 | engine = VectorBTEngine()
1084 | data = await engine.get_historical_data(symbol, start_date, end_date)
1085 |
1086 | if len(data) < lookback_period + 50:
1087 | return {
1088 | "error": f"Insufficient data for regime analysis (minimum {lookback_period + 50} data points)"
1089 | }
1090 |
1091 | # Create regime detector and analyze
1092 | regime_detector = MarketRegimeDetector(
1093 | method=method, n_regimes=n_regimes, lookback_period=lookback_period
1094 | )
1095 |
1096 | # Fit regime detector
1097 | regime_detector.fit_regimes(data)
1098 |
1099 | # Analyze regimes over time
1100 | regime_history = []
1101 | regime_probabilities = []
1102 |
1103 | for i in range(lookback_period, len(data)):
1104 | window_data = data.iloc[i - lookback_period : i + 1]
1105 | current_regime = regime_detector.detect_current_regime(window_data)
1106 | regime_probs = regime_detector.get_regime_probabilities(window_data)
1107 |
1108 | regime_history.append(
1109 | {
1110 | "date": data.index[i].strftime("%Y-%m-%d"),
1111 | "regime": int(current_regime),
1112 | "probabilities": regime_probs.tolist(),
1113 | }
1114 | )
1115 | regime_probabilities.append(regime_probs)
1116 |
1117 | # Calculate regime statistics
1118 | regimes = [r["regime"] for r in regime_history]
1119 | regime_counts = {i: regimes.count(i) for i in range(n_regimes)}
1120 | regime_percentages = {
1121 | k: (v / len(regimes)) * 100 for k, v in regime_counts.items()
1122 | }
1123 |
1124 | # Calculate average regime durations
1125 | regime_durations = {i: [] for i in range(n_regimes)}
1126 | current_regime = regimes[0]
1127 | duration = 1
1128 |
1129 | for regime in regimes[1:]:
1130 | if regime == current_regime:
1131 | duration += 1
1132 | else:
1133 | regime_durations[current_regime].append(duration)
1134 | current_regime = regime
1135 | duration = 1
1136 | regime_durations[current_regime].append(duration)
1137 |
1138 | avg_durations = {
1139 | k: np.mean(v) if v else 0 for k, v in regime_durations.items()
1140 | }
1141 |
1142 | analysis_results = {
1143 | "symbol": symbol,
1144 | "analysis_period": f"{start_date} to {end_date}",
1145 | "method": method,
1146 | "n_regimes": n_regimes,
1147 | "regime_names": {
1148 | 0: "Bear/Declining",
1149 | 1: "Sideways/Uncertain",
1150 | 2: "Bull/Trending",
1151 | },
1152 | "current_regime": regimes[-1] if regimes else 1,
1153 | "regime_counts": regime_counts,
1154 | "regime_percentages": regime_percentages,
1155 | "average_regime_durations": avg_durations,
1156 | "recent_regime_history": regime_history[-20:], # Last 20 periods
1157 | "total_regime_switches": len(
1158 | [i for i in range(1, len(regimes)) if regimes[i] != regimes[i - 1]]
1159 | ),
1160 | }
1161 |
1162 | return analysis_results
1163 |
1164 | except Exception as e:
1165 | return {"error": f"Regime analysis failed: {str(e)}"}
1166 |
1167 | @mcp.tool()
1168 | async def create_strategy_ensemble(
1169 | ctx: Context,
1170 | symbols: list[str],
1171 | base_strategies: list[str] | None = None,
1172 | weighting_method: str = "performance",
1173 | start_date: str | None = None,
1174 | end_date: str | None = None,
1175 | initial_capital: float = 10000.0,
1176 | ) -> dict[str, Any]:
1177 | """Create and backtest a strategy ensemble across multiple symbols.
1178 |
1179 | Args:
1180 | symbols: List of stock symbols
1181 | base_strategies: List of base strategy names to ensemble
1182 | weighting_method: Weighting method (performance, equal, volatility)
1183 | start_date: Start date for backtesting
1184 | end_date: End date for backtesting
1185 | initial_capital: Initial capital per symbol
1186 |
1187 | Returns:
1188 | Ensemble backtest results with strategy weights
1189 | """
1190 | from datetime import datetime, timedelta
1191 |
1192 | from maverick_mcp.backtesting.strategies.ml import StrategyEnsemble
1193 | from maverick_mcp.backtesting.strategies.templates import (
1194 | SimpleMovingAverageStrategy,
1195 | )
1196 |
1197 | # Default strategies if none provided
1198 | if base_strategies is None:
1199 | base_strategies = ["sma_cross", "rsi", "macd"]
1200 |
1201 | # Default date range
1202 | if not end_date:
1203 | end_date = datetime.now().strftime("%Y-%m-%d")
1204 | if not start_date:
1205 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
1206 |
1207 | try:
1208 | # Create base strategy instances
1209 | strategy_instances = []
1210 | for strategy_name in base_strategies:
1211 | if strategy_name == "sma_cross":
1212 | strategy_instances.append(SimpleMovingAverageStrategy())
1213 | elif strategy_name == "rsi":
1214 | # Create RSI-based SMA strategy with different parameters
1215 | strategy_instances.append(
1216 | SimpleMovingAverageStrategy(
1217 | {"fast_period": 14, "slow_period": 28}
1218 | )
1219 | )
1220 | elif strategy_name == "macd":
1221 | # Create MACD-like SMA strategy with MACD default periods
1222 | strategy_instances.append(
1223 | SimpleMovingAverageStrategy(
1224 | {"fast_period": 12, "slow_period": 26}
1225 | )
1226 | )
1227 | # Add more strategies as needed
1228 |
1229 | if not strategy_instances:
1230 | return {"error": "No valid base strategies provided"}
1231 |
1232 | # Create ensemble strategy
1233 | ensemble = StrategyEnsemble(
1234 | strategies=strategy_instances, weighting_method=weighting_method
1235 | )
1236 |
1237 | # Run ensemble backtest on multiple symbols
1238 | ensemble_results = []
1239 | total_return = 0
1240 | total_trades = 0
1241 |
1242 | for symbol in symbols[:5]: # Limit to 5 symbols for performance
1243 | try:
1244 | # Get data and run backtest
1245 | engine = VectorBTEngine()
1246 | data = await engine.get_historical_data(
1247 | symbol, start_date, end_date
1248 | )
1249 |
1250 | if len(data) < 100:
1251 | continue
1252 |
1253 | # Generate ensemble signals
1254 | entry_signals, exit_signals = ensemble.generate_signals(data)
1255 |
1256 | # Run backtest
1257 | analyzer = BacktestAnalyzer()
1258 | results = await analyzer.run_vectorbt_backtest(
1259 | data=data,
1260 | entry_signals=entry_signals,
1261 | exit_signals=exit_signals,
1262 | initial_capital=initial_capital,
1263 | )
1264 |
1265 | # Add ensemble-specific metrics
1266 | results["ensemble_metrics"] = {
1267 | "strategy_weights": ensemble.get_strategy_weights(),
1268 | "strategy_performance": ensemble.get_strategy_performance(),
1269 | }
1270 |
1271 | ensemble_results.append({"symbol": symbol, "results": results})
1272 |
1273 | total_return += results["metrics"]["total_return"]
1274 | total_trades += results["metrics"]["total_trades"]
1275 |
1276 | except Exception:
1277 | continue
1278 |
1279 | if not ensemble_results:
1280 | return {"error": "No symbols could be processed"}
1281 |
1282 | # Calculate aggregate metrics
1283 | avg_return = total_return / len(ensemble_results)
1284 | avg_trades = total_trades / len(ensemble_results)
1285 |
1286 | # Convert all numpy types before returning
1287 | return convert_numpy_types(
1288 | {
1289 | "ensemble_summary": {
1290 | "symbols_tested": len(ensemble_results),
1291 | "base_strategies": base_strategies,
1292 | "weighting_method": weighting_method,
1293 | "average_return": avg_return,
1294 | "total_trades": total_trades,
1295 | "average_trades_per_symbol": avg_trades,
1296 | },
1297 | "individual_results": ensemble_results,
1298 | "final_strategy_weights": ensemble.get_strategy_weights(),
1299 | "strategy_performance_analysis": ensemble.get_strategy_performance(),
1300 | }
1301 | )
1302 |
1303 | except Exception as e:
1304 | return {"error": f"Ensemble creation failed: {str(e)}"}
1305 |
```