This is page 20 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/batch_processing.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Batch Processing Extensions for VectorBTEngine.
3 |
4 | This module adds batch processing capabilities to the VectorBT engine,
5 | allowing for parallel execution of multiple backtest strategies,
6 | parameter optimization, and strategy validation.
7 | """
8 |
9 | import asyncio
10 | import gc
11 | import time
12 | from typing import Any
13 |
14 | import numpy as np
15 |
16 | from maverick_mcp.utils.memory_profiler import (
17 | cleanup_dataframes,
18 | get_memory_stats,
19 | profile_memory,
20 | )
21 | from maverick_mcp.utils.structured_logger import (
22 | get_structured_logger,
23 | with_structured_logging,
24 | )
25 |
26 | logger = get_structured_logger(__name__)
27 |
28 |
29 | class BatchProcessingMixin:
30 | """Mixin class to add batch processing methods to VectorBTEngine."""
31 |
32 | @with_structured_logging(
33 | "run_batch_backtest", include_performance=True, log_params=True
34 | )
35 | @profile_memory(log_results=True, threshold_mb=100.0)
36 | async def run_batch_backtest(
37 | self,
38 | batch_configs: list[dict[str, Any]],
39 | max_workers: int = 6,
40 | chunk_size: int = 10,
41 | validate_data: bool = True,
42 | fail_fast: bool = False,
43 | ) -> dict[str, Any]:
44 | """
45 | Run multiple backtest strategies in parallel with optimized batch processing.
46 |
47 | Args:
48 | batch_configs: List of backtest configurations, each containing:
49 | - symbol: Stock symbol
50 | - strategy_type: Strategy type name
51 | - parameters: Strategy parameters dict
52 | - start_date: Start date string
53 | - end_date: End date string
54 | - initial_capital: Starting capital (optional, default 10000)
55 | - fees: Trading fees (optional, default 0.001)
56 | - slippage: Slippage factor (optional, default 0.001)
57 | max_workers: Maximum concurrent workers
58 | chunk_size: Number of configs to process per batch
59 | validate_data: Whether to validate input data
60 | fail_fast: Whether to stop on first failure
61 |
62 | Returns:
63 | Dictionary containing batch results and summary statistics
64 | """
65 | from maverick_mcp.backtesting.strategy_executor import (
66 | ExecutionContext,
67 | ExecutionResult,
68 | StrategyExecutor,
69 | )
70 |
71 | start_time = time.time()
72 | batch_id = f"batch_{int(start_time)}"
73 |
74 | logger.info(
75 | f"Starting batch backtest {batch_id} with {len(batch_configs)} configurations"
76 | )
77 |
78 | # Validate input data if requested
79 | if validate_data:
80 | validation_errors = []
81 | for i, config in enumerate(batch_configs):
82 | try:
83 | self._validate_batch_config(config, f"config_{i}")
84 | except Exception as e:
85 | validation_errors.append(f"Config {i}: {str(e)}")
86 |
87 | if validation_errors:
88 | if fail_fast:
89 | raise ValueError(
90 | f"Batch validation failed: {'; '.join(validation_errors)}"
91 | )
92 | else:
93 | logger.warning(
94 | f"Validation warnings for batch {batch_id}: {validation_errors}"
95 | )
96 |
97 | # Initialize executor
98 | executor = StrategyExecutor(
99 | max_concurrent_strategies=max_workers,
100 | cache_manager=getattr(self, "cache", None),
101 | )
102 |
103 | # Convert configs to execution contexts
104 | contexts = []
105 | for i, config in enumerate(batch_configs):
106 | context = ExecutionContext(
107 | strategy_id=f"{batch_id}_strategy_{i}",
108 | symbol=config["symbol"],
109 | strategy_type=config["strategy_type"],
110 | parameters=config["parameters"],
111 | start_date=config["start_date"],
112 | end_date=config["end_date"],
113 | initial_capital=config.get("initial_capital", 10000.0),
114 | fees=config.get("fees", 0.001),
115 | slippage=config.get("slippage", 0.001),
116 | )
117 | contexts.append(context)
118 |
119 | # Process in chunks to manage memory
120 | all_results = []
121 | successful_results = []
122 | failed_results = []
123 |
124 | for chunk_start in range(0, len(contexts), chunk_size):
125 | chunk_end = min(chunk_start + chunk_size, len(contexts))
126 | chunk_contexts = contexts[chunk_start:chunk_end]
127 |
128 | logger.info(
129 | f"Processing chunk {chunk_start // chunk_size + 1} ({len(chunk_contexts)} items)"
130 | )
131 |
132 | try:
133 | # Execute chunk in parallel
134 | chunk_results = await executor.execute_strategies(chunk_contexts)
135 |
136 | # Process results
137 | for result in chunk_results:
138 | all_results.append(result)
139 | if result.success:
140 | successful_results.append(result)
141 | else:
142 | failed_results.append(result)
143 | if fail_fast:
144 | logger.error(f"Batch failed fast on: {result.error}")
145 | break
146 |
147 | # Memory cleanup between chunks
148 | if getattr(self, "enable_memory_profiling", False):
149 | cleanup_dataframes()
150 | gc.collect()
151 |
152 | except Exception as e:
153 | logger.error(f"Chunk processing failed: {e}")
154 | if fail_fast:
155 | raise
156 | # Add failed result for chunk
157 | for context in chunk_contexts:
158 | failed_results.append(
159 | ExecutionResult(
160 | context=context,
161 | success=False,
162 | error=f"Chunk processing error: {e}",
163 | )
164 | )
165 |
166 | # Cleanup executor
167 | await executor.cleanup()
168 |
169 | # Calculate summary statistics
170 | total_execution_time = time.time() - start_time
171 | success_rate = (
172 | len(successful_results) / len(all_results) if all_results else 0.0
173 | )
174 |
175 | summary = {
176 | "batch_id": batch_id,
177 | "total_configs": len(batch_configs),
178 | "successful": len(successful_results),
179 | "failed": len(failed_results),
180 | "success_rate": success_rate,
181 | "total_execution_time": total_execution_time,
182 | "avg_execution_time": total_execution_time / len(all_results)
183 | if all_results
184 | else 0.0,
185 | "memory_stats": get_memory_stats()
186 | if getattr(self, "enable_memory_profiling", False)
187 | else None,
188 | }
189 |
190 | logger.info(f"Batch backtest {batch_id} completed: {summary}")
191 |
192 | return {
193 | "batch_id": batch_id,
194 | "summary": summary,
195 | "successful_results": [r.result for r in successful_results if r.result],
196 | "failed_results": [
197 | {
198 | "strategy_id": r.context.strategy_id,
199 | "symbol": r.context.symbol,
200 | "strategy_type": r.context.strategy_type,
201 | "error": r.error,
202 | }
203 | for r in failed_results
204 | ],
205 | "all_results": all_results,
206 | }
207 |
208 | @with_structured_logging(
209 | "batch_optimize_parameters", include_performance=True, log_params=True
210 | )
211 | async def batch_optimize_parameters(
212 | self,
213 | optimization_configs: list[dict[str, Any]],
214 | max_workers: int = 4,
215 | optimization_method: str = "grid_search",
216 | max_iterations: int = 100,
217 | ) -> dict[str, Any]:
218 | """
219 | Optimize strategy parameters for multiple symbols/strategies in parallel.
220 |
221 | Args:
222 | optimization_configs: List of optimization configurations, each containing:
223 | - symbol: Stock symbol
224 | - strategy_type: Strategy type name
225 | - parameter_ranges: Dictionary of parameter ranges to optimize
226 | - start_date: Start date string
227 | - end_date: End date string
228 | - optimization_metric: Metric to optimize (default: sharpe_ratio)
229 | - initial_capital: Starting capital
230 | max_workers: Maximum concurrent workers
231 | optimization_method: Optimization method ('grid_search', 'random_search')
232 | max_iterations: Maximum optimization iterations per config
233 |
234 | Returns:
235 | Dictionary containing optimization results for all configurations
236 | """
237 | start_time = time.time()
238 | batch_id = f"optimize_batch_{int(start_time)}"
239 |
240 | logger.info(
241 | f"Starting batch optimization {batch_id} with {len(optimization_configs)} configurations"
242 | )
243 |
244 | # Process optimizations in parallel
245 | optimization_tasks = []
246 | for i, config in enumerate(optimization_configs):
247 | task = self._run_single_optimization(
248 | config, f"{batch_id}_opt_{i}", optimization_method, max_iterations
249 | )
250 | optimization_tasks.append(task)
251 |
252 | # Execute with concurrency limit
253 | semaphore = asyncio.BoundedSemaphore(max_workers)
254 |
255 | async def limited_optimization(task):
256 | async with semaphore:
257 | return await task
258 |
259 | # Run all optimizations
260 | optimization_results = await asyncio.gather(
261 | *[limited_optimization(task) for task in optimization_tasks],
262 | return_exceptions=True,
263 | )
264 |
265 | # Process results
266 | successful_optimizations = []
267 | failed_optimizations = []
268 |
269 | for i, result in enumerate(optimization_results):
270 | if isinstance(result, Exception):
271 | failed_optimizations.append(
272 | {
273 | "config_index": i,
274 | "config": optimization_configs[i],
275 | "error": str(result),
276 | }
277 | )
278 | else:
279 | successful_optimizations.append(result)
280 |
281 | # Calculate summary
282 | total_execution_time = time.time() - start_time
283 | success_rate = (
284 | len(successful_optimizations) / len(optimization_configs)
285 | if optimization_configs
286 | else 0.0
287 | )
288 |
289 | summary = {
290 | "batch_id": batch_id,
291 | "total_optimizations": len(optimization_configs),
292 | "successful": len(successful_optimizations),
293 | "failed": len(failed_optimizations),
294 | "success_rate": success_rate,
295 | "total_execution_time": total_execution_time,
296 | "optimization_method": optimization_method,
297 | "max_iterations": max_iterations,
298 | }
299 |
300 | logger.info(f"Batch optimization {batch_id} completed: {summary}")
301 |
302 | return {
303 | "batch_id": batch_id,
304 | "summary": summary,
305 | "successful_optimizations": successful_optimizations,
306 | "failed_optimizations": failed_optimizations,
307 | }
308 |
309 | async def batch_validate_strategies(
310 | self,
311 | validation_configs: list[dict[str, Any]],
312 | validation_start_date: str,
313 | validation_end_date: str,
314 | max_workers: int = 6,
315 | ) -> dict[str, Any]:
316 | """
317 | Validate multiple strategies against out-of-sample data in parallel.
318 |
319 | Args:
320 | validation_configs: List of validation configurations with optimized parameters
321 | validation_start_date: Start date for validation period
322 | validation_end_date: End date for validation period
323 | max_workers: Maximum concurrent workers
324 |
325 | Returns:
326 | Dictionary containing validation results and performance comparison
327 | """
328 | start_time = time.time()
329 | batch_id = f"validate_batch_{int(start_time)}"
330 |
331 | logger.info(
332 | f"Starting batch validation {batch_id} with {len(validation_configs)} strategies"
333 | )
334 |
335 | # Create validation backtest configs
336 | validation_batch_configs = []
337 | for config in validation_configs:
338 | validation_config = {
339 | "symbol": config["symbol"],
340 | "strategy_type": config["strategy_type"],
341 | "parameters": config.get(
342 | "optimized_parameters", config.get("parameters", {})
343 | ),
344 | "start_date": validation_start_date,
345 | "end_date": validation_end_date,
346 | "initial_capital": config.get("initial_capital", 10000.0),
347 | "fees": config.get("fees", 0.001),
348 | "slippage": config.get("slippage", 0.001),
349 | }
350 | validation_batch_configs.append(validation_config)
351 |
352 | # Run validation backtests
353 | validation_results = await self.run_batch_backtest(
354 | validation_batch_configs,
355 | max_workers=max_workers,
356 | validate_data=True,
357 | fail_fast=False,
358 | )
359 |
360 | # Calculate validation metrics
361 | validation_metrics = self._calculate_validation_metrics(
362 | validation_configs, validation_results["successful_results"]
363 | )
364 |
365 | total_execution_time = time.time() - start_time
366 |
367 | return {
368 | "batch_id": batch_id,
369 | "validation_period": {
370 | "start_date": validation_start_date,
371 | "end_date": validation_end_date,
372 | },
373 | "summary": {
374 | "total_strategies": len(validation_configs),
375 | "validated_strategies": len(validation_results["successful_results"]),
376 | "validation_success_rate": len(validation_results["successful_results"])
377 | / len(validation_configs)
378 | if validation_configs
379 | else 0.0,
380 | "total_execution_time": total_execution_time,
381 | },
382 | "validation_results": validation_results["successful_results"],
383 | "validation_metrics": validation_metrics,
384 | "failed_validations": validation_results["failed_results"],
385 | }
386 |
387 | async def get_batch_results(
388 | self, batch_id: str, include_detailed_results: bool = False
389 | ) -> dict[str, Any] | None:
390 | """
391 | Retrieve results for a completed batch operation.
392 |
393 | Args:
394 | batch_id: Batch ID to retrieve results for
395 | include_detailed_results: Whether to include full result details
396 |
397 | Returns:
398 | Dictionary containing batch results or None if not found
399 | """
400 | # This would typically retrieve from a persistence layer
401 | # For now, return None as results are returned directly
402 | logger.warning(f"Batch result retrieval not implemented for {batch_id}")
403 | logger.info(
404 | "Batch results are currently returned directly from batch operations"
405 | )
406 |
407 | return None
408 |
409 | # Alias method for backward compatibility
410 | async def batch_optimize(self, *args, **kwargs):
411 | """Alias for batch_optimize_parameters for backward compatibility."""
412 | return await self.batch_optimize_parameters(*args, **kwargs)
413 |
414 | # =============================================================================
415 | # BATCH PROCESSING HELPER METHODS
416 | # =============================================================================
417 |
418 | def _validate_batch_config(self, config: dict[str, Any], config_name: str) -> None:
419 | """Validate a single batch configuration."""
420 | required_fields = [
421 | "symbol",
422 | "strategy_type",
423 | "parameters",
424 | "start_date",
425 | "end_date",
426 | ]
427 |
428 | for field in required_fields:
429 | if field not in config:
430 | raise ValueError(f"Missing required field '{field}' in {config_name}")
431 |
432 | # Validate dates
433 | try:
434 | from maverick_mcp.data.validation import DataValidator
435 |
436 | DataValidator.validate_date_range(config["start_date"], config["end_date"])
437 | except Exception as e:
438 | raise ValueError(f"Invalid date range in {config_name}: {e}") from e
439 |
440 | # Validate symbol
441 | if not isinstance(config["symbol"], str) or len(config["symbol"]) == 0:
442 | raise ValueError(f"Invalid symbol in {config_name}")
443 |
444 | # Validate strategy type
445 | if not isinstance(config["strategy_type"], str):
446 | raise ValueError(f"Invalid strategy_type in {config_name}")
447 |
448 | # Validate parameters
449 | if not isinstance(config["parameters"], dict):
450 | raise ValueError(f"Parameters must be a dictionary in {config_name}")
451 |
452 | async def _run_single_optimization(
453 | self,
454 | config: dict[str, Any],
455 | optimization_id: str,
456 | method: str,
457 | max_iterations: int,
458 | ) -> dict[str, Any]:
459 | """Run optimization for a single configuration."""
460 | try:
461 | # Extract configuration
462 | symbol = config["symbol"]
463 | strategy_type = config["strategy_type"]
464 | parameter_ranges = config["parameter_ranges"]
465 | start_date = config["start_date"]
466 | end_date = config["end_date"]
467 | optimization_metric = config.get("optimization_metric", "sharpe_ratio")
468 | initial_capital = config.get("initial_capital", 10000.0)
469 |
470 | # Simple parameter optimization (placeholder - would use actual optimizer)
471 | # For now, return basic result structure
472 | best_params = {}
473 | for param, ranges in parameter_ranges.items():
474 | if isinstance(ranges, list) and len(ranges) >= 2:
475 | # Use middle value as "optimized"
476 | best_params[param] = ranges[len(ranges) // 2]
477 | elif isinstance(ranges, dict):
478 | if "min" in ranges and "max" in ranges:
479 | best_params[param] = (ranges["min"] + ranges["max"]) / 2
480 |
481 | # Run a basic backtest with these parameters
482 | backtest_result = await self.run_backtest(
483 | symbol=symbol,
484 | strategy_type=strategy_type,
485 | parameters=best_params,
486 | start_date=start_date,
487 | end_date=end_date,
488 | initial_capital=initial_capital,
489 | )
490 |
491 | best_score = backtest_result.get("metrics", {}).get(
492 | optimization_metric, 0.0
493 | )
494 |
495 | return {
496 | "optimization_id": optimization_id,
497 | "symbol": symbol,
498 | "strategy_type": strategy_type,
499 | "optimized_parameters": best_params,
500 | "best_score": best_score,
501 | "optimization_history": [
502 | {"parameters": best_params, "score": best_score}
503 | ],
504 | "execution_time": 0.0,
505 | }
506 |
507 | except Exception as e:
508 | logger.error(f"Optimization failed for {optimization_id}: {e}")
509 | raise
510 |
511 | def _calculate_validation_metrics(
512 | self,
513 | original_configs: list[dict[str, Any]],
514 | validation_results: list[dict[str, Any]],
515 | ) -> dict[str, Any]:
516 | """Calculate validation metrics comparing in-sample vs out-of-sample performance."""
517 | metrics = {
518 | "strategy_comparisons": [],
519 | "aggregate_metrics": {
520 | "avg_in_sample_sharpe": 0.0,
521 | "avg_out_of_sample_sharpe": 0.0,
522 | "sharpe_degradation": 0.0,
523 | "strategies_with_positive_validation": 0,
524 | },
525 | }
526 |
527 | if not original_configs or not validation_results:
528 | return metrics
529 |
530 | sharpe_ratios_in_sample = []
531 | sharpe_ratios_out_of_sample = []
532 |
533 | for i, (original, validation) in enumerate(
534 | zip(original_configs, validation_results, strict=False)
535 | ):
536 | # Get in-sample performance (from original optimization)
537 | in_sample_sharpe = original.get("best_score", 0.0)
538 |
539 | # Get out-of-sample performance
540 | out_of_sample_sharpe = validation.get("metrics", {}).get(
541 | "sharpe_ratio", 0.0
542 | )
543 |
544 | strategy_comparison = {
545 | "strategy_index": i,
546 | "symbol": original["symbol"],
547 | "strategy_type": original["strategy_type"],
548 | "in_sample_sharpe": in_sample_sharpe,
549 | "out_of_sample_sharpe": out_of_sample_sharpe,
550 | "sharpe_degradation": in_sample_sharpe - out_of_sample_sharpe,
551 | "validation_success": out_of_sample_sharpe > 0,
552 | }
553 |
554 | metrics["strategy_comparisons"].append(strategy_comparison)
555 | sharpe_ratios_in_sample.append(in_sample_sharpe)
556 | sharpe_ratios_out_of_sample.append(out_of_sample_sharpe)
557 |
558 | # Calculate aggregate metrics
559 | if sharpe_ratios_in_sample and sharpe_ratios_out_of_sample:
560 | metrics["aggregate_metrics"]["avg_in_sample_sharpe"] = np.mean(
561 | sharpe_ratios_in_sample
562 | )
563 | metrics["aggregate_metrics"]["avg_out_of_sample_sharpe"] = np.mean(
564 | sharpe_ratios_out_of_sample
565 | )
566 | metrics["aggregate_metrics"]["sharpe_degradation"] = (
567 | metrics["aggregate_metrics"]["avg_in_sample_sharpe"]
568 | - metrics["aggregate_metrics"]["avg_out_of_sample_sharpe"]
569 | )
570 | metrics["aggregate_metrics"]["strategies_with_positive_validation"] = sum(
571 | 1
572 | for comp in metrics["strategy_comparisons"]
573 | if comp["validation_success"]
574 | )
575 |
576 | return metrics
577 |
```
--------------------------------------------------------------------------------
/maverick_mcp/workflows/agents/optimizer_agent.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Optimizer Agent for intelligent parameter optimization.
3 |
4 | This agent performs regime-aware parameter optimization for selected strategies,
5 | using adaptive grid sizes and optimization metrics based on market conditions.
6 | """
7 |
8 | import asyncio
9 | import logging
10 | from datetime import datetime, timedelta
11 | from typing import Any
12 |
13 | from maverick_mcp.backtesting import StrategyOptimizer, VectorBTEngine
14 | from maverick_mcp.backtesting.strategies.templates import get_strategy_info
15 | from maverick_mcp.workflows.state import BacktestingWorkflowState
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class OptimizerAgent:
21 | """Intelligent parameter optimizer with regime-aware optimization."""
22 |
23 | def __init__(
24 | self,
25 | vectorbt_engine: VectorBTEngine | None = None,
26 | strategy_optimizer: StrategyOptimizer | None = None,
27 | ):
28 | """Initialize optimizer agent.
29 |
30 | Args:
31 | vectorbt_engine: VectorBT backtesting engine
32 | strategy_optimizer: Strategy optimization engine
33 | """
34 | self.engine = vectorbt_engine or VectorBTEngine()
35 | self.optimizer = strategy_optimizer or StrategyOptimizer(self.engine)
36 |
37 | # Optimization configurations for different regimes
38 | self.REGIME_OPTIMIZATION_CONFIG = {
39 | "trending": {
40 | "optimization_metric": "total_return", # Focus on capturing trends
41 | "grid_size": "medium",
42 | "min_trades": 10,
43 | "max_drawdown_limit": 0.25,
44 | },
45 | "ranging": {
46 | "optimization_metric": "sharpe_ratio", # Focus on risk-adjusted returns
47 | "grid_size": "fine", # More precision needed for ranging markets
48 | "min_trades": 15,
49 | "max_drawdown_limit": 0.15,
50 | },
51 | "volatile": {
52 | "optimization_metric": "calmar_ratio", # Risk-adjusted for volatility
53 | "grid_size": "coarse", # Avoid overfitting in volatile conditions
54 | "min_trades": 8,
55 | "max_drawdown_limit": 0.35,
56 | },
57 | "volatile_trending": {
58 | "optimization_metric": "sortino_ratio", # Focus on downside risk
59 | "grid_size": "medium",
60 | "min_trades": 10,
61 | "max_drawdown_limit": 0.30,
62 | },
63 | "low_volume": {
64 | "optimization_metric": "win_rate", # Consistency important in low volume
65 | "grid_size": "medium",
66 | "min_trades": 12,
67 | "max_drawdown_limit": 0.20,
68 | },
69 | "unknown": {
70 | "optimization_metric": "sharpe_ratio", # Balanced approach
71 | "grid_size": "medium",
72 | "min_trades": 10,
73 | "max_drawdown_limit": 0.20,
74 | },
75 | }
76 |
77 | # Strategy-specific optimization parameters
78 | self.STRATEGY_PARAM_RANGES = {
79 | "sma_cross": {
80 | "fast_period": {
81 | "coarse": [5, 10, 15],
82 | "medium": [5, 8, 10, 12, 15, 20],
83 | "fine": list(range(5, 21)),
84 | },
85 | "slow_period": {
86 | "coarse": [20, 30, 50],
87 | "medium": [20, 25, 30, 40, 50],
88 | "fine": list(range(20, 51, 5)),
89 | },
90 | },
91 | "rsi": {
92 | "period": {
93 | "coarse": [10, 14, 21],
94 | "medium": [10, 12, 14, 16, 21],
95 | "fine": list(range(8, 25)),
96 | },
97 | "oversold": {
98 | "coarse": [25, 30, 35],
99 | "medium": [20, 25, 30, 35],
100 | "fine": list(range(15, 36, 5)),
101 | },
102 | "overbought": {
103 | "coarse": [65, 70, 75],
104 | "medium": [65, 70, 75, 80],
105 | "fine": list(range(65, 86, 5)),
106 | },
107 | },
108 | "macd": {
109 | "fast_period": {
110 | "coarse": [8, 12, 16],
111 | "medium": [8, 10, 12, 14, 16],
112 | "fine": list(range(8, 17)),
113 | },
114 | "slow_period": {
115 | "coarse": [21, 26, 32],
116 | "medium": [21, 24, 26, 28, 32],
117 | "fine": list(range(21, 35)),
118 | },
119 | "signal_period": {
120 | "coarse": [6, 9, 12],
121 | "medium": [6, 8, 9, 10, 12],
122 | "fine": list(range(6, 15)),
123 | },
124 | },
125 | "bollinger": {
126 | "period": {
127 | "coarse": [15, 20, 25],
128 | "medium": [15, 18, 20, 22, 25],
129 | "fine": list(range(12, 28)),
130 | },
131 | "std_dev": {
132 | "coarse": [1.5, 2.0, 2.5],
133 | "medium": [1.5, 1.8, 2.0, 2.2, 2.5],
134 | "fine": [1.0, 1.5, 1.8, 2.0, 2.2, 2.5, 3.0],
135 | },
136 | },
137 | "momentum": {
138 | "lookback": {
139 | "coarse": [10, 20, 30],
140 | "medium": [10, 15, 20, 25, 30],
141 | "fine": list(range(5, 31, 5)),
142 | },
143 | "threshold": {
144 | "coarse": [0.03, 0.05, 0.08],
145 | "medium": [0.02, 0.03, 0.05, 0.07, 0.10],
146 | "fine": [0.01, 0.02, 0.03, 0.04, 0.05, 0.07, 0.10, 0.15],
147 | },
148 | },
149 | }
150 |
151 | logger.info("OptimizerAgent initialized")
152 |
153 | async def optimize_parameters(
154 | self, state: BacktestingWorkflowState
155 | ) -> BacktestingWorkflowState:
156 | """Optimize parameters for selected strategies.
157 |
158 | Args:
159 | state: Current workflow state with selected strategies
160 |
161 | Returns:
162 | Updated state with optimization results
163 | """
164 | start_time = datetime.now()
165 |
166 | try:
167 | logger.info(
168 | f"Optimizing parameters for {len(state.selected_strategies)} strategies on {state.symbol}"
169 | )
170 |
171 | # Get optimization configuration based on regime
172 | optimization_config = self._get_optimization_config(
173 | state.market_regime, state.regime_confidence
174 | )
175 |
176 | # Generate parameter grids for each strategy
177 | parameter_grids = self._generate_parameter_grids(
178 | state.selected_strategies, optimization_config["grid_size"]
179 | )
180 |
181 | # Optimize each strategy
182 | optimization_results = {}
183 | best_parameters = {}
184 | total_iterations = 0
185 |
186 | # Use shorter timeframe for optimization to avoid overfitting
187 | opt_start_date = self._calculate_optimization_window(
188 | state.start_date, state.end_date
189 | )
190 |
191 | for strategy in state.selected_strategies:
192 | try:
193 | logger.info(f"Optimizing {strategy} strategy...")
194 |
195 | param_grid = parameter_grids.get(strategy, {})
196 | if not param_grid:
197 | logger.warning(
198 | f"No parameter grid for {strategy}, using defaults"
199 | )
200 | continue
201 |
202 | # Run optimization
203 | result = await self.engine.optimize_parameters(
204 | symbol=state.symbol,
205 | strategy_type=strategy,
206 | param_grid=param_grid,
207 | start_date=opt_start_date,
208 | end_date=state.end_date,
209 | optimization_metric=optimization_config["optimization_metric"],
210 | initial_capital=state.initial_capital,
211 | top_n=min(
212 | 10, len(state.selected_strategies) * 2
213 | ), # Adaptive top_n
214 | )
215 |
216 | # Filter results by quality metrics
217 | filtered_result = self._filter_optimization_results(
218 | result, optimization_config
219 | )
220 |
221 | optimization_results[strategy] = filtered_result
222 | best_parameters[strategy] = filtered_result.get(
223 | "best_parameters", {}
224 | )
225 | total_iterations += filtered_result.get("valid_combinations", 0)
226 |
227 | logger.info(
228 | f"Optimized {strategy}: {filtered_result.get('best_metric_value', 0):.3f} {optimization_config['optimization_metric']}"
229 | )
230 |
231 | except Exception as e:
232 | logger.error(f"Failed to optimize {strategy}: {e}")
233 | # Use default parameters as fallback
234 | strategy_info = get_strategy_info(strategy)
235 | best_parameters[strategy] = strategy_info.get("parameters", {})
236 | state.fallback_strategies_used.append(
237 | f"{strategy}_optimization_fallback"
238 | )
239 |
240 | # Update state
241 | state.optimization_config = optimization_config
242 | state.parameter_grids = parameter_grids
243 | state.optimization_results = optimization_results
244 | state.best_parameters = best_parameters
245 | state.optimization_iterations = total_iterations
246 |
247 | # Record execution time
248 | execution_time = (datetime.now() - start_time).total_seconds() * 1000
249 | state.optimization_time_ms = execution_time
250 |
251 | # Update workflow status
252 | state.workflow_status = "validating"
253 | state.current_step = "optimization_completed"
254 | state.steps_completed.append("parameter_optimization")
255 |
256 | logger.info(
257 | f"Parameter optimization completed for {state.symbol}: "
258 | f"{total_iterations} combinations tested in {execution_time:.0f}ms"
259 | )
260 |
261 | return state
262 |
263 | except Exception as e:
264 | error_info = {
265 | "step": "parameter_optimization",
266 | "error": str(e),
267 | "timestamp": datetime.now().isoformat(),
268 | "symbol": state.symbol,
269 | }
270 | state.errors_encountered.append(error_info)
271 |
272 | # Fallback to default parameters
273 | default_params = {}
274 | for strategy in state.selected_strategies:
275 | strategy_info = get_strategy_info(strategy)
276 | default_params[strategy] = strategy_info.get("parameters", {})
277 |
278 | state.best_parameters = default_params
279 | state.fallback_strategies_used.append("optimization_fallback")
280 |
281 | logger.error(f"Parameter optimization failed for {state.symbol}: {e}")
282 | return state
283 |
284 | def _get_optimization_config(
285 | self, regime: str, regime_confidence: float
286 | ) -> dict[str, Any]:
287 | """Get optimization configuration based on market regime."""
288 | base_config = self.REGIME_OPTIMIZATION_CONFIG.get(
289 | regime, self.REGIME_OPTIMIZATION_CONFIG["unknown"]
290 | ).copy()
291 |
292 | # Adjust grid size based on confidence
293 | if regime_confidence < 0.5:
294 | # Low confidence -> use coarser grid to avoid overfitting
295 | if base_config["grid_size"] == "fine":
296 | base_config["grid_size"] = "medium"
297 | elif base_config["grid_size"] == "medium":
298 | base_config["grid_size"] = "coarse"
299 |
300 | return base_config
301 |
302 | def _generate_parameter_grids(
303 | self, strategies: list[str], grid_size: str
304 | ) -> dict[str, dict[str, list]]:
305 | """Generate parameter grids for optimization."""
306 | parameter_grids = {}
307 |
308 | for strategy in strategies:
309 | if strategy in self.STRATEGY_PARAM_RANGES:
310 | param_ranges = self.STRATEGY_PARAM_RANGES[strategy]
311 | grid = {}
312 |
313 | for param_name, size_ranges in param_ranges.items():
314 | if grid_size in size_ranges:
315 | grid[param_name] = size_ranges[grid_size]
316 | else:
317 | # Fallback to medium if requested size not available
318 | grid[param_name] = size_ranges.get(
319 | "medium", size_ranges["coarse"]
320 | )
321 |
322 | parameter_grids[strategy] = grid
323 | else:
324 | # For strategies not in our predefined ranges, use default minimal grid
325 | parameter_grids[strategy] = self._generate_default_grid(
326 | strategy, grid_size
327 | )
328 |
329 | return parameter_grids
330 |
331 | def _generate_default_grid(self, strategy: str, grid_size: str) -> dict[str, list]:
332 | """Generate default parameter grid for unknown strategies."""
333 | # Get strategy info to understand default parameters
334 | strategy_info = get_strategy_info(strategy)
335 | default_params = strategy_info.get("parameters", {})
336 |
337 | grid = {}
338 |
339 | # Generate basic variations around default values
340 | for param_name, default_value in default_params.items():
341 | if isinstance(default_value, int | float):
342 | if grid_size == "coarse":
343 | variations = [
344 | default_value * 0.8,
345 | default_value,
346 | default_value * 1.2,
347 | ]
348 | elif grid_size == "fine":
349 | variations = [
350 | default_value * 0.7,
351 | default_value * 0.8,
352 | default_value * 0.9,
353 | default_value,
354 | default_value * 1.1,
355 | default_value * 1.2,
356 | default_value * 1.3,
357 | ]
358 | else: # medium
359 | variations = [
360 | default_value * 0.8,
361 | default_value * 0.9,
362 | default_value,
363 | default_value * 1.1,
364 | default_value * 1.2,
365 | ]
366 |
367 | # Convert back to appropriate type and filter valid values
368 | if isinstance(default_value, int):
369 | grid[param_name] = [max(1, int(v)) for v in variations]
370 | else:
371 | grid[param_name] = [max(0.001, v) for v in variations]
372 | else:
373 | # For non-numeric parameters, just use the default
374 | grid[param_name] = [default_value]
375 |
376 | return grid
377 |
378 | def _calculate_optimization_window(self, start_date: str, end_date: str) -> str:
379 | """Calculate optimization window to prevent overfitting."""
380 | start_dt = datetime.strptime(start_date, "%Y-%m-%d")
381 | end_dt = datetime.strptime(end_date, "%Y-%m-%d")
382 |
383 | total_days = (end_dt - start_dt).days
384 |
385 | # Use 70% of data for optimization, leaving 30% for validation
386 | opt_days = int(total_days * 0.7)
387 | opt_start = end_dt - timedelta(days=opt_days)
388 |
389 | return opt_start.strftime("%Y-%m-%d")
390 |
391 | def _filter_optimization_results(
392 | self, result: dict[str, Any], optimization_config: dict[str, Any]
393 | ) -> dict[str, Any]:
394 | """Filter optimization results based on quality criteria."""
395 | if "top_results" not in result or not result["top_results"]:
396 | return result
397 |
398 | # Quality filters
399 | min_trades = optimization_config.get("min_trades", 10)
400 | max_drawdown_limit = optimization_config.get("max_drawdown_limit", 0.25)
401 |
402 | # Filter results by quality criteria
403 | filtered_results = []
404 | for res in result["top_results"]:
405 | # Check minimum trades
406 | if res.get("total_trades", 0) < min_trades:
407 | continue
408 |
409 | # Check maximum drawdown
410 | if abs(res.get("max_drawdown", 0)) > max_drawdown_limit:
411 | continue
412 |
413 | filtered_results.append(res)
414 |
415 | # If no results pass filters, relax criteria
416 | if not filtered_results and result["top_results"]:
417 | logger.warning("No results passed quality filters, relaxing criteria")
418 | # Take top results but with warning
419 | filtered_results = result["top_results"][:3] # Top 3 regardless of quality
420 |
421 | # Update result with filtered data
422 | filtered_result = result.copy()
423 | filtered_result["top_results"] = filtered_results
424 |
425 | if filtered_results:
426 | filtered_result["best_parameters"] = filtered_results[0]["parameters"]
427 | filtered_result["best_metric_value"] = filtered_results[0][
428 | optimization_config["optimization_metric"]
429 | ]
430 | else:
431 | # Complete fallback
432 | filtered_result["best_parameters"] = {}
433 | filtered_result["best_metric_value"] = 0.0
434 |
435 | return filtered_result
436 |
437 | def get_optimization_summary(
438 | self, state: BacktestingWorkflowState
439 | ) -> dict[str, Any]:
440 | """Get summary of optimization results."""
441 | if not state.optimization_results:
442 | return {"summary": "No optimization results available"}
443 |
444 | summary = {
445 | "total_strategies": len(state.selected_strategies),
446 | "optimized_strategies": len(state.optimization_results),
447 | "total_iterations": state.optimization_iterations,
448 | "execution_time_ms": state.optimization_time_ms,
449 | "optimization_config": state.optimization_config,
450 | "strategy_results": {},
451 | }
452 |
453 | for strategy, results in state.optimization_results.items():
454 | if results:
455 | summary["strategy_results"][strategy] = {
456 | "best_metric": results.get("best_metric_value", 0),
457 | "metric_type": state.optimization_config.get(
458 | "optimization_metric", "unknown"
459 | ),
460 | "valid_combinations": results.get("valid_combinations", 0),
461 | "best_parameters": state.best_parameters.get(strategy, {}),
462 | }
463 |
464 | return summary
465 |
466 | async def parallel_optimization(
467 | self, state: BacktestingWorkflowState, max_concurrent: int = 3
468 | ) -> BacktestingWorkflowState:
469 | """Run optimization for multiple strategies in parallel."""
470 | if len(state.selected_strategies) <= 1:
471 | return await self.optimize_parameters(state)
472 |
473 | start_time = datetime.now()
474 | logger.info(
475 | f"Running parallel optimization for {len(state.selected_strategies)} strategies"
476 | )
477 |
478 | # Create semaphore to limit concurrent optimizations
479 | semaphore = asyncio.Semaphore(max_concurrent)
480 |
481 | async def optimize_single_strategy(strategy: str) -> tuple[str, dict[str, Any]]:
482 | async with semaphore:
483 | try:
484 | optimization_config = self._get_optimization_config(
485 | state.market_regime, state.regime_confidence
486 | )
487 | parameter_grids = self._generate_parameter_grids(
488 | [strategy], optimization_config["grid_size"]
489 | )
490 |
491 | opt_start_date = self._calculate_optimization_window(
492 | state.start_date, state.end_date
493 | )
494 |
495 | result = await self.engine.optimize_parameters(
496 | symbol=state.symbol,
497 | strategy_type=strategy,
498 | param_grid=parameter_grids.get(strategy, {}),
499 | start_date=opt_start_date,
500 | end_date=state.end_date,
501 | optimization_metric=optimization_config["optimization_metric"],
502 | initial_capital=state.initial_capital,
503 | top_n=10,
504 | )
505 |
506 | filtered_result = self._filter_optimization_results(
507 | result, optimization_config
508 | )
509 | return strategy, filtered_result
510 |
511 | except Exception as e:
512 | logger.error(f"Failed to optimize {strategy}: {e}")
513 | return strategy, {"error": str(e)}
514 |
515 | # Run optimizations in parallel
516 | tasks = [
517 | optimize_single_strategy(strategy) for strategy in state.selected_strategies
518 | ]
519 | results = await asyncio.gather(*tasks, return_exceptions=True)
520 |
521 | # Process results
522 | optimization_results = {}
523 | best_parameters = {}
524 | total_iterations = 0
525 |
526 | for result in results:
527 | if isinstance(result, Exception):
528 | logger.error(f"Parallel optimization failed: {result}")
529 | continue
530 |
531 | strategy, opt_result = result
532 | if "error" not in opt_result:
533 | optimization_results[strategy] = opt_result
534 | best_parameters[strategy] = opt_result.get("best_parameters", {})
535 | total_iterations += opt_result.get("valid_combinations", 0)
536 |
537 | # Update state
538 | optimization_config = self._get_optimization_config(
539 | state.market_regime, state.regime_confidence
540 | )
541 | state.optimization_config = optimization_config
542 | state.optimization_results = optimization_results
543 | state.best_parameters = best_parameters
544 | state.optimization_iterations = total_iterations
545 |
546 | # Record execution time
547 | execution_time = (datetime.now() - start_time).total_seconds() * 1000
548 | state.optimization_time_ms = execution_time
549 |
550 | logger.info(f"Parallel optimization completed in {execution_time:.0f}ms")
551 | return state
552 |
```
--------------------------------------------------------------------------------
/scripts/run_stock_screening.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Stock screening script for self-contained Maverick-MCP database.
4 |
5 | This script runs various stock screening algorithms and populates the
6 | screening tables with results, making the system completely self-contained.
7 |
8 | Usage:
9 | python scripts/run_stock_screening.py --all
10 | python scripts/run_stock_screening.py --maverick
11 | python scripts/run_stock_screening.py --bear
12 | python scripts/run_stock_screening.py --supply-demand
13 | """
14 |
15 | import argparse
16 | import asyncio
17 | import logging
18 | import sys
19 | from datetime import datetime, timedelta
20 | from pathlib import Path
21 |
22 | import numpy as np
23 | import pandas as pd
24 | import talib
25 |
26 | # Add parent directory to path for imports
27 | sys.path.append(str(Path(__file__).parent.parent))
28 |
29 | from maverick_mcp.config.database_self_contained import (
30 | SelfContainedDatabaseSession,
31 | init_self_contained_database,
32 | )
33 | from maverick_mcp.data.models import (
34 | MaverickBearStocks,
35 | MaverickStocks,
36 | PriceCache,
37 | Stock,
38 | SupplyDemandBreakoutStocks,
39 | bulk_insert_screening_data,
40 | )
41 |
42 | # Set up logging
43 | logging.basicConfig(
44 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
45 | )
46 | logger = logging.getLogger("stock_screener")
47 |
48 |
49 | class TechnicalAnalyzer:
50 | """Calculates technical indicators for stock screening."""
51 |
52 | @staticmethod
53 | def calculate_moving_averages(df: pd.DataFrame) -> pd.DataFrame:
54 | """Calculate various moving averages."""
55 | df["SMA_20"] = talib.SMA(df["close"].values, timeperiod=20)
56 | df["SMA_50"] = talib.SMA(df["close"].values, timeperiod=50)
57 | df["SMA_150"] = talib.SMA(df["close"].values, timeperiod=150)
58 | df["SMA_200"] = talib.SMA(df["close"].values, timeperiod=200)
59 | df["EMA_21"] = talib.EMA(df["close"].values, timeperiod=21)
60 | return df
61 |
62 | @staticmethod
63 | def calculate_rsi(df: pd.DataFrame, period: int = 14) -> pd.DataFrame:
64 | """Calculate RSI indicator."""
65 | df[f"RSI_{period}"] = talib.RSI(df["close"].values, timeperiod=period)
66 | return df
67 |
68 | @staticmethod
69 | def calculate_macd(df: pd.DataFrame) -> pd.DataFrame:
70 | """Calculate MACD indicator."""
71 | macd, macd_signal, macd_hist = talib.MACD(df["close"].values)
72 | df["MACD"] = macd
73 | df["MACD_Signal"] = macd_signal
74 | df["MACD_Histogram"] = macd_hist
75 | return df
76 |
77 | @staticmethod
78 | def calculate_atr(df: pd.DataFrame, period: int = 14) -> pd.DataFrame:
79 | """Calculate Average True Range."""
80 | df[f"ATR_{period}"] = talib.ATR(
81 | df["high"].values, df["low"].values, df["close"].values, timeperiod=period
82 | )
83 | return df
84 |
85 | @staticmethod
86 | def calculate_relative_strength(
87 | df: pd.DataFrame, benchmark_df: pd.DataFrame
88 | ) -> pd.DataFrame:
89 | """Calculate relative strength vs benchmark (simplified)."""
90 | # Simplified RS calculation - in production would use proper technical analysis methodology
91 | stock_returns = df["close"].pct_change(periods=252).fillna(0) # 1 year
92 | benchmark_returns = benchmark_df["close"].pct_change(periods=252).fillna(0)
93 |
94 | # Momentum Score approximation (0-100 scale)
95 | relative_performance = stock_returns - benchmark_returns
96 | df["Momentum_Score"] = np.clip((relative_performance + 1) * 50, 0, 100)
97 | return df
98 |
99 | @staticmethod
100 | def detect_patterns(df: pd.DataFrame) -> pd.DataFrame:
101 | """Detect chart patterns (simplified)."""
102 | df["Pattern"] = "None"
103 | df["Squeeze"] = "None"
104 | df["Consolidation"] = "None"
105 | df["Entry"] = "None"
106 |
107 | # Simplified pattern detection
108 | latest = df.iloc[-1]
109 |
110 | # Basic trend detection
111 | if (
112 | latest["close"] > latest["SMA_20"]
113 | and latest["SMA_20"] > latest["SMA_50"]
114 | and latest["SMA_50"] > latest["SMA_150"]
115 | ):
116 | df.loc[df.index[-1], "Pattern"] = "Uptrend"
117 |
118 | # Basic squeeze detection (Bollinger Band width vs ATR)
119 | if latest["ATR_14"] < df["ATR_14"].rolling(20).mean().iloc[-1]:
120 | df.loc[df.index[-1], "Squeeze"] = "Yes"
121 |
122 | return df
123 |
124 |
125 | class StockScreener:
126 | """Runs various stock screening algorithms."""
127 |
128 | def __init__(self):
129 | self.analyzer = TechnicalAnalyzer()
130 |
131 | async def get_stock_data(
132 | self, session, symbol: str, days: int = 365
133 | ) -> pd.DataFrame | None:
134 | """
135 | Get stock price data from database.
136 |
137 | Args:
138 | session: Database session
139 | symbol: Stock ticker symbol
140 | days: Number of days of historical data
141 |
142 | Returns:
143 | DataFrame with price data or None
144 | """
145 | cutoff_date = datetime.now().date() - timedelta(days=days)
146 |
147 | query = (
148 | session.query(PriceCache)
149 | .join(Stock)
150 | .filter(Stock.ticker_symbol == symbol, PriceCache.date >= cutoff_date)
151 | .order_by(PriceCache.date)
152 | )
153 |
154 | records = query.all()
155 | if not records:
156 | return None
157 |
158 | data = []
159 | for record in records:
160 | data.append(
161 | {
162 | "date": record.date,
163 | "open": float(record.open_price) if record.open_price else 0,
164 | "high": float(record.high_price) if record.high_price else 0,
165 | "low": float(record.low_price) if record.low_price else 0,
166 | "close": float(record.close_price) if record.close_price else 0,
167 | "volume": record.volume or 0,
168 | }
169 | )
170 |
171 | if not data:
172 | return None
173 |
174 | df = pd.DataFrame(data)
175 | df.set_index("date", inplace=True)
176 |
177 | return df
178 |
179 | async def run_maverick_screening(self, session) -> list[dict]:
180 | """
181 | Run Maverick momentum screening algorithm.
182 |
183 | Returns:
184 | List of screening results
185 | """
186 | logger.info("Running Maverick momentum screening...")
187 |
188 | # Get all active stocks
189 | stocks = session.query(Stock).filter(Stock.is_active).all()
190 | results = []
191 |
192 | for stock in stocks:
193 | try:
194 | df = await self.get_stock_data(session, stock.ticker_symbol, days=365)
195 | if df is None or len(df) < 200:
196 | continue
197 |
198 | # Calculate technical indicators
199 | df = self.analyzer.calculate_moving_averages(df)
200 | df = self.analyzer.calculate_rsi(df)
201 | df = self.analyzer.calculate_atr(df)
202 | df = self.analyzer.detect_patterns(df)
203 |
204 | latest = df.iloc[-1]
205 |
206 | # Maverick screening criteria (simplified)
207 | score = 0
208 |
209 | # Price above moving averages
210 | if latest["close"] > latest["SMA_50"]:
211 | score += 25
212 | if latest["close"] > latest["SMA_150"]:
213 | score += 25
214 | if latest["close"] > latest["SMA_200"]:
215 | score += 25
216 |
217 | # Moving average alignment
218 | if (
219 | latest["SMA_50"] > latest["SMA_150"]
220 | and latest["SMA_150"] > latest["SMA_200"]
221 | ):
222 | score += 25
223 |
224 | # Volume above average
225 | avg_volume = df["volume"].rolling(30).mean().iloc[-1]
226 | if latest["volume"] > avg_volume * 1.5:
227 | score += 10
228 |
229 | # RSI not overbought
230 | if latest["RSI_14"] < 80:
231 | score += 10
232 |
233 | # Pattern detection bonus
234 | if latest["Pattern"] == "Uptrend":
235 | score += 15
236 |
237 | if score >= 50: # Minimum threshold
238 | result = {
239 | "ticker": stock.ticker_symbol,
240 | "open_price": latest["open"],
241 | "high_price": latest["high"],
242 | "low_price": latest["low"],
243 | "close_price": latest["close"],
244 | "volume": int(latest["volume"]),
245 | "ema_21": latest["EMA_21"],
246 | "sma_50": latest["SMA_50"],
247 | "sma_150": latest["SMA_150"],
248 | "sma_200": latest["SMA_200"],
249 | "momentum_score": latest.get("Momentum_Score", 50),
250 | "avg_vol_30d": avg_volume,
251 | "adr_pct": (
252 | (latest["high"] - latest["low"]) / latest["close"] * 100
253 | ),
254 | "atr": latest["ATR_14"],
255 | "pattern_type": latest["Pattern"],
256 | "squeeze_status": latest["Squeeze"],
257 | "consolidation_status": latest["Consolidation"],
258 | "entry_signal": latest["Entry"],
259 | "compression_score": min(score // 10, 10),
260 | "pattern_detected": 1 if latest["Pattern"] != "None" else 0,
261 | "combined_score": score,
262 | }
263 | results.append(result)
264 |
265 | except Exception as e:
266 | logger.warning(f"Error screening {stock.ticker_symbol}: {e}")
267 | continue
268 |
269 | logger.info(f"Maverick screening found {len(results)} candidates")
270 | return results
271 |
272 | async def run_bear_screening(self, session) -> list[dict]:
273 | """
274 | Run bear market screening algorithm.
275 |
276 | Returns:
277 | List of screening results
278 | """
279 | logger.info("Running bear market screening...")
280 |
281 | stocks = session.query(Stock).filter(Stock.is_active).all()
282 | results = []
283 |
284 | for stock in stocks:
285 | try:
286 | df = await self.get_stock_data(session, stock.ticker_symbol, days=365)
287 | if df is None or len(df) < 200:
288 | continue
289 |
290 | # Calculate technical indicators
291 | df = self.analyzer.calculate_moving_averages(df)
292 | df = self.analyzer.calculate_rsi(df)
293 | df = self.analyzer.calculate_macd(df)
294 | df = self.analyzer.calculate_atr(df)
295 |
296 | latest = df.iloc[-1]
297 |
298 | # Bear screening criteria
299 | score = 0
300 |
301 | # Price below moving averages (bearish)
302 | if latest["close"] < latest["SMA_50"]:
303 | score += 20
304 | if latest["close"] < latest["SMA_200"]:
305 | score += 20
306 |
307 | # RSI oversold
308 | if latest["RSI_14"] < 30:
309 | score += 15
310 | elif latest["RSI_14"] < 40:
311 | score += 10
312 |
313 | # MACD bearish
314 | if latest["MACD"] < latest["MACD_Signal"]:
315 | score += 15
316 |
317 | # High volume decline
318 | avg_volume = df["volume"].rolling(30).mean().iloc[-1]
319 | if (
320 | latest["volume"] > avg_volume * 1.2
321 | and latest["close"] < df["close"].iloc[-2]
322 | ):
323 | score += 20
324 |
325 | # ATR contraction (consolidation)
326 | atr_avg = df["ATR_14"].rolling(20).mean().iloc[-1]
327 | atr_contraction = latest["ATR_14"] < atr_avg * 0.8
328 | if atr_contraction:
329 | score += 10
330 |
331 | if score >= 40: # Minimum threshold for bear candidates
332 | # Calculate distance from 20-day SMA
333 | sma_20 = df["close"].rolling(20).mean().iloc[-1]
334 | dist_from_sma20 = (latest["close"] - sma_20) / sma_20 * 100
335 |
336 | result = {
337 | "ticker": stock.ticker_symbol,
338 | "open_price": latest["open"],
339 | "high_price": latest["high"],
340 | "low_price": latest["low"],
341 | "close_price": latest["close"],
342 | "volume": int(latest["volume"]),
343 | "momentum_score": latest.get("Momentum_Score", 50),
344 | "ema_21": latest["EMA_21"],
345 | "sma_50": latest["SMA_50"],
346 | "sma_200": latest["SMA_200"],
347 | "rsi_14": latest["RSI_14"],
348 | "macd": latest["MACD"],
349 | "macd_signal": latest["MACD_Signal"],
350 | "macd_histogram": latest["MACD_Histogram"],
351 | "dist_days_20": int(abs(dist_from_sma20)),
352 | "adr_pct": (
353 | (latest["high"] - latest["low"]) / latest["close"] * 100
354 | ),
355 | "atr_contraction": atr_contraction,
356 | "atr": latest["ATR_14"],
357 | "avg_vol_30d": avg_volume,
358 | "big_down_vol": (
359 | latest["volume"] > avg_volume * 1.5
360 | and latest["close"] < df["close"].iloc[-2]
361 | ),
362 | "squeeze_status": "Contraction" if atr_contraction else "None",
363 | "consolidation_status": "None",
364 | "score": score,
365 | }
366 | results.append(result)
367 |
368 | except Exception as e:
369 | logger.warning(f"Error in bear screening {stock.ticker_symbol}: {e}")
370 | continue
371 |
372 | logger.info(f"Bear screening found {len(results)} candidates")
373 | return results
374 |
375 | async def run_supply_demand_screening(self, session) -> list[dict]:
376 | """
377 | Run supply/demand breakout screening algorithm.
378 |
379 | Returns:
380 | List of screening results
381 | """
382 | logger.info("Running supply/demand breakout screening...")
383 |
384 | stocks = session.query(Stock).filter(Stock.is_active).all()
385 | results = []
386 |
387 | for stock in stocks:
388 | try:
389 | df = await self.get_stock_data(session, stock.ticker_symbol, days=365)
390 | if df is None or len(df) < 200:
391 | continue
392 |
393 | # Calculate technical indicators
394 | df = self.analyzer.calculate_moving_averages(df)
395 | df = self.analyzer.calculate_atr(df)
396 | df = self.analyzer.detect_patterns(df)
397 |
398 | latest = df.iloc[-1]
399 |
400 | # Supply/Demand criteria (Technical Breakout Analysis)
401 | meets_criteria = True
402 |
403 | # Criteria 1: Current stock price > 150 and 200-day SMA
404 | if not (
405 | latest["close"] > latest["SMA_150"]
406 | and latest["close"] > latest["SMA_200"]
407 | ):
408 | meets_criteria = False
409 |
410 | # Criteria 2: 150-day SMA > 200-day SMA
411 | if not (latest["SMA_150"] > latest["SMA_200"]):
412 | meets_criteria = False
413 |
414 | # Criteria 3: 200-day SMA trending up for at least 1 month
415 | sma_200_1m_ago = (
416 | df["SMA_200"].iloc[-22] if len(df) > 22 else df["SMA_200"].iloc[0]
417 | )
418 | if not (latest["SMA_200"] > sma_200_1m_ago):
419 | meets_criteria = False
420 |
421 | # Criteria 4: 50-day SMA > 150 and 200-day SMA
422 | if not (
423 | latest["SMA_50"] > latest["SMA_150"]
424 | and latest["SMA_50"] > latest["SMA_200"]
425 | ):
426 | meets_criteria = False
427 |
428 | # Criteria 5: Current stock price > 50-day SMA
429 | if not (latest["close"] > latest["SMA_50"]):
430 | meets_criteria = False
431 |
432 | # Additional scoring for quality
433 | accumulation_rating = 0
434 | distribution_rating = 0
435 | breakout_strength = 0
436 |
437 | # Price above all MAs = accumulation
438 | if (
439 | latest["close"]
440 | > latest["SMA_50"]
441 | > latest["SMA_150"]
442 | > latest["SMA_200"]
443 | ):
444 | accumulation_rating = 85
445 |
446 | # Volume above average = institutional interest
447 | avg_volume = df["volume"].rolling(30).mean().iloc[-1]
448 | if latest["volume"] > avg_volume * 1.2:
449 | breakout_strength += 25
450 |
451 | # Price near 52-week high
452 | high_52w = df["high"].rolling(252).max().iloc[-1]
453 | if latest["close"] > high_52w * 0.75: # Within 25% of 52-week high
454 | breakout_strength += 25
455 |
456 | if meets_criteria:
457 | result = {
458 | "ticker": stock.ticker_symbol,
459 | "open_price": latest["open"],
460 | "high_price": latest["high"],
461 | "low_price": latest["low"],
462 | "close_price": latest["close"],
463 | "volume": int(latest["volume"]),
464 | "ema_21": latest["EMA_21"],
465 | "sma_50": latest["SMA_50"],
466 | "sma_150": latest["SMA_150"],
467 | "sma_200": latest["SMA_200"],
468 | "momentum_score": latest.get(
469 | "Momentum_Score", 75
470 | ), # Higher default for qualified stocks
471 | "avg_volume_30d": avg_volume,
472 | "adr_pct": (
473 | (latest["high"] - latest["low"]) / latest["close"] * 100
474 | ),
475 | "atr": latest["ATR_14"],
476 | "pattern_type": latest["Pattern"],
477 | "squeeze_status": latest["Squeeze"],
478 | "consolidation_status": latest["Consolidation"],
479 | "entry_signal": latest["Entry"],
480 | "accumulation_rating": accumulation_rating,
481 | "distribution_rating": distribution_rating,
482 | "breakout_strength": breakout_strength,
483 | }
484 | results.append(result)
485 |
486 | except Exception as e:
487 | logger.warning(
488 | f"Error in supply/demand screening {stock.ticker_symbol}: {e}"
489 | )
490 | continue
491 |
492 | logger.info(f"Supply/demand screening found {len(results)} candidates")
493 | return results
494 |
495 |
496 | async def main():
497 | """Main function to run stock screening."""
498 | parser = argparse.ArgumentParser(description="Run stock screening algorithms")
499 | parser.add_argument(
500 | "--all", action="store_true", help="Run all screening algorithms"
501 | )
502 | parser.add_argument(
503 | "--maverick", action="store_true", help="Run Maverick momentum screening"
504 | )
505 | parser.add_argument("--bear", action="store_true", help="Run bear market screening")
506 | parser.add_argument(
507 | "--supply-demand", action="store_true", help="Run supply/demand screening"
508 | )
509 | parser.add_argument("--database-url", type=str, help="Override database URL")
510 |
511 | args = parser.parse_args()
512 |
513 | if not any([args.all, args.maverick, args.bear, args.supply_demand]):
514 | parser.print_help()
515 | sys.exit(1)
516 |
517 | # Initialize database
518 | try:
519 | init_self_contained_database(database_url=args.database_url)
520 | logger.info("Self-contained database initialized")
521 | except Exception as e:
522 | logger.error(f"Database initialization failed: {e}")
523 | sys.exit(1)
524 |
525 | # Initialize screener
526 | screener = StockScreener()
527 | today = datetime.now().date()
528 |
529 | with SelfContainedDatabaseSession() as session:
530 | # Run Maverick screening
531 | if args.all or args.maverick:
532 | try:
533 | maverick_results = await screener.run_maverick_screening(session)
534 | if maverick_results:
535 | count = bulk_insert_screening_data(
536 | session, MaverickStocks, maverick_results, today
537 | )
538 | logger.info(f"Inserted {count} Maverick screening results")
539 | except Exception as e:
540 | logger.error(f"Maverick screening failed: {e}")
541 |
542 | # Run Bear screening
543 | if args.all or args.bear:
544 | try:
545 | bear_results = await screener.run_bear_screening(session)
546 | if bear_results:
547 | count = bulk_insert_screening_data(
548 | session, MaverickBearStocks, bear_results, today
549 | )
550 | logger.info(f"Inserted {count} Bear screening results")
551 | except Exception as e:
552 | logger.error(f"Bear screening failed: {e}")
553 |
554 | # Run Supply/Demand screening
555 | if args.all or args.supply_demand:
556 | try:
557 | sd_results = await screener.run_supply_demand_screening(session)
558 | if sd_results:
559 | count = bulk_insert_screening_data(
560 | session, SupplyDemandBreakoutStocks, sd_results, today
561 | )
562 | logger.info(f"Inserted {count} Supply/Demand screening results")
563 | except Exception as e:
564 | logger.error(f"Supply/Demand screening failed: {e}")
565 |
566 | # Display final stats
567 | from maverick_mcp.config.database_self_contained import get_self_contained_db_config
568 |
569 | db_config = get_self_contained_db_config()
570 | stats = db_config.get_database_stats()
571 |
572 | print("\n📊 Final Database Statistics:")
573 | print(f" Total Records: {stats['total_records']}")
574 | for table, count in stats["tables"].items():
575 | if "screening" in table or "maverick" in table or "supply_demand" in table:
576 | print(f" {table}: {count}")
577 |
578 | print("\n✅ Stock screening completed successfully!")
579 |
580 |
581 | if __name__ == "__main__":
582 | asyncio.run(main())
583 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/debug_utils.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Debug utilities for backtesting system troubleshooting.
3 |
4 | This module provides comprehensive debugging tools including:
5 | - Request/response logging
6 | - Performance profiling
7 | - Memory analysis
8 | - Error tracking
9 | - Debug mode utilities
10 | """
11 |
12 | import inspect
13 | import json
14 | import time
15 | import traceback
16 | from collections.abc import Callable, Generator
17 | from contextlib import contextmanager
18 | from datetime import UTC, datetime
19 | from functools import wraps
20 | from typing import Any
21 |
22 | import psutil
23 |
24 | from maverick_mcp.utils.structured_logger import (
25 | CorrelationIDGenerator,
26 | get_logger_manager,
27 | get_performance_logger,
28 | get_structured_logger,
29 | )
30 |
31 |
32 | class DebugProfiler:
33 | """Comprehensive debug profiler for performance analysis."""
34 |
35 | def __init__(self):
36 | self.logger = get_structured_logger("maverick_mcp.debug")
37 | self.performance_logger = get_performance_logger("debug_profiler")
38 | self._profiles: dict[str, dict[str, Any]] = {}
39 |
40 | def start_profile(self, profile_name: str, **context):
41 | """Start a debug profiling session."""
42 | profile_id = f"{profile_name}_{int(time.time() * 1000)}"
43 |
44 | profile_data = {
45 | "profile_name": profile_name,
46 | "profile_id": profile_id,
47 | "start_time": time.time(),
48 | "start_memory": self._get_memory_usage(),
49 | "start_cpu": self._get_cpu_usage(),
50 | "context": context,
51 | "checkpoints": [],
52 | }
53 |
54 | self._profiles[profile_id] = profile_data
55 |
56 | self.logger.debug(
57 | f"Started debug profile: {profile_name}",
58 | extra={
59 | "profile_id": profile_id,
60 | "start_memory_mb": profile_data["start_memory"],
61 | "start_cpu_percent": profile_data["start_cpu"],
62 | **context,
63 | },
64 | )
65 |
66 | return profile_id
67 |
68 | def checkpoint(self, profile_id: str, checkpoint_name: str, **data):
69 | """Add a checkpoint to an active profile."""
70 | if profile_id not in self._profiles:
71 | self.logger.warning(f"Profile {profile_id} not found for checkpoint")
72 | return
73 |
74 | profile = self._profiles[profile_id]
75 | current_time = time.time()
76 | elapsed_ms = (current_time - profile["start_time"]) * 1000
77 |
78 | checkpoint_data = {
79 | "name": checkpoint_name,
80 | "timestamp": current_time,
81 | "elapsed_ms": elapsed_ms,
82 | "memory_mb": self._get_memory_usage(),
83 | "cpu_percent": self._get_cpu_usage(),
84 | "data": data,
85 | }
86 |
87 | profile["checkpoints"].append(checkpoint_data)
88 |
89 | self.logger.debug(
90 | f"Profile checkpoint: {checkpoint_name} at {elapsed_ms:.2f}ms",
91 | extra={
92 | "profile_id": profile_id,
93 | "checkpoint": checkpoint_name,
94 | "elapsed_ms": elapsed_ms,
95 | "memory_mb": checkpoint_data["memory_mb"],
96 | **data,
97 | },
98 | )
99 |
100 | def end_profile(
101 | self, profile_id: str, success: bool = True, **final_data
102 | ) -> dict[str, Any]:
103 | """End a debug profiling session and return comprehensive results."""
104 | if profile_id not in self._profiles:
105 | self.logger.warning(f"Profile {profile_id} not found for ending")
106 | return {}
107 |
108 | profile = self._profiles.pop(profile_id)
109 | end_time = time.time()
110 | total_duration_ms = (end_time - profile["start_time"]) * 1000
111 |
112 | # Calculate memory and CPU deltas
113 | end_memory = self._get_memory_usage()
114 | end_cpu = self._get_cpu_usage()
115 | memory_delta = end_memory - profile["start_memory"]
116 |
117 | results = {
118 | "profile_name": profile["profile_name"],
119 | "profile_id": profile_id,
120 | "success": success,
121 | "total_duration_ms": total_duration_ms,
122 | "start_time": profile["start_time"],
123 | "end_time": end_time,
124 | "memory_stats": {
125 | "start_mb": profile["start_memory"],
126 | "end_mb": end_memory,
127 | "delta_mb": memory_delta,
128 | "peak_usage": max(cp["memory_mb"] for cp in profile["checkpoints"])
129 | if profile["checkpoints"]
130 | else end_memory,
131 | },
132 | "cpu_stats": {
133 | "start_percent": profile["start_cpu"],
134 | "end_percent": end_cpu,
135 | "avg_percent": sum(cp["cpu_percent"] for cp in profile["checkpoints"])
136 | / len(profile["checkpoints"])
137 | if profile["checkpoints"]
138 | else end_cpu,
139 | },
140 | "checkpoints": profile["checkpoints"],
141 | "checkpoint_count": len(profile["checkpoints"]),
142 | "context": profile["context"],
143 | "final_data": final_data,
144 | }
145 |
146 | # Log profile completion
147 | log_level = "info" if success else "error"
148 | getattr(self.logger, log_level)(
149 | f"Completed debug profile: {profile['profile_name']} in {total_duration_ms:.2f}ms",
150 | extra={
151 | "profile_results": results,
152 | "performance_summary": {
153 | "duration_ms": total_duration_ms,
154 | "memory_delta_mb": memory_delta,
155 | "checkpoint_count": len(profile["checkpoints"]),
156 | "success": success,
157 | },
158 | },
159 | )
160 |
161 | return results
162 |
163 | @staticmethod
164 | def _get_memory_usage() -> float:
165 | """Get current memory usage in MB."""
166 | try:
167 | process = psutil.Process()
168 | return process.memory_info().rss / 1024 / 1024
169 | except (psutil.NoSuchProcess, psutil.AccessDenied):
170 | return 0.0
171 |
172 | @staticmethod
173 | def _get_cpu_usage() -> float:
174 | """Get current CPU usage percentage."""
175 | try:
176 | process = psutil.Process()
177 | return process.cpu_percent(interval=None)
178 | except (psutil.NoSuchProcess, psutil.AccessDenied):
179 | return 0.0
180 |
181 |
182 | class RequestResponseLogger:
183 | """Detailed request/response logging for debugging."""
184 |
185 | def __init__(self, max_payload_size: int = 5000):
186 | self.logger = get_structured_logger("maverick_mcp.requests")
187 | self.max_payload_size = max_payload_size
188 |
189 | def log_request(self, operation: str, **request_data):
190 | """Log detailed request information."""
191 | correlation_id = CorrelationIDGenerator.get_correlation_id()
192 |
193 | # Sanitize and truncate request data
194 | sanitized_data = self._sanitize_data(request_data)
195 | truncated_data = self._truncate_data(sanitized_data)
196 |
197 | self.logger.info(
198 | f"Request: {operation}",
199 | extra={
200 | "operation": operation,
201 | "correlation_id": correlation_id,
202 | "request_data": truncated_data,
203 | "request_size": len(json.dumps(request_data, default=str)),
204 | "timestamp": datetime.now(UTC).isoformat(),
205 | },
206 | )
207 |
208 | def log_response(
209 | self, operation: str, success: bool, duration_ms: float, **response_data
210 | ):
211 | """Log detailed response information."""
212 | correlation_id = CorrelationIDGenerator.get_correlation_id()
213 |
214 | # Sanitize and truncate response data
215 | sanitized_data = self._sanitize_data(response_data)
216 | truncated_data = self._truncate_data(sanitized_data)
217 |
218 | log_method = self.logger.info if success else self.logger.error
219 |
220 | log_method(
221 | f"Response: {operation} ({'success' if success else 'failure'}) in {duration_ms:.2f}ms",
222 | extra={
223 | "operation": operation,
224 | "correlation_id": correlation_id,
225 | "success": success,
226 | "duration_ms": duration_ms,
227 | "response_data": truncated_data,
228 | "response_size": len(json.dumps(response_data, default=str)),
229 | "timestamp": datetime.now(UTC).isoformat(),
230 | },
231 | )
232 |
233 | def _sanitize_data(self, data: Any) -> Any:
234 | """Remove sensitive information from data."""
235 | if isinstance(data, dict):
236 | sanitized = {}
237 | for key, value in data.items():
238 | if any(
239 | sensitive in key.lower()
240 | for sensitive in ["password", "token", "key", "secret"]
241 | ):
242 | sanitized[key] = "***REDACTED***"
243 | else:
244 | sanitized[key] = self._sanitize_data(value)
245 | return sanitized
246 | elif isinstance(data, list | tuple):
247 | return [self._sanitize_data(item) for item in data]
248 | else:
249 | return data
250 |
251 | def _truncate_data(self, data: Any) -> Any:
252 | """Truncate data to prevent log overflow."""
253 | data_str = json.dumps(data, default=str)
254 | if len(data_str) > self.max_payload_size:
255 | truncated = data_str[: self.max_payload_size]
256 | return f"{truncated}... (truncated, original size: {len(data_str)})"
257 | return data
258 |
259 |
260 | class ErrorTracker:
261 | """Comprehensive error tracking and analysis."""
262 |
263 | def __init__(self):
264 | self.logger = get_structured_logger("maverick_mcp.errors")
265 | self._error_stats: dict[str, dict[str, Any]] = {}
266 |
267 | def track_error(
268 | self,
269 | error: Exception,
270 | operation: str,
271 | context: dict[str, Any],
272 | severity: str = "error",
273 | ):
274 | """Track error with detailed context and statistics."""
275 | error_type = type(error).__name__
276 | error_key = f"{operation}_{error_type}"
277 |
278 | # Update error statistics
279 | if error_key not in self._error_stats:
280 | self._error_stats[error_key] = {
281 | "first_seen": datetime.now(UTC),
282 | "last_seen": datetime.now(UTC),
283 | "count": 0,
284 | "operation": operation,
285 | "error_type": error_type,
286 | "contexts": [],
287 | }
288 |
289 | stats = self._error_stats[error_key]
290 | stats["last_seen"] = datetime.now(UTC)
291 | stats["count"] += 1
292 |
293 | # Keep only recent contexts (last 10)
294 | stats["contexts"].append(
295 | {
296 | "timestamp": datetime.now(UTC).isoformat(),
297 | "context": context,
298 | "error_message": str(error),
299 | }
300 | )
301 | stats["contexts"] = stats["contexts"][-10:] # Keep only last 10
302 |
303 | # Get stack trace
304 | stack_trace = traceback.format_exception(
305 | type(error), error, error.__traceback__
306 | )
307 |
308 | # Log the error
309 | correlation_id = CorrelationIDGenerator.get_correlation_id()
310 |
311 | log_data = {
312 | "operation": operation,
313 | "correlation_id": correlation_id,
314 | "error_type": error_type,
315 | "error_message": str(error),
316 | "error_count": stats["count"],
317 | "first_seen": stats["first_seen"].isoformat(),
318 | "last_seen": stats["last_seen"].isoformat(),
319 | "context": context,
320 | "stack_trace": stack_trace,
321 | "severity": severity,
322 | }
323 |
324 | if severity == "critical":
325 | self.logger.critical(
326 | f"Critical error in {operation}: {error}", extra=log_data
327 | )
328 | elif severity == "error":
329 | self.logger.error(f"Error in {operation}: {error}", extra=log_data)
330 | elif severity == "warning":
331 | self.logger.warning(f"Warning in {operation}: {error}", extra=log_data)
332 |
333 | def get_error_summary(self) -> dict[str, Any]:
334 | """Get comprehensive error statistics summary."""
335 | if not self._error_stats:
336 | return {"message": "No errors tracked"}
337 |
338 | summary = {
339 | "total_error_types": len(self._error_stats),
340 | "total_errors": sum(stats["count"] for stats in self._error_stats.values()),
341 | "error_breakdown": {},
342 | "most_common_errors": [],
343 | "recent_errors": [],
344 | }
345 |
346 | # Error breakdown by type
347 | for _error_key, stats in self._error_stats.items():
348 | summary["error_breakdown"][stats["error_type"]] = (
349 | summary["error_breakdown"].get(stats["error_type"], 0) + stats["count"]
350 | )
351 |
352 | # Most common errors
353 | sorted_errors = sorted(
354 | self._error_stats.items(), key=lambda x: x[1]["count"], reverse=True
355 | )
356 | summary["most_common_errors"] = [
357 | {
358 | "operation": stats["operation"],
359 | "error_type": stats["error_type"],
360 | "count": stats["count"],
361 | "first_seen": stats["first_seen"].isoformat(),
362 | "last_seen": stats["last_seen"].isoformat(),
363 | }
364 | for _, stats in sorted_errors[:10]
365 | ]
366 |
367 | # Recent errors
368 | all_contexts = []
369 | for stats in self._error_stats.values():
370 | for context in stats["contexts"]:
371 | all_contexts.append(
372 | {
373 | "operation": stats["operation"],
374 | "error_type": stats["error_type"],
375 | **context,
376 | }
377 | )
378 |
379 | summary["recent_errors"] = sorted(
380 | all_contexts, key=lambda x: x["timestamp"], reverse=True
381 | )[:20]
382 |
383 | return summary
384 |
385 |
386 | class DebugContextManager:
387 | """Context manager for debug sessions with automatic cleanup."""
388 |
389 | def __init__(
390 | self,
391 | operation_name: str,
392 | enable_profiling: bool = True,
393 | enable_request_logging: bool = True,
394 | enable_error_tracking: bool = True,
395 | **context,
396 | ):
397 | self.operation_name = operation_name
398 | self.enable_profiling = enable_profiling
399 | self.enable_request_logging = enable_request_logging
400 | self.enable_error_tracking = enable_error_tracking
401 | self.context = context
402 |
403 | # Initialize components
404 | self.profiler = DebugProfiler() if enable_profiling else None
405 | self.request_logger = (
406 | RequestResponseLogger() if enable_request_logging else None
407 | )
408 | self.error_tracker = ErrorTracker() if enable_error_tracking else None
409 |
410 | self.profile_id = None
411 | self.start_time = None
412 |
413 | def __enter__(self):
414 | """Enter debug context."""
415 | self.start_time = time.time()
416 |
417 | # Set correlation ID if not present
418 | if not CorrelationIDGenerator.get_correlation_id():
419 | CorrelationIDGenerator.set_correlation_id()
420 |
421 | # Start profiling
422 | if self.profiler:
423 | self.profile_id = self.profiler.start_profile(
424 | self.operation_name, **self.context
425 | )
426 |
427 | # Log request
428 | if self.request_logger:
429 | self.request_logger.log_request(self.operation_name, **self.context)
430 |
431 | return self
432 |
433 | def __exit__(self, exc_type, exc_val, exc_tb):
434 | """Exit debug context with cleanup."""
435 | duration_ms = (time.time() - self.start_time) * 1000 if self.start_time else 0
436 | success = exc_type is None
437 |
438 | # Track error if occurred
439 | if not success and self.error_tracker and exc_val:
440 | self.error_tracker.track_error(
441 | exc_val, self.operation_name, self.context, severity="error"
442 | )
443 |
444 | # End profiling
445 | if self.profiler and self.profile_id:
446 | self.profiler.end_profile(
447 | self.profile_id,
448 | success=success,
449 | exception_type=exc_type.__name__ if exc_type else None,
450 | )
451 |
452 | # Log response
453 | if self.request_logger:
454 | response_data = {"exception": str(exc_val)} if exc_val else {}
455 | self.request_logger.log_response(
456 | self.operation_name, success, duration_ms, **response_data
457 | )
458 |
459 | def checkpoint(self, name: str, **data):
460 | """Add a checkpoint during debug session."""
461 | if self.profiler and self.profile_id:
462 | self.profiler.checkpoint(self.profile_id, name, **data)
463 |
464 |
465 | # Decorator for automatic debug wrapping
466 | def debug_operation(
467 | operation_name: str | None = None,
468 | enable_profiling: bool = True,
469 | enable_request_logging: bool = True,
470 | enable_error_tracking: bool = True,
471 | **default_context,
472 | ):
473 | """Decorator to automatically wrap operations with debug context."""
474 |
475 | def decorator(func: Callable) -> Callable:
476 | actual_operation_name = operation_name or func.__name__
477 |
478 | @wraps(func)
479 | async def async_wrapper(*args, **kwargs):
480 | # Extract additional context from function signature
481 | sig = inspect.signature(func)
482 | bound_args = sig.bind(*args, **kwargs)
483 | bound_args.apply_defaults()
484 |
485 | context = {**default_context}
486 | # Add non-sensitive parameters to context
487 | for param_name, param_value in bound_args.arguments.items():
488 | if not any(
489 | sensitive in param_name.lower()
490 | for sensitive in ["password", "token", "key", "secret"]
491 | ):
492 | if (
493 | isinstance(param_value, str | int | float | bool)
494 | or param_value is None
495 | ):
496 | context[param_name] = param_value
497 |
498 | with DebugContextManager(
499 | actual_operation_name,
500 | enable_profiling,
501 | enable_request_logging,
502 | enable_error_tracking,
503 | **context,
504 | ) as debug_ctx:
505 | result = await func(*args, **kwargs)
506 | debug_ctx.checkpoint(
507 | "function_completed", result_type=type(result).__name__
508 | )
509 | return result
510 |
511 | @wraps(func)
512 | def sync_wrapper(*args, **kwargs):
513 | # Similar logic for sync functions
514 | sig = inspect.signature(func)
515 | bound_args = sig.bind(*args, **kwargs)
516 | bound_args.apply_defaults()
517 |
518 | context = {**default_context}
519 | for param_name, param_value in bound_args.arguments.items():
520 | if not any(
521 | sensitive in param_name.lower()
522 | for sensitive in ["password", "token", "key", "secret"]
523 | ):
524 | if (
525 | isinstance(param_value, str | int | float | bool)
526 | or param_value is None
527 | ):
528 | context[param_name] = param_value
529 |
530 | with DebugContextManager(
531 | actual_operation_name,
532 | enable_profiling,
533 | enable_request_logging,
534 | enable_error_tracking,
535 | **context,
536 | ) as debug_ctx:
537 | result = func(*args, **kwargs)
538 | debug_ctx.checkpoint(
539 | "function_completed", result_type=type(result).__name__
540 | )
541 | return result
542 |
543 | return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
544 |
545 | return decorator
546 |
547 |
548 | @contextmanager
549 | def debug_session(
550 | session_name: str, **context
551 | ) -> Generator[DebugContextManager, None, None]:
552 | """Context manager for manual debug sessions."""
553 | with DebugContextManager(session_name, **context) as debug_ctx:
554 | yield debug_ctx
555 |
556 |
557 | # Global debug utilities
558 | _debug_profiler = DebugProfiler()
559 | _error_tracker = ErrorTracker()
560 |
561 |
562 | def get_debug_profiler() -> DebugProfiler:
563 | """Get global debug profiler instance."""
564 | return _debug_profiler
565 |
566 |
567 | def get_error_tracker() -> ErrorTracker:
568 | """Get global error tracker instance."""
569 | return _error_tracker
570 |
571 |
572 | def print_debug_summary():
573 | """Print comprehensive debug summary to console."""
574 | print("\n" + "=" * 80)
575 | print("MAVERICK MCP DEBUG SUMMARY")
576 | print("=" * 80)
577 |
578 | # Performance metrics
579 | print("\n📊 PERFORMANCE METRICS")
580 | print("-" * 40)
581 | try:
582 | manager = get_logger_manager()
583 | dashboard_data = manager.create_dashboard_metrics()
584 |
585 | print(
586 | f"Log Level Counts: {dashboard_data.get('system_metrics', {}).get('log_level_counts', {})}"
587 | )
588 | print(
589 | f"Active Correlation IDs: {dashboard_data.get('system_metrics', {}).get('active_correlation_ids', 0)}"
590 | )
591 |
592 | if "memory_stats" in dashboard_data:
593 | memory_stats = dashboard_data["memory_stats"]
594 | print(
595 | f"Memory Usage: {memory_stats.get('rss_mb', 0):.1f}MB RSS, {memory_stats.get('cpu_percent', 0):.1f}% CPU"
596 | )
597 |
598 | except Exception as e:
599 | print(f"Error getting performance metrics: {e}")
600 |
601 | # Error summary
602 | print("\n🚨 ERROR SUMMARY")
603 | print("-" * 40)
604 | try:
605 | error_summary = _error_tracker.get_error_summary()
606 | if "message" in error_summary:
607 | print(error_summary["message"])
608 | else:
609 | print(f"Total Error Types: {error_summary['total_error_types']}")
610 | print(f"Total Errors: {error_summary['total_errors']}")
611 |
612 | if error_summary["most_common_errors"]:
613 | print("\nMost Common Errors:")
614 | for error in error_summary["most_common_errors"][:5]:
615 | print(
616 | f" {error['error_type']} in {error['operation']}: {error['count']} times"
617 | )
618 |
619 | except Exception as e:
620 | print(f"Error getting error summary: {e}")
621 |
622 | print("\n" + "=" * 80)
623 |
624 |
625 | def enable_debug_mode():
626 | """Enable comprehensive debug mode."""
627 | import os
628 |
629 | os.environ["MAVERICK_DEBUG"] = "true"
630 | print("🐛 Debug mode enabled")
631 | print(" - Verbose logging activated")
632 | print(" - Request/response logging enabled")
633 | print(" - Performance profiling enabled")
634 | print(" - Error tracking enhanced")
635 |
636 |
637 | def disable_debug_mode():
638 | """Disable debug mode."""
639 | import os
640 |
641 | if "MAVERICK_DEBUG" in os.environ:
642 | del os.environ["MAVERICK_DEBUG"]
643 | print("🐛 Debug mode disabled")
644 |
```
--------------------------------------------------------------------------------
/maverick_mcp/agents/market_analysis.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Market Analysis Agent using LangGraph best practices with professional features.
3 | """
4 |
5 | import hashlib
6 | import logging
7 | from datetime import datetime
8 | from typing import Any
9 |
10 | from langchain_core.messages import HumanMessage
11 | from langchain_core.tools import BaseTool
12 | from langgraph.checkpoint.memory import MemorySaver
13 | from langgraph.graph import END, START, StateGraph
14 |
15 | from maverick_mcp.agents.circuit_breaker import circuit_manager
16 | from maverick_mcp.config.settings import get_settings
17 | from maverick_mcp.exceptions import (
18 | AgentInitializationError,
19 | PersonaConfigurationError,
20 | ToolRegistrationError,
21 | )
22 | from maverick_mcp.langchain_tools import get_tool_registry
23 | from maverick_mcp.memory import ConversationStore
24 | from maverick_mcp.tools.risk_management import (
25 | PositionSizeTool,
26 | RiskMetricsTool,
27 | TechnicalStopsTool,
28 | )
29 | from maverick_mcp.tools.sentiment_analysis import (
30 | MarketBreadthTool,
31 | NewsSentimentTool,
32 | SectorSentimentTool,
33 | )
34 | from maverick_mcp.workflows.state import MarketAnalysisState
35 |
36 | from .base import PersonaAwareAgent
37 |
38 | logger = logging.getLogger(__name__)
39 | settings = get_settings()
40 |
41 |
42 | class MarketAnalysisAgent(PersonaAwareAgent):
43 | """
44 | Professional market analysis agent with advanced screening and risk assessment.
45 |
46 | Features:
47 | - Multi-strategy screening (momentum, mean reversion, breakout)
48 | - Sector rotation analysis
49 | - Market regime detection
50 | - Risk-adjusted recommendations
51 | - Real-time sentiment integration
52 | - Circuit breaker protection for API calls
53 | """
54 |
55 | VALID_PERSONAS = ["conservative", "moderate", "aggressive", "day_trader"]
56 |
57 | def __init__(
58 | self,
59 | llm,
60 | persona: str = "moderate",
61 | ttl_hours: int | None = None,
62 | ):
63 | """
64 | Initialize market analysis agent.
65 |
66 | Args:
67 | llm: Language model
68 | persona: Investor persona
69 | ttl_hours: Cache TTL in hours (uses config default if None)
70 | postgres_url: Optional PostgreSQL URL for checkpointing
71 |
72 | Raises:
73 | PersonaConfigurationError: If persona is invalid
74 | AgentInitializationError: If initialization fails
75 | """
76 | try:
77 | # Validate persona
78 | if persona.lower() not in self.VALID_PERSONAS:
79 | raise PersonaConfigurationError(
80 | persona=persona, valid_personas=self.VALID_PERSONAS
81 | )
82 |
83 | # Store persona temporarily for tool configuration
84 | self._temp_persona = persona.lower()
85 |
86 | # Get comprehensive tool set
87 | tools = self._get_comprehensive_tools()
88 |
89 | if not tools:
90 | raise AgentInitializationError(
91 | agent_type="MarketAnalysisAgent",
92 | reason="No tools available for initialization",
93 | )
94 |
95 | # Use default TTL from config if not provided
96 | if ttl_hours is None:
97 | ttl_hours = settings.agent.conversation_cache_ttl_hours
98 |
99 | # Initialize with MemorySaver
100 | super().__init__(
101 | llm=llm,
102 | tools=tools,
103 | persona=persona,
104 | checkpointer=MemorySaver(),
105 | ttl_hours=ttl_hours,
106 | )
107 |
108 | except (PersonaConfigurationError, AgentInitializationError):
109 | raise
110 | except Exception as e:
111 | logger.error(f"Failed to initialize MarketAnalysisAgent: {str(e)}")
112 | error = AgentInitializationError(
113 | agent_type="MarketAnalysisAgent",
114 | reason=str(e),
115 | )
116 | error.context["original_error"] = type(e).__name__
117 | raise error
118 |
119 | # Initialize conversation store
120 | self.conversation_store = ConversationStore(ttl_hours=ttl_hours)
121 |
122 | # Circuit breakers for external APIs
123 | self.circuit_breakers = {
124 | "screening": None,
125 | "sentiment": None,
126 | "market_data": None,
127 | }
128 |
129 | def _get_comprehensive_tools(self) -> list[BaseTool]:
130 | """Get comprehensive set of market analysis tools.
131 |
132 | Returns:
133 | List of configured tools
134 |
135 | Raises:
136 | ToolRegistrationError: If critical tools cannot be loaded
137 | """
138 | try:
139 | registry = get_tool_registry()
140 | except Exception as e:
141 | logger.error(f"Failed to get tool registry: {str(e)}")
142 | raise ToolRegistrationError(tool_name="registry", reason=str(e))
143 |
144 | # Core screening tools
145 | screening_tools = [
146 | registry.get_tool("get_maverick_stocks"),
147 | registry.get_tool("get_maverick_bear_stocks"),
148 | registry.get_tool("get_supply_demand_breakouts"),
149 | registry.get_tool("get_all_screening_recommendations"),
150 | ]
151 |
152 | # Technical analysis tools
153 | technical_tools = [
154 | registry.get_tool("get_technical_indicators"),
155 | registry.get_tool("calculate_support_resistance"),
156 | registry.get_tool("detect_chart_patterns"),
157 | ]
158 |
159 | # Market data tools
160 | market_tools = [
161 | registry.get_tool("get_market_movers"),
162 | registry.get_tool("get_sector_performance"),
163 | registry.get_tool("get_market_indices"),
164 | ]
165 |
166 | # Risk management tools (persona-aware)
167 | risk_tools = [
168 | PositionSizeTool(),
169 | TechnicalStopsTool(),
170 | RiskMetricsTool(),
171 | ]
172 |
173 | # Sentiment analysis tools
174 | sentiment_tools = [
175 | NewsSentimentTool(),
176 | MarketBreadthTool(),
177 | SectorSentimentTool(),
178 | ]
179 |
180 | # Combine all tools and filter None
181 | all_tools = (
182 | screening_tools
183 | + technical_tools
184 | + market_tools
185 | + risk_tools
186 | + sentiment_tools
187 | )
188 | tools = [t for t in all_tools if t is not None]
189 |
190 | # Configure persona for PersonaAwareTools
191 | for tool in tools:
192 | try:
193 | if hasattr(tool, "set_persona"):
194 | tool.set_persona(self._temp_persona)
195 | except Exception as e:
196 | logger.warning(
197 | f"Failed to set persona for tool {tool.__class__.__name__}: {str(e)}"
198 | )
199 | # Continue with other tools
200 |
201 | if not tools:
202 | logger.warning("No tools available for market analysis")
203 | return []
204 |
205 | logger.info(f"Loaded {len(tools)} tools for {self._temp_persona} persona")
206 | return tools
207 |
208 | def get_state_schema(self) -> type:
209 | """Return enhanced state schema for market analysis."""
210 | return MarketAnalysisState
211 |
212 | def _build_system_prompt(self) -> str:
213 | """Build comprehensive system prompt for professional market analysis."""
214 | base_prompt = super()._build_system_prompt()
215 |
216 | market_prompt = f"""
217 |
218 | You are a professional market analyst specializing in systematic screening and analysis.
219 | Current market date: {datetime.now().strftime("%Y-%m-%d")}
220 |
221 | ## Core Responsibilities:
222 |
223 | 1. **Multi-Strategy Screening**:
224 | - Momentum: High RS stocks breaking out on volume
225 | - Mean Reversion: Oversold quality stocks at support
226 | - Breakout: Stocks clearing resistance with volume surge
227 | - Trend Following: Stocks in established uptrends
228 |
229 | 2. **Market Regime Analysis**:
230 | - Identify current market regime (bull/bear/sideways)
231 | - Analyze sector rotation patterns
232 | - Monitor breadth indicators and sentiment
233 | - Detect risk-on vs risk-off environments
234 |
235 | 3. **Risk-Adjusted Selection**:
236 | - Filter stocks by persona risk tolerance
237 | - Calculate position sizes using Kelly Criterion
238 | - Set appropriate stop losses using ATR
239 | - Consider correlation and portfolio heat
240 |
241 | 4. **Professional Reporting**:
242 | - Provide actionable entry/exit levels
243 | - Include risk/reward ratios
244 | - Highlight key catalysts and risks
245 | - Suggest portfolio allocation
246 |
247 | ## Screening Criteria by Persona:
248 |
249 | **Conservative ({self.persona.name if self.persona.name == "Conservative" else "N/A"})**:
250 | - Large-cap stocks (>$10B market cap)
251 | - Dividend yield > 2%
252 | - Low debt/equity < 1.5
253 | - Beta < 1.2
254 | - Established uptrends only
255 |
256 | **Moderate ({self.persona.name if self.persona.name == "Moderate" else "N/A"})**:
257 | - Mid to large-cap stocks (>$2B)
258 | - Balanced growth/value metrics
259 | - Moderate volatility accepted
260 | - Mix of dividend and growth stocks
261 |
262 | **Aggressive ({self.persona.name if self.persona.name == "Aggressive" else "N/A"})**:
263 | - All market caps considered
264 | - High growth rates prioritized
265 | - Momentum and relative strength focus
266 | - Higher volatility tolerated
267 |
268 | **Day Trader ({self.persona.name if self.persona.name == "Day Trader" else "N/A"})**:
269 | - High liquidity (>1M avg volume)
270 | - Tight spreads (<0.1%)
271 | - High ATR for movement
272 | - Technical patterns emphasized
273 |
274 | ## Analysis Framework:
275 |
276 | 1. Start with market regime assessment
277 | 2. Identify leading/lagging sectors
278 | 3. Screen for stocks matching criteria
279 | 4. Apply technical analysis filters
280 | 5. Calculate risk metrics
281 | 6. Generate recommendations with specific levels
282 |
283 | Remember to:
284 | - Cite specific data points
285 | - Explain your reasoning
286 | - Highlight risks clearly
287 | - Provide actionable insights
288 | - Consider time horizon
289 | """
290 |
291 | return base_prompt + market_prompt
292 |
293 | def _build_graph(self):
294 | """Build enhanced graph with multiple analysis nodes."""
295 | workflow = StateGraph(MarketAnalysisState)
296 |
297 | # Add specialized nodes with unique names
298 | workflow.add_node("analyze_market_regime", self._analyze_market_regime)
299 | workflow.add_node("analyze_sectors", self._analyze_sectors)
300 | workflow.add_node("run_screening", self._run_screening)
301 | workflow.add_node("assess_risks", self._assess_risks)
302 | workflow.add_node("agent", self._agent_node)
303 |
304 | # Create tool node if tools available
305 | if self.tools:
306 | from langgraph.prebuilt import ToolNode
307 |
308 | tool_node = ToolNode(self.tools)
309 | workflow.add_node("tools", tool_node)
310 |
311 | # Define flow
312 | workflow.add_edge(START, "analyze_market_regime")
313 | workflow.add_edge("analyze_market_regime", "analyze_sectors")
314 | workflow.add_edge("analyze_sectors", "run_screening")
315 | workflow.add_edge("run_screening", "assess_risks")
316 | workflow.add_edge("assess_risks", "agent")
317 |
318 | if self.tools:
319 | workflow.add_conditional_edges(
320 | "agent",
321 | self._should_continue,
322 | {
323 | "continue": "tools",
324 | "end": END,
325 | },
326 | )
327 | workflow.add_edge("tools", "agent")
328 | else:
329 | workflow.add_edge("agent", END)
330 |
331 | return workflow.compile(checkpointer=self.checkpointer)
332 |
333 | async def _analyze_market_regime(
334 | self, state: MarketAnalysisState
335 | ) -> dict[str, Any]:
336 | """Analyze current market regime."""
337 | try:
338 | # Use market breadth tool
339 | breadth_tool = next(
340 | (t for t in self.tools if t.name == "analyze_market_breadth"), None
341 | )
342 |
343 | if breadth_tool:
344 | circuit_breaker = await circuit_manager.get_or_create("market_data")
345 |
346 | async def get_breadth():
347 | return await breadth_tool.ainvoke({"index": "SPY"})
348 |
349 | breadth_data = await circuit_breaker.call(get_breadth)
350 |
351 | # Parse results to determine regime
352 | # Handle both string and dict responses
353 | if isinstance(breadth_data, str):
354 | # Try to extract sentiment from string response
355 | if "Bullish" in breadth_data:
356 | state["market_regime"] = "bullish"
357 | elif "Bearish" in breadth_data:
358 | state["market_regime"] = "bearish"
359 | else:
360 | state["market_regime"] = "neutral"
361 | elif (
362 | isinstance(breadth_data, dict) and "market_breadth" in breadth_data
363 | ):
364 | sentiment = breadth_data["market_breadth"].get(
365 | "sentiment", "Neutral"
366 | )
367 | state["market_regime"] = sentiment.lower()
368 | else:
369 | state["market_regime"] = "neutral"
370 | else:
371 | state["market_regime"] = "neutral"
372 |
373 | except Exception as e:
374 | logger.error(f"Error analyzing market regime: {e}")
375 | state["market_regime"] = "unknown"
376 |
377 | state["api_calls_made"] = state.get("api_calls_made", 0) + 1
378 | return {"market_regime": state.get("market_regime", "neutral")}
379 |
380 | async def _analyze_sectors(self, state: MarketAnalysisState) -> dict[str, Any]:
381 | """Analyze sector rotation patterns."""
382 | try:
383 | # Use sector sentiment tool
384 | sector_tool = next(
385 | (t for t in self.tools if t.name == "analyze_sector_sentiment"), None
386 | )
387 |
388 | if sector_tool:
389 | circuit_breaker = await circuit_manager.get_or_create("market_data")
390 |
391 | async def get_sectors():
392 | return await sector_tool.ainvoke({})
393 |
394 | sector_data = await circuit_breaker.call(get_sectors)
395 |
396 | if "sector_rotation" in sector_data:
397 | state["sector_rotation"] = sector_data["sector_rotation"]
398 |
399 | # Extract leading sectors
400 | leading = sector_data["sector_rotation"].get("leading_sectors", [])
401 | if leading and state.get("sector_filter"):
402 | # Prioritize screening in leading sectors
403 | state["sector_filter"] = leading[0].get("name", "")
404 |
405 | except Exception as e:
406 | logger.error(f"Error analyzing sectors: {e}")
407 |
408 | state["api_calls_made"] = state.get("api_calls_made", 0) + 1
409 | return {"sector_rotation": state.get("sector_rotation", {})}
410 |
411 | async def _run_screening(self, state: MarketAnalysisState) -> dict[str, Any]:
412 | """Run multi-strategy screening."""
413 | try:
414 | # Determine which screening strategy based on market regime
415 | strategy = state.get("screening_strategy", "momentum")
416 |
417 | # Adjust strategy based on regime
418 | if state.get("market_regime") == "bearish" and strategy == "momentum":
419 | strategy = "mean_reversion"
420 |
421 | # Get appropriate screening tool
422 | tool_map = {
423 | "momentum": "get_maverick_stocks",
424 | "supply_demand_breakout": "get_supply_demand_breakouts",
425 | "bearish": "get_maverick_bear_stocks",
426 | }
427 |
428 | tool_name = tool_map.get(strategy, "get_maverick_stocks")
429 | screening_tool = next((t for t in self.tools if t.name == tool_name), None)
430 |
431 | if screening_tool:
432 | circuit_breaker = await circuit_manager.get_or_create("screening")
433 |
434 | async def run_screen():
435 | return await screening_tool.ainvoke(
436 | {"limit": state.get("max_results", 20)}
437 | )
438 |
439 | results = await circuit_breaker.call(run_screen)
440 |
441 | if "stocks" in results:
442 | symbols = [s.get("symbol") for s in results["stocks"]]
443 | scores = {
444 | s.get("symbol"): s.get("combined_score", 0)
445 | for s in results["stocks"]
446 | }
447 |
448 | state["screened_symbols"] = symbols
449 | state["screening_scores"] = scores
450 | state["cache_hits"] += 1
451 |
452 | except Exception as e:
453 | logger.error(f"Error running screening: {e}")
454 | state["cache_misses"] += 1
455 |
456 | state["api_calls_made"] = state.get("api_calls_made", 0) + 1
457 | return {
458 | "screened_symbols": state.get("screened_symbols", []),
459 | "screening_scores": state.get("screening_scores", {}),
460 | }
461 |
462 | async def _assess_risks(self, state: MarketAnalysisState) -> dict[str, Any]:
463 | """Assess risks for screened symbols."""
464 | symbols = state.get("screened_symbols", [])[:5] # Top 5 only
465 |
466 | if not symbols:
467 | return {}
468 |
469 | try:
470 | # Get risk metrics tool
471 | risk_tool = next(
472 | (t for t in self.tools if isinstance(t, RiskMetricsTool)), None
473 | )
474 |
475 | if risk_tool and len(symbols) > 1:
476 | # Calculate portfolio risk metrics
477 | risk_data = await risk_tool.ainvoke(
478 | {"symbols": symbols, "lookback_days": 252}
479 | )
480 |
481 | # Store risk assessment
482 | state["conversation_context"]["risk_metrics"] = risk_data
483 |
484 | except Exception as e:
485 | logger.error(f"Error assessing risks: {e}")
486 |
487 | return {}
488 |
489 | async def analyze_market(
490 | self,
491 | query: str,
492 | session_id: str,
493 | screening_strategy: str = "momentum",
494 | max_results: int = 20,
495 | **kwargs,
496 | ) -> dict[str, Any]:
497 | """
498 | Analyze market with specific screening parameters.
499 |
500 | Enhanced with caching, circuit breakers, and comprehensive analysis.
501 | """
502 | start_time = datetime.now()
503 |
504 | # Check cache first
505 | cached = self._check_enhanced_cache(query, session_id, screening_strategy)
506 | if cached:
507 | return cached
508 |
509 | # Prepare initial state
510 | initial_state = {
511 | "messages": [HumanMessage(content=query)],
512 | "persona": self.persona.name,
513 | "session_id": session_id,
514 | "screening_strategy": screening_strategy,
515 | "max_results": max_results,
516 | "timestamp": datetime.now(),
517 | "token_count": 0,
518 | "error": None,
519 | "analyzed_stocks": {},
520 | "key_price_levels": {},
521 | "last_analysis_time": {},
522 | "conversation_context": {},
523 | "execution_time_ms": None,
524 | "api_calls_made": 0,
525 | "cache_hits": 0,
526 | "cache_misses": 0,
527 | }
528 |
529 | # Update with any additional parameters
530 | initial_state.update(kwargs)
531 |
532 | # Run the analysis
533 | result = await self.ainvoke(query, session_id, initial_state=initial_state)
534 |
535 | # Calculate execution time
536 | execution_time = (datetime.now() - start_time).total_seconds() * 1000
537 | result["execution_time_ms"] = execution_time
538 |
539 | # Extract and cache results
540 | analysis_results = self._extract_enhanced_results(result)
541 |
542 | # Create same cache key as used in _check_enhanced_cache
543 | query_hash = hashlib.md5(query.lower().encode()).hexdigest()[:8]
544 | cache_key = f"{screening_strategy}_{query_hash}"
545 |
546 | self.conversation_store.save_analysis(
547 | session_id=session_id,
548 | symbol=cache_key,
549 | analysis_type=f"{screening_strategy}_analysis",
550 | data=analysis_results,
551 | )
552 |
553 | return analysis_results
554 |
555 | def _check_enhanced_cache(
556 | self, query: str, session_id: str, strategy: str
557 | ) -> dict[str, Any] | None:
558 | """Check for cached analysis with strategy awareness."""
559 | # Create a hash of the query to use as cache key
560 | query_hash = hashlib.md5(query.lower().encode()).hexdigest()[:8]
561 | cache_key = f"{strategy}_{query_hash}"
562 |
563 | cached = self.conversation_store.get_analysis(
564 | session_id=session_id,
565 | symbol=cache_key,
566 | analysis_type=f"{strategy}_analysis",
567 | )
568 |
569 | if cached and cached.get("data"):
570 | # Check cache age based on strategy
571 | timestamp = datetime.fromisoformat(cached["timestamp"])
572 | age_minutes = (datetime.now() - timestamp).total_seconds() / 60
573 |
574 | # Different cache durations for different strategies
575 | cache_durations = {
576 | "momentum": 15, # 15 minutes for fast-moving
577 | "trending": 60, # 1 hour for trend following
578 | "mean_reversion": 30, # 30 minutes
579 | }
580 |
581 | max_age = cache_durations.get(strategy, 30)
582 |
583 | if age_minutes < max_age:
584 | logger.info(f"Using cached {strategy} analysis")
585 | return cached["data"] # type: ignore
586 |
587 | return None
588 |
589 | def _extract_enhanced_results(self, result: dict[str, Any]) -> dict[str, Any]:
590 | """Extract comprehensive results from agent output."""
591 | state = result.get("state", {})
592 |
593 | # Get final message content
594 | messages = result.get("messages", [])
595 | content = messages[-1].content if messages else ""
596 |
597 | return {
598 | "status": "success",
599 | "timestamp": datetime.now().isoformat(),
600 | "query_type": "professional_market_analysis",
601 | "execution_metrics": {
602 | "execution_time_ms": result.get("execution_time_ms", 0),
603 | "api_calls": state.get("api_calls_made", 0),
604 | "cache_hits": state.get("cache_hits", 0),
605 | "cache_misses": state.get("cache_misses", 0),
606 | },
607 | "market_analysis": {
608 | "regime": state.get("market_regime", "unknown"),
609 | "sector_rotation": state.get("sector_rotation", {}),
610 | "breadth": state.get("market_breadth", {}),
611 | "sentiment": state.get("sentiment_indicators", {}),
612 | },
613 | "screening_results": {
614 | "strategy": state.get("screening_strategy", "momentum"),
615 | "symbols": state.get("screened_symbols", [])[:20],
616 | "scores": state.get("screening_scores", {}),
617 | "count": len(state.get("screened_symbols", [])),
618 | "metadata": state.get("symbol_metadata", {}),
619 | },
620 | "risk_assessment": state.get("conversation_context", {}).get(
621 | "risk_metrics", {}
622 | ),
623 | "recommendations": {
624 | "summary": content,
625 | "persona_adjusted": True,
626 | "risk_level": self.persona.name,
627 | "position_sizing": f"Max {self.persona.position_size_max * 100:.1f}% per position",
628 | },
629 | }
630 |
```
--------------------------------------------------------------------------------
/maverick_mcp/tests/test_stock_data_enhanced.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive tests for the enhanced stock data provider with SQLAlchemy integration.
3 | """
4 |
5 | from datetime import datetime, timedelta
6 | from decimal import Decimal
7 | from unittest.mock import MagicMock, patch
8 |
9 | import pandas as pd
10 | import pytest
11 | from sqlalchemy import create_engine
12 | from sqlalchemy.orm import sessionmaker
13 |
14 | from maverick_mcp.data.models import (
15 | Base,
16 | MaverickBearStocks,
17 | MaverickStocks,
18 | PriceCache,
19 | Stock,
20 | SupplyDemandBreakoutStocks,
21 | )
22 | from maverick_mcp.providers.stock_data import EnhancedStockDataProvider
23 |
24 |
25 | @pytest.fixture(scope="module")
26 | def test_db():
27 | """Create a test database for the tests."""
28 | # Use in-memory SQLite for tests
29 | engine = create_engine("sqlite:///:memory:")
30 | Base.metadata.create_all(engine)
31 |
32 | yield engine
33 |
34 | engine.dispose()
35 |
36 |
37 | @pytest.fixture
38 | def db_session(test_db):
39 | """Create a database session for each test."""
40 | # Clear all data before each test
41 | Base.metadata.drop_all(bind=test_db)
42 | Base.metadata.create_all(bind=test_db)
43 |
44 | SessionLocal = sessionmaker(bind=test_db)
45 | session = SessionLocal()
46 |
47 | yield session
48 |
49 | session.rollback()
50 | session.close()
51 |
52 |
53 | @pytest.fixture
54 | def provider():
55 | """Create an instance of the enhanced provider."""
56 | return EnhancedStockDataProvider()
57 |
58 |
59 | @pytest.fixture
60 | def sample_stock(db_session):
61 | """Create a sample stock in the database."""
62 | stock = Stock(
63 | ticker_symbol="AAPL",
64 | company_name="Apple Inc.",
65 | sector="Technology",
66 | industry="Consumer Electronics",
67 | exchange="NASDAQ",
68 | currency="USD",
69 | )
70 | db_session.add(stock)
71 | db_session.commit()
72 | return stock
73 |
74 |
75 | @pytest.fixture
76 | def sample_price_data(db_session, sample_stock):
77 | """Create sample price data in the database."""
78 | prices = []
79 | base_date = datetime(2024, 1, 1).date()
80 |
81 | for i in range(5):
82 | price = PriceCache(
83 | stock_id=sample_stock.stock_id,
84 | date=base_date + timedelta(days=i),
85 | open_price=Decimal(f"{150 + i}.00"),
86 | high_price=Decimal(f"{155 + i}.00"),
87 | low_price=Decimal(f"{149 + i}.00"),
88 | close_price=Decimal(f"{152 + i}.00"),
89 | volume=1000000 + i * 10000,
90 | )
91 | prices.append(price)
92 |
93 | db_session.add_all(prices)
94 | db_session.commit()
95 | return prices
96 |
97 |
98 | @pytest.fixture
99 | def mock_yfinance():
100 | """Mock yfinance responses."""
101 | with patch("maverick_mcp.providers.stock_data.yf") as mock_yf:
102 | # Mock ticker
103 | mock_ticker = MagicMock()
104 | mock_yf.Ticker.return_value = mock_ticker
105 |
106 | # Mock history data
107 | dates = pd.date_range("2024-01-01", periods=5, freq="D")
108 | mock_df = pd.DataFrame(
109 | {
110 | "Open": [150.0, 151.0, 152.0, 153.0, 154.0],
111 | "High": [155.0, 156.0, 157.0, 158.0, 159.0],
112 | "Low": [149.0, 150.0, 151.0, 152.0, 153.0],
113 | "Close": [152.0, 153.0, 154.0, 155.0, 156.0],
114 | "Volume": [1000000, 1010000, 1020000, 1030000, 1040000],
115 | },
116 | index=dates,
117 | )
118 | mock_ticker.history.return_value = mock_df
119 |
120 | # Mock info
121 | mock_ticker.info = {
122 | "longName": "Apple Inc.",
123 | "sector": "Technology",
124 | "industry": "Consumer Electronics",
125 | "exchange": "NASDAQ",
126 | "currency": "USD",
127 | "country": "United States",
128 | "previousClose": 151.0,
129 | "quoteType": "EQUITY",
130 | }
131 |
132 | # Mock other attributes
133 | mock_ticker.news = []
134 | mock_ticker.earnings = pd.DataFrame()
135 | mock_ticker.earnings_dates = pd.DataFrame()
136 | mock_ticker.earnings_trend = {}
137 | mock_ticker.recommendations = pd.DataFrame()
138 |
139 | yield mock_yf
140 |
141 |
142 | class TestEnhancedStockDataProvider:
143 | """Test the enhanced stock data provider."""
144 |
145 | def test_singleton_pattern(self):
146 | """Test that provider follows singleton pattern."""
147 | provider1 = EnhancedStockDataProvider()
148 | provider2 = EnhancedStockDataProvider()
149 | assert provider1 is provider2
150 |
151 | def test_get_db_session(self, provider, monkeypatch):
152 | """Test database session retrieval."""
153 | mock_session = MagicMock()
154 | mock_get_db = MagicMock(return_value=iter([mock_session]))
155 |
156 | monkeypatch.setattr("maverick_mcp.providers.stock_data.get_db", mock_get_db)
157 |
158 | session = provider._get_db_session()
159 | assert session == mock_session
160 |
161 | def test_get_or_create_stock_existing(self, provider, db_session, sample_stock):
162 | """Test getting an existing stock."""
163 | stock = provider._get_or_create_stock(db_session, "AAPL")
164 | assert stock.stock_id == sample_stock.stock_id
165 | assert stock.ticker_symbol == "AAPL"
166 |
167 | def test_get_or_create_stock_new(self, provider, db_session, mock_yfinance):
168 | """Test creating a new stock."""
169 | stock = provider._get_or_create_stock(db_session, "GOOGL")
170 | assert stock.ticker_symbol == "GOOGL"
171 | assert stock.company_name == "Apple Inc." # From mock
172 | assert stock.sector == "Technology"
173 |
174 | # Verify it was saved
175 | found = db_session.query(Stock).filter_by(ticker_symbol="GOOGL").first()
176 | assert found is not None
177 |
178 | def test_get_cached_price_data(
179 | self, provider, db_session, sample_stock, sample_price_data, monkeypatch
180 | ):
181 | """Test retrieving cached price data."""
182 |
183 | # Mock the get_db function to return our test session
184 | def mock_get_db():
185 | yield db_session
186 |
187 | monkeypatch.setattr("maverick_mcp.providers.stock_data.get_db", mock_get_db)
188 |
189 | df = provider._get_cached_price_data(
190 | db_session, "AAPL", "2024-01-01", "2024-01-05"
191 | )
192 |
193 | assert not df.empty
194 | assert len(df) == 5
195 | assert df.index[0] == pd.Timestamp("2024-01-01")
196 | assert df["Close"].iloc[0] == 152.0
197 |
198 | def test_get_cached_price_data_partial_range(
199 | self, provider, db_session, sample_stock, sample_price_data
200 | ):
201 | """Test retrieving cached data for partial range."""
202 | df = provider._get_cached_price_data(
203 | db_session, "AAPL", "2024-01-02", "2024-01-04"
204 | )
205 |
206 | assert not df.empty
207 | assert len(df) == 3
208 | assert df.index[0] == pd.Timestamp("2024-01-02")
209 | assert df.index[-1] == pd.Timestamp("2024-01-04")
210 |
211 | def test_get_cached_price_data_no_data(self, provider, db_session):
212 | """Test retrieving cached data when none exists."""
213 | df = provider._get_cached_price_data(
214 | db_session, "TSLA", "2024-01-01", "2024-01-05"
215 | )
216 |
217 | assert df is None
218 |
219 | def test_cache_price_data(self, provider, db_session, sample_stock):
220 | """Test caching price data."""
221 | # Create test DataFrame
222 | dates = pd.date_range("2024-02-01", periods=3, freq="D")
223 | df = pd.DataFrame(
224 | {
225 | "Open": [160.0, 161.0, 162.0],
226 | "High": [165.0, 166.0, 167.0],
227 | "Low": [159.0, 160.0, 161.0],
228 | "Close": [162.0, 163.0, 164.0],
229 | "Volume": [2000000, 2100000, 2200000],
230 | },
231 | index=dates,
232 | )
233 |
234 | provider._cache_price_data(db_session, "AAPL", df)
235 |
236 | # Verify data was cached
237 | prices = (
238 | db_session.query(PriceCache)
239 | .filter(
240 | PriceCache.stock_id == sample_stock.stock_id,
241 | PriceCache.date >= dates[0].date(),
242 | )
243 | .all()
244 | )
245 |
246 | assert len(prices) == 3
247 | assert prices[0].close_price == Decimal("162.00")
248 |
249 | def test_get_stock_data_with_cache(
250 | self, provider, db_session, sample_stock, sample_price_data, monkeypatch
251 | ):
252 | """Test getting stock data with cache hit."""
253 |
254 | # Mock the get_db function to return our test session
255 | def mock_get_db():
256 | yield db_session
257 |
258 | monkeypatch.setattr("maverick_mcp.providers.stock_data.get_db", mock_get_db)
259 |
260 | df = provider.get_stock_data("AAPL", "2024-01-01", "2024-01-05", use_cache=True)
261 |
262 | assert not df.empty
263 | assert len(df) == 5
264 | assert df["Close"].iloc[0] == 152.0
265 |
266 | def test_get_stock_data_without_cache(self, provider, mock_yfinance):
267 | """Test getting stock data without cache."""
268 | df = provider.get_stock_data(
269 | "AAPL", "2024-01-01", "2024-01-05", use_cache=False
270 | )
271 |
272 | assert not df.empty
273 | assert len(df) == 5
274 | assert df["Close"].iloc[0] == 152.0
275 |
276 | def test_get_stock_data_cache_miss(
277 | self, provider, db_session, mock_yfinance, monkeypatch
278 | ):
279 | """Test getting stock data with cache miss."""
280 | # Mock the session getter
281 | monkeypatch.setattr(provider, "_get_db_session", lambda: db_session)
282 |
283 | df = provider.get_stock_data("TSLA", "2024-01-01", "2024-01-05", use_cache=True)
284 |
285 | assert not df.empty
286 | assert len(df) == 5
287 | # Data should come from mock yfinance
288 | assert df["Close"].iloc[0] == 152.0
289 |
290 | def test_get_stock_data_non_daily_interval(self, provider, mock_yfinance):
291 | """Test that non-daily intervals bypass cache."""
292 | df = provider.get_stock_data("AAPL", interval="1wk", period="1mo")
293 |
294 | assert not df.empty
295 | # Should call yfinance directly
296 | mock_yfinance.Ticker.return_value.history.assert_called_with(
297 | period="1mo", interval="1wk"
298 | )
299 |
300 |
301 | class TestMaverickRecommendations:
302 | """Test maverick screening recommendation methods."""
303 |
304 | @pytest.fixture
305 | def sample_maverick_stocks(self, db_session):
306 | """Create sample maverick stocks."""
307 | stocks = []
308 | for i in range(3):
309 | stock = MaverickStocks(
310 | id=i + 1, # Add explicit ID for SQLite
311 | stock=f"STOCK{i}",
312 | close=100.0 + i * 10,
313 | volume=1000000,
314 | momentum_score=95.0 - i * 5,
315 | adr_pct=3.0 + i * 0.5,
316 | pat="Cup&Handle" if i == 0 else "Base",
317 | sqz="active" if i < 2 else "neutral",
318 | consolidation="yes" if i == 0 else "no",
319 | entry=f"{102.0 + i * 10}",
320 | combined_score=95 - i * 5,
321 | compression_score=90 - i * 3,
322 | pattern_detected=1,
323 | )
324 | stocks.append(stock)
325 |
326 | db_session.add_all(stocks)
327 | db_session.commit()
328 | return stocks
329 |
330 | def test_get_maverick_recommendations(
331 | self, provider, db_session, sample_maverick_stocks, monkeypatch
332 | ):
333 | """Test getting maverick recommendations."""
334 | monkeypatch.setattr(provider, "_get_db_session", lambda: db_session)
335 |
336 | recommendations = provider.get_maverick_recommendations(limit=2)
337 |
338 | assert len(recommendations) == 2
339 | assert recommendations[0]["stock"] == "STOCK0"
340 | assert recommendations[0]["combined_score"] == 95
341 | assert recommendations[0]["recommendation_type"] == "maverick_bullish"
342 | assert "reason" in recommendations[0]
343 | assert "Exceptional combined score" in recommendations[0]["reason"]
344 |
345 | def test_get_maverick_recommendations_with_min_score(
346 | self, provider, db_session, sample_maverick_stocks, monkeypatch
347 | ):
348 | """Test getting maverick recommendations with minimum score filter."""
349 | monkeypatch.setattr(provider, "_get_db_session", lambda: db_session)
350 |
351 | recommendations = provider.get_maverick_recommendations(limit=10, min_score=90)
352 |
353 | assert len(recommendations) == 2 # Only STOCK0 and STOCK1 have score >= 90
354 | assert all(rec["combined_score"] >= 90 for rec in recommendations)
355 |
356 | @pytest.fixture
357 | def sample_bear_stocks(self, db_session):
358 | """Create sample bear stocks."""
359 | stocks = []
360 | for i in range(3):
361 | stock = MaverickBearStocks(
362 | id=i + 1, # Add explicit ID for SQLite
363 | stock=f"BEAR{i}",
364 | close=50.0 - i * 5,
365 | volume=500000,
366 | momentum_score=30.0 - i * 5,
367 | rsi_14=28.0 - i * 3,
368 | macd=-0.5 - i * 0.1,
369 | adr_pct=4.0 + i * 0.5,
370 | atr_contraction=i < 2,
371 | big_down_vol=i == 0,
372 | score=90 - i * 5,
373 | sqz="red" if i < 2 else "neutral",
374 | )
375 | stocks.append(stock)
376 |
377 | db_session.add_all(stocks)
378 | db_session.commit()
379 | return stocks
380 |
381 | def test_get_maverick_bear_recommendations(
382 | self, provider, db_session, sample_bear_stocks, monkeypatch
383 | ):
384 | """Test getting bear recommendations."""
385 | monkeypatch.setattr(provider, "_get_db_session", lambda: db_session)
386 |
387 | recommendations = provider.get_maverick_bear_recommendations(limit=2)
388 |
389 | assert len(recommendations) == 2
390 | assert recommendations[0]["stock"] == "BEAR0"
391 | assert recommendations[0]["score"] == 90
392 | assert recommendations[0]["recommendation_type"] == "maverick_bearish"
393 | assert "reason" in recommendations[0]
394 | assert "Exceptional bear score" in recommendations[0]["reason"]
395 |
396 | @pytest.fixture
397 | def sample_trending_stocks(self, db_session):
398 | """Create sample trending stocks."""
399 | stocks = []
400 | for i in range(3):
401 | stock = SupplyDemandBreakoutStocks(
402 | id=i + 1, # Add explicit ID for SQLite
403 | stock=f"MNRV{i}",
404 | close=200.0 + i * 10,
405 | volume=2000000,
406 | ema_21=195.0 + i * 9,
407 | sma_50=190.0 + i * 8,
408 | sma_150=185.0 + i * 7,
409 | sma_200=180.0 + i * 6,
410 | momentum_score=92.0 - i * 2,
411 | adr_pct=2.8 + i * 0.2,
412 | pat="Base" if i == 0 else "Flag",
413 | sqz="neutral",
414 | consolidation="yes" if i < 2 else "no",
415 | entry=f"{202.0 + i * 10}",
416 | )
417 | stocks.append(stock)
418 |
419 | db_session.add_all(stocks)
420 | db_session.commit()
421 | return stocks
422 |
423 | def test_get_trending_recommendations(
424 | self, provider, db_session, sample_trending_stocks, monkeypatch
425 | ):
426 | """Test getting trending recommendations."""
427 | monkeypatch.setattr(provider, "_get_db_session", lambda: db_session)
428 |
429 | recommendations = provider.get_trending_recommendations(limit=2)
430 |
431 | assert len(recommendations) == 2
432 | assert recommendations[0]["stock"] == "MNRV0"
433 | assert recommendations[0]["momentum_score"] == 92.0
434 | assert recommendations[0]["recommendation_type"] == "trending_stage2"
435 | assert "reason" in recommendations[0]
436 | assert "Uptrend" in recommendations[0]["reason"]
437 |
438 | def test_get_all_screening_recommendations(self, provider, monkeypatch):
439 | """Test getting all screening recommendations."""
440 | mock_results = {
441 | "maverick_stocks": [
442 | {"stock": "AAPL", "combined_score": 95, "momentum_score": 90}
443 | ],
444 | "maverick_bear_stocks": [
445 | {"stock": "BEAR", "score": 88, "momentum_score": 25}
446 | ],
447 | "trending_stocks": [{"stock": "MSFT", "momentum_score": 91}],
448 | }
449 |
450 | monkeypatch.setattr(
451 | "maverick_mcp.providers.stock_data.get_latest_maverick_screening",
452 | lambda: mock_results,
453 | )
454 |
455 | results = provider.get_all_screening_recommendations()
456 |
457 | assert "maverick_stocks" in results
458 | assert "maverick_bear_stocks" in results
459 | assert "trending_stocks" in results
460 |
461 | # Check that reasons were added
462 | assert (
463 | results["maverick_stocks"][0]["recommendation_type"] == "maverick_bullish"
464 | )
465 | assert "reason" in results["maverick_stocks"][0]
466 |
467 | assert (
468 | results["maverick_bear_stocks"][0]["recommendation_type"]
469 | == "maverick_bearish"
470 | )
471 | assert "reason" in results["maverick_bear_stocks"][0]
472 |
473 | assert results["trending_stocks"][0]["recommendation_type"] == "trending_stage2"
474 | assert "reason" in results["trending_stocks"][0]
475 |
476 |
477 | class TestBackwardCompatibility:
478 | """Test backward compatibility with original StockDataProvider."""
479 |
480 | def test_get_stock_info(self, provider, mock_yfinance):
481 | """Test get_stock_info method."""
482 | info = provider.get_stock_info("AAPL")
483 | assert info["longName"] == "Apple Inc."
484 | assert info["sector"] == "Technology"
485 |
486 | def test_get_realtime_data(self, provider, mock_yfinance):
487 | """Test get_realtime_data method."""
488 | data = provider.get_realtime_data("AAPL")
489 |
490 | assert data is not None
491 | assert data["symbol"] == "AAPL"
492 | assert data["price"] == 156.0 # Last close from mock
493 | assert data["change"] == 5.0 # 156 - 151 (previousClose)
494 | assert data["change_percent"] == pytest.approx(3.31, rel=0.01)
495 |
496 | def test_get_all_realtime_data(self, provider, mock_yfinance):
497 | """Test get_all_realtime_data method."""
498 | results = provider.get_all_realtime_data(["AAPL", "GOOGL"])
499 |
500 | assert len(results) == 2
501 | assert "AAPL" in results
502 | assert "GOOGL" in results
503 | assert results["AAPL"]["price"] == 156.0
504 |
505 | def test_is_market_open(self, provider, monkeypatch):
506 | """Test is_market_open method."""
507 | import pytz
508 |
509 | # Mock a weekday at 10 AM Eastern
510 | mock_now = datetime(2024, 1, 2, 10, 0, 0) # Tuesday
511 | mock_now = pytz.timezone("US/Eastern").localize(mock_now)
512 |
513 | monkeypatch.setattr(
514 | "maverick_mcp.providers.stock_data.datetime",
515 | MagicMock(now=MagicMock(return_value=mock_now)),
516 | )
517 |
518 | assert provider.is_market_open() is True
519 |
520 | # Mock a weekend
521 | mock_now = datetime(2024, 1, 6, 10, 0, 0) # Saturday
522 | mock_now = pytz.timezone("US/Eastern").localize(mock_now)
523 |
524 | monkeypatch.setattr(
525 | "maverick_mcp.providers.stock_data.datetime",
526 | MagicMock(now=MagicMock(return_value=mock_now)),
527 | )
528 |
529 | assert provider.is_market_open() is False
530 |
531 | def test_get_news(self, provider, mock_yfinance):
532 | """Test get_news method."""
533 | mock_news = [
534 | {
535 | "title": "Apple News 1",
536 | "publisher": "Reuters",
537 | "link": "https://example.com/1",
538 | "providerPublishTime": 1704134400, # 2024-01-01 timestamp
539 | "type": "STORY",
540 | }
541 | ]
542 | mock_yfinance.Ticker.return_value.news = mock_news
543 |
544 | df = provider.get_news("AAPL", limit=5)
545 |
546 | assert not df.empty
547 | assert len(df) == 1
548 | assert df["title"].iloc[0] == "Apple News 1"
549 | assert isinstance(df["providerPublishTime"].iloc[0], pd.Timestamp)
550 |
551 | def test_get_earnings(self, provider, mock_yfinance):
552 | """Test get_earnings method."""
553 | result = provider.get_earnings("AAPL")
554 |
555 | assert "earnings" in result
556 | assert "earnings_dates" in result
557 | assert "earnings_trend" in result
558 |
559 | def test_get_recommendations(self, provider, mock_yfinance):
560 | """Test get_recommendations method."""
561 | df = provider.get_recommendations("AAPL")
562 |
563 | assert isinstance(df, pd.DataFrame)
564 | assert list(df.columns) == ["firm", "toGrade", "fromGrade", "action"]
565 |
566 | def test_is_etf(self, provider, mock_yfinance):
567 | """Test is_etf method."""
568 | # Test regular stock
569 | assert provider.is_etf("AAPL") is False
570 |
571 | # Test ETF
572 | mock_yfinance.Ticker.return_value.info["quoteType"] = "ETF"
573 | assert provider.is_etf("SPY") is True
574 |
575 | # Test by symbol pattern
576 | assert provider.is_etf("QQQ") is True
577 |
578 |
579 | class TestErrorHandling:
580 | """Test error handling in the enhanced provider."""
581 |
582 | def test_get_stock_data_error_handling(self, provider, mock_yfinance, monkeypatch):
583 | """Test error handling in get_stock_data."""
584 | # Mock an exception for all yfinance calls
585 | mock_yfinance.Ticker.return_value.history.side_effect = Exception("API Error")
586 |
587 | # Also mock the database session to ensure no cache is used
588 | def mock_get_db_session():
589 | raise Exception("Database error")
590 |
591 | monkeypatch.setattr(provider, "_get_db_session", mock_get_db_session)
592 |
593 | # Now the provider should return empty DataFrame since both cache and yfinance fail
594 | df = provider.get_stock_data(
595 | "AAPL", "2024-01-01", "2024-01-05", use_cache=False
596 | )
597 |
598 | assert df.empty
599 | assert list(df.columns) == ["Open", "High", "Low", "Close", "Volume"]
600 |
601 | def test_get_cached_price_data_error_handling(
602 | self, provider, db_session, monkeypatch
603 | ):
604 | """Test error handling in _get_cached_price_data."""
605 |
606 | # Mock a database error
607 | def mock_get_price_data(*args, **kwargs):
608 | raise Exception("Database error")
609 |
610 | monkeypatch.setattr(PriceCache, "get_price_data", mock_get_price_data)
611 |
612 | result = provider._get_cached_price_data(
613 | db_session, "AAPL", "2024-01-01", "2024-01-05"
614 | )
615 | assert result is None
616 |
617 | def test_cache_price_data_error_handling(self, provider, db_session, monkeypatch):
618 | """Test error handling in _cache_price_data."""
619 |
620 | # Mock a database error
621 | def mock_bulk_insert(*args, **kwargs):
622 | raise Exception("Insert error")
623 |
624 | monkeypatch.setattr(
625 | "maverick_mcp.providers.stock_data.bulk_insert_price_data", mock_bulk_insert
626 | )
627 |
628 | dates = pd.date_range("2024-01-01", periods=3, freq="D")
629 | df = pd.DataFrame(
630 | {
631 | "Open": [150.0, 151.0, 152.0],
632 | "High": [155.0, 156.0, 157.0],
633 | "Low": [149.0, 150.0, 151.0],
634 | "Close": [152.0, 153.0, 154.0],
635 | "Volume": [1000000, 1010000, 1020000],
636 | },
637 | index=dates,
638 | )
639 |
640 | # Should not raise exception
641 | provider._cache_price_data(db_session, "AAPL", df)
642 |
643 | def test_get_maverick_recommendations_error_handling(self, provider, monkeypatch):
644 | """Test error handling in get_maverick_recommendations."""
645 | # Mock a database session that throws when used
646 | mock_session = MagicMock()
647 | mock_session.query.side_effect = Exception("Database query error")
648 | mock_session.close = MagicMock()
649 |
650 | monkeypatch.setattr(provider, "_get_db_session", lambda: mock_session)
651 |
652 | recommendations = provider.get_maverick_recommendations()
653 | assert recommendations == []
654 |
```
--------------------------------------------------------------------------------
/tests/test_agents_router_mcp.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Tests for the new MCP tools in the agents router.
3 |
4 | Tests the orchestrated_analysis, deep_research_financial, and
5 | compare_multi_agent_analysis MCP tools for Claude Desktop integration.
6 | """
7 |
8 | import uuid
9 | from unittest.mock import AsyncMock, MagicMock, patch
10 |
11 | import pytest
12 |
13 | from maverick_mcp.api.routers.agents import (
14 | compare_multi_agent_analysis,
15 | deep_research_financial,
16 | get_or_create_agent,
17 | list_available_agents,
18 | orchestrated_analysis,
19 | )
20 |
21 |
22 | @pytest.fixture
23 | def mock_supervisor_agent():
24 | """Mock SupervisorAgent for testing."""
25 | agent = MagicMock()
26 | agent.orchestrate_analysis = AsyncMock(
27 | return_value={
28 | "status": "success",
29 | "synthesis": "Comprehensive analysis completed",
30 | "agents_used": ["market", "research"],
31 | "execution_time_ms": 4500,
32 | "synthesis_confidence": 0.87,
33 | "agent_results": {
34 | "market": {"summary": "Strong momentum", "confidence": 0.85},
35 | "research": {"summary": "Solid fundamentals", "confidence": 0.88},
36 | },
37 | "key_recommendations": ["Focus on large-cap tech", "Monitor earnings"],
38 | "confidence": 0.87,
39 | }
40 | )
41 | return agent
42 |
43 |
44 | @pytest.fixture
45 | def mock_research_agent():
46 | """Mock DeepResearchAgent for testing."""
47 | agent = MagicMock()
48 | agent.conduct_research = AsyncMock(
49 | return_value={
50 | "status": "success",
51 | "research_findings": [
52 | {
53 | "insight": "Strong revenue growth",
54 | "confidence": 0.9,
55 | "source": "earnings-report",
56 | },
57 | {
58 | "insight": "Expanding market share",
59 | "confidence": 0.85,
60 | "source": "market-analysis",
61 | },
62 | ],
63 | "sources_analyzed": 42,
64 | "research_confidence": 0.88,
65 | "validation_checks_passed": 35,
66 | "total_sources_processed": 50,
67 | "content_summaries": [
68 | "Financial performance strong",
69 | "Market position improving",
70 | ],
71 | "citations": [
72 | {"url": "https://example.com/report1", "title": "Q3 Earnings Analysis"},
73 | {"url": "https://example.com/report2", "title": "Market Share Report"},
74 | ],
75 | "execution_time_ms": 6500,
76 | }
77 | )
78 | return agent
79 |
80 |
81 | @pytest.fixture
82 | def mock_market_agent():
83 | """Mock MarketAnalysisAgent for testing."""
84 | agent = MagicMock()
85 | agent.analyze_market = AsyncMock(
86 | return_value={
87 | "status": "success",
88 | "summary": "Top momentum stocks identified",
89 | "screened_symbols": ["AAPL", "MSFT", "NVDA"],
90 | "confidence": 0.82,
91 | "results": {
92 | "screening_scores": {"AAPL": 0.92, "MSFT": 0.88, "NVDA": 0.95},
93 | "sector_performance": {"Technology": 0.15, "Healthcare": 0.08},
94 | },
95 | "execution_time_ms": 2100,
96 | }
97 | )
98 | return agent
99 |
100 |
101 | class TestOrchestratedAnalysis:
102 | """Test orchestrated_analysis MCP tool."""
103 |
104 | @pytest.mark.asyncio
105 | async def test_orchestrated_analysis_success(self, mock_supervisor_agent):
106 | """Test successful orchestrated analysis."""
107 | with patch(
108 | "maverick_mcp.api.routers.agents.get_or_create_agent",
109 | return_value=mock_supervisor_agent,
110 | ):
111 | result = await orchestrated_analysis(
112 | query="Analyze tech sector opportunities",
113 | persona="moderate",
114 | routing_strategy="llm_powered",
115 | max_agents=3,
116 | parallel_execution=True,
117 | )
118 |
119 | assert result["status"] == "success"
120 | assert result["agent_type"] == "supervisor_orchestrated"
121 | assert result["persona"] == "moderate"
122 | assert result["routing_strategy"] == "llm_powered"
123 | assert "agents_used" in result
124 | assert "synthesis_confidence" in result
125 | assert "execution_time_ms" in result
126 |
127 | mock_supervisor_agent.orchestrate_analysis.assert_called_once()
128 |
129 | @pytest.mark.asyncio
130 | async def test_orchestrated_analysis_with_session_id(self, mock_supervisor_agent):
131 | """Test orchestrated analysis with provided session ID."""
132 | session_id = "test-session-123"
133 |
134 | with patch(
135 | "maverick_mcp.api.routers.agents.get_or_create_agent",
136 | return_value=mock_supervisor_agent,
137 | ):
138 | result = await orchestrated_analysis(
139 | query="Market analysis", session_id=session_id
140 | )
141 |
142 | assert result["session_id"] == session_id
143 | call_args = mock_supervisor_agent.orchestrate_analysis.call_args
144 | assert call_args[1]["session_id"] == session_id
145 |
146 | @pytest.mark.asyncio
147 | async def test_orchestrated_analysis_generates_session_id(
148 | self, mock_supervisor_agent
149 | ):
150 | """Test orchestrated analysis generates session ID when not provided."""
151 | with patch(
152 | "maverick_mcp.api.routers.agents.get_or_create_agent",
153 | return_value=mock_supervisor_agent,
154 | ):
155 | result = await orchestrated_analysis(query="Market analysis")
156 |
157 | assert "session_id" in result
158 | # Should be a valid UUID format
159 | uuid.UUID(result["session_id"])
160 |
161 | @pytest.mark.asyncio
162 | async def test_orchestrated_analysis_error_handling(self, mock_supervisor_agent):
163 | """Test orchestrated analysis error handling."""
164 | mock_supervisor_agent.orchestrate_analysis.side_effect = Exception(
165 | "Orchestration failed"
166 | )
167 |
168 | with patch(
169 | "maverick_mcp.api.routers.agents.get_or_create_agent",
170 | return_value=mock_supervisor_agent,
171 | ):
172 | result = await orchestrated_analysis(query="Test error handling")
173 |
174 | assert result["status"] == "error"
175 | assert result["agent_type"] == "supervisor_orchestrated"
176 | assert "error" in result
177 | assert "Orchestration failed" in result["error"]
178 |
179 | @pytest.mark.asyncio
180 | async def test_orchestrated_analysis_persona_variations(
181 | self, mock_supervisor_agent
182 | ):
183 | """Test orchestrated analysis with different personas."""
184 | personas = ["conservative", "moderate", "aggressive", "day_trader"]
185 |
186 | with patch(
187 | "maverick_mcp.api.routers.agents.get_or_create_agent",
188 | return_value=mock_supervisor_agent,
189 | ):
190 | for persona in personas:
191 | result = await orchestrated_analysis(
192 | query="Test persona", persona=persona
193 | )
194 |
195 | assert result["persona"] == persona
196 | # Check agent was created with correct persona
197 | call_args = mock_supervisor_agent.orchestrate_analysis.call_args
198 | assert call_args is not None
199 |
200 |
201 | class TestDeepResearchFinancial:
202 | """Test deep_research_financial MCP tool."""
203 |
204 | @pytest.mark.asyncio
205 | async def test_deep_research_success(self, mock_research_agent):
206 | """Test successful deep research."""
207 | with patch(
208 | "maverick_mcp.api.routers.agents.get_or_create_agent",
209 | return_value=mock_research_agent,
210 | ):
211 | result = await deep_research_financial(
212 | research_topic="AAPL competitive analysis",
213 | persona="moderate",
214 | research_depth="comprehensive",
215 | focus_areas=["fundamentals", "competitive_landscape"],
216 | timeframe="90d",
217 | )
218 |
219 | assert result["status"] == "success"
220 | assert result["agent_type"] == "deep_research"
221 | assert result["research_topic"] == "AAPL competitive analysis"
222 | assert result["research_depth"] == "comprehensive"
223 | assert "fundamentals" in result["focus_areas"]
224 | assert "competitive_landscape" in result["focus_areas"]
225 | assert result["sources_analyzed"] == 42
226 | assert result["research_confidence"] == 0.88
227 |
228 | mock_research_agent.conduct_research.assert_called_once()
229 |
230 | @pytest.mark.asyncio
231 | async def test_deep_research_default_focus_areas(self, mock_research_agent):
232 | """Test deep research with default focus areas."""
233 | with patch(
234 | "maverick_mcp.api.routers.agents.get_or_create_agent",
235 | return_value=mock_research_agent,
236 | ):
237 | result = await deep_research_financial(
238 | research_topic="Tesla analysis",
239 | focus_areas=None, # Should use defaults
240 | )
241 |
242 | expected_defaults = [
243 | "fundamentals",
244 | "market_sentiment",
245 | "competitive_landscape",
246 | ]
247 | assert result["focus_areas"] == expected_defaults
248 |
249 | call_args = mock_research_agent.conduct_research.call_args
250 | assert call_args[1]["focus_areas"] == expected_defaults
251 |
252 | @pytest.mark.asyncio
253 | async def test_deep_research_depth_variations(self, mock_research_agent):
254 | """Test deep research with different depth levels."""
255 | depth_levels = ["basic", "standard", "comprehensive", "exhaustive"]
256 |
257 | with patch(
258 | "maverick_mcp.api.routers.agents.get_or_create_agent",
259 | return_value=mock_research_agent,
260 | ):
261 | for depth in depth_levels:
262 | result = await deep_research_financial(
263 | research_topic="Test research", research_depth=depth
264 | )
265 |
266 | assert result["research_depth"] == depth
267 |
268 | @pytest.mark.asyncio
269 | async def test_deep_research_error_handling(self, mock_research_agent):
270 | """Test deep research error handling."""
271 | mock_research_agent.conduct_research.side_effect = Exception(
272 | "Research API failed"
273 | )
274 |
275 | with patch(
276 | "maverick_mcp.api.routers.agents.get_or_create_agent",
277 | return_value=mock_research_agent,
278 | ):
279 | result = await deep_research_financial(research_topic="Error test")
280 |
281 | assert result["status"] == "error"
282 | assert result["agent_type"] == "deep_research"
283 | assert "Research API failed" in result["error"]
284 |
285 | @pytest.mark.asyncio
286 | async def test_deep_research_timeframe_handling(self, mock_research_agent):
287 | """Test deep research with different timeframes."""
288 | timeframes = ["7d", "30d", "90d", "1y"]
289 |
290 | with patch(
291 | "maverick_mcp.api.routers.agents.get_or_create_agent",
292 | return_value=mock_research_agent,
293 | ):
294 | for timeframe in timeframes:
295 | await deep_research_financial(
296 | research_topic="Timeframe test", timeframe=timeframe
297 | )
298 |
299 | call_args = mock_research_agent.conduct_research.call_args
300 | assert call_args[1]["timeframe"] == timeframe
301 |
302 |
303 | class TestCompareMultiAgentAnalysis:
304 | """Test compare_multi_agent_analysis MCP tool."""
305 |
306 | @pytest.mark.asyncio
307 | async def test_compare_agents_success(
308 | self, mock_market_agent, mock_supervisor_agent
309 | ):
310 | """Test successful multi-agent comparison."""
311 |
312 | def get_agent_mock(agent_type, persona):
313 | if agent_type == "market":
314 | return mock_market_agent
315 | elif agent_type == "supervisor":
316 | return mock_supervisor_agent
317 | else:
318 | raise ValueError(f"Unknown agent type: {agent_type}")
319 |
320 | with patch(
321 | "maverick_mcp.api.routers.agents.get_or_create_agent",
322 | side_effect=get_agent_mock,
323 | ):
324 | result = await compare_multi_agent_analysis(
325 | query="Analyze NVDA stock potential",
326 | agent_types=["market", "supervisor"],
327 | persona="moderate",
328 | )
329 |
330 | assert result["status"] == "success"
331 | assert result["persona"] == "moderate"
332 | assert "comparison" in result
333 | assert "market" in result["comparison"]
334 | assert "supervisor" in result["comparison"]
335 | assert "execution_times_ms" in result
336 |
337 | # Both agents should have been called
338 | mock_market_agent.analyze_market.assert_called_once()
339 | mock_supervisor_agent.orchestrate_analysis.assert_called_once()
340 |
341 | @pytest.mark.asyncio
342 | async def test_compare_agents_default_types(
343 | self, mock_market_agent, mock_supervisor_agent
344 | ):
345 | """Test comparison with default agent types."""
346 |
347 | def get_agent_mock(agent_type, persona):
348 | return (
349 | mock_market_agent if agent_type == "market" else mock_supervisor_agent
350 | )
351 |
352 | with patch(
353 | "maverick_mcp.api.routers.agents.get_or_create_agent",
354 | side_effect=get_agent_mock,
355 | ):
356 | result = await compare_multi_agent_analysis(
357 | query="Default comparison test",
358 | agent_types=None, # Should use defaults
359 | )
360 |
361 | # Should use default agent types ["market", "supervisor"]
362 | assert "market" in result["agents_compared"]
363 | assert "supervisor" in result["agents_compared"]
364 |
365 | @pytest.mark.asyncio
366 | async def test_compare_agents_with_failure(
367 | self, mock_market_agent, mock_supervisor_agent
368 | ):
369 | """Test comparison with one agent failing."""
370 | mock_market_agent.analyze_market.side_effect = Exception("Market agent failed")
371 |
372 | def get_agent_mock(agent_type, persona):
373 | return (
374 | mock_market_agent if agent_type == "market" else mock_supervisor_agent
375 | )
376 |
377 | with patch(
378 | "maverick_mcp.api.routers.agents.get_or_create_agent",
379 | side_effect=get_agent_mock,
380 | ):
381 | result = await compare_multi_agent_analysis(
382 | query="Failure handling test", agent_types=["market", "supervisor"]
383 | )
384 |
385 | assert result["status"] == "success" # Overall success
386 | assert "comparison" in result
387 | assert "error" in result["comparison"]["market"]
388 | assert result["comparison"]["market"]["status"] == "failed"
389 | # Supervisor should still succeed
390 | assert "summary" in result["comparison"]["supervisor"]
391 |
392 | @pytest.mark.asyncio
393 | async def test_compare_agents_session_id_handling(
394 | self, mock_market_agent, mock_supervisor_agent
395 | ):
396 | """Test session ID handling in agent comparison."""
397 | session_id = "compare-test-456"
398 |
399 | def get_agent_mock(agent_type, persona):
400 | return (
401 | mock_market_agent if agent_type == "market" else mock_supervisor_agent
402 | )
403 |
404 | with patch(
405 | "maverick_mcp.api.routers.agents.get_or_create_agent",
406 | side_effect=get_agent_mock,
407 | ):
408 | await compare_multi_agent_analysis(
409 | query="Session ID test", session_id=session_id
410 | )
411 |
412 | # Check session IDs were properly formatted for each agent
413 | market_call_args = mock_market_agent.analyze_market.call_args
414 | assert market_call_args[1]["session_id"] == f"{session_id}_market"
415 |
416 | supervisor_call_args = mock_supervisor_agent.orchestrate_analysis.call_args
417 | assert supervisor_call_args[1]["session_id"] == f"{session_id}_supervisor"
418 |
419 |
420 | class TestGetOrCreateAgent:
421 | """Test agent factory function."""
422 |
423 | @patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"})
424 | def test_create_supervisor_agent(self):
425 | """Test creating supervisor agent."""
426 | with patch("maverick_mcp.api.routers.agents.SupervisorAgent") as mock_class:
427 | mock_instance = MagicMock()
428 | mock_class.return_value = mock_instance
429 |
430 | agent = get_or_create_agent("supervisor", "moderate")
431 |
432 | assert agent == mock_instance
433 | mock_class.assert_called_once()
434 |
435 | @patch.dict(
436 | "os.environ",
437 | {
438 | "OPENAI_API_KEY": "test-key",
439 | "EXA_API_KEY": "exa-key",
440 | "TAVILY_API_KEY": "tavily-key",
441 | },
442 | )
443 | def test_create_deep_research_agent_with_api_keys(self):
444 | """Test creating deep research agent with API keys."""
445 | with patch("maverick_mcp.api.routers.agents.DeepResearchAgent") as mock_class:
446 | mock_instance = MagicMock()
447 | mock_class.return_value = mock_instance
448 |
449 | get_or_create_agent("deep_research", "moderate")
450 |
451 | # Should pass API keys to constructor
452 | call_args = mock_class.call_args
453 | assert call_args[1]["exa_api_key"] == "exa-key"
454 | assert call_args[1]["tavily_api_key"] == "tavily-key"
455 |
456 | @patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"})
457 | def test_create_deep_research_agent_without_api_keys(self):
458 | """Test creating deep research agent without optional API keys."""
459 | with patch("maverick_mcp.api.routers.agents.DeepResearchAgent") as mock_class:
460 | mock_instance = MagicMock()
461 | mock_class.return_value = mock_instance
462 |
463 | get_or_create_agent("deep_research", "moderate")
464 |
465 | # Should pass None for missing API keys
466 | call_args = mock_class.call_args
467 | assert call_args[1]["exa_api_key"] is None
468 | assert call_args[1]["tavily_api_key"] is None
469 |
470 | def test_agent_caching(self):
471 | """Test agent instance caching."""
472 | with patch("maverick_mcp.api.routers.agents.MarketAnalysisAgent") as mock_class:
473 | mock_instance = MagicMock()
474 | mock_class.return_value = mock_instance
475 |
476 | # First call should create agent
477 | agent1 = get_or_create_agent("market", "moderate")
478 | # Second call should return cached agent
479 | agent2 = get_or_create_agent("market", "moderate")
480 |
481 | assert agent1 == agent2 == mock_instance
482 | # Constructor should only be called once due to caching
483 | mock_class.assert_called_once()
484 |
485 | def test_different_personas_create_different_agents(self):
486 | """Test different personas create separate cached agents."""
487 | with patch("maverick_mcp.api.routers.agents.MarketAnalysisAgent") as mock_class:
488 | mock_class.return_value = MagicMock()
489 |
490 | agent_moderate = get_or_create_agent("market", "moderate")
491 | agent_aggressive = get_or_create_agent("market", "aggressive")
492 |
493 | # Should create separate instances for different personas
494 | assert agent_moderate != agent_aggressive
495 | assert mock_class.call_count == 2
496 |
497 | def test_invalid_agent_type(self):
498 | """Test handling of invalid agent type."""
499 | with pytest.raises(ValueError, match="Unknown agent type"):
500 | get_or_create_agent("invalid_agent_type", "moderate")
501 |
502 |
503 | class TestListAvailableAgents:
504 | """Test list_available_agents MCP tool."""
505 |
506 | def test_list_available_agents_structure(self):
507 | """Test the structure of available agents list."""
508 | result = list_available_agents()
509 |
510 | assert result["status"] == "success"
511 | assert "agents" in result
512 | assert "orchestrated_tools" in result
513 | assert "features" in result
514 |
515 | def test_active_agents_listed(self):
516 | """Test that active agents are properly listed."""
517 | result = list_available_agents()
518 | agents = result["agents"]
519 |
520 | # Check new orchestrated agents
521 | assert "supervisor_orchestrated" in agents
522 | assert agents["supervisor_orchestrated"]["status"] == "active"
523 | assert (
524 | "Multi-agent orchestration"
525 | in agents["supervisor_orchestrated"]["description"]
526 | )
527 |
528 | assert "deep_research" in agents
529 | assert agents["deep_research"]["status"] == "active"
530 | assert (
531 | "comprehensive financial research"
532 | in agents["deep_research"]["description"].lower()
533 | )
534 |
535 | def test_orchestrated_tools_listed(self):
536 | """Test that orchestrated tools are listed."""
537 | result = list_available_agents()
538 | tools = result["orchestrated_tools"]
539 |
540 | assert "orchestrated_analysis" in tools
541 | assert "deep_research_financial" in tools
542 | assert "compare_multi_agent_analysis" in tools
543 |
544 | def test_personas_supported(self):
545 | """Test that all personas are supported."""
546 | result = list_available_agents()
547 |
548 | expected_personas = ["conservative", "moderate", "aggressive", "day_trader"]
549 |
550 | # Check supervisor agent supports all personas
551 | supervisor_personas = result["agents"]["supervisor_orchestrated"]["personas"]
552 | assert all(persona in supervisor_personas for persona in expected_personas)
553 |
554 | # Check research agent supports all personas
555 | research_personas = result["agents"]["deep_research"]["personas"]
556 | assert all(persona in research_personas for persona in expected_personas)
557 |
558 | def test_capabilities_documented(self):
559 | """Test that agent capabilities are documented."""
560 | result = list_available_agents()
561 | agents = result["agents"]
562 |
563 | # Supervisor capabilities
564 | supervisor_caps = agents["supervisor_orchestrated"]["capabilities"]
565 | assert "Intelligent query routing" in supervisor_caps
566 | assert "Multi-agent coordination" in supervisor_caps
567 |
568 | # Research capabilities
569 | research_caps = agents["deep_research"]["capabilities"]
570 | assert "Multi-provider web search" in research_caps
571 | assert "AI-powered content analysis" in research_caps
572 |
573 | def test_new_features_documented(self):
574 | """Test that new orchestration features are documented."""
575 | result = list_available_agents()
576 | features = result["features"]
577 |
578 | assert "multi_agent_orchestration" in features
579 | assert "web_search_research" in features
580 | assert "intelligent_routing" in features
581 |
582 |
583 | @pytest.mark.integration
584 | class TestAgentRouterIntegration:
585 | """Integration tests for agent router MCP tools."""
586 |
587 | @pytest.mark.asyncio
588 | async def test_end_to_end_orchestrated_workflow(self):
589 | """Test complete orchestrated analysis workflow."""
590 | # This would be a full integration test with real or more sophisticated mocks
591 | # Testing the complete flow: query -> classification -> agent execution -> synthesis
592 | pass
593 |
594 | @pytest.mark.asyncio
595 | async def test_research_agent_with_supervisor_integration(self):
596 | """Test research agent working with supervisor."""
597 | # Test how research agent integrates with supervisor routing
598 | pass
599 |
600 | @pytest.mark.asyncio
601 | async def test_error_propagation_across_agents(self):
602 | """Test how errors propagate through the orchestration system."""
603 | pass
604 |
605 |
606 | if __name__ == "__main__":
607 | # Run tests
608 | pytest.main([__file__, "-v", "--tb=short"])
609 |
```