This is page 18 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/config/tool_estimation.py:
--------------------------------------------------------------------------------
```python
1 | """Centralised tool usage estimation configuration."""
2 |
3 | from __future__ import annotations
4 |
5 | from enum import Enum
6 | from typing import Any
7 |
8 | from pydantic import BaseModel, ConfigDict, Field, field_validator
9 |
10 |
11 | class EstimationBasis(str, Enum):
12 | """Describes how a tool estimate was derived."""
13 |
14 | EMPIRICAL = "empirical"
15 | CONSERVATIVE = "conservative"
16 | HEURISTIC = "heuristic"
17 | SIMULATED = "simulated"
18 |
19 |
20 | class ToolComplexity(str, Enum):
21 | """Qualitative complexity buckets used for monitoring and reporting."""
22 |
23 | SIMPLE = "simple"
24 | STANDARD = "standard"
25 | COMPLEX = "complex"
26 | PREMIUM = "premium"
27 |
28 |
29 | class ToolEstimate(BaseModel):
30 | """Static estimate describing expected LLM usage for a tool."""
31 |
32 | model_config = ConfigDict(frozen=True)
33 |
34 | llm_calls: int = Field(ge=0)
35 | total_tokens: int = Field(ge=0)
36 | confidence: float = Field(ge=0.0, le=1.0)
37 | based_on: EstimationBasis
38 | complexity: ToolComplexity
39 | notes: str | None = None
40 |
41 | @field_validator("llm_calls", "total_tokens")
42 | @classmethod
43 | def _non_negative(cls, value: int) -> int:
44 | if value < 0:
45 | raise ValueError("Estimates must be non-negative")
46 | return value
47 |
48 |
49 | class MonitoringThresholds(BaseModel):
50 | """Thresholds for triggering alerting logic."""
51 |
52 | llm_calls_warning: int = 15
53 | llm_calls_critical: int = 25
54 | tokens_warning: int = 20_000
55 | tokens_critical: int = 35_000
56 | variance_warning: float = 0.5
57 | variance_critical: float = 1.0
58 |
59 | model_config = ConfigDict(validate_assignment=True)
60 |
61 | @field_validator(
62 | "llm_calls_warning",
63 | "llm_calls_critical",
64 | "tokens_warning",
65 | "tokens_critical",
66 | )
67 | @classmethod
68 | def _positive(cls, value: int) -> int:
69 | if value <= 0:
70 | raise ValueError("Monitoring thresholds must be positive")
71 | return value
72 |
73 |
74 | class ToolEstimationConfig(BaseModel):
75 | """Container for all tool estimates used across the service."""
76 |
77 | model_config = ConfigDict(arbitrary_types_allowed=True)
78 |
79 | default_confidence: float = 0.75
80 | monitoring: MonitoringThresholds = Field(default_factory=MonitoringThresholds)
81 | simple_default: ToolEstimate = Field(
82 | default_factory=lambda: ToolEstimate(
83 | llm_calls=1,
84 | total_tokens=600,
85 | confidence=0.85,
86 | based_on=EstimationBasis.EMPIRICAL,
87 | complexity=ToolComplexity.SIMPLE,
88 | notes="Baseline simple operation",
89 | )
90 | )
91 | standard_default: ToolEstimate = Field(
92 | default_factory=lambda: ToolEstimate(
93 | llm_calls=3,
94 | total_tokens=4000,
95 | confidence=0.75,
96 | based_on=EstimationBasis.HEURISTIC,
97 | complexity=ToolComplexity.STANDARD,
98 | notes="Baseline standard analysis",
99 | )
100 | )
101 | complex_default: ToolEstimate = Field(
102 | default_factory=lambda: ToolEstimate(
103 | llm_calls=6,
104 | total_tokens=9000,
105 | confidence=0.7,
106 | based_on=EstimationBasis.SIMULATED,
107 | complexity=ToolComplexity.COMPLEX,
108 | notes="Baseline complex workflow",
109 | )
110 | )
111 | premium_default: ToolEstimate = Field(
112 | default_factory=lambda: ToolEstimate(
113 | llm_calls=10,
114 | total_tokens=15000,
115 | confidence=0.65,
116 | based_on=EstimationBasis.CONSERVATIVE,
117 | complexity=ToolComplexity.PREMIUM,
118 | notes="Baseline premium orchestration",
119 | )
120 | )
121 | unknown_tool_estimate: ToolEstimate = Field(
122 | default_factory=lambda: ToolEstimate(
123 | llm_calls=3,
124 | total_tokens=5000,
125 | confidence=0.3,
126 | based_on=EstimationBasis.CONSERVATIVE,
127 | complexity=ToolComplexity.STANDARD,
128 | notes="Fallback estimate for unknown tools",
129 | )
130 | )
131 | tool_estimates: dict[str, ToolEstimate] = Field(default_factory=dict)
132 |
133 | def model_post_init(self, _context: Any) -> None: # noqa: D401
134 | if not self.tool_estimates:
135 | self.tool_estimates = _build_default_estimates(self)
136 | else:
137 | normalised: dict[str, ToolEstimate] = {}
138 | for key, estimate in self.tool_estimates.items():
139 | normalised[key.lower()] = estimate
140 | self.tool_estimates = normalised
141 |
142 | def get_estimate(self, tool_name: str) -> ToolEstimate:
143 | key = tool_name.lower()
144 | return self.tool_estimates.get(key, self.unknown_tool_estimate)
145 |
146 | def get_default_for_complexity(self, complexity: ToolComplexity) -> ToolEstimate:
147 | mapping = {
148 | ToolComplexity.SIMPLE: self.simple_default,
149 | ToolComplexity.STANDARD: self.standard_default,
150 | ToolComplexity.COMPLEX: self.complex_default,
151 | ToolComplexity.PREMIUM: self.premium_default,
152 | }
153 | return mapping[complexity]
154 |
155 | def get_tools_by_complexity(self, complexity: ToolComplexity) -> list[str]:
156 | return sorted(
157 | [
158 | name
159 | for name, estimate in self.tool_estimates.items()
160 | if estimate.complexity == complexity
161 | ]
162 | )
163 |
164 | def get_summary_stats(self) -> dict[str, Any]:
165 | if not self.tool_estimates:
166 | return {}
167 |
168 | total_tools = len(self.tool_estimates)
169 | by_complexity: dict[str, int] = {c.value: 0 for c in ToolComplexity}
170 | basis_distribution: dict[str, int] = {b.value: 0 for b in EstimationBasis}
171 | llm_total = 0
172 | token_total = 0
173 | confidence_total = 0.0
174 |
175 | for estimate in self.tool_estimates.values():
176 | by_complexity[estimate.complexity.value] += 1
177 | basis_distribution[estimate.based_on.value] += 1
178 | llm_total += estimate.llm_calls
179 | token_total += estimate.total_tokens
180 | confidence_total += estimate.confidence
181 |
182 | return {
183 | "total_tools": total_tools,
184 | "by_complexity": by_complexity,
185 | "avg_llm_calls": llm_total / total_tools,
186 | "avg_tokens": token_total / total_tools,
187 | "avg_confidence": confidence_total / total_tools,
188 | "basis_distribution": basis_distribution,
189 | }
190 |
191 | def should_alert(
192 | self, tool_name: str, actual_llm_calls: int, actual_tokens: int
193 | ) -> tuple[bool, str]:
194 | estimate = self.get_estimate(tool_name)
195 | thresholds = self.monitoring
196 | alerts: list[str] = []
197 |
198 | if actual_llm_calls >= thresholds.llm_calls_critical:
199 | alerts.append(
200 | f"Critical: LLM calls ({actual_llm_calls}) exceeded threshold ({thresholds.llm_calls_critical})"
201 | )
202 | elif actual_llm_calls >= thresholds.llm_calls_warning:
203 | alerts.append(
204 | f"Warning: LLM calls ({actual_llm_calls}) exceeded threshold ({thresholds.llm_calls_warning})"
205 | )
206 |
207 | if actual_tokens >= thresholds.tokens_critical:
208 | alerts.append(
209 | f"Critical: Token usage ({actual_tokens}) exceeded threshold ({thresholds.tokens_critical})"
210 | )
211 | elif actual_tokens >= thresholds.tokens_warning:
212 | alerts.append(
213 | f"Warning: Token usage ({actual_tokens}) exceeded threshold ({thresholds.tokens_warning})"
214 | )
215 |
216 | expected_llm = estimate.llm_calls
217 | expected_tokens = estimate.total_tokens
218 |
219 | llm_variance = (
220 | float("inf")
221 | if expected_llm == 0 and actual_llm_calls > 0
222 | else ((actual_llm_calls - expected_llm) / max(expected_llm, 1))
223 | )
224 | token_variance = (
225 | float("inf")
226 | if expected_tokens == 0 and actual_tokens > 0
227 | else ((actual_tokens - expected_tokens) / max(expected_tokens, 1))
228 | )
229 |
230 | if llm_variance == float("inf") or llm_variance > thresholds.variance_critical:
231 | alerts.append("Critical: LLM call variance exceeded acceptable range")
232 | elif llm_variance > thresholds.variance_warning:
233 | alerts.append("Warning: LLM call variance elevated")
234 |
235 | if (
236 | token_variance == float("inf")
237 | or token_variance > thresholds.variance_critical
238 | ):
239 | alerts.append("Critical: Token variance exceeded acceptable range")
240 | elif token_variance > thresholds.variance_warning:
241 | alerts.append("Warning: Token variance elevated")
242 |
243 | message = "; ".join(alerts)
244 | return (bool(alerts), message)
245 |
246 |
247 | def _build_default_estimates(config: ToolEstimationConfig) -> dict[str, ToolEstimate]:
248 | data: dict[str, dict[str, Any]] = {
249 | "get_stock_price": {
250 | "llm_calls": 0,
251 | "total_tokens": 200,
252 | "confidence": 0.92,
253 | "based_on": EstimationBasis.EMPIRICAL,
254 | "complexity": ToolComplexity.SIMPLE,
255 | "notes": "Direct market data lookup",
256 | },
257 | "get_company_info": {
258 | "llm_calls": 1,
259 | "total_tokens": 600,
260 | "confidence": 0.88,
261 | "based_on": EstimationBasis.EMPIRICAL,
262 | "complexity": ToolComplexity.SIMPLE,
263 | "notes": "Cached profile summary",
264 | },
265 | "get_stock_info": {
266 | "llm_calls": 1,
267 | "total_tokens": 550,
268 | "confidence": 0.87,
269 | "based_on": EstimationBasis.EMPIRICAL,
270 | "complexity": ToolComplexity.SIMPLE,
271 | "notes": "Quote lookup",
272 | },
273 | "calculate_sma": {
274 | "llm_calls": 0,
275 | "total_tokens": 180,
276 | "confidence": 0.9,
277 | "based_on": EstimationBasis.EMPIRICAL,
278 | "complexity": ToolComplexity.SIMPLE,
279 | "notes": "Local technical calculation",
280 | },
281 | "get_market_hours": {
282 | "llm_calls": 0,
283 | "total_tokens": 120,
284 | "confidence": 0.95,
285 | "based_on": EstimationBasis.EMPIRICAL,
286 | "complexity": ToolComplexity.SIMPLE,
287 | "notes": "Static schedule lookup",
288 | },
289 | "get_chart_links": {
290 | "llm_calls": 1,
291 | "total_tokens": 500,
292 | "confidence": 0.85,
293 | "based_on": EstimationBasis.HEURISTIC,
294 | "complexity": ToolComplexity.SIMPLE,
295 | "notes": "Generates chart URLs",
296 | },
297 | "list_available_agents": {
298 | "llm_calls": 1,
299 | "total_tokens": 800,
300 | "confidence": 0.82,
301 | "based_on": EstimationBasis.HEURISTIC,
302 | "complexity": ToolComplexity.SIMPLE,
303 | "notes": "Lists registered AI agents",
304 | },
305 | "clear_cache": {
306 | "llm_calls": 0,
307 | "total_tokens": 100,
308 | "confidence": 0.9,
309 | "based_on": EstimationBasis.EMPIRICAL,
310 | "complexity": ToolComplexity.SIMPLE,
311 | "notes": "Invalidates cache entries",
312 | },
313 | "get_cached_price_data": {
314 | "llm_calls": 0,
315 | "total_tokens": 150,
316 | "confidence": 0.86,
317 | "based_on": EstimationBasis.EMPIRICAL,
318 | "complexity": ToolComplexity.SIMPLE,
319 | "notes": "Reads cached OHLC data",
320 | },
321 | "get_watchlist": {
322 | "llm_calls": 1,
323 | "total_tokens": 650,
324 | "confidence": 0.84,
325 | "based_on": EstimationBasis.EMPIRICAL,
326 | "complexity": ToolComplexity.SIMPLE,
327 | "notes": "Fetches saved watchlists",
328 | },
329 | "generate_dev_token": {
330 | "llm_calls": 1,
331 | "total_tokens": 700,
332 | "confidence": 0.82,
333 | "based_on": EstimationBasis.HEURISTIC,
334 | "complexity": ToolComplexity.SIMPLE,
335 | "notes": "Generates development API token",
336 | },
337 | "get_rsi_analysis": {
338 | "llm_calls": 2,
339 | "total_tokens": 3000,
340 | "confidence": 0.78,
341 | "based_on": EstimationBasis.EMPIRICAL,
342 | "complexity": ToolComplexity.STANDARD,
343 | "notes": "RSI interpretation",
344 | },
345 | "get_macd_analysis": {
346 | "llm_calls": 3,
347 | "total_tokens": 3200,
348 | "confidence": 0.74,
349 | "based_on": EstimationBasis.EMPIRICAL,
350 | "complexity": ToolComplexity.STANDARD,
351 | "notes": "MACD indicator narrative",
352 | },
353 | "get_support_resistance": {
354 | "llm_calls": 4,
355 | "total_tokens": 3400,
356 | "confidence": 0.72,
357 | "based_on": EstimationBasis.HEURISTIC,
358 | "complexity": ToolComplexity.STANDARD,
359 | "notes": "Support/resistance summary",
360 | },
361 | "fetch_stock_data": {
362 | "llm_calls": 1,
363 | "total_tokens": 2600,
364 | "confidence": 0.8,
365 | "based_on": EstimationBasis.EMPIRICAL,
366 | "complexity": ToolComplexity.STANDARD,
367 | "notes": "Aggregates OHLC data",
368 | },
369 | "get_maverick_stocks": {
370 | "llm_calls": 4,
371 | "total_tokens": 4500,
372 | "confidence": 0.73,
373 | "based_on": EstimationBasis.SIMULATED,
374 | "complexity": ToolComplexity.STANDARD,
375 | "notes": "Retrieves screening candidates",
376 | },
377 | "get_news_sentiment": {
378 | "llm_calls": 3,
379 | "total_tokens": 4800,
380 | "confidence": 0.76,
381 | "based_on": EstimationBasis.EMPIRICAL,
382 | "complexity": ToolComplexity.STANDARD,
383 | "notes": "Summarises latest news sentiment",
384 | },
385 | "get_economic_calendar": {
386 | "llm_calls": 2,
387 | "total_tokens": 2800,
388 | "confidence": 0.79,
389 | "based_on": EstimationBasis.EMPIRICAL,
390 | "complexity": ToolComplexity.STANDARD,
391 | "notes": "Economic calendar summary",
392 | },
393 | "get_full_technical_analysis": {
394 | "llm_calls": 6,
395 | "total_tokens": 9200,
396 | "confidence": 0.72,
397 | "based_on": EstimationBasis.EMPIRICAL,
398 | "complexity": ToolComplexity.COMPLEX,
399 | "notes": "Comprehensive technical package",
400 | },
401 | "risk_adjusted_analysis": {
402 | "llm_calls": 5,
403 | "total_tokens": 8800,
404 | "confidence": 0.7,
405 | "based_on": EstimationBasis.HEURISTIC,
406 | "complexity": ToolComplexity.COMPLEX,
407 | "notes": "Risk-adjusted metrics",
408 | },
409 | "compare_tickers": {
410 | "llm_calls": 6,
411 | "total_tokens": 9400,
412 | "confidence": 0.71,
413 | "based_on": EstimationBasis.SIMULATED,
414 | "complexity": ToolComplexity.COMPLEX,
415 | "notes": "Ticker comparison",
416 | },
417 | "portfolio_correlation_analysis": {
418 | "llm_calls": 5,
419 | "total_tokens": 8700,
420 | "confidence": 0.72,
421 | "based_on": EstimationBasis.SIMULATED,
422 | "complexity": ToolComplexity.COMPLEX,
423 | "notes": "Portfolio correlation study",
424 | },
425 | "get_market_overview": {
426 | "llm_calls": 4,
427 | "total_tokens": 7800,
428 | "confidence": 0.74,
429 | "based_on": EstimationBasis.HEURISTIC,
430 | "complexity": ToolComplexity.COMPLEX,
431 | "notes": "Market breadth overview",
432 | },
433 | "get_all_screening_recommendations": {
434 | "llm_calls": 5,
435 | "total_tokens": 8200,
436 | "confidence": 0.7,
437 | "based_on": EstimationBasis.SIMULATED,
438 | "complexity": ToolComplexity.COMPLEX,
439 | "notes": "Bulk screening results",
440 | },
441 | "analyze_market_with_agent": {
442 | "llm_calls": 10,
443 | "total_tokens": 14000,
444 | "confidence": 0.65,
445 | "based_on": EstimationBasis.CONSERVATIVE,
446 | "complexity": ToolComplexity.PREMIUM,
447 | "notes": "Multi-agent orchestration",
448 | },
449 | "get_agent_streaming_analysis": {
450 | "llm_calls": 12,
451 | "total_tokens": 16000,
452 | "confidence": 0.6,
453 | "based_on": EstimationBasis.CONSERVATIVE,
454 | "complexity": ToolComplexity.PREMIUM,
455 | "notes": "Streaming agent analysis",
456 | },
457 | "compare_personas_analysis": {
458 | "llm_calls": 9,
459 | "total_tokens": 12000,
460 | "confidence": 0.62,
461 | "based_on": EstimationBasis.HEURISTIC,
462 | "complexity": ToolComplexity.PREMIUM,
463 | "notes": "Persona comparison",
464 | },
465 | }
466 |
467 | estimates = {name: ToolEstimate(**details) for name, details in data.items()}
468 | return estimates
469 |
470 |
471 | _config: ToolEstimationConfig | None = None
472 |
473 |
474 | def get_tool_estimation_config() -> ToolEstimationConfig:
475 | """Return the singleton tool estimation configuration."""
476 |
477 | global _config
478 | if _config is None:
479 | _config = ToolEstimationConfig()
480 | return _config
481 |
482 |
483 | def get_tool_estimate(tool_name: str) -> ToolEstimate:
484 | """Convenience helper returning the estimate for ``tool_name``."""
485 |
486 | return get_tool_estimation_config().get_estimate(tool_name)
487 |
488 |
489 | def should_alert_for_usage(
490 | tool_name: str, llm_calls: int, total_tokens: int
491 | ) -> tuple[bool, str]:
492 | """Check whether actual usage deviates enough to raise an alert."""
493 |
494 | return get_tool_estimation_config().should_alert(tool_name, llm_calls, total_tokens)
495 |
496 |
497 | class ToolCostEstimator:
498 | """Legacy cost estimator retained for backwards compatibility."""
499 |
500 | BASE_COSTS = {
501 | "search": {"simple": 1, "moderate": 3, "complex": 5, "very_complex": 8},
502 | "analysis": {"simple": 2, "moderate": 4, "complex": 7, "very_complex": 12},
503 | "data": {"simple": 1, "moderate": 2, "complex": 4, "very_complex": 6},
504 | "research": {"simple": 3, "moderate": 6, "complex": 10, "very_complex": 15},
505 | }
506 |
507 | MULTIPLIERS = {
508 | "batch_size": {"small": 1.0, "medium": 1.5, "large": 2.0},
509 | "time_sensitivity": {"normal": 1.0, "urgent": 1.3, "real_time": 1.5},
510 | }
511 |
512 | @classmethod
513 | def estimate_tool_cost(
514 | cls,
515 | tool_name: str,
516 | category: str,
517 | complexity: str = "moderate",
518 | additional_params: dict[str, Any] | None = None,
519 | ) -> int:
520 | additional_params = additional_params or {}
521 | base_cost = cls.BASE_COSTS.get(category, {}).get(complexity, 3)
522 |
523 | batch_size = additional_params.get("batch_size", 1)
524 | if batch_size <= 10:
525 | batch_multiplier = cls.MULTIPLIERS["batch_size"]["small"]
526 | elif batch_size <= 50:
527 | batch_multiplier = cls.MULTIPLIERS["batch_size"]["medium"]
528 | else:
529 | batch_multiplier = cls.MULTIPLIERS["batch_size"]["large"]
530 |
531 | time_sensitivity = additional_params.get("time_sensitivity", "normal")
532 | time_multiplier = cls.MULTIPLIERS["time_sensitivity"].get(time_sensitivity, 1.0)
533 |
534 | total_cost = base_cost * batch_multiplier * time_multiplier
535 |
536 | if "portfolio" in tool_name.lower():
537 | total_cost *= 1.2
538 | elif "screening" in tool_name.lower():
539 | total_cost *= 1.1
540 | elif "real_time" in tool_name.lower():
541 | total_cost *= 1.3
542 |
543 | return max(1, int(total_cost))
544 |
545 |
546 | tool_cost_estimator = ToolCostEstimator()
547 |
548 |
549 | def estimate_tool_cost(
550 | tool_name: str,
551 | category: str = "analysis",
552 | complexity: str = "moderate",
553 | **kwargs: Any,
554 | ) -> int:
555 | """Convenience wrapper around :class:`ToolCostEstimator`."""
556 |
557 | return tool_cost_estimator.estimate_tool_cost(
558 | tool_name, category, complexity, kwargs
559 | )
560 |
```
--------------------------------------------------------------------------------
/tests/test_rate_limiting_enhanced.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Test suite for enhanced rate limiting middleware.
3 |
4 | Tests various rate limiting scenarios including:
5 | - Different user types (anonymous, authenticated, premium)
6 | - Different endpoint tiers
7 | - Multiple rate limiting strategies
8 | - Monitoring and alerting
9 | """
10 |
11 | import time
12 | from unittest.mock import AsyncMock, MagicMock, patch
13 |
14 | import pytest
15 | import redis.asyncio as redis
16 | from fastapi import FastAPI, Request
17 | from fastapi.testclient import TestClient
18 |
19 | from maverick_mcp.api.middleware.rate_limiting_enhanced import (
20 | EndpointClassification,
21 | EnhancedRateLimitMiddleware,
22 | RateLimitConfig,
23 | RateLimiter,
24 | RateLimitStrategy,
25 | RateLimitTier,
26 | rate_limit,
27 | )
28 | from maverick_mcp.exceptions import RateLimitError
29 |
30 |
31 | @pytest.fixture
32 | def rate_limit_config():
33 | """Create test rate limit configuration."""
34 | return RateLimitConfig(
35 | public_limit=100,
36 | auth_limit=5,
37 | data_limit=20,
38 | data_limit_anonymous=5,
39 | analysis_limit=10,
40 | analysis_limit_anonymous=2,
41 | bulk_limit_per_hour=5,
42 | admin_limit=10,
43 | premium_multiplier=5.0,
44 | enterprise_multiplier=10.0,
45 | default_strategy=RateLimitStrategy.SLIDING_WINDOW,
46 | burst_multiplier=1.5,
47 | window_size_seconds=60,
48 | token_refill_rate=1.0,
49 | max_tokens=10,
50 | log_violations=True,
51 | alert_threshold=3,
52 | )
53 |
54 |
55 | @pytest.fixture
56 | def rate_limiter(rate_limit_config):
57 | """Create rate limiter instance."""
58 | return RateLimiter(rate_limit_config)
59 |
60 |
61 | @pytest.fixture
62 | async def mock_redis():
63 | """Create mock Redis client."""
64 | mock = AsyncMock(spec=redis.Redis)
65 |
66 | # Mock pipeline
67 | mock_pipeline = AsyncMock()
68 | mock_pipeline.execute = AsyncMock(return_value=[None, 0, None, None])
69 | mock.pipeline = MagicMock(return_value=mock_pipeline)
70 |
71 | # Mock other methods
72 | mock.zrange = AsyncMock(return_value=[])
73 | mock.hgetall = AsyncMock(return_value={})
74 | mock.incr = AsyncMock(return_value=1)
75 |
76 | return mock
77 |
78 |
79 | @pytest.fixture
80 | def test_app():
81 | """Create test FastAPI app."""
82 | app = FastAPI()
83 |
84 | @app.get("/health")
85 | async def health():
86 | return {"status": "ok"}
87 |
88 | @app.post("/api/auth/login")
89 | async def login():
90 | return {"token": "test"}
91 |
92 | @app.get("/api/data/stock/{symbol}")
93 | async def get_stock(symbol: str):
94 | return {"symbol": symbol, "price": 100}
95 |
96 | @app.post("/api/screening/bulk")
97 | async def bulk_screening():
98 | return {"stocks": ["AAPL", "GOOGL", "MSFT"]}
99 |
100 | @app.get("/api/admin/users")
101 | async def admin_users():
102 | return {"users": []}
103 |
104 | return app
105 |
106 |
107 | class TestEndpointClassification:
108 | """Test endpoint classification."""
109 |
110 | def test_classify_public_endpoints(self):
111 | """Test classification of public endpoints."""
112 | assert (
113 | EndpointClassification.classify_endpoint("/health") == RateLimitTier.PUBLIC
114 | )
115 | assert (
116 | EndpointClassification.classify_endpoint("/api/docs")
117 | == RateLimitTier.PUBLIC
118 | )
119 | assert (
120 | EndpointClassification.classify_endpoint("/api/openapi.json")
121 | == RateLimitTier.PUBLIC
122 | )
123 |
124 | def test_classify_auth_endpoints(self):
125 | """Test classification of authentication endpoints."""
126 | assert (
127 | EndpointClassification.classify_endpoint("/api/auth/login")
128 | == RateLimitTier.AUTHENTICATION
129 | )
130 | assert (
131 | EndpointClassification.classify_endpoint("/api/auth/signup")
132 | == RateLimitTier.AUTHENTICATION
133 | )
134 | assert (
135 | EndpointClassification.classify_endpoint("/api/auth/refresh")
136 | == RateLimitTier.AUTHENTICATION
137 | )
138 |
139 | def test_classify_data_endpoints(self):
140 | """Test classification of data retrieval endpoints."""
141 | assert (
142 | EndpointClassification.classify_endpoint("/api/data/stock/AAPL")
143 | == RateLimitTier.DATA_RETRIEVAL
144 | )
145 | assert (
146 | EndpointClassification.classify_endpoint("/api/stock/quote")
147 | == RateLimitTier.DATA_RETRIEVAL
148 | )
149 | assert (
150 | EndpointClassification.classify_endpoint("/api/market/movers")
151 | == RateLimitTier.DATA_RETRIEVAL
152 | )
153 |
154 | def test_classify_analysis_endpoints(self):
155 | """Test classification of analysis endpoints."""
156 | assert (
157 | EndpointClassification.classify_endpoint("/api/technical/indicators")
158 | == RateLimitTier.ANALYSIS
159 | )
160 | assert (
161 | EndpointClassification.classify_endpoint("/api/screening/maverick")
162 | == RateLimitTier.ANALYSIS
163 | )
164 | assert (
165 | EndpointClassification.classify_endpoint("/api/portfolio/optimize")
166 | == RateLimitTier.ANALYSIS
167 | )
168 |
169 | def test_classify_bulk_endpoints(self):
170 | """Test classification of bulk operation endpoints."""
171 | assert (
172 | EndpointClassification.classify_endpoint("/api/screening/bulk")
173 | == RateLimitTier.BULK_OPERATION
174 | )
175 | assert (
176 | EndpointClassification.classify_endpoint("/api/data/bulk")
177 | == RateLimitTier.BULK_OPERATION
178 | )
179 | assert (
180 | EndpointClassification.classify_endpoint("/api/portfolio/batch")
181 | == RateLimitTier.BULK_OPERATION
182 | )
183 |
184 | def test_classify_admin_endpoints(self):
185 | """Test classification of administrative endpoints."""
186 | assert (
187 | EndpointClassification.classify_endpoint("/api/admin/users")
188 | == RateLimitTier.ADMINISTRATIVE
189 | )
190 | assert (
191 | EndpointClassification.classify_endpoint("/api/admin/system")
192 | == RateLimitTier.ADMINISTRATIVE
193 | )
194 | assert (
195 | EndpointClassification.classify_endpoint("/api/users/admin/delete")
196 | == RateLimitTier.ADMINISTRATIVE
197 | )
198 |
199 | def test_default_classification(self):
200 | """Test default classification for unknown endpoints."""
201 | assert (
202 | EndpointClassification.classify_endpoint("/api/unknown")
203 | == RateLimitTier.DATA_RETRIEVAL
204 | )
205 | assert (
206 | EndpointClassification.classify_endpoint("/random/path")
207 | == RateLimitTier.DATA_RETRIEVAL
208 | )
209 |
210 |
211 | class TestRateLimiter:
212 | """Test rate limiter core functionality."""
213 |
214 | @pytest.mark.asyncio
215 | async def test_sliding_window_allows_requests(self, rate_limiter, mock_redis):
216 | """Test sliding window allows requests within limit."""
217 | with patch(
218 | "maverick_mcp.data.performance.redis_manager.get_client",
219 | return_value=mock_redis,
220 | ):
221 | is_allowed, info = await rate_limiter.check_rate_limit(
222 | key="test_user",
223 | tier=RateLimitTier.DATA_RETRIEVAL,
224 | limit=10,
225 | window_seconds=60,
226 | strategy=RateLimitStrategy.SLIDING_WINDOW,
227 | )
228 |
229 | assert is_allowed is True
230 | assert info["limit"] == 10
231 | assert info["remaining"] == 9
232 | assert "burst_limit" in info
233 |
234 | @pytest.mark.asyncio
235 | async def test_sliding_window_blocks_excess(self, rate_limiter, mock_redis):
236 | """Test sliding window blocks requests over limit."""
237 | # Mock pipeline to return high count
238 | mock_pipeline = AsyncMock()
239 | mock_pipeline.execute = AsyncMock(return_value=[None, 15, None, None])
240 | mock_redis.pipeline = MagicMock(return_value=mock_pipeline)
241 |
242 | with patch(
243 | "maverick_mcp.data.performance.redis_manager.get_client",
244 | return_value=mock_redis,
245 | ):
246 | is_allowed, info = await rate_limiter.check_rate_limit(
247 | key="test_user",
248 | tier=RateLimitTier.DATA_RETRIEVAL,
249 | limit=10,
250 | window_seconds=60,
251 | strategy=RateLimitStrategy.SLIDING_WINDOW,
252 | )
253 |
254 | assert is_allowed is False
255 | assert info["remaining"] == 0
256 | assert info["retry_after"] > 0
257 |
258 | @pytest.mark.asyncio
259 | async def test_token_bucket_allows_requests(self, rate_limiter, mock_redis):
260 | """Test token bucket allows requests with tokens."""
261 | mock_redis.hgetall = AsyncMock(
262 | return_value={"tokens": "5.0", "last_refill": str(time.time())}
263 | )
264 |
265 | with patch(
266 | "maverick_mcp.data.performance.redis_manager.get_client",
267 | return_value=mock_redis,
268 | ):
269 | is_allowed, info = await rate_limiter.check_rate_limit(
270 | key="test_user",
271 | tier=RateLimitTier.DATA_RETRIEVAL,
272 | limit=10,
273 | window_seconds=60,
274 | strategy=RateLimitStrategy.TOKEN_BUCKET,
275 | )
276 |
277 | assert is_allowed is True
278 | assert "tokens" in info
279 | assert "refill_rate" in info
280 |
281 | @pytest.mark.asyncio
282 | async def test_token_bucket_blocks_no_tokens(self, rate_limiter, mock_redis):
283 | """Test token bucket blocks requests without tokens."""
284 | mock_redis.hgetall = AsyncMock(
285 | return_value={"tokens": "0.5", "last_refill": str(time.time())}
286 | )
287 |
288 | with patch(
289 | "maverick_mcp.data.performance.redis_manager.get_client",
290 | return_value=mock_redis,
291 | ):
292 | is_allowed, info = await rate_limiter.check_rate_limit(
293 | key="test_user",
294 | tier=RateLimitTier.DATA_RETRIEVAL,
295 | limit=10,
296 | window_seconds=60,
297 | strategy=RateLimitStrategy.TOKEN_BUCKET,
298 | )
299 |
300 | assert is_allowed is False
301 | assert info["retry_after"] > 0
302 |
303 | @pytest.mark.asyncio
304 | async def test_fixed_window_allows_requests(self, rate_limiter, mock_redis):
305 | """Test fixed window allows requests within limit."""
306 | mock_pipeline = AsyncMock()
307 | mock_pipeline.execute = AsyncMock(return_value=[5, None])
308 | mock_redis.pipeline = MagicMock(return_value=mock_pipeline)
309 |
310 | with patch(
311 | "maverick_mcp.data.performance.redis_manager.get_client",
312 | return_value=mock_redis,
313 | ):
314 | is_allowed, info = await rate_limiter.check_rate_limit(
315 | key="test_user",
316 | tier=RateLimitTier.DATA_RETRIEVAL,
317 | limit=10,
318 | window_seconds=60,
319 | strategy=RateLimitStrategy.FIXED_WINDOW,
320 | )
321 |
322 | assert is_allowed is True
323 | assert info["current_count"] == 5
324 |
325 | @pytest.mark.asyncio
326 | async def test_local_fallback_rate_limiting(self, rate_limiter):
327 | """Test local rate limiting when Redis unavailable."""
328 | with patch(
329 | "maverick_mcp.data.performance.redis_manager.get_client", return_value=None
330 | ):
331 | # First few requests should pass
332 | for _i in range(5):
333 | is_allowed, info = await rate_limiter.check_rate_limit(
334 | key="test_user",
335 | tier=RateLimitTier.DATA_RETRIEVAL,
336 | limit=5,
337 | window_seconds=60,
338 | )
339 | assert is_allowed is True
340 | assert info["fallback"] is True
341 |
342 | # Next request should be blocked
343 | is_allowed, info = await rate_limiter.check_rate_limit(
344 | key="test_user",
345 | tier=RateLimitTier.DATA_RETRIEVAL,
346 | limit=5,
347 | window_seconds=60,
348 | )
349 | assert is_allowed is False
350 |
351 | def test_violation_recording(self, rate_limiter):
352 | """Test violation count recording."""
353 | tier = RateLimitTier.DATA_RETRIEVAL
354 | assert rate_limiter.get_violation_count("user1", tier=tier) == 0
355 |
356 | rate_limiter.record_violation("user1", tier=tier)
357 | assert rate_limiter.get_violation_count("user1", tier=tier) == 1
358 |
359 | rate_limiter.record_violation("user1", tier=tier)
360 | assert rate_limiter.get_violation_count("user1", tier=tier) == 2
361 |
362 | # Different tiers maintain independent counters
363 | other_tier = RateLimitTier.ANALYSIS
364 | assert rate_limiter.get_violation_count("user1", tier=other_tier) == 0
365 |
366 |
367 | class TestEnhancedRateLimitMiddleware:
368 | """Test enhanced rate limit middleware integration."""
369 |
370 | @pytest.fixture
371 | def middleware_app(self, test_app, rate_limit_config):
372 | """Create app with rate limit middleware."""
373 | test_app.add_middleware(EnhancedRateLimitMiddleware, config=rate_limit_config)
374 | return test_app
375 |
376 | @pytest.fixture
377 | def client(self, middleware_app):
378 | """Create test client."""
379 | return TestClient(middleware_app)
380 |
381 | def test_bypass_health_check(self, client):
382 | """Test health check endpoint bypasses rate limiting."""
383 | # Should always succeed
384 | for _ in range(10):
385 | response = client.get("/health")
386 | assert response.status_code == 200
387 | assert "X-RateLimit-Limit" not in response.headers
388 |
389 | @patch("maverick_mcp.data.performance.redis_manager.get_client")
390 | def test_anonymous_rate_limiting(self, mock_get_client, client, mock_redis):
391 | """Test rate limiting for anonymous users."""
392 | mock_get_client.return_value = mock_redis
393 |
394 | # Configure mock to allow first 5 requests
395 | call_count = 0
396 |
397 | def mock_execute():
398 | nonlocal call_count
399 | call_count += 1
400 | if call_count <= 5:
401 | return [None, call_count - 1, None, None]
402 | else:
403 | return [None, 10, None, None] # Over limit
404 |
405 | mock_pipeline = AsyncMock()
406 | mock_pipeline.execute = AsyncMock(side_effect=mock_execute)
407 | mock_redis.pipeline = MagicMock(return_value=mock_pipeline)
408 | mock_redis.zrange = AsyncMock(return_value=[(b"1", time.time())])
409 |
410 | # First 5 requests should succeed
411 | for _i in range(5):
412 | response = client.get("/api/data/stock/AAPL")
413 | assert response.status_code == 200
414 | assert "X-RateLimit-Limit" in response.headers
415 | assert "X-RateLimit-Remaining" in response.headers
416 |
417 | # 6th request should be rate limited
418 | response = client.get("/api/data/stock/AAPL")
419 | assert response.status_code == 429
420 | assert "Rate limit exceeded" in response.json()["error"]
421 | assert "Retry-After" in response.headers
422 |
423 | def test_authenticated_user_headers(self, client):
424 | """Test authenticated users get proper headers."""
425 | # Mock authenticated request
426 | request = MagicMock(spec=Request)
427 | request.state.user_id = "123"
428 | request.state.user_context = {"role": "user"}
429 |
430 | # Headers should be added to response
431 | # This would be tested in integration tests with actual auth
432 |
433 | def test_premium_user_multiplier(self, client):
434 | """Test premium users get higher limits."""
435 | # Mock premium user request
436 | request = MagicMock(spec=Request)
437 | request.state.user_id = "123"
438 | request.state.user_context = {"role": "premium"}
439 |
440 | # Premium users should have 5x the limit
441 | # This would be tested in integration tests
442 |
443 | def test_endpoint_tier_headers(self, client):
444 | """Test different endpoints return tier information."""
445 | with patch(
446 | "maverick_mcp.data.performance.redis_manager.get_client", return_value=None
447 | ):
448 | # Test auth endpoint
449 | response = client.post("/api/auth/login")
450 | if "X-RateLimit-Tier" in response.headers:
451 | assert response.headers["X-RateLimit-Tier"] == "authentication"
452 |
453 | # Test data endpoint
454 | response = client.get("/api/data/stock/AAPL")
455 | if "X-RateLimit-Tier" in response.headers:
456 | assert response.headers["X-RateLimit-Tier"] == "data_retrieval"
457 |
458 | # Test bulk endpoint
459 | response = client.post("/api/screening/bulk")
460 | if "X-RateLimit-Tier" in response.headers:
461 | assert response.headers["X-RateLimit-Tier"] == "bulk_operation"
462 |
463 |
464 | class TestRateLimitDecorator:
465 | """Test function-level rate limiting decorator."""
466 |
467 | @pytest.mark.asyncio
468 | async def test_decorator_allows_requests(self):
469 | """Test decorator allows requests within limit."""
470 | call_count = 0
471 |
472 | @rate_limit(requests_per_minute=5)
473 | async def test_function(request: Request):
474 | nonlocal call_count
475 | call_count += 1
476 | return {"count": call_count}
477 |
478 | # Mock request
479 | request = MagicMock(spec=Request)
480 | request.state.user_id = "test_user"
481 |
482 | with patch(
483 | "maverick_mcp.data.performance.redis_manager.get_client", return_value=None
484 | ):
485 | # Should allow first few calls
486 | for i in range(5):
487 | result = await test_function(request)
488 | assert result["count"] == i + 1
489 |
490 | @pytest.mark.asyncio
491 | async def test_decorator_blocks_excess(self):
492 | """Test decorator blocks excessive requests."""
493 |
494 | @rate_limit(requests_per_minute=2)
495 | async def test_function(request: Request):
496 | return {"success": True}
497 |
498 | # Mock request with proper attributes for rate limiting
499 | request = MagicMock()
500 | request.state = MagicMock()
501 | request.state.user_id = "test_user"
502 | request.url = MagicMock() # Required for rate limiting detection
503 |
504 | with patch(
505 | "maverick_mcp.data.performance.redis_manager.get_client", return_value=None
506 | ):
507 | # First 2 should succeed
508 | await test_function(request)
509 | await test_function(request)
510 |
511 | # 3rd should raise exception
512 | with pytest.raises(RateLimitError) as exc_info:
513 | await test_function(request)
514 |
515 | assert "Rate limit exceeded" in str(exc_info.value)
516 |
517 | @pytest.mark.asyncio
518 | async def test_decorator_without_request(self):
519 | """Test decorator works without request object."""
520 |
521 | @rate_limit(requests_per_minute=5)
522 | async def test_function(value: int):
523 | return value * 2
524 |
525 | # Should work without rate limiting
526 | result = await test_function(5)
527 | assert result == 10
528 |
529 |
530 | class TestMonitoringIntegration:
531 | """Test monitoring and alerting integration."""
532 |
533 | @pytest.mark.asyncio
534 | async def test_violation_monitoring(self, rate_limiter, rate_limit_config):
535 | """Test violations are recorded for monitoring."""
536 | # Record multiple violations
537 | for _i in range(rate_limit_config.alert_threshold + 1):
538 | rate_limiter.record_violation("bad_user", tier=RateLimitTier.DATA_RETRIEVAL)
539 |
540 | # Check violation count
541 | assert (
542 | rate_limiter.get_violation_count(
543 | "bad_user", tier=RateLimitTier.DATA_RETRIEVAL
544 | )
545 | > rate_limit_config.alert_threshold
546 | )
547 |
548 | @pytest.mark.asyncio
549 | async def test_cleanup_task(self, rate_limiter, mock_redis):
550 | """Test periodic cleanup of old data."""
551 | mock_redis.scan = AsyncMock(
552 | return_value=(
553 | 0,
554 | [
555 | "rate_limit:sw:test1",
556 | "rate_limit:sw:test2",
557 | ],
558 | )
559 | )
560 | mock_redis.type = AsyncMock(return_value="zset")
561 | mock_redis.zremrangebyscore = AsyncMock()
562 | mock_redis.zcard = AsyncMock(return_value=0)
563 | mock_redis.delete = AsyncMock()
564 |
565 | with patch(
566 | "maverick_mcp.data.performance.redis_manager.get_client",
567 | return_value=mock_redis,
568 | ):
569 | await rate_limiter.cleanup_old_data(older_than_hours=1)
570 |
571 | # Should have called delete for empty keys
572 | assert mock_redis.delete.called
573 |
```
--------------------------------------------------------------------------------
/tests/test_security_headers.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive Security Headers Tests for Maverick MCP.
3 |
4 | Tests security headers configuration, middleware implementation,
5 | environment-specific headers, and CSP/HSTS policies.
6 | """
7 |
8 | import os
9 | from unittest.mock import MagicMock, patch
10 |
11 | import pytest
12 | from fastapi import FastAPI
13 | from fastapi.testclient import TestClient
14 |
15 | from maverick_mcp.api.middleware.security import (
16 | SecurityHeadersMiddleware as APISecurityHeadersMiddleware,
17 | )
18 | from maverick_mcp.config.security import (
19 | SecurityConfig,
20 | SecurityHeadersConfig,
21 | )
22 | from maverick_mcp.config.security_utils import (
23 | SecurityHeadersMiddleware,
24 | apply_security_headers_to_fastapi,
25 | )
26 |
27 |
28 | class TestSecurityHeadersConfig:
29 | """Test security headers configuration."""
30 |
31 | def test_security_headers_default_values(self):
32 | """Test security headers have secure default values."""
33 | config = SecurityHeadersConfig()
34 |
35 | assert config.x_content_type_options == "nosniff"
36 | assert config.x_frame_options == "DENY"
37 | assert config.x_xss_protection == "1; mode=block"
38 | assert config.referrer_policy == "strict-origin-when-cross-origin"
39 | assert "geolocation=()" in config.permissions_policy
40 |
41 | def test_hsts_header_generation(self):
42 | """Test HSTS header value generation."""
43 | config = SecurityHeadersConfig()
44 |
45 | hsts_header = config.hsts_header_value
46 |
47 | assert f"max-age={config.hsts_max_age}" in hsts_header
48 | assert "includeSubDomains" in hsts_header
49 | assert "preload" not in hsts_header # Default is False
50 |
51 | def test_hsts_header_with_preload(self):
52 | """Test HSTS header with preload enabled."""
53 | config = SecurityHeadersConfig(hsts_preload=True)
54 |
55 | hsts_header = config.hsts_header_value
56 |
57 | assert "preload" in hsts_header
58 |
59 | def test_hsts_header_without_subdomains(self):
60 | """Test HSTS header without subdomains."""
61 | config = SecurityHeadersConfig(hsts_include_subdomains=False)
62 |
63 | hsts_header = config.hsts_header_value
64 |
65 | assert "includeSubDomains" not in hsts_header
66 |
67 | def test_csp_header_generation(self):
68 | """Test CSP header value generation."""
69 | config = SecurityHeadersConfig()
70 |
71 | csp_header = config.csp_header_value
72 |
73 | # Check required directives
74 | assert "default-src 'self'" in csp_header
75 | assert "script-src 'self' 'unsafe-inline'" in csp_header
76 | assert "style-src 'self' 'unsafe-inline'" in csp_header
77 | assert "object-src 'none'" in csp_header
78 | assert "connect-src 'self'" in csp_header
79 | assert "frame-src 'none'" in csp_header
80 | assert "base-uri 'self'" in csp_header
81 | assert "form-action 'self'" in csp_header
82 |
83 | def test_csp_custom_directives(self):
84 | """Test CSP with custom directives."""
85 | config = SecurityHeadersConfig(
86 | csp_script_src=["'self'", "https://trusted.com"],
87 | csp_connect_src=["'self'", "https://api.trusted.com"],
88 | )
89 |
90 | csp_header = config.csp_header_value
91 |
92 | assert "script-src 'self' https://trusted.com" in csp_header
93 | assert "connect-src 'self' https://api.trusted.com" in csp_header
94 |
95 | def test_permissions_policy_default(self):
96 | """Test permissions policy default configuration."""
97 | config = SecurityHeadersConfig()
98 |
99 | permissions = config.permissions_policy
100 |
101 | assert "geolocation=()" in permissions
102 | assert "microphone=()" in permissions
103 | assert "camera=()" in permissions
104 | assert "usb=()" in permissions
105 | assert "magnetometer=()" in permissions
106 |
107 |
108 | class TestSecurityHeadersMiddleware:
109 | """Test security headers middleware implementation."""
110 |
111 | def test_middleware_adds_headers(self):
112 | """Test that middleware adds security headers to responses."""
113 | app = FastAPI()
114 |
115 | # Create mock security config
116 | mock_config = MagicMock()
117 | mock_config.get_security_headers.return_value = {
118 | "X-Content-Type-Options": "nosniff",
119 | "X-Frame-Options": "DENY",
120 | "X-XSS-Protection": "1; mode=block",
121 | "Content-Security-Policy": "default-src 'self'",
122 | }
123 |
124 | app.add_middleware(SecurityHeadersMiddleware, security_config=mock_config)
125 |
126 | @app.get("/test")
127 | async def test_endpoint():
128 | return {"message": "test"}
129 |
130 | client = TestClient(app)
131 | response = client.get("/test")
132 |
133 | assert response.headers["X-Content-Type-Options"] == "nosniff"
134 | assert response.headers["X-Frame-Options"] == "DENY"
135 | assert response.headers["X-XSS-Protection"] == "1; mode=block"
136 | assert response.headers["Content-Security-Policy"] == "default-src 'self'"
137 |
138 | def test_middleware_uses_default_config(self):
139 | """Test that middleware uses default security config when none provided."""
140 | app = FastAPI()
141 |
142 | with patch(
143 | "maverick_mcp.config.security_utils.get_security_config"
144 | ) as mock_get_config:
145 | mock_config = MagicMock()
146 | mock_config.get_security_headers.return_value = {"X-Frame-Options": "DENY"}
147 | mock_get_config.return_value = mock_config
148 |
149 | app.add_middleware(SecurityHeadersMiddleware)
150 |
151 | @app.get("/test")
152 | async def test_endpoint():
153 | return {"message": "test"}
154 |
155 | client = TestClient(app)
156 | response = client.get("/test")
157 |
158 | mock_get_config.assert_called_once()
159 | assert response.headers["X-Frame-Options"] == "DENY"
160 |
161 | def test_api_middleware_integration(self):
162 | """Test API security headers middleware integration."""
163 | app = FastAPI()
164 | app.add_middleware(APISecurityHeadersMiddleware)
165 |
166 | @app.get("/test")
167 | async def test_endpoint():
168 | return {"message": "test"}
169 |
170 | client = TestClient(app)
171 | response = client.get("/test")
172 |
173 | # Should have basic security headers
174 | assert "X-Content-Type-Options" in response.headers
175 | assert "X-Frame-Options" in response.headers
176 |
177 |
178 | class TestEnvironmentSpecificHeaders:
179 | """Test environment-specific security headers."""
180 |
181 | def test_hsts_in_production(self):
182 | """Test HSTS header is included in production."""
183 | with patch.dict(os.environ, {"ENVIRONMENT": "production"}, clear=False):
184 | config = SecurityConfig()
185 | headers = config.get_security_headers()
186 |
187 | assert "Strict-Transport-Security" in headers
188 | assert "max-age=" in headers["Strict-Transport-Security"]
189 |
190 | def test_hsts_in_development(self):
191 | """Test HSTS header is not included in development."""
192 | with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=False):
193 | config = SecurityConfig(force_https=False)
194 | headers = config.get_security_headers()
195 |
196 | assert "Strict-Transport-Security" not in headers
197 |
198 | def test_hsts_with_force_https(self):
199 | """Test HSTS header is included when HTTPS is forced."""
200 | with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=False):
201 | config = SecurityConfig(force_https=True)
202 | headers = config.get_security_headers()
203 |
204 | assert "Strict-Transport-Security" in headers
205 |
206 | def test_production_security_validation(self):
207 | """Test production security validation."""
208 | with patch.dict(os.environ, {"ENVIRONMENT": "production"}, clear=False):
209 | with patch(
210 | "maverick_mcp.config.security._get_cors_origins"
211 | ) as mock_origins:
212 | mock_origins.return_value = ["https://app.maverick-mcp.com"]
213 |
214 | with patch("logging.getLogger") as mock_logger:
215 | mock_logger_instance = MagicMock()
216 | mock_logger.return_value = mock_logger_instance
217 |
218 | # Test with HTTPS not forced (should warn)
219 | SecurityConfig(force_https=False)
220 |
221 | # Should log warning about HTTPS
222 | mock_logger_instance.warning.assert_called()
223 |
224 | def test_development_security_permissive(self):
225 | """Test development security is more permissive."""
226 | with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=False):
227 | config = SecurityConfig()
228 |
229 | assert config.is_development() is True
230 | assert config.is_production() is False
231 |
232 |
233 | class TestCSPConfiguration:
234 | """Test Content Security Policy configuration."""
235 |
236 | def test_csp_avoids_checkout_domains(self):
237 | """Test CSP excludes third-party checkout provider domains."""
238 | config = SecurityHeadersConfig()
239 |
240 | assert config.csp_script_src == ["'self'", "'unsafe-inline'"]
241 | assert config.csp_connect_src == ["'self'"]
242 | assert config.csp_frame_src == ["'none'"]
243 |
244 | def test_csp_blocks_inline_scripts_by_default(self):
245 | """Test CSP configuration for inline scripts."""
246 | config = SecurityHeadersConfig()
247 | csp = config.csp_header_value
248 |
249 | # Note: Current config allows 'unsafe-inline' for compatibility
250 | # In a more secure setup, this should use nonces or hashes
251 | assert "'unsafe-inline'" in csp
252 |
253 | def test_csp_blocks_object_embedding(self):
254 | """Test CSP blocks object embedding."""
255 | config = SecurityHeadersConfig()
256 | csp = config.csp_header_value
257 |
258 | assert "object-src 'none'" in csp
259 |
260 | def test_csp_restricts_base_uri(self):
261 | """Test CSP restricts base URI."""
262 | config = SecurityHeadersConfig()
263 | csp = config.csp_header_value
264 |
265 | assert "base-uri 'self'" in csp
266 |
267 | def test_csp_restricts_form_action(self):
268 | """Test CSP restricts form actions."""
269 | config = SecurityHeadersConfig()
270 | csp = config.csp_header_value
271 |
272 | assert "form-action 'self'" in csp
273 |
274 | def test_csp_image_sources(self):
275 | """Test CSP allows necessary image sources."""
276 | config = SecurityHeadersConfig()
277 | csp = config.csp_header_value
278 |
279 | assert "img-src 'self' data: https:" in csp
280 |
281 | def test_csp_custom_configuration(self):
282 | """Test CSP with custom configuration."""
283 | custom_config = SecurityHeadersConfig(
284 | csp_default_src=["'self'", "https://trusted.com"],
285 | csp_script_src=["'self'"],
286 | csp_style_src=["'self'"], # Remove unsafe-inline from styles too
287 | csp_object_src=["'none'"],
288 | )
289 |
290 | csp = custom_config.csp_header_value
291 |
292 | assert "default-src 'self' https://trusted.com" in csp
293 | assert "script-src 'self'" in csp
294 | # Since we removed unsafe-inline from style-src, it shouldn't be in CSP
295 | assert "style-src 'self'" in csp
296 | assert "'unsafe-inline'" not in csp
297 |
298 |
299 | class TestXFrameOptionsConfiguration:
300 | """Test X-Frame-Options configuration."""
301 |
302 | def test_frame_options_deny_default(self):
303 | """Test X-Frame-Options defaults to DENY."""
304 | SecurityHeadersConfig()
305 | headers = SecurityConfig().get_security_headers()
306 |
307 | assert headers["X-Frame-Options"] == "DENY"
308 |
309 | def test_frame_options_sameorigin(self):
310 | """Test X-Frame-Options can be set to SAMEORIGIN."""
311 | config = SecurityHeadersConfig(x_frame_options="SAMEORIGIN")
312 | security_config = SecurityConfig(headers=config)
313 | headers = security_config.get_security_headers()
314 |
315 | assert headers["X-Frame-Options"] == "SAMEORIGIN"
316 |
317 | def test_frame_options_allow_from(self):
318 | """Test X-Frame-Options with ALLOW-FROM directive."""
319 | config = SecurityHeadersConfig(x_frame_options="ALLOW-FROM https://trusted.com")
320 | security_config = SecurityConfig(headers=config)
321 | headers = security_config.get_security_headers()
322 |
323 | assert headers["X-Frame-Options"] == "ALLOW-FROM https://trusted.com"
324 |
325 |
326 | class TestReferrerPolicyConfiguration:
327 | """Test Referrer-Policy configuration."""
328 |
329 | def test_referrer_policy_default(self):
330 | """Test Referrer-Policy default value."""
331 | SecurityHeadersConfig()
332 | headers = SecurityConfig().get_security_headers()
333 |
334 | assert headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
335 |
336 | def test_referrer_policy_custom(self):
337 | """Test custom Referrer-Policy."""
338 | config = SecurityHeadersConfig(referrer_policy="no-referrer")
339 | security_config = SecurityConfig(headers=config)
340 | headers = security_config.get_security_headers()
341 |
342 | assert headers["Referrer-Policy"] == "no-referrer"
343 |
344 |
345 | class TestPermissionsPolicyConfiguration:
346 | """Test Permissions-Policy configuration."""
347 |
348 | def test_permissions_policy_blocks_dangerous_features(self):
349 | """Test Permissions-Policy blocks dangerous browser features."""
350 | SecurityHeadersConfig()
351 | headers = SecurityConfig().get_security_headers()
352 |
353 | permissions = headers["Permissions-Policy"]
354 |
355 | assert "geolocation=()" in permissions
356 | assert "microphone=()" in permissions
357 | assert "camera=()" in permissions
358 | assert "usb=()" in permissions
359 |
360 | def test_permissions_policy_custom(self):
361 | """Test custom Permissions-Policy configuration."""
362 | custom_policy = "geolocation=(self), camera=(), microphone=()"
363 | config = SecurityHeadersConfig(permissions_policy=custom_policy)
364 | security_config = SecurityConfig(headers=config)
365 | headers = security_config.get_security_headers()
366 |
367 | assert headers["Permissions-Policy"] == custom_policy
368 |
369 |
370 | class TestSecurityHeadersIntegration:
371 | """Test security headers integration with application."""
372 |
373 | def test_all_headers_applied(self):
374 | """Test that all security headers are applied to responses."""
375 | app = FastAPI()
376 | apply_security_headers_to_fastapi(app)
377 |
378 | @app.get("/test")
379 | async def test_endpoint():
380 | return {"message": "test"}
381 |
382 | client = TestClient(app)
383 | response = client.get("/test")
384 |
385 | # Check all expected headers are present
386 | expected_headers = [
387 | "X-Content-Type-Options",
388 | "X-Frame-Options",
389 | "X-XSS-Protection",
390 | "Referrer-Policy",
391 | "Permissions-Policy",
392 | "Content-Security-Policy",
393 | ]
394 |
395 | for header in expected_headers:
396 | assert header in response.headers
397 |
398 | def test_headers_on_error_responses(self):
399 | """Test security headers are included on error responses."""
400 | app = FastAPI()
401 | apply_security_headers_to_fastapi(app)
402 |
403 | @app.get("/error")
404 | async def error_endpoint():
405 | from fastapi import HTTPException
406 |
407 | raise HTTPException(status_code=500, detail="Test error")
408 |
409 | client = TestClient(app)
410 | response = client.get("/error")
411 |
412 | # Even on errors, security headers should be present
413 | assert response.status_code == 500
414 | assert "X-Frame-Options" in response.headers
415 | assert "X-Content-Type-Options" in response.headers
416 |
417 | def test_headers_on_different_methods(self):
418 | """Test security headers on different HTTP methods."""
419 | app = FastAPI()
420 | apply_security_headers_to_fastapi(app)
421 |
422 | @app.get("/test")
423 | async def get_endpoint():
424 | return {"method": "GET"}
425 |
426 | @app.post("/test")
427 | async def post_endpoint():
428 | return {"method": "POST"}
429 |
430 | @app.put("/test")
431 | async def put_endpoint():
432 | return {"method": "PUT"}
433 |
434 | client = TestClient(app)
435 |
436 | methods = [(client.get, "/test"), (client.post, "/test"), (client.put, "/test")]
437 |
438 | for method_func, path in methods:
439 | response = method_func(path)
440 | assert "X-Frame-Options" in response.headers
441 | assert "Content-Security-Policy" in response.headers
442 |
443 | def test_headers_override_existing(self):
444 | """Test security headers override any existing headers."""
445 | app = FastAPI()
446 | apply_security_headers_to_fastapi(app)
447 |
448 | @app.get("/test")
449 | async def test_endpoint():
450 | from fastapi import Response
451 |
452 | response = Response(content='{"message": "test"}')
453 | response.headers["X-Frame-Options"] = "ALLOWALL" # Insecure value
454 | return response
455 |
456 | client = TestClient(app)
457 | response = client.get("/test")
458 |
459 | # Security middleware should override the insecure value
460 | assert response.headers["X-Frame-Options"] == "DENY"
461 |
462 |
463 | class TestSecurityHeadersValidation:
464 | """Test security headers validation and best practices."""
465 |
466 | def test_no_server_header_disclosure(self):
467 | """Test that server information is not disclosed."""
468 | app = FastAPI()
469 | apply_security_headers_to_fastapi(app)
470 |
471 | @app.get("/test")
472 | async def test_endpoint():
473 | return {"message": "test"}
474 |
475 | client = TestClient(app)
476 | response = client.get("/test")
477 |
478 | # Should not disclose server information
479 | server_header = response.headers.get("Server", "")
480 | assert "uvicorn" not in server_header.lower()
481 |
482 | def test_no_powered_by_header(self):
483 | """Test that X-Powered-By header is not present."""
484 | app = FastAPI()
485 | apply_security_headers_to_fastapi(app)
486 |
487 | @app.get("/test")
488 | async def test_endpoint():
489 | return {"message": "test"}
490 |
491 | client = TestClient(app)
492 | response = client.get("/test")
493 |
494 | assert "X-Powered-By" not in response.headers
495 |
496 | def test_content_type_nosniff(self):
497 | """Test X-Content-Type-Options prevents MIME sniffing."""
498 | app = FastAPI()
499 | apply_security_headers_to_fastapi(app)
500 |
501 | @app.get("/test")
502 | async def test_endpoint():
503 | return {"message": "test"}
504 |
505 | client = TestClient(app)
506 | response = client.get("/test")
507 |
508 | assert response.headers["X-Content-Type-Options"] == "nosniff"
509 |
510 | def test_xss_protection_enabled(self):
511 | """Test X-XSS-Protection is properly configured."""
512 | app = FastAPI()
513 | apply_security_headers_to_fastapi(app)
514 |
515 | @app.get("/test")
516 | async def test_endpoint():
517 | return {"message": "test"}
518 |
519 | client = TestClient(app)
520 | response = client.get("/test")
521 |
522 | xss_protection = response.headers["X-XSS-Protection"]
523 | assert "1" in xss_protection
524 | assert "mode=block" in xss_protection
525 |
526 |
527 | class TestSecurityHeadersPerformance:
528 | """Test security headers don't impact performance significantly."""
529 |
530 | def test_headers_middleware_performance(self):
531 | """Test security headers middleware performance."""
532 | app = FastAPI()
533 | apply_security_headers_to_fastapi(app)
534 |
535 | @app.get("/test")
536 | async def test_endpoint():
537 | return {"message": "test"}
538 |
539 | client = TestClient(app)
540 |
541 | # Make multiple requests to test performance
542 | import time
543 |
544 | start_time = time.time()
545 |
546 | for _ in range(100):
547 | response = client.get("/test")
548 | assert response.status_code == 200
549 |
550 | end_time = time.time()
551 | total_time = end_time - start_time
552 |
553 | # Should complete 100 requests quickly (less than 5 seconds)
554 | assert total_time < 5.0
555 |
556 | def test_headers_memory_usage(self):
557 | """Test security headers don't cause memory leaks."""
558 | app = FastAPI()
559 | apply_security_headers_to_fastapi(app)
560 |
561 | @app.get("/test")
562 | async def test_endpoint():
563 | return {"message": "test"}
564 |
565 | client = TestClient(app)
566 |
567 | # Make many requests to check for memory leaks
568 | for _ in range(1000):
569 | response = client.get("/test")
570 | assert "X-Frame-Options" in response.headers
571 |
572 | # If we reach here without memory issues, test passes
573 |
574 |
575 | if __name__ == "__main__":
576 | pytest.main([__file__, "-v"])
577 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/parallel_research.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Parallel Research Execution Utilities
3 |
4 | This module provides infrastructure for spawning and managing parallel research
5 | subagents for comprehensive financial analysis.
6 | """
7 |
8 | import asyncio
9 | import logging
10 | import time
11 | from collections.abc import Callable
12 | from typing import Any
13 |
14 | from ..agents.circuit_breaker import circuit_breaker
15 | from ..config.settings import get_settings
16 | from .orchestration_logging import (
17 | get_orchestration_logger,
18 | log_agent_execution,
19 | log_method_call,
20 | log_parallel_execution,
21 | log_performance_metrics,
22 | log_resource_usage,
23 | )
24 |
25 | logger = logging.getLogger(__name__)
26 | settings = get_settings()
27 |
28 |
29 | class ParallelResearchConfig:
30 | """Configuration for parallel research operations."""
31 |
32 | def __init__(
33 | self,
34 | max_concurrent_agents: int = 6, # OPTIMIZATION: Increased from 4 for better parallelism
35 | timeout_per_agent: int = 60, # OPTIMIZATION: Reduced from 180s to prevent blocking
36 | enable_fallbacks: bool = False, # Disabled by default for speed
37 | rate_limit_delay: float = 0.05, # OPTIMIZATION: Minimal delay (50ms) for API rate limits only
38 | batch_size: int = 3, # OPTIMIZATION: Batch size for task grouping
39 | use_worker_pool: bool = True, # OPTIMIZATION: Enable worker pool pattern
40 | ):
41 | self.max_concurrent_agents = max_concurrent_agents
42 | self.timeout_per_agent = timeout_per_agent
43 | self.enable_fallbacks = enable_fallbacks
44 | self.rate_limit_delay = rate_limit_delay
45 | self.batch_size = batch_size
46 | self.use_worker_pool = use_worker_pool
47 |
48 |
49 | class ResearchTask:
50 | """Represents a single research task for parallel execution."""
51 |
52 | def __init__(
53 | self,
54 | task_id: str,
55 | task_type: str,
56 | target_topic: str,
57 | focus_areas: list[str],
58 | priority: int = 1,
59 | timeout: int | None = None,
60 | ):
61 | self.task_id = task_id
62 | self.task_type = task_type # fundamental, technical, sentiment, competitive
63 | self.target_topic = target_topic
64 | self.focus_areas = focus_areas
65 | self.priority = priority
66 | self.timeout = timeout
67 | self.start_time: float | None = None
68 | self.end_time: float | None = None
69 | self.status: str = "pending" # pending, running, completed, failed
70 | self.result: dict[str, Any] | None = None
71 | self.error: str | None = None
72 |
73 |
74 | class ResearchResult:
75 | """Aggregated results from parallel research execution."""
76 |
77 | def __init__(self):
78 | self.task_results: dict[str, ResearchTask] = {}
79 | self.synthesis: dict[str, Any] | None = None
80 | self.total_execution_time: float = 0.0
81 | self.successful_tasks: int = 0
82 | self.failed_tasks: int = 0
83 | self.parallel_efficiency: float = 0.0
84 |
85 |
86 | class ParallelResearchOrchestrator:
87 | """Orchestrates parallel research agent execution."""
88 |
89 | def __init__(self, config: ParallelResearchConfig | None = None):
90 | self.config = config or ParallelResearchConfig()
91 | self.active_tasks: dict[str, ResearchTask] = {}
92 | # OPTIMIZATION: Use bounded semaphore for better control
93 | self._semaphore = asyncio.BoundedSemaphore(self.config.max_concurrent_agents)
94 | self.orchestration_logger = get_orchestration_logger("ParallelOrchestrator")
95 | # Track active workers for better coordination
96 | self._active_workers = 0
97 | self._worker_lock = asyncio.Lock()
98 |
99 | # Log initialization
100 | self.orchestration_logger.info(
101 | "🎛️ ORCHESTRATOR_INIT",
102 | max_agents=self.config.max_concurrent_agents,
103 | )
104 |
105 | @log_method_call(component="ParallelOrchestrator", include_timing=True)
106 | async def execute_parallel_research(
107 | self,
108 | tasks: list[ResearchTask],
109 | research_executor,
110 | synthesis_callback: Callable[..., Any] | None = None,
111 | ) -> ResearchResult:
112 | """
113 | Execute multiple research tasks in parallel with intelligent coordination.
114 |
115 | Args:
116 | tasks: List of research tasks to execute
117 | research_executor: Function to execute individual research tasks
118 | synthesis_callback: Optional function to synthesize results
119 |
120 | Returns:
121 | ResearchResult with aggregated results and synthesis
122 | """
123 | self.orchestration_logger.set_request_context(
124 | session_id=tasks[0].task_id.split("_")[0] if tasks else "unknown",
125 | task_count=len(tasks),
126 | )
127 |
128 | # Log task overview
129 | self.orchestration_logger.info(
130 | "📋 TASK_OVERVIEW",
131 | task_count=len(tasks),
132 | max_concurrent=self.config.max_concurrent_agents,
133 | )
134 |
135 | start_time = time.time()
136 |
137 | # Create result container
138 | result = ResearchResult()
139 |
140 | with log_parallel_execution(
141 | "ParallelOrchestrator", "research execution", len(tasks)
142 | ) as exec_logger:
143 | try:
144 | # Prepare tasks for execution
145 | prepared_tasks = await self._prepare_tasks(tasks)
146 | exec_logger.info(
147 | "🔧 TASKS_PREPARED", prepared_count=len(prepared_tasks)
148 | )
149 |
150 | # OPTIMIZATION: Use create_task for true parallel execution
151 | # This allows tasks to start immediately without waiting
152 | exec_logger.info("🚀 PARALLEL_EXECUTION_START")
153 |
154 | # Create all tasks immediately for maximum parallelism
155 | running_tasks = []
156 | for task in prepared_tasks:
157 | # Create task immediately without awaiting
158 | task_future = asyncio.create_task(
159 | self._execute_single_task(task, research_executor)
160 | )
161 | running_tasks.append(task_future)
162 |
163 | # OPTIMIZATION: Minimal delay only if absolutely needed for API rate limits
164 | # Reduced from progressive delays to fixed minimal delay
165 | if self.config.rate_limit_delay > 0 and len(running_tasks) < len(
166 | prepared_tasks
167 | ):
168 | await asyncio.sleep(
169 | self.config.rate_limit_delay * 0.1
170 | ) # 10% of original delay
171 |
172 | # Wait for all tasks to complete using asyncio.as_completed for better responsiveness
173 | completed_tasks = []
174 | for task_future in asyncio.as_completed(running_tasks):
175 | try:
176 | result_task = await task_future
177 | completed_tasks.append(result_task)
178 | except Exception as e:
179 | # Handle exceptions without blocking other tasks
180 | completed_tasks.append(e)
181 |
182 | exec_logger.info("🏁 PARALLEL_EXECUTION_COMPLETE")
183 |
184 | # Process results
185 | result = await self._process_task_results(
186 | prepared_tasks, completed_tasks, start_time
187 | )
188 |
189 | # Log performance metrics
190 | log_performance_metrics(
191 | "ParallelOrchestrator",
192 | {
193 | "total_tasks": len(tasks),
194 | "successful_tasks": result.successful_tasks,
195 | "failed_tasks": result.failed_tasks,
196 | "parallel_efficiency": result.parallel_efficiency,
197 | "total_duration": result.total_execution_time,
198 | },
199 | )
200 |
201 | # Synthesize results if callback provided
202 | if synthesis_callback and result.successful_tasks > 0:
203 | exec_logger.info("🧠 SYNTHESIS_START")
204 | try:
205 | synthesis_start = time.time()
206 | result.synthesis = await synthesis_callback(result.task_results)
207 | _ = (
208 | time.time() - synthesis_start
209 | ) # Track duration but not used currently
210 | exec_logger.info("✅ SYNTHESIS_SUCCESS")
211 | except Exception as e:
212 | exec_logger.error("❌ SYNTHESIS_FAILED", error=str(e))
213 | result.synthesis = {"error": f"Synthesis failed: {str(e)}"}
214 | else:
215 | exec_logger.info("⏭️ SYNTHESIS_SKIPPED")
216 |
217 | return result
218 |
219 | except Exception as e:
220 | exec_logger.error("💥 PARALLEL_EXECUTION_FAILED", error=str(e))
221 | result.total_execution_time = time.time() - start_time
222 | return result
223 |
224 | async def _prepare_tasks(self, tasks: list[ResearchTask]) -> list[ResearchTask]:
225 | """Prepare tasks for execution by setting timeouts and priorities."""
226 | prepared = []
227 |
228 | for task in sorted(tasks, key=lambda t: t.priority, reverse=True):
229 | # Set default timeout if not specified
230 | if not task.timeout:
231 | task.timeout = self.config.timeout_per_agent
232 |
233 | # Set task to pending status
234 | task.status = "pending"
235 | self.active_tasks[task.task_id] = task
236 | prepared.append(task)
237 |
238 | return prepared[: self.config.max_concurrent_agents]
239 |
240 | @circuit_breaker("parallel_research_task", failure_threshold=2, recovery_timeout=30)
241 | async def _execute_single_task(
242 | self, task: ResearchTask, research_executor
243 | ) -> ResearchTask:
244 | """Execute a single research task with optimized error handling."""
245 | # OPTIMIZATION: Acquire semaphore with try_acquire pattern for non-blocking
246 | acquired = False
247 | try:
248 | # Try to acquire immediately, if not available, task is already created and will wait
249 | acquired = not self._semaphore.locked()
250 | if not acquired:
251 | # Wait for semaphore but don't block other task creation
252 | await self._semaphore.acquire()
253 | acquired = True
254 |
255 | task.start_time = time.time()
256 | task.status = "running"
257 |
258 | # Track active worker count
259 | async with self._worker_lock:
260 | self._active_workers += 1
261 |
262 | with log_agent_execution(
263 | task.task_type, task.task_id, task.focus_areas
264 | ) as agent_logger:
265 | try:
266 | agent_logger.info(
267 | "🎯 TASK_EXECUTION_START",
268 | timeout=task.timeout,
269 | priority=task.priority,
270 | )
271 |
272 | # OPTIMIZATION: Use shield to prevent cancellation during critical work
273 | result = await asyncio.shield(
274 | asyncio.wait_for(research_executor(task), timeout=task.timeout)
275 | )
276 |
277 | task.result = result
278 | task.status = "completed"
279 | task.end_time = time.time()
280 |
281 | # Log successful completion
282 | execution_time = task.end_time - task.start_time
283 | agent_logger.info(
284 | "✨ TASK_EXECUTION_SUCCESS",
285 | duration=f"{execution_time:.3f}s",
286 | )
287 |
288 | # Log resource usage if available
289 | if isinstance(result, dict) and "metrics" in result:
290 | log_resource_usage(
291 | f"{task.task_type}Agent",
292 | api_calls=result["metrics"].get("api_calls"),
293 | cache_hits=result["metrics"].get("cache_hits"),
294 | )
295 |
296 | return task
297 |
298 | except TimeoutError:
299 | task.error = f"Task timeout after {task.timeout}s"
300 | task.status = "failed"
301 | agent_logger.error("⏰ TASK_TIMEOUT", timeout=task.timeout)
302 |
303 | except Exception as e:
304 | task.error = str(e)
305 | task.status = "failed"
306 | agent_logger.error("💥 TASK_EXECUTION_FAILED", error=str(e))
307 |
308 | finally:
309 | task.end_time = time.time()
310 | # Track active worker count
311 | async with self._worker_lock:
312 | self._active_workers -= 1
313 |
314 | return task
315 | finally:
316 | # Always release semaphore if acquired
317 | if acquired:
318 | self._semaphore.release()
319 |
320 | async def _process_task_results(
321 | self, tasks: list[ResearchTask], completed_tasks: list[Any], start_time: float
322 | ) -> ResearchResult:
323 | """Process and aggregate results from completed tasks."""
324 | result = ResearchResult()
325 | result.total_execution_time = time.time() - start_time
326 |
327 | for task in tasks:
328 | result.task_results[task.task_id] = task
329 |
330 | if task.status == "completed":
331 | result.successful_tasks += 1
332 | else:
333 | result.failed_tasks += 1
334 |
335 | # Calculate parallel efficiency
336 | if result.total_execution_time > 0:
337 | total_sequential_time = sum(
338 | (task.end_time or 0) - (task.start_time or 0)
339 | for task in tasks
340 | if task.start_time
341 | )
342 | result.parallel_efficiency = (
343 | (total_sequential_time / result.total_execution_time)
344 | if total_sequential_time > 0
345 | else 0.0
346 | )
347 |
348 | logger.info(
349 | f"Parallel research completed: {result.successful_tasks} successful, "
350 | f"{result.failed_tasks} failed, {result.parallel_efficiency:.2f}x speedup"
351 | )
352 |
353 | return result
354 |
355 |
356 | class TaskDistributionEngine:
357 | """Intelligent task distribution for research topics."""
358 |
359 | TASK_TYPES = {
360 | "fundamental": {
361 | "keywords": [
362 | "earnings",
363 | "revenue",
364 | "profit",
365 | "cash flow",
366 | "debt",
367 | "valuation",
368 | ],
369 | "focus_areas": ["financials", "fundamentals", "earnings", "balance_sheet"],
370 | },
371 | "technical": {
372 | "keywords": [
373 | "price",
374 | "chart",
375 | "trend",
376 | "support",
377 | "resistance",
378 | "momentum",
379 | ],
380 | "focus_areas": ["technical_analysis", "chart_patterns", "indicators"],
381 | },
382 | "sentiment": {
383 | "keywords": [
384 | "sentiment",
385 | "news",
386 | "analyst",
387 | "opinion",
388 | "rating",
389 | "recommendation",
390 | ],
391 | "focus_areas": ["market_sentiment", "analyst_ratings", "news_sentiment"],
392 | },
393 | "competitive": {
394 | "keywords": [
395 | "competitor",
396 | "market share",
397 | "industry",
398 | "competitive",
399 | "peers",
400 | ],
401 | "focus_areas": [
402 | "competitive_analysis",
403 | "industry_analysis",
404 | "market_position",
405 | ],
406 | },
407 | }
408 |
409 | @log_method_call(component="TaskDistributionEngine", include_timing=True)
410 | def distribute_research_tasks(
411 | self, topic: str, session_id: str, focus_areas: list[str] | None = None
412 | ) -> list[ResearchTask]:
413 | """
414 | Intelligently distribute a research topic into specialized tasks.
415 |
416 | Args:
417 | topic: Main research topic
418 | session_id: Session identifier for tracking
419 | focus_areas: Optional specific areas to focus on
420 |
421 | Returns:
422 | List of specialized research tasks
423 | """
424 | distribution_logger = get_orchestration_logger("TaskDistributionEngine")
425 | distribution_logger.set_request_context(session_id=session_id)
426 |
427 | distribution_logger.info(
428 | "🎯 TASK_DISTRIBUTION_START",
429 | session_id=session_id,
430 | )
431 |
432 | tasks = []
433 | topic_lower = topic.lower()
434 |
435 | # Determine which task types are relevant
436 | relevant_types = self._analyze_topic_relevance(topic_lower, focus_areas)
437 |
438 | # Log relevance analysis results
439 | distribution_logger.info("🧠 RELEVANCE_ANALYSIS")
440 |
441 | # Create tasks for relevant types
442 | created_tasks = []
443 | for task_type, score in relevant_types.items():
444 | if score > 0.3: # Relevance threshold
445 | task = ResearchTask(
446 | task_id=f"{session_id}_{task_type}",
447 | task_type=task_type,
448 | target_topic=topic,
449 | focus_areas=self.TASK_TYPES[task_type]["focus_areas"],
450 | priority=int(score * 10), # Convert to 1-10 priority
451 | )
452 | tasks.append(task)
453 | created_tasks.append(
454 | {
455 | "type": task_type,
456 | "priority": task.priority,
457 | "score": score,
458 | "focus_areas": task.focus_areas[:3], # Limit for logging
459 | }
460 | )
461 |
462 | # Log created tasks
463 | if created_tasks:
464 | distribution_logger.info(
465 | "✅ TASKS_CREATED",
466 | task_count=len(created_tasks),
467 | )
468 |
469 | # Ensure at least one task (fallback to fundamental analysis)
470 | if not tasks:
471 | distribution_logger.warning(
472 | "⚠️ NO_RELEVANT_TASKS_FOUND - using fallback",
473 | threshold=0.3,
474 | max_score=max(relevant_types.values()) if relevant_types else 0,
475 | )
476 |
477 | fallback_task = ResearchTask(
478 | task_id=f"{session_id}_fundamental",
479 | task_type="fundamental",
480 | target_topic=topic,
481 | focus_areas=["general_analysis"],
482 | priority=5,
483 | )
484 | tasks.append(fallback_task)
485 |
486 | distribution_logger.info(
487 | "🔄 FALLBACK_TASK_CREATED", task_type="fundamental"
488 | )
489 |
490 | # Final summary
491 | task_summary = {
492 | "total_tasks": len(tasks),
493 | "task_types": [t.task_type for t in tasks],
494 | "avg_priority": sum(t.priority for t in tasks) / len(tasks) if tasks else 0,
495 | }
496 |
497 | distribution_logger.info("🎉 TASK_DISTRIBUTION_COMPLETE", **task_summary)
498 |
499 | return tasks
500 |
501 | def _analyze_topic_relevance(
502 | self, topic: str, focus_areas: list[str] | None = None
503 | ) -> dict[str, float]:
504 | """Analyze topic relevance to different research types."""
505 | relevance_scores = {}
506 |
507 | for task_type, config in self.TASK_TYPES.items():
508 | score = 0.0
509 |
510 | # Score based on keywords in topic
511 | keyword_matches = sum(
512 | 1 for keyword in config["keywords"] if keyword in topic
513 | )
514 | score += keyword_matches / len(config["keywords"]) * 0.6
515 |
516 | # Score based on focus areas
517 | if focus_areas:
518 | focus_matches = sum(
519 | 1
520 | for focus in focus_areas
521 | if any(area in focus.lower() for area in config["focus_areas"])
522 | )
523 | score += focus_matches / len(config["focus_areas"]) * 0.4
524 | else:
525 | # Default relevance for common research types
526 | score += {
527 | "fundamental": 0.8,
528 | "sentiment": 0.6,
529 | "technical": 0.4,
530 | "competitive": 0.5,
531 | }.get(task_type, 0.3)
532 |
533 | relevance_scores[task_type] = min(score, 1.0)
534 |
535 | return relevance_scores
536 |
537 |
538 | # Export key classes for easy import
539 | __all__ = [
540 | "ParallelResearchConfig",
541 | "ResearchTask",
542 | "ResearchResult",
543 | "ParallelResearchOrchestrator",
544 | "TaskDistributionEngine",
545 | ]
546 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/data_chunking.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Data chunking utilities for memory-efficient processing of large datasets.
3 | Provides streaming, batching, and generator-based approaches for handling large DataFrames.
4 | """
5 |
6 | import logging
7 | import math
8 | from collections.abc import Callable, Generator
9 | from typing import Any, Literal
10 |
11 | import numpy as np
12 | import pandas as pd
13 |
14 | from maverick_mcp.utils.memory_profiler import (
15 | force_garbage_collection,
16 | get_dataframe_memory_usage,
17 | memory_context,
18 | optimize_dataframe,
19 | )
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | # Default chunk size configurations
24 | DEFAULT_CHUNK_SIZE_MB = 50.0
25 | MAX_CHUNK_SIZE_MB = 200.0
26 | MIN_ROWS_PER_CHUNK = 100
27 |
28 |
29 | class DataChunker:
30 | """Advanced data chunking utility with multiple strategies."""
31 |
32 | def __init__(
33 | self,
34 | chunk_size_mb: float = DEFAULT_CHUNK_SIZE_MB,
35 | min_rows_per_chunk: int = MIN_ROWS_PER_CHUNK,
36 | optimize_chunks: bool = True,
37 | auto_gc: bool = True,
38 | ):
39 | """Initialize data chunker.
40 |
41 | Args:
42 | chunk_size_mb: Target chunk size in megabytes
43 | min_rows_per_chunk: Minimum rows per chunk
44 | optimize_chunks: Whether to optimize chunk memory usage
45 | auto_gc: Whether to automatically run garbage collection
46 | """
47 | self.chunk_size_mb = min(chunk_size_mb, MAX_CHUNK_SIZE_MB)
48 | self.chunk_size_bytes = int(self.chunk_size_mb * 1024 * 1024)
49 | self.min_rows_per_chunk = min_rows_per_chunk
50 | self.optimize_chunks = optimize_chunks
51 | self.auto_gc = auto_gc
52 |
53 | logger.debug(
54 | f"DataChunker initialized: {self.chunk_size_mb}MB chunks, "
55 | f"min {self.min_rows_per_chunk} rows"
56 | )
57 |
58 | def estimate_chunk_size(self, df: pd.DataFrame) -> tuple[int, int]:
59 | """Estimate optimal chunk size for a DataFrame.
60 |
61 | Args:
62 | df: DataFrame to analyze
63 |
64 | Returns:
65 | Tuple of (rows_per_chunk, estimated_chunks)
66 | """
67 | total_memory = df.memory_usage(deep=True).sum()
68 | memory_per_row = total_memory / len(df) if len(df) > 0 else 0
69 |
70 | if memory_per_row == 0:
71 | return len(df), 1
72 |
73 | # Calculate rows per chunk based on memory target
74 | rows_per_chunk = max(
75 | self.min_rows_per_chunk, int(self.chunk_size_bytes / memory_per_row)
76 | )
77 |
78 | # Ensure we don't exceed the DataFrame size
79 | rows_per_chunk = min(rows_per_chunk, len(df))
80 |
81 | estimated_chunks = math.ceil(len(df) / rows_per_chunk)
82 |
83 | logger.debug(
84 | f"Estimated chunking: {rows_per_chunk} rows/chunk, "
85 | f"{estimated_chunks} chunks total"
86 | )
87 |
88 | return rows_per_chunk, estimated_chunks
89 |
90 | def chunk_by_rows(
91 | self, df: pd.DataFrame, rows_per_chunk: int = None
92 | ) -> Generator[pd.DataFrame, None, None]:
93 | """Chunk DataFrame by number of rows.
94 |
95 | Args:
96 | df: DataFrame to chunk
97 | rows_per_chunk: Rows per chunk (auto-estimated if None)
98 |
99 | Yields:
100 | DataFrame chunks
101 | """
102 | if rows_per_chunk is None:
103 | rows_per_chunk, _ = self.estimate_chunk_size(df)
104 |
105 | total_chunks = math.ceil(len(df) / rows_per_chunk)
106 | logger.debug(
107 | f"Chunking {len(df)} rows into {total_chunks} chunks "
108 | f"of ~{rows_per_chunk} rows each"
109 | )
110 |
111 | for i, start_idx in enumerate(range(0, len(df), rows_per_chunk)):
112 | end_idx = min(start_idx + rows_per_chunk, len(df))
113 | chunk = df.iloc[start_idx:end_idx].copy()
114 |
115 | if self.optimize_chunks:
116 | chunk = optimize_dataframe(chunk)
117 |
118 | logger.debug(
119 | f"Yielding chunk {i + 1}/{total_chunks}: rows {start_idx}-{end_idx - 1}"
120 | )
121 |
122 | yield chunk
123 |
124 | # Cleanup after yielding
125 | if self.auto_gc:
126 | del chunk
127 | if i % 5 == 0: # GC every 5 chunks
128 | force_garbage_collection()
129 |
130 | def chunk_by_memory(self, df: pd.DataFrame) -> Generator[pd.DataFrame, None, None]:
131 | """Chunk DataFrame by memory size.
132 |
133 | Args:
134 | df: DataFrame to chunk
135 |
136 | Yields:
137 | DataFrame chunks
138 | """
139 | total_memory = df.memory_usage(deep=True).sum()
140 |
141 | if total_memory <= self.chunk_size_bytes:
142 | if self.optimize_chunks:
143 | df = optimize_dataframe(df)
144 | yield df
145 | return
146 |
147 | # Use row-based chunking with memory-based estimation
148 | yield from self.chunk_by_rows(df)
149 |
150 | def chunk_by_date(
151 | self,
152 | df: pd.DataFrame,
153 | freq: Literal["D", "W", "M", "Q", "Y"] = "M",
154 | date_column: str = None,
155 | ) -> Generator[pd.DataFrame, None, None]:
156 | """Chunk DataFrame by date periods.
157 |
158 | Args:
159 | df: DataFrame to chunk (must have datetime index or date_column)
160 | freq: Frequency for chunking (D=daily, W=weekly, M=monthly, etc.)
161 | date_column: Name of date column (uses index if None)
162 |
163 | Yields:
164 | DataFrame chunks by date periods
165 | """
166 | if date_column:
167 | if date_column not in df.columns:
168 | raise ValueError(f"Date column '{date_column}' not found")
169 | elif not isinstance(df.index, pd.DatetimeIndex):
170 | raise ValueError(
171 | "DataFrame must have datetime index or specify date_column"
172 | )
173 |
174 | # Group by period
175 | period_groups = df.groupby(
176 | pd.Grouper(key=date_column, freq=freq)
177 | if date_column
178 | else pd.Grouper(freq=freq)
179 | )
180 |
181 | total_periods = len(period_groups)
182 | logger.debug(f"Chunking by {freq} periods: {total_periods} chunks")
183 |
184 | for i, (period, group) in enumerate(period_groups):
185 | if len(group) == 0:
186 | continue
187 |
188 | if self.optimize_chunks:
189 | group = optimize_dataframe(group)
190 |
191 | logger.debug(
192 | f"Yielding period chunk {i + 1}/{total_periods}: "
193 | f"{period} ({len(group)} rows)"
194 | )
195 |
196 | yield group
197 |
198 | if self.auto_gc and i % 3 == 0: # GC every 3 periods
199 | force_garbage_collection()
200 |
201 | def process_in_chunks(
202 | self,
203 | df: pd.DataFrame,
204 | processor: Callable[[pd.DataFrame], Any],
205 | combiner: Callable[[list], Any] = None,
206 | chunk_method: Literal["rows", "memory", "date"] = "memory",
207 | **chunk_kwargs,
208 | ) -> Any:
209 | """Process DataFrame in chunks and combine results.
210 |
211 | Args:
212 | df: DataFrame to process
213 | processor: Function to apply to each chunk
214 | combiner: Function to combine results (default: list)
215 | chunk_method: Chunking method to use
216 | **chunk_kwargs: Additional arguments for chunking method
217 |
218 | Returns:
219 | Combined results
220 | """
221 | results = []
222 |
223 | # Select chunking method
224 | if chunk_method == "rows":
225 | chunk_generator = self.chunk_by_rows(df, **chunk_kwargs)
226 | elif chunk_method == "memory":
227 | chunk_generator = self.chunk_by_memory(df)
228 | elif chunk_method == "date":
229 | chunk_generator = self.chunk_by_date(df, **chunk_kwargs)
230 | else:
231 | raise ValueError(f"Unknown chunk method: {chunk_method}")
232 |
233 | with memory_context("chunk_processing"):
234 | for i, chunk in enumerate(chunk_generator):
235 | try:
236 | with memory_context(f"chunk_{i}"):
237 | result = processor(chunk)
238 | results.append(result)
239 |
240 | except Exception as e:
241 | logger.error(f"Error processing chunk {i}: {e}")
242 | raise
243 |
244 | # Combine results
245 | if combiner:
246 | return combiner(results)
247 | elif results and isinstance(results[0], pd.DataFrame):
248 | # Auto-combine DataFrames
249 | return pd.concat(results, ignore_index=True)
250 | else:
251 | return results
252 |
253 |
254 | class StreamingDataProcessor:
255 | """Streaming data processor for very large datasets."""
256 |
257 | def __init__(self, chunk_size_mb: float = DEFAULT_CHUNK_SIZE_MB):
258 | """Initialize streaming processor.
259 |
260 | Args:
261 | chunk_size_mb: Chunk size in MB
262 | """
263 | self.chunk_size_mb = chunk_size_mb
264 | self.chunker = DataChunker(chunk_size_mb=chunk_size_mb)
265 |
266 | def stream_from_csv(
267 | self,
268 | filepath: str,
269 | processor: Callable[[pd.DataFrame], Any],
270 | chunksize: int = None,
271 | **read_kwargs,
272 | ) -> Generator[Any, None, None]:
273 | """Stream process CSV file in chunks.
274 |
275 | Args:
276 | filepath: Path to CSV file
277 | processor: Function to process each chunk
278 | chunksize: Rows per chunk (auto-estimated if None)
279 | **read_kwargs: Additional arguments for pd.read_csv
280 |
281 | Yields:
282 | Processed results for each chunk
283 | """
284 | # Estimate chunk size if not provided
285 | if chunksize is None:
286 | # Read a sample to estimate memory usage
287 | sample = pd.read_csv(filepath, nrows=1000, **read_kwargs)
288 | memory_per_row = sample.memory_usage(deep=True).sum() / len(sample)
289 | chunksize = max(100, int(self.chunker.chunk_size_bytes / memory_per_row))
290 | del sample
291 | force_garbage_collection()
292 |
293 | logger.info(f"Streaming CSV with {chunksize} rows per chunk")
294 |
295 | chunk_reader = pd.read_csv(filepath, chunksize=chunksize, **read_kwargs)
296 |
297 | for i, chunk in enumerate(chunk_reader):
298 | with memory_context(f"csv_chunk_{i}"):
299 | # Optimize chunk if needed
300 | if self.chunker.optimize_chunks:
301 | chunk = optimize_dataframe(chunk)
302 |
303 | result = processor(chunk)
304 | yield result
305 |
306 | # Clean up
307 | del chunk
308 | if i % 5 == 0:
309 | force_garbage_collection()
310 |
311 | def stream_from_database(
312 | self,
313 | query: str,
314 | connection,
315 | processor: Callable[[pd.DataFrame], Any],
316 | chunksize: int = None,
317 | ) -> Generator[Any, None, None]:
318 | """Stream process database query results in chunks.
319 |
320 | Args:
321 | query: SQL query
322 | connection: Database connection
323 | processor: Function to process each chunk
324 | chunksize: Rows per chunk
325 |
326 | Yields:
327 | Processed results for each chunk
328 | """
329 | if chunksize is None:
330 | chunksize = 10000 # Default for database queries
331 |
332 | logger.info(f"Streaming database query with {chunksize} rows per chunk")
333 |
334 | chunk_reader = pd.read_sql(query, connection, chunksize=chunksize)
335 |
336 | for i, chunk in enumerate(chunk_reader):
337 | with memory_context(f"db_chunk_{i}"):
338 | if self.chunker.optimize_chunks:
339 | chunk = optimize_dataframe(chunk)
340 |
341 | result = processor(chunk)
342 | yield result
343 |
344 | del chunk
345 | if i % 3 == 0:
346 | force_garbage_collection()
347 |
348 |
349 | def optimize_dataframe_dtypes(
350 | df: pd.DataFrame, aggressive: bool = False, categorical_threshold: float = 0.5
351 | ) -> pd.DataFrame:
352 | """Optimize DataFrame data types for memory efficiency.
353 |
354 | Args:
355 | df: DataFrame to optimize
356 | aggressive: Use aggressive optimizations (may lose precision)
357 | categorical_threshold: Threshold for categorical conversion
358 |
359 | Returns:
360 | Optimized DataFrame
361 | """
362 | logger.debug(f"Optimizing DataFrame dtypes: {df.shape}")
363 |
364 | initial_memory = df.memory_usage(deep=True).sum()
365 | df_opt = df.copy()
366 |
367 | for col in df_opt.columns:
368 | col_type = df_opt[col].dtype
369 |
370 | try:
371 | if col_type == "object":
372 | # Convert string columns to categorical if beneficial
373 | unique_count = df_opt[col].nunique()
374 | total_count = len(df_opt[col])
375 |
376 | if unique_count / total_count < categorical_threshold:
377 | df_opt[col] = df_opt[col].astype("category")
378 | logger.debug(f"Converted {col} to categorical")
379 |
380 | elif "int" in str(col_type):
381 | # Downcast integers
382 | c_min = df_opt[col].min()
383 | c_max = df_opt[col].max()
384 |
385 | if c_min >= np.iinfo(np.int8).min and c_max <= np.iinfo(np.int8).max:
386 | df_opt[col] = df_opt[col].astype(np.int8)
387 | elif (
388 | c_min >= np.iinfo(np.int16).min and c_max <= np.iinfo(np.int16).max
389 | ):
390 | df_opt[col] = df_opt[col].astype(np.int16)
391 | elif (
392 | c_min >= np.iinfo(np.int32).min and c_max <= np.iinfo(np.int32).max
393 | ):
394 | df_opt[col] = df_opt[col].astype(np.int32)
395 |
396 | elif "float" in str(col_type) and col_type == "float64":
397 | # Downcast float64 to float32 if no precision loss
398 | if aggressive:
399 | # Check if conversion preserves data
400 | temp = df_opt[col].astype(np.float32)
401 | if np.allclose(
402 | df_opt[col].fillna(0), temp.fillna(0), rtol=1e-6, equal_nan=True
403 | ):
404 | df_opt[col] = temp
405 | logger.debug(f"Converted {col} to float32")
406 |
407 | except Exception as e:
408 | logger.debug(f"Could not optimize column {col}: {e}")
409 | continue
410 |
411 | final_memory = df_opt.memory_usage(deep=True).sum()
412 | memory_saved = initial_memory - final_memory
413 |
414 | if memory_saved > 0:
415 | logger.info(
416 | f"DataFrame optimization saved {memory_saved / (1024**2):.2f}MB "
417 | f"({memory_saved / initial_memory * 100:.1f}% reduction)"
418 | )
419 |
420 | return df_opt
421 |
422 |
423 | def create_memory_efficient_dataframe(
424 | data: dict | list, optimize: bool = True, categorical_columns: list[str] = None
425 | ) -> pd.DataFrame:
426 | """Create a memory-efficient DataFrame from data.
427 |
428 | Args:
429 | data: Data to create DataFrame from
430 | optimize: Whether to optimize dtypes
431 | categorical_columns: Columns to convert to categorical
432 |
433 | Returns:
434 | Memory-optimized DataFrame
435 | """
436 | with memory_context("creating_dataframe"):
437 | df = pd.DataFrame(data)
438 |
439 | if categorical_columns:
440 | for col in categorical_columns:
441 | if col in df.columns:
442 | df[col] = df[col].astype("category")
443 |
444 | if optimize:
445 | df = optimize_dataframe_dtypes(df)
446 |
447 | return df
448 |
449 |
450 | def batch_process_large_dataframe(
451 | df: pd.DataFrame,
452 | operation: Callable,
453 | batch_size: int = None,
454 | combine_results: bool = True,
455 | ) -> Any:
456 | """Process large DataFrame in batches to manage memory.
457 |
458 | Args:
459 | df: Large DataFrame to process
460 | operation: Function to apply to each batch
461 | batch_size: Size of each batch (auto-estimated if None)
462 | combine_results: Whether to combine batch results
463 |
464 | Returns:
465 | Combined results or list of batch results
466 | """
467 | chunker = DataChunker()
468 |
469 | if batch_size:
470 | chunk_generator = chunker.chunk_by_rows(df, batch_size)
471 | else:
472 | chunk_generator = chunker.chunk_by_memory(df)
473 |
474 | results = []
475 |
476 | with memory_context("batch_processing"):
477 | for i, batch in enumerate(chunk_generator):
478 | logger.debug(f"Processing batch {i + 1}")
479 |
480 | with memory_context(f"batch_{i}"):
481 | result = operation(batch)
482 | results.append(result)
483 |
484 | if combine_results and results:
485 | if isinstance(results[0], pd.DataFrame):
486 | return pd.concat(results, ignore_index=True)
487 | elif isinstance(results[0], int | float):
488 | return sum(results)
489 | elif isinstance(results[0], list):
490 | return [item for sublist in results for item in sublist]
491 |
492 | return results
493 |
494 |
495 | class LazyDataFrame:
496 | """Lazy evaluation wrapper for large DataFrames."""
497 |
498 | def __init__(self, data_source: str | pd.DataFrame, chunk_size_mb: float = 50.0):
499 | """Initialize lazy DataFrame.
500 |
501 | Args:
502 | data_source: File path or DataFrame
503 | chunk_size_mb: Chunk size for processing
504 | """
505 | self.data_source = data_source
506 | self.chunker = DataChunker(chunk_size_mb=chunk_size_mb)
507 | self._cached_info = None
508 |
509 | def get_info(self) -> dict[str, Any]:
510 | """Get DataFrame information without loading full data."""
511 | if self._cached_info:
512 | return self._cached_info
513 |
514 | if isinstance(self.data_source, str):
515 | # Read just the header and a sample
516 | sample = pd.read_csv(self.data_source, nrows=100)
517 | total_rows = sum(1 for _ in open(self.data_source)) - 1 # Subtract header
518 |
519 | self._cached_info = {
520 | "columns": sample.columns.tolist(),
521 | "dtypes": sample.dtypes.to_dict(),
522 | "estimated_rows": total_rows,
523 | "sample_memory_mb": sample.memory_usage(deep=True).sum() / (1024**2),
524 | }
525 | else:
526 | self._cached_info = get_dataframe_memory_usage(self.data_source)
527 |
528 | return self._cached_info
529 |
530 | def apply_chunked(self, operation: Callable) -> Any:
531 | """Apply operation in chunks."""
532 | if isinstance(self.data_source, str):
533 | processor = StreamingDataProcessor(self.chunker.chunk_size_mb)
534 | results = list(processor.stream_from_csv(self.data_source, operation))
535 | else:
536 | results = self.chunker.process_in_chunks(self.data_source, operation)
537 |
538 | return results
539 |
540 | def to_optimized_dataframe(self) -> pd.DataFrame:
541 | """Load and optimize the full DataFrame."""
542 | if isinstance(self.data_source, str):
543 | df = pd.read_csv(self.data_source)
544 | else:
545 | df = self.data_source.copy()
546 |
547 | return optimize_dataframe_dtypes(df)
548 |
549 |
550 | # Utility functions for common operations
551 |
552 |
553 | def chunked_concat(
554 | dataframes: list[pd.DataFrame], chunk_size: int = 10
555 | ) -> pd.DataFrame:
556 | """Concatenate DataFrames in chunks to manage memory.
557 |
558 | Args:
559 | dataframes: List of DataFrames to concatenate
560 | chunk_size: Number of DataFrames to concat at once
561 |
562 | Returns:
563 | Concatenated DataFrame
564 | """
565 | if not dataframes:
566 | return pd.DataFrame()
567 |
568 | if len(dataframes) <= chunk_size:
569 | return pd.concat(dataframes, ignore_index=True)
570 |
571 | # Process in chunks
572 | results = []
573 | for i in range(0, len(dataframes), chunk_size):
574 | chunk = dataframes[i : i + chunk_size]
575 | with memory_context(f"concat_chunk_{i // chunk_size}"):
576 | result = pd.concat(chunk, ignore_index=True)
577 | results.append(result)
578 |
579 | # Clean up chunk
580 | for df in chunk:
581 | del df
582 | force_garbage_collection()
583 |
584 | # Final concatenation
585 | with memory_context("final_concat"):
586 | final_result = pd.concat(results, ignore_index=True)
587 |
588 | return final_result
589 |
590 |
591 | def memory_efficient_groupby(
592 | df: pd.DataFrame, group_col: str, agg_func: Callable, chunk_size_mb: float = 50.0
593 | ) -> pd.DataFrame:
594 | """Perform memory-efficient groupby operations.
595 |
596 | Args:
597 | df: DataFrame to group
598 | group_col: Column to group by
599 | agg_func: Aggregation function
600 | chunk_size_mb: Chunk size in MB
601 |
602 | Returns:
603 | Aggregated DataFrame
604 | """
605 | if group_col not in df.columns:
606 | raise ValueError(f"Group column '{group_col}' not found")
607 |
608 | chunker = DataChunker(chunk_size_mb=chunk_size_mb)
609 | results = []
610 |
611 | def process_chunk(chunk):
612 | return chunk.groupby(group_col).apply(agg_func).reset_index()
613 |
614 | results = chunker.process_in_chunks(df, process_chunk)
615 |
616 | # Combine and re-aggregate results
617 | combined = pd.concat(results, ignore_index=True)
618 | final_result = combined.groupby(group_col).apply(agg_func).reset_index()
619 |
620 | return final_result
621 |
```
--------------------------------------------------------------------------------
/maverick_mcp/workflows/agents/market_analyzer.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Market Analyzer Agent for intelligent market regime detection.
3 |
4 | This agent analyzes market conditions to determine the current market regime
5 | (trending, ranging, volatile, etc.) and provides context for strategy selection.
6 | """
7 |
8 | import logging
9 | import math
10 | from datetime import datetime, timedelta
11 | from typing import Any
12 |
13 | import numpy as np
14 | import pandas as pd
15 | import pandas_ta as ta
16 |
17 | from maverick_mcp.data.cache import CacheManager
18 | from maverick_mcp.providers.stock_data import EnhancedStockDataProvider
19 | from maverick_mcp.workflows.state import BacktestingWorkflowState
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class MarketAnalyzerAgent:
25 | """Intelligent market regime analyzer for backtesting workflows."""
26 |
27 | def __init__(
28 | self,
29 | data_provider: EnhancedStockDataProvider | None = None,
30 | cache_manager: CacheManager | None = None,
31 | ):
32 | """Initialize market analyzer agent.
33 |
34 | Args:
35 | data_provider: Stock data provider instance
36 | cache_manager: Cache manager for performance optimization
37 | """
38 | self.data_provider = data_provider or EnhancedStockDataProvider()
39 | self.cache = cache_manager or CacheManager()
40 |
41 | # Market regime detection thresholds
42 | self.TREND_THRESHOLD = 0.15 # 15% for strong trend
43 | self.VOLATILITY_THRESHOLD = 0.02 # 2% daily volatility threshold
44 | self.VOLUME_THRESHOLD = 1.5 # 1.5x average volume for high volume
45 |
46 | # Analysis periods for different regimes
47 | self.SHORT_PERIOD = 20 # Short-term trend analysis
48 | self.MEDIUM_PERIOD = 50 # Medium-term trend analysis
49 | self.LONG_PERIOD = 200 # Long-term trend analysis
50 |
51 | logger.info("MarketAnalyzerAgent initialized")
52 |
53 | async def analyze_market_regime(
54 | self, state: BacktestingWorkflowState
55 | ) -> BacktestingWorkflowState:
56 | """Analyze market regime and update state.
57 |
58 | Args:
59 | state: Current workflow state
60 |
61 | Returns:
62 | Updated state with market regime analysis
63 | """
64 | start_time = datetime.now()
65 |
66 | try:
67 | logger.info(f"Analyzing market regime for {state['symbol']}")
68 |
69 | # Get historical data for analysis
70 | extended_start = self._calculate_extended_start_date(state["start_date"])
71 | price_data = await self._get_price_data(
72 | state["symbol"], extended_start, state["end_date"]
73 | )
74 |
75 | if price_data is None or len(price_data) < self.LONG_PERIOD:
76 | raise ValueError(
77 | f"Insufficient data for market regime analysis: {state['symbol']}"
78 | )
79 |
80 | # Perform comprehensive market analysis
81 | regime_analysis = self._perform_regime_analysis(price_data)
82 |
83 | # Update state with analysis results
84 | state["market_regime"] = regime_analysis["regime"]
85 | state["regime_confidence"] = regime_analysis["confidence"]
86 | state["regime_indicators"] = regime_analysis["indicators"]
87 | state["volatility_percentile"] = regime_analysis["volatility_percentile"]
88 | state["trend_strength"] = regime_analysis["trend_strength"]
89 | state["market_conditions"] = regime_analysis["market_conditions"]
90 | state["volume_profile"] = regime_analysis["volume_profile"]
91 | state["support_resistance_levels"] = regime_analysis["support_resistance"]
92 |
93 | # Record execution time
94 | execution_time = (datetime.now() - start_time).total_seconds() * 1000
95 | state["regime_analysis_time_ms"] = execution_time
96 |
97 | # Update workflow status
98 | state["workflow_status"] = "selecting_strategies"
99 | state["current_step"] = "market_analysis_completed"
100 | state["steps_completed"].append("market_regime_analysis")
101 |
102 | logger.info(
103 | f"Market regime analysis completed for {state['symbol']}: "
104 | f"{state['market_regime']} (confidence: {state['regime_confidence']:.2f})"
105 | )
106 |
107 | return state
108 |
109 | except Exception as e:
110 | error_info = {
111 | "step": "market_regime_analysis",
112 | "error": str(e),
113 | "timestamp": datetime.now().isoformat(),
114 | "symbol": state["symbol"],
115 | }
116 | state["errors_encountered"].append(error_info)
117 |
118 | # Set fallback regime
119 | state["market_regime"] = "unknown"
120 | state["regime_confidence"] = 0.0
121 | state["fallback_strategies_used"].append("regime_detection_fallback")
122 |
123 | logger.error(f"Market regime analysis failed for {state['symbol']}: {e}")
124 | return state
125 |
126 | def _calculate_extended_start_date(self, start_date: str) -> str:
127 | """Calculate extended start date to ensure sufficient data for analysis."""
128 | start_dt = datetime.strptime(start_date, "%Y-%m-%d")
129 | # Add extra buffer for technical indicators
130 | extended_start = start_dt - timedelta(days=self.LONG_PERIOD + 50)
131 | return extended_start.strftime("%Y-%m-%d")
132 |
133 | async def _get_price_data(
134 | self, symbol: str, start_date: str, end_date: str
135 | ) -> pd.DataFrame | None:
136 | """Get price data with caching."""
137 | cache_key = f"market_analysis:{symbol}:{start_date}:{end_date}"
138 |
139 | # Try cache first
140 | cached_data = await self.cache.get(cache_key)
141 | if cached_data is not None:
142 | return pd.DataFrame(cached_data)
143 |
144 | try:
145 | # Fetch from provider
146 | data = self.data_provider.get_stock_data(
147 | symbol=symbol, start_date=start_date, end_date=end_date, interval="1d"
148 | )
149 |
150 | if data is not None and not data.empty:
151 | # Cache for 30 minutes
152 | await self.cache.set(cache_key, data.to_dict(), ttl=1800)
153 | return data
154 |
155 | return None
156 |
157 | except Exception as e:
158 | logger.error(f"Failed to fetch price data for {symbol}: {e}")
159 | return None
160 |
161 | def _perform_regime_analysis(self, data: pd.DataFrame) -> dict[str, Any]:
162 | """Perform comprehensive market regime analysis."""
163 | # Ensure column names are lowercase
164 | data.columns = [col.lower() for col in data.columns]
165 |
166 | # Calculate technical indicators
167 | close = data["close"]
168 | high = data["high"]
169 | low = data["low"]
170 | volume = data["volume"]
171 |
172 | # Trend analysis
173 | trend_analysis = self._analyze_trend(close)
174 |
175 | # Volatility analysis
176 | volatility_analysis = self._analyze_volatility(close)
177 |
178 | # Volume analysis
179 | volume_analysis = self._analyze_volume(volume, close)
180 |
181 | # Support/resistance analysis
182 | support_resistance = self._identify_support_resistance(high, low, close)
183 |
184 | # Market structure analysis
185 | market_structure = self._analyze_market_structure(high, low, close)
186 |
187 | # Determine overall regime
188 | regime_info = self._classify_regime(
189 | trend_analysis, volatility_analysis, volume_analysis, market_structure
190 | )
191 |
192 | return {
193 | "regime": regime_info["regime"],
194 | "confidence": regime_info["confidence"],
195 | "indicators": {
196 | "trend_slope": trend_analysis["slope"],
197 | "trend_r2": trend_analysis["r_squared"],
198 | "volatility_20d": volatility_analysis["volatility_20d"],
199 | "volume_ratio": volume_analysis["volume_ratio"],
200 | "rsi_14": trend_analysis["rsi"],
201 | "adx": trend_analysis["adx"],
202 | },
203 | "volatility_percentile": volatility_analysis["percentile"],
204 | "trend_strength": trend_analysis["strength"],
205 | "market_conditions": {
206 | "trend_direction": trend_analysis["direction"],
207 | "trend_consistency": trend_analysis["consistency"],
208 | "volatility_regime": volatility_analysis["regime"],
209 | "volume_regime": volume_analysis["regime"],
210 | "market_structure": market_structure["structure_type"],
211 | },
212 | "volume_profile": volume_analysis["profile"],
213 | "support_resistance": support_resistance,
214 | }
215 |
216 | def _analyze_trend(self, close: pd.Series) -> dict[str, Any]:
217 | """Analyze trend characteristics."""
218 | # Calculate moving averages
219 | ma_20 = ta.sma(close, length=self.SHORT_PERIOD)
220 | ma_50 = ta.sma(close, length=self.MEDIUM_PERIOD)
221 | ma_200 = ta.sma(close, length=self.LONG_PERIOD)
222 |
223 | # Calculate trend slope using linear regression
224 | recent_data = close.tail(self.MEDIUM_PERIOD).reset_index(drop=True)
225 | x = np.arange(len(recent_data))
226 |
227 | if len(recent_data) > 1:
228 | slope, intercept = np.polyfit(x, recent_data, 1)
229 | y_pred = slope * x + intercept
230 | r_squared = 1 - (
231 | np.sum((recent_data - y_pred) ** 2)
232 | / np.sum((recent_data - np.mean(recent_data)) ** 2)
233 | )
234 | else:
235 | slope = 0
236 | r_squared = 0
237 |
238 | # Normalize slope by price for comparability
239 | normalized_slope = slope / close.iloc[-1] if close.iloc[-1] != 0 else 0
240 |
241 | # Calculate RSI and ADX for trend strength
242 | rsi = ta.rsi(close, length=14).iloc[-1] if len(close) >= 14 else 50
243 | adx_result = ta.adx(
244 | close.to_frame().rename(columns={"close": "high"}),
245 | close.to_frame().rename(columns={"close": "low"}),
246 | close,
247 | length=14,
248 | )
249 | adx = (
250 | adx_result.iloc[-1, 0]
251 | if adx_result is not None and len(adx_result) > 0
252 | else 25
253 | )
254 |
255 | # Determine trend direction and strength
256 | if normalized_slope > 0.001: # 0.1% daily trend
257 | direction = "bullish"
258 | strength = min(abs(normalized_slope) * 1000, 1.0) # Cap at 1.0
259 | elif normalized_slope < -0.001:
260 | direction = "bearish"
261 | strength = min(abs(normalized_slope) * 1000, 1.0)
262 | else:
263 | direction = "sideways"
264 | strength = 0.2 # Low strength for sideways
265 |
266 | # Calculate trend consistency
267 | ma_alignment = 0
268 | if len(ma_20) > 0 and len(ma_50) > 0 and len(ma_200) > 0:
269 | current_price = close.iloc[-1]
270 | if ma_20.iloc[-1] > ma_50.iloc[-1] > ma_200.iloc[-1] > current_price * 0.95:
271 | ma_alignment = 1.0 # Bullish alignment
272 | elif (
273 | ma_20.iloc[-1] < ma_50.iloc[-1] < ma_200.iloc[-1] < current_price * 1.05
274 | ):
275 | ma_alignment = -1.0 # Bearish alignment
276 | else:
277 | ma_alignment = 0.0 # Mixed alignment
278 |
279 | consistency = (abs(ma_alignment) + r_squared) / 2
280 |
281 | return {
282 | "slope": normalized_slope,
283 | "r_squared": r_squared,
284 | "direction": direction,
285 | "strength": strength,
286 | "consistency": consistency,
287 | "rsi": rsi,
288 | "adx": adx,
289 | }
290 |
291 | def _analyze_volatility(self, close: pd.Series) -> dict[str, Any]:
292 | """Analyze volatility characteristics."""
293 | # Calculate various volatility measures
294 | returns = close.pct_change().dropna()
295 |
296 | volatility_5d = (
297 | returns.tail(5).std() * math.sqrt(252) if len(returns) >= 5 else 0
298 | )
299 | volatility_20d = (
300 | returns.tail(20).std() * math.sqrt(252) if len(returns) >= 20 else 0
301 | )
302 | volatility_60d = (
303 | returns.tail(60).std() * math.sqrt(252) if len(returns) >= 60 else 0
304 | )
305 |
306 | # Calculate historical volatility percentile
307 | historical_vol = returns.rolling(20).std() * math.sqrt(252)
308 | if len(historical_vol.dropna()) > 0:
309 | current_vol = historical_vol.iloc[-1]
310 | percentile = (historical_vol < current_vol).sum() / len(
311 | historical_vol.dropna()
312 | )
313 | else:
314 | percentile = 0.5
315 |
316 | # Classify volatility regime
317 | if volatility_20d > 0.4: # > 40% annualized
318 | regime = "high_volatility"
319 | elif volatility_20d > 0.2: # 20-40% annualized
320 | regime = "medium_volatility"
321 | else:
322 | regime = "low_volatility"
323 |
324 | return {
325 | "volatility_5d": volatility_5d,
326 | "volatility_20d": volatility_20d,
327 | "volatility_60d": volatility_60d,
328 | "percentile": percentile,
329 | "regime": regime,
330 | }
331 |
332 | def _analyze_volume(self, volume: pd.Series, close: pd.Series) -> dict[str, Any]:
333 | """Analyze volume characteristics."""
334 | # Calculate volume moving averages
335 | volume_ma_20 = volume.rolling(20).mean()
336 |
337 | # Current volume ratio vs average
338 | current_volume = volume.iloc[-1] if len(volume) > 0 else 0
339 | avg_volume = volume_ma_20.iloc[-1] if len(volume_ma_20.dropna()) > 0 else 1
340 | volume_ratio = current_volume / avg_volume if avg_volume > 0 else 1
341 |
342 | # Volume trend
343 | recent_volume = volume.tail(10)
344 | volume_trend = (
345 | "increasing"
346 | if recent_volume.iloc[-1] > recent_volume.mean()
347 | else "decreasing"
348 | )
349 |
350 | # Price-volume relationship
351 | price_change = close.pct_change().tail(10)
352 | volume_change = volume.pct_change().tail(10)
353 |
354 | correlation = price_change.corr(volume_change) if len(price_change) >= 2 else 0
355 |
356 | # Volume regime classification
357 | if volume_ratio > 2.0:
358 | regime = "high_volume"
359 | elif volume_ratio > 1.5:
360 | regime = "elevated_volume"
361 | elif volume_ratio < 0.5:
362 | regime = "low_volume"
363 | else:
364 | regime = "normal_volume"
365 |
366 | return {
367 | "volume_ratio": volume_ratio,
368 | "volume_trend": volume_trend,
369 | "price_volume_correlation": correlation,
370 | "regime": regime,
371 | "profile": {
372 | "current_vs_20d": volume_ratio,
373 | "trend_direction": volume_trend,
374 | "price_correlation": correlation,
375 | },
376 | }
377 |
378 | def _identify_support_resistance(
379 | self, high: pd.Series, low: pd.Series, close: pd.Series
380 | ) -> list[float]:
381 | """Identify key support and resistance levels."""
382 | levels = []
383 |
384 | try:
385 | # Recent price range
386 | recent_data = close.tail(50) if len(close) >= 50 else close
387 | price_range = recent_data.max() - recent_data.min()
388 |
389 | # Identify local peaks and troughs
390 | try:
391 | from scipy.signal import find_peaks
392 |
393 | # Find resistance levels (peaks)
394 | peaks, _ = find_peaks(
395 | high.values, distance=5, prominence=price_range * 0.02
396 | )
397 | resistance_levels = high.iloc[peaks].tolist()
398 |
399 | # Find support levels (troughs)
400 | troughs, _ = find_peaks(
401 | -low.values, distance=5, prominence=price_range * 0.02
402 | )
403 | support_levels = low.iloc[troughs].tolist()
404 | except ImportError:
405 | logger.warning("scipy not available, using simple peak detection")
406 | # Fallback to simple method
407 | resistance_levels = [recent_data.max()]
408 | support_levels = [recent_data.min()]
409 |
410 | # Combine and filter levels
411 | all_levels = resistance_levels + support_levels
412 |
413 | # Remove levels too close to each other
414 | filtered_levels = []
415 | for level in sorted(all_levels):
416 | if not any(
417 | abs(level - existing) < price_range * 0.01
418 | for existing in filtered_levels
419 | ):
420 | filtered_levels.append(level)
421 |
422 | # Keep only most significant levels
423 | levels = sorted(filtered_levels)[-10:] # Top 10 levels
424 |
425 | except Exception as e:
426 | logger.warning(f"Failed to calculate support/resistance levels: {e}")
427 | # Fallback to simple levels
428 | current_price = close.iloc[-1]
429 | levels = [
430 | current_price * 0.95, # 5% below
431 | current_price * 1.05, # 5% above
432 | ]
433 |
434 | return levels
435 |
436 | def _analyze_market_structure(
437 | self, high: pd.Series, low: pd.Series, close: pd.Series
438 | ) -> dict[str, Any]:
439 | """Analyze market structure patterns."""
440 | try:
441 | # Calculate recent highs and lows
442 | lookback = min(20, len(close))
443 | recent_highs = high.tail(lookback)
444 | recent_lows = low.tail(lookback)
445 |
446 | # Identify higher highs, higher lows, etc.
447 | higher_highs = (recent_highs.rolling(3).max() == recent_highs).sum()
448 | higher_lows = (recent_lows.rolling(3).min() == recent_lows).sum()
449 |
450 | # Classify structure
451 | if higher_highs > lookback * 0.3 and higher_lows > lookback * 0.3:
452 | structure_type = "uptrend_structure"
453 | elif higher_highs < lookback * 0.1 and higher_lows < lookback * 0.1:
454 | structure_type = "downtrend_structure"
455 | else:
456 | structure_type = "ranging_structure"
457 |
458 | return {
459 | "structure_type": structure_type,
460 | "higher_highs": higher_highs,
461 | "higher_lows": higher_lows,
462 | }
463 |
464 | except Exception as e:
465 | logger.warning(f"Failed to analyze market structure: {e}")
466 | return {
467 | "structure_type": "unknown_structure",
468 | "higher_highs": 0,
469 | "higher_lows": 0,
470 | }
471 |
472 | def _classify_regime(
473 | self,
474 | trend_analysis: dict,
475 | volatility_analysis: dict,
476 | volume_analysis: dict,
477 | market_structure: dict,
478 | ) -> dict[str, Any]:
479 | """Classify overall market regime based on component analyses."""
480 |
481 | # Initialize scoring system
482 | regime_scores = {
483 | "trending": 0.0,
484 | "ranging": 0.0,
485 | "volatile": 0.0,
486 | "low_volume": 0.0,
487 | }
488 |
489 | # Trend scoring
490 | if trend_analysis["strength"] > 0.6 and trend_analysis["consistency"] > 0.6:
491 | regime_scores["trending"] += 0.4
492 |
493 | if trend_analysis["adx"] > 25: # Strong trend
494 | regime_scores["trending"] += 0.2
495 |
496 | # Ranging scoring
497 | if (
498 | trend_analysis["strength"] < 0.3
499 | and trend_analysis["direction"] == "sideways"
500 | ):
501 | regime_scores["ranging"] += 0.4
502 |
503 | if market_structure["structure_type"] == "ranging_structure":
504 | regime_scores["ranging"] += 0.2
505 |
506 | # Volatility scoring
507 | if volatility_analysis["regime"] == "high_volatility":
508 | regime_scores["volatile"] += 0.3
509 |
510 | if volatility_analysis["percentile"] > 0.8: # High volatility percentile
511 | regime_scores["volatile"] += 0.2
512 |
513 | # Volume scoring
514 | if volume_analysis["regime"] == "low_volume":
515 | regime_scores["low_volume"] += 0.3
516 |
517 | # Determine primary regime
518 | primary_regime = max(regime_scores.items(), key=lambda x: x[1])
519 | regime_name = primary_regime[0]
520 |
521 | # Combine regimes for complex cases
522 | if regime_scores["volatile"] > 0.3 and regime_scores["trending"] > 0.3:
523 | regime_name = "volatile_trending"
524 | elif regime_scores["low_volume"] > 0.2 and regime_scores["ranging"] > 0.3:
525 | regime_name = "low_volume_ranging"
526 |
527 | # Calculate confidence based on score spread
528 | sorted_scores = sorted(regime_scores.values(), reverse=True)
529 | confidence = (
530 | sorted_scores[0] - sorted_scores[1]
531 | if len(sorted_scores) > 1
532 | else sorted_scores[0]
533 | )
534 | confidence = min(max(confidence, 0.1), 0.95) # Clamp between 0.1 and 0.95
535 |
536 | return {
537 | "regime": regime_name,
538 | "confidence": confidence,
539 | "scores": regime_scores,
540 | }
541 |
```
--------------------------------------------------------------------------------
/maverick_mcp/monitoring/status_dashboard.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Status Dashboard for Backtesting System Health Monitoring.
3 |
4 | This module provides a comprehensive dashboard that aggregates health status
5 | from all components and provides real-time metrics visualization data.
6 | """
7 |
8 | import logging
9 | import time
10 | from datetime import UTC, datetime, timedelta
11 | from typing import Any
12 |
13 | from maverick_mcp.config.settings import get_settings
14 | from maverick_mcp.utils.circuit_breaker import get_all_circuit_breaker_status
15 |
16 | logger = logging.getLogger(__name__)
17 | settings = get_settings()
18 |
19 | # Dashboard refresh interval (seconds)
20 | DASHBOARD_REFRESH_INTERVAL = 30
21 |
22 | # Historical data retention (hours)
23 | HISTORICAL_DATA_RETENTION = 24
24 |
25 |
26 | class StatusDashboard:
27 | """Comprehensive status dashboard for the backtesting system."""
28 |
29 | def __init__(self):
30 | self.start_time = time.time()
31 | self.historical_data = []
32 | self.last_update = None
33 | self.alert_thresholds = {
34 | "cpu_usage": 80.0,
35 | "memory_usage": 85.0,
36 | "disk_usage": 90.0,
37 | "response_time_ms": 5000.0,
38 | "failure_rate": 0.1,
39 | }
40 |
41 | async def get_dashboard_data(self) -> dict[str, Any]:
42 | """Get comprehensive dashboard data."""
43 | try:
44 | from maverick_mcp.api.routers.health_enhanced import (
45 | _get_detailed_health_status,
46 | )
47 |
48 | # Get current health status
49 | health_status = await _get_detailed_health_status()
50 |
51 | # Get circuit breaker status
52 | circuit_breaker_status = get_all_circuit_breaker_status()
53 |
54 | # Calculate metrics
55 | metrics = await self._calculate_metrics(
56 | health_status, circuit_breaker_status
57 | )
58 |
59 | # Get alerts
60 | alerts = self._generate_alerts(health_status, metrics)
61 |
62 | # Build dashboard data
63 | dashboard_data = {
64 | "overview": self._build_overview(health_status),
65 | "components": self._build_component_summary(health_status),
66 | "circuit_breakers": self._build_circuit_breaker_summary(
67 | circuit_breaker_status
68 | ),
69 | "resources": self._build_resource_summary(health_status),
70 | "metrics": metrics,
71 | "alerts": alerts,
72 | "historical": self._get_historical_data(),
73 | "metadata": {
74 | "last_updated": datetime.now(UTC).isoformat(),
75 | "uptime_seconds": time.time() - self.start_time,
76 | "dashboard_version": "1.0.0",
77 | "auto_refresh_interval": DASHBOARD_REFRESH_INTERVAL,
78 | },
79 | }
80 |
81 | # Update historical data
82 | self._update_historical_data(health_status, metrics)
83 |
84 | self.last_update = datetime.now(UTC)
85 | return dashboard_data
86 |
87 | except Exception as e:
88 | logger.error(f"Failed to generate dashboard data: {e}")
89 | return self._get_error_dashboard(str(e))
90 |
91 | def _build_overview(self, health_status: dict[str, Any]) -> dict[str, Any]:
92 | """Build overview section of the dashboard."""
93 | components = health_status.get("components", {})
94 | checks_summary = health_status.get("checks_summary", {})
95 |
96 | total_components = len(components)
97 | healthy_components = checks_summary.get("healthy", 0)
98 | degraded_components = checks_summary.get("degraded", 0)
99 | unhealthy_components = checks_summary.get("unhealthy", 0)
100 |
101 | # Calculate health percentage
102 | health_percentage = (
103 | (healthy_components / total_components * 100) if total_components > 0 else 0
104 | )
105 |
106 | return {
107 | "overall_status": health_status.get("status", "unknown"),
108 | "health_percentage": round(health_percentage, 1),
109 | "total_components": total_components,
110 | "component_breakdown": {
111 | "healthy": healthy_components,
112 | "degraded": degraded_components,
113 | "unhealthy": unhealthy_components,
114 | },
115 | "uptime_seconds": health_status.get("uptime_seconds", 0),
116 | "version": health_status.get("version", "unknown"),
117 | }
118 |
119 | def _build_component_summary(self, health_status: dict[str, Any]) -> dict[str, Any]:
120 | """Build component summary with status and response times."""
121 | components = health_status.get("components", {})
122 |
123 | component_summary = {}
124 | for name, status in components.items():
125 | component_summary[name] = {
126 | "status": status.status,
127 | "response_time_ms": status.response_time_ms,
128 | "last_check": status.last_check,
129 | "has_error": status.error is not None,
130 | "error_message": status.error,
131 | }
132 |
133 | return component_summary
134 |
135 | def _build_circuit_breaker_summary(
136 | self, circuit_breaker_status: dict[str, Any]
137 | ) -> dict[str, Any]:
138 | """Build circuit breaker summary."""
139 | summary = {
140 | "total_breakers": len(circuit_breaker_status),
141 | "states": {"closed": 0, "open": 0, "half_open": 0},
142 | "breakers": {},
143 | }
144 |
145 | for name, status in circuit_breaker_status.items():
146 | state = status.get("state", "unknown")
147 | if state in summary["states"]:
148 | summary["states"][state] += 1
149 |
150 | metrics = status.get("metrics", {})
151 | summary["breakers"][name] = {
152 | "state": state,
153 | "failure_count": status.get("consecutive_failures", 0),
154 | "success_rate": metrics.get("success_rate", 0),
155 | "avg_response_time": metrics.get("avg_response_time", 0),
156 | "total_calls": metrics.get("total_calls", 0),
157 | }
158 |
159 | return summary
160 |
161 | def _build_resource_summary(self, health_status: dict[str, Any]) -> dict[str, Any]:
162 | """Build resource usage summary."""
163 | resource_usage = health_status.get("resource_usage", {})
164 |
165 | return {
166 | "cpu_percent": resource_usage.get("cpu_percent", 0),
167 | "memory_percent": resource_usage.get("memory_percent", 0),
168 | "disk_percent": resource_usage.get("disk_percent", 0),
169 | "memory_used_mb": resource_usage.get("memory_used_mb", 0),
170 | "memory_total_mb": resource_usage.get("memory_total_mb", 0),
171 | "disk_used_gb": resource_usage.get("disk_used_gb", 0),
172 | "disk_total_gb": resource_usage.get("disk_total_gb", 0),
173 | "load_average": resource_usage.get("load_average", []),
174 | }
175 |
176 | async def _calculate_metrics(
177 | self, health_status: dict[str, Any], circuit_breaker_status: dict[str, Any]
178 | ) -> dict[str, Any]:
179 | """Calculate performance and availability metrics."""
180 | components = health_status.get("components", {})
181 | resource_usage = health_status.get("resource_usage", {})
182 |
183 | # Calculate average response time
184 | response_times = [
185 | comp.response_time_ms
186 | for comp in components.values()
187 | if comp.response_time_ms is not None
188 | ]
189 | avg_response_time = (
190 | sum(response_times) / len(response_times) if response_times else 0
191 | )
192 |
193 | # Calculate availability
194 | total_components = len(components)
195 | available_components = sum(
196 | 1 for comp in components.values() if comp.status in ["healthy", "degraded"]
197 | )
198 | availability_percentage = (
199 | (available_components / total_components * 100)
200 | if total_components > 0
201 | else 0
202 | )
203 |
204 | # Calculate circuit breaker metrics
205 | total_breakers = len(circuit_breaker_status)
206 | closed_breakers = sum(
207 | 1 for cb in circuit_breaker_status.values() if cb.get("state") == "closed"
208 | )
209 | breaker_health = (
210 | (closed_breakers / total_breakers * 100) if total_breakers > 0 else 100
211 | )
212 |
213 | # Get resource metrics
214 | cpu_usage = resource_usage.get("cpu_percent", 0)
215 | memory_usage = resource_usage.get("memory_percent", 0)
216 | disk_usage = resource_usage.get("disk_percent", 0)
217 |
218 | # Calculate system health score (0-100)
219 | health_score = self._calculate_health_score(
220 | availability_percentage,
221 | breaker_health,
222 | cpu_usage,
223 | memory_usage,
224 | avg_response_time,
225 | )
226 |
227 | return {
228 | "availability_percentage": round(availability_percentage, 2),
229 | "average_response_time_ms": round(avg_response_time, 2),
230 | "circuit_breaker_health": round(breaker_health, 2),
231 | "system_health_score": round(health_score, 1),
232 | "resource_utilization": {
233 | "cpu_percent": cpu_usage,
234 | "memory_percent": memory_usage,
235 | "disk_percent": disk_usage,
236 | },
237 | "performance_indicators": {
238 | "total_components": total_components,
239 | "available_components": available_components,
240 | "response_times_collected": len(response_times),
241 | "circuit_breakers_closed": closed_breakers,
242 | "circuit_breakers_total": total_breakers,
243 | },
244 | }
245 |
246 | def _calculate_health_score(
247 | self,
248 | availability: float,
249 | breaker_health: float,
250 | cpu_usage: float,
251 | memory_usage: float,
252 | response_time: float,
253 | ) -> float:
254 | """Calculate overall system health score (0-100)."""
255 | # Weighted scoring
256 | weights = {
257 | "availability": 0.3,
258 | "breaker_health": 0.25,
259 | "cpu_performance": 0.2,
260 | "memory_performance": 0.15,
261 | "response_time": 0.1,
262 | }
263 |
264 | # Calculate individual scores (higher is better)
265 | availability_score = availability # Already 0-100
266 |
267 | breaker_score = breaker_health # Already 0-100
268 |
269 | # CPU score (invert usage - lower usage is better)
270 | cpu_score = max(0, 100 - cpu_usage)
271 |
272 | # Memory score (invert usage - lower usage is better)
273 | memory_score = max(0, 100 - memory_usage)
274 |
275 | # Response time score (lower is better, scale to 0-100)
276 | if response_time <= 100:
277 | response_score = 100
278 | elif response_time <= 1000:
279 | response_score = (
280 | 100 - (response_time - 100) / 9
281 | ) # Linear decay from 100 to 0
282 | else:
283 | response_score = max(
284 | 0, 100 - response_time / 50
285 | ) # Slower decay for very slow responses
286 |
287 | # Calculate weighted score
288 | health_score = (
289 | availability_score * weights["availability"]
290 | + breaker_score * weights["breaker_health"]
291 | + cpu_score * weights["cpu_performance"]
292 | + memory_score * weights["memory_performance"]
293 | + response_score * weights["response_time"]
294 | )
295 |
296 | return min(100, max(0, health_score))
297 |
298 | def _generate_alerts(
299 | self, health_status: dict[str, Any], metrics: dict[str, Any]
300 | ) -> list[dict[str, Any]]:
301 | """Generate alerts based on health status and metrics."""
302 | alerts = []
303 |
304 | # Check overall system health
305 | if health_status.get("status") == "unhealthy":
306 | alerts.append(
307 | {
308 | "severity": "critical",
309 | "type": "system_health",
310 | "title": "System Unhealthy",
311 | "message": "One or more critical components are unhealthy",
312 | "timestamp": datetime.now(UTC).isoformat(),
313 | }
314 | )
315 | elif health_status.get("status") == "degraded":
316 | alerts.append(
317 | {
318 | "severity": "warning",
319 | "type": "system_health",
320 | "title": "System Degraded",
321 | "message": "System is operating with reduced functionality",
322 | "timestamp": datetime.now(UTC).isoformat(),
323 | }
324 | )
325 |
326 | # Check resource usage
327 | resource_usage = health_status.get("resource_usage", {})
328 |
329 | if resource_usage.get("cpu_percent", 0) > self.alert_thresholds["cpu_usage"]:
330 | alerts.append(
331 | {
332 | "severity": "warning",
333 | "type": "resource_usage",
334 | "title": "High CPU Usage",
335 | "message": f"CPU usage is {resource_usage.get('cpu_percent')}%, above threshold of {self.alert_thresholds['cpu_usage']}%",
336 | "timestamp": datetime.now(UTC).isoformat(),
337 | }
338 | )
339 |
340 | if (
341 | resource_usage.get("memory_percent", 0)
342 | > self.alert_thresholds["memory_usage"]
343 | ):
344 | alerts.append(
345 | {
346 | "severity": "warning",
347 | "type": "resource_usage",
348 | "title": "High Memory Usage",
349 | "message": f"Memory usage is {resource_usage.get('memory_percent')}%, above threshold of {self.alert_thresholds['memory_usage']}%",
350 | "timestamp": datetime.now(UTC).isoformat(),
351 | }
352 | )
353 |
354 | if resource_usage.get("disk_percent", 0) > self.alert_thresholds["disk_usage"]:
355 | alerts.append(
356 | {
357 | "severity": "critical",
358 | "type": "resource_usage",
359 | "title": "High Disk Usage",
360 | "message": f"Disk usage is {resource_usage.get('disk_percent')}%, above threshold of {self.alert_thresholds['disk_usage']}%",
361 | "timestamp": datetime.now(UTC).isoformat(),
362 | }
363 | )
364 |
365 | # Check response times
366 | avg_response_time = metrics.get("average_response_time_ms", 0)
367 | if avg_response_time > self.alert_thresholds["response_time_ms"]:
368 | alerts.append(
369 | {
370 | "severity": "warning",
371 | "type": "performance",
372 | "title": "Slow Response Times",
373 | "message": f"Average response time is {avg_response_time:.1f}ms, above threshold of {self.alert_thresholds['response_time_ms']}ms",
374 | "timestamp": datetime.now(UTC).isoformat(),
375 | }
376 | )
377 |
378 | # Check circuit breakers
379 | circuit_breakers = health_status.get("circuit_breakers", {})
380 | for name, breaker in circuit_breakers.items():
381 | if breaker.state == "open":
382 | alerts.append(
383 | {
384 | "severity": "critical",
385 | "type": "circuit_breaker",
386 | "title": f"Circuit Breaker Open: {name}",
387 | "message": f"Circuit breaker for {name} is open due to failures",
388 | "timestamp": datetime.now(UTC).isoformat(),
389 | }
390 | )
391 | elif breaker.state == "half_open":
392 | alerts.append(
393 | {
394 | "severity": "info",
395 | "type": "circuit_breaker",
396 | "title": f"Circuit Breaker Testing: {name}",
397 | "message": f"Circuit breaker for {name} is testing recovery",
398 | "timestamp": datetime.now(UTC).isoformat(),
399 | }
400 | )
401 |
402 | return alerts
403 |
404 | def _update_historical_data(
405 | self, health_status: dict[str, Any], metrics: dict[str, Any]
406 | ):
407 | """Update historical data for trending."""
408 | timestamp = datetime.now(UTC)
409 |
410 | # Add current data point
411 | data_point = {
412 | "timestamp": timestamp.isoformat(),
413 | "health_score": metrics.get("system_health_score", 0),
414 | "availability": metrics.get("availability_percentage", 0),
415 | "response_time": metrics.get("average_response_time_ms", 0),
416 | "cpu_usage": health_status.get("resource_usage", {}).get("cpu_percent", 0),
417 | "memory_usage": health_status.get("resource_usage", {}).get(
418 | "memory_percent", 0
419 | ),
420 | "circuit_breaker_health": metrics.get("circuit_breaker_health", 100),
421 | }
422 |
423 | self.historical_data.append(data_point)
424 |
425 | # Clean up old data
426 | cutoff_time = timestamp - timedelta(hours=HISTORICAL_DATA_RETENTION)
427 | self.historical_data = [
428 | point
429 | for point in self.historical_data
430 | if datetime.fromisoformat(point["timestamp"].replace("Z", "+00:00"))
431 | > cutoff_time
432 | ]
433 |
434 | def _get_historical_data(self) -> dict[str, Any]:
435 | """Get historical data for trending charts."""
436 | if not self.historical_data:
437 | return {"data": [], "summary": {"points": 0, "timespan_hours": 0}}
438 |
439 | # Calculate summary
440 | summary = {
441 | "points": len(self.historical_data),
442 | "timespan_hours": HISTORICAL_DATA_RETENTION,
443 | "avg_health_score": sum(p["health_score"] for p in self.historical_data)
444 | / len(self.historical_data),
445 | "avg_availability": sum(p["availability"] for p in self.historical_data)
446 | / len(self.historical_data),
447 | "avg_response_time": sum(p["response_time"] for p in self.historical_data)
448 | / len(self.historical_data),
449 | }
450 |
451 | # Downsample data if we have too many points (keep last 100 points for visualization)
452 | data = self.historical_data
453 | if len(data) > 100:
454 | step = len(data) // 100
455 | data = data[::step]
456 |
457 | return {
458 | "data": data,
459 | "summary": summary,
460 | }
461 |
462 | def _get_error_dashboard(self, error_message: str) -> dict[str, Any]:
463 | """Get minimal dashboard data when there's an error."""
464 | return {
465 | "overview": {
466 | "overall_status": "error",
467 | "health_percentage": 0,
468 | "error": error_message,
469 | },
470 | "components": {},
471 | "circuit_breakers": {},
472 | "resources": {},
473 | "metrics": {},
474 | "alerts": [
475 | {
476 | "severity": "critical",
477 | "type": "dashboard_error",
478 | "title": "Dashboard Error",
479 | "message": f"Failed to generate dashboard data: {error_message}",
480 | "timestamp": datetime.now(UTC).isoformat(),
481 | }
482 | ],
483 | "historical": {"data": [], "summary": {"points": 0, "timespan_hours": 0}},
484 | "metadata": {
485 | "last_updated": datetime.now(UTC).isoformat(),
486 | "dashboard_version": "1.0.0",
487 | "error": True,
488 | },
489 | }
490 |
491 | def get_alert_summary(self) -> dict[str, Any]:
492 | """Get a summary of current alerts."""
493 | try:
494 | # This would typically use cached data or a quick check
495 | return {
496 | "total_alerts": 0,
497 | "critical": 0,
498 | "warning": 0,
499 | "info": 0,
500 | "last_check": datetime.now(UTC).isoformat(),
501 | }
502 | except Exception as e:
503 | logger.error(f"Failed to get alert summary: {e}")
504 | return {
505 | "total_alerts": 1,
506 | "critical": 1,
507 | "warning": 0,
508 | "info": 0,
509 | "error": str(e),
510 | "last_check": datetime.now(UTC).isoformat(),
511 | }
512 |
513 |
514 | # Global dashboard instance
515 | _dashboard = StatusDashboard()
516 |
517 |
518 | def get_status_dashboard() -> StatusDashboard:
519 | """Get the global status dashboard instance."""
520 | return _dashboard
521 |
522 |
523 | async def get_dashboard_data() -> dict[str, Any]:
524 | """Get dashboard data (convenience function)."""
525 | return await _dashboard.get_dashboard_data()
526 |
527 |
528 | def get_dashboard_metadata() -> dict[str, Any]:
529 | """Get dashboard metadata."""
530 | return {
531 | "version": "1.0.0",
532 | "last_updated": _dashboard.last_update.isoformat()
533 | if _dashboard.last_update
534 | else None,
535 | "uptime_seconds": time.time() - _dashboard.start_time,
536 | "refresh_interval": DASHBOARD_REFRESH_INTERVAL,
537 | "retention_hours": HISTORICAL_DATA_RETENTION,
538 | }
539 |
```
--------------------------------------------------------------------------------
/maverick_mcp/workflows/backtesting_workflow.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Intelligent Backtesting Workflow using LangGraph.
3 |
4 | This workflow orchestrates market regime analysis, strategy selection, parameter optimization,
5 | and validation to provide intelligent, confidence-scored backtesting recommendations.
6 | """
7 |
8 | import logging
9 | from datetime import datetime, timedelta
10 | from typing import Any
11 |
12 | from langchain_core.messages import HumanMessage
13 | from langgraph.graph import END, StateGraph
14 |
15 | from maverick_mcp.workflows.agents import (
16 | MarketAnalyzerAgent,
17 | OptimizerAgent,
18 | StrategySelectorAgent,
19 | ValidatorAgent,
20 | )
21 | from maverick_mcp.workflows.state import BacktestingWorkflowState
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class BacktestingWorkflow:
27 | """Intelligent backtesting workflow orchestrator."""
28 |
29 | def __init__(
30 | self,
31 | market_analyzer: MarketAnalyzerAgent | None = None,
32 | strategy_selector: StrategySelectorAgent | None = None,
33 | optimizer: OptimizerAgent | None = None,
34 | validator: ValidatorAgent | None = None,
35 | ):
36 | """Initialize backtesting workflow.
37 |
38 | Args:
39 | market_analyzer: Market regime analysis agent
40 | strategy_selector: Strategy selection agent
41 | optimizer: Parameter optimization agent
42 | validator: Results validation agent
43 | """
44 | self.market_analyzer = market_analyzer or MarketAnalyzerAgent()
45 | self.strategy_selector = strategy_selector or StrategySelectorAgent()
46 | self.optimizer = optimizer or OptimizerAgent()
47 | self.validator = validator or ValidatorAgent()
48 |
49 | # Build the workflow graph
50 | self.workflow = self._build_workflow_graph()
51 |
52 | logger.info("BacktestingWorkflow initialized")
53 |
54 | def _build_workflow_graph(self) -> StateGraph:
55 | """Build the LangGraph workflow."""
56 | # Define the workflow graph
57 | workflow = StateGraph(BacktestingWorkflowState)
58 |
59 | # Add nodes for each step
60 | workflow.add_node("initialize", self._initialize_workflow)
61 | workflow.add_node("analyze_market_regime", self._analyze_market_regime_node)
62 | workflow.add_node("select_strategies", self._select_strategies_node)
63 | workflow.add_node("optimize_parameters", self._optimize_parameters_node)
64 | workflow.add_node("validate_results", self._validate_results_node)
65 | workflow.add_node("finalize_workflow", self._finalize_workflow)
66 |
67 | # Define the workflow flow
68 | workflow.set_entry_point("initialize")
69 |
70 | # Sequential workflow with conditional routing
71 | workflow.add_edge("initialize", "analyze_market_regime")
72 | workflow.add_conditional_edges(
73 | "analyze_market_regime",
74 | self._should_proceed_after_market_analysis,
75 | {
76 | "continue": "select_strategies",
77 | "fallback": "finalize_workflow",
78 | },
79 | )
80 | workflow.add_conditional_edges(
81 | "select_strategies",
82 | self._should_proceed_after_strategy_selection,
83 | {
84 | "continue": "optimize_parameters",
85 | "fallback": "finalize_workflow",
86 | },
87 | )
88 | workflow.add_conditional_edges(
89 | "optimize_parameters",
90 | self._should_proceed_after_optimization,
91 | {
92 | "continue": "validate_results",
93 | "fallback": "finalize_workflow",
94 | },
95 | )
96 | workflow.add_edge("validate_results", "finalize_workflow")
97 | workflow.add_edge("finalize_workflow", END)
98 |
99 | return workflow.compile()
100 |
101 | async def run_intelligent_backtest(
102 | self,
103 | symbol: str,
104 | start_date: str | None = None,
105 | end_date: str | None = None,
106 | initial_capital: float = 10000.0,
107 | requested_strategy: str | None = None,
108 | ) -> dict[str, Any]:
109 | """Run intelligent backtesting workflow.
110 |
111 | Args:
112 | symbol: Stock symbol to analyze
113 | start_date: Start date (YYYY-MM-DD), defaults to 1 year ago
114 | end_date: End date (YYYY-MM-DD), defaults to today
115 | initial_capital: Starting capital for backtest
116 | requested_strategy: User-requested strategy (optional)
117 |
118 | Returns:
119 | Comprehensive backtesting results with recommendations
120 | """
121 | start_time = datetime.now()
122 |
123 | try:
124 | logger.info(f"Starting intelligent backtest workflow for {symbol}")
125 |
126 | # Set default date range if not provided
127 | if not end_date:
128 | end_date = datetime.now().strftime("%Y-%m-%d")
129 | if not start_date:
130 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
131 |
132 | # Initialize workflow state
133 | initial_state = self._create_initial_state(
134 | symbol=symbol,
135 | start_date=start_date,
136 | end_date=end_date,
137 | initial_capital=initial_capital,
138 | requested_strategy=requested_strategy,
139 | )
140 |
141 | # Run the workflow
142 | final_state = await self.workflow.ainvoke(initial_state)
143 |
144 | # Convert state to results dictionary
145 | results = self._format_results(final_state)
146 |
147 | # Add execution metadata
148 | total_execution_time = (datetime.now() - start_time).total_seconds() * 1000
149 | results["execution_metadata"] = {
150 | "total_execution_time_ms": total_execution_time,
151 | "workflow_completed": final_state.workflow_status == "completed",
152 | "steps_completed": final_state.steps_completed,
153 | "errors_encountered": final_state.errors_encountered,
154 | "fallback_strategies_used": final_state.fallback_strategies_used,
155 | }
156 |
157 | logger.info(
158 | f"Intelligent backtest completed for {symbol} in {total_execution_time:.0f}ms: "
159 | f"{final_state.recommended_strategy} recommended with {final_state.recommendation_confidence:.1%} confidence"
160 | )
161 |
162 | return results
163 |
164 | except Exception as e:
165 | logger.error(f"Intelligent backtest failed for {symbol}: {e}")
166 | return {
167 | "symbol": symbol,
168 | "error": str(e),
169 | "execution_metadata": {
170 | "total_execution_time_ms": (
171 | datetime.now() - start_time
172 | ).total_seconds()
173 | * 1000,
174 | "workflow_completed": False,
175 | },
176 | }
177 |
178 | def _create_initial_state(
179 | self,
180 | symbol: str,
181 | start_date: str,
182 | end_date: str,
183 | initial_capital: float,
184 | requested_strategy: str | None,
185 | ) -> BacktestingWorkflowState:
186 | """Create initial workflow state."""
187 | return BacktestingWorkflowState(
188 | # Base agent state
189 | messages=[
190 | HumanMessage(content=f"Analyze backtesting opportunities for {symbol}")
191 | ],
192 | session_id=f"backtest_{symbol}_{datetime.now().isoformat()}",
193 | persona="intelligent_backtesting_agent",
194 | timestamp=datetime.now(),
195 | token_count=0,
196 | error=None,
197 | analyzed_stocks={},
198 | key_price_levels={},
199 | last_analysis_time={},
200 | conversation_context={},
201 | execution_time_ms=None,
202 | api_calls_made=0,
203 | cache_hits=0,
204 | cache_misses=0,
205 | # Input parameters
206 | symbol=symbol,
207 | start_date=start_date,
208 | end_date=end_date,
209 | initial_capital=initial_capital,
210 | requested_strategy=requested_strategy,
211 | # Market regime analysis (initialized)
212 | market_regime="unknown",
213 | regime_confidence=0.0,
214 | regime_indicators={},
215 | regime_analysis_time_ms=0.0,
216 | volatility_percentile=0.0,
217 | trend_strength=0.0,
218 | market_conditions={},
219 | sector_performance={},
220 | correlation_to_market=0.0,
221 | volume_profile={},
222 | support_resistance_levels=[],
223 | # Strategy selection (initialized)
224 | candidate_strategies=[],
225 | strategy_rankings={},
226 | selected_strategies=[],
227 | strategy_selection_reasoning="",
228 | strategy_selection_confidence=0.0,
229 | # Parameter optimization (initialized)
230 | optimization_config={},
231 | parameter_grids={},
232 | optimization_results={},
233 | best_parameters={},
234 | optimization_time_ms=0.0,
235 | optimization_iterations=0,
236 | # Validation (initialized)
237 | walk_forward_results={},
238 | monte_carlo_results={},
239 | out_of_sample_performance={},
240 | robustness_score={},
241 | validation_warnings=[],
242 | # Final recommendations (initialized)
243 | final_strategy_ranking=[],
244 | recommended_strategy="",
245 | recommended_parameters={},
246 | recommendation_confidence=0.0,
247 | risk_assessment={},
248 | # Performance metrics (initialized)
249 | comparative_metrics={},
250 | benchmark_comparison={},
251 | risk_adjusted_performance={},
252 | drawdown_analysis={},
253 | # Workflow control (initialized)
254 | workflow_status="initializing",
255 | current_step="initialization",
256 | steps_completed=[],
257 | total_execution_time_ms=0.0,
258 | # Error handling (initialized)
259 | errors_encountered=[],
260 | fallback_strategies_used=[],
261 | data_quality_issues=[],
262 | # Caching (initialized)
263 | cached_results={},
264 | cache_hit_rate=0.0,
265 | # Advanced analysis (initialized)
266 | regime_transition_analysis={},
267 | multi_timeframe_analysis={},
268 | correlation_analysis={},
269 | macroeconomic_context={},
270 | )
271 |
272 | async def _initialize_workflow(
273 | self, state: BacktestingWorkflowState
274 | ) -> BacktestingWorkflowState:
275 | """Initialize the workflow and validate inputs."""
276 | logger.info(f"Initializing backtesting workflow for {state.symbol}")
277 |
278 | # Validate inputs
279 | if not state.symbol:
280 | state.errors_encountered.append(
281 | {
282 | "step": "initialization",
283 | "error": "Symbol is required",
284 | "timestamp": datetime.now().isoformat(),
285 | }
286 | )
287 | state.workflow_status = "failed"
288 | return state
289 |
290 | # Update workflow state
291 | state.workflow_status = "analyzing_regime"
292 | state.current_step = "initialization_completed"
293 | state.steps_completed.append("initialization")
294 |
295 | logger.info(f"Workflow initialized for {state.symbol}")
296 | return state
297 |
298 | async def _analyze_market_regime_node(
299 | self, state: BacktestingWorkflowState
300 | ) -> BacktestingWorkflowState:
301 | """Market regime analysis node."""
302 | return await self.market_analyzer.analyze_market_regime(state)
303 |
304 | async def _select_strategies_node(
305 | self, state: BacktestingWorkflowState
306 | ) -> BacktestingWorkflowState:
307 | """Strategy selection node."""
308 | return await self.strategy_selector.select_strategies(state)
309 |
310 | async def _optimize_parameters_node(
311 | self, state: BacktestingWorkflowState
312 | ) -> BacktestingWorkflowState:
313 | """Parameter optimization node."""
314 | return await self.optimizer.optimize_parameters(state)
315 |
316 | async def _validate_results_node(
317 | self, state: BacktestingWorkflowState
318 | ) -> BacktestingWorkflowState:
319 | """Results validation node."""
320 | return await self.validator.validate_strategies(state)
321 |
322 | async def _finalize_workflow(
323 | self, state: BacktestingWorkflowState
324 | ) -> BacktestingWorkflowState:
325 | """Finalize the workflow and prepare results."""
326 | if state.workflow_status != "completed":
327 | # Handle incomplete workflow
328 | if not state.recommended_strategy and state.best_parameters:
329 | # Select first available strategy as fallback
330 | state.recommended_strategy = list(state.best_parameters.keys())[0]
331 | state.recommended_parameters = state.best_parameters[
332 | state.recommended_strategy
333 | ]
334 | state.recommendation_confidence = 0.3
335 | state.fallback_strategies_used.append("incomplete_workflow_fallback")
336 |
337 | state.current_step = "workflow_finalized"
338 | logger.info(f"Workflow finalized for {state.symbol}")
339 | return state
340 |
341 | def _should_proceed_after_market_analysis(
342 | self, state: BacktestingWorkflowState
343 | ) -> str:
344 | """Decide whether to proceed after market analysis."""
345 | if state.errors_encountered and any(
346 | "market_regime_analysis" in err.get("step", "")
347 | for err in state.errors_encountered
348 | ):
349 | return "fallback"
350 | if state.market_regime == "unknown" and state.regime_confidence < 0.1:
351 | return "fallback"
352 | return "continue"
353 |
354 | def _should_proceed_after_strategy_selection(
355 | self, state: BacktestingWorkflowState
356 | ) -> str:
357 | """Decide whether to proceed after strategy selection."""
358 | if not state.selected_strategies:
359 | return "fallback"
360 | if state.strategy_selection_confidence < 0.2:
361 | return "fallback"
362 | return "continue"
363 |
364 | def _should_proceed_after_optimization(
365 | self, state: BacktestingWorkflowState
366 | ) -> str:
367 | """Decide whether to proceed after optimization."""
368 | if not state.best_parameters:
369 | return "fallback"
370 | return "continue"
371 |
372 | def _format_results(self, state: BacktestingWorkflowState) -> dict[str, Any]:
373 | """Format final results for output."""
374 | return {
375 | "symbol": state.symbol,
376 | "period": {
377 | "start_date": state.start_date,
378 | "end_date": state.end_date,
379 | "initial_capital": state.initial_capital,
380 | },
381 | "market_analysis": {
382 | "regime": state.market_regime,
383 | "regime_confidence": state.regime_confidence,
384 | "regime_indicators": state.regime_indicators,
385 | "volatility_percentile": state.volatility_percentile,
386 | "trend_strength": state.trend_strength,
387 | "market_conditions": state.market_conditions,
388 | "support_resistance_levels": state.support_resistance_levels,
389 | },
390 | "strategy_selection": {
391 | "selected_strategies": state.selected_strategies,
392 | "strategy_rankings": state.strategy_rankings,
393 | "selection_reasoning": state.strategy_selection_reasoning,
394 | "selection_confidence": state.strategy_selection_confidence,
395 | "candidate_strategies": state.candidate_strategies,
396 | },
397 | "optimization": {
398 | "optimization_config": state.optimization_config,
399 | "best_parameters": state.best_parameters,
400 | "optimization_iterations": state.optimization_iterations,
401 | "optimization_time_ms": state.optimization_time_ms,
402 | },
403 | "validation": {
404 | "robustness_scores": state.robustness_score,
405 | "validation_warnings": state.validation_warnings,
406 | "out_of_sample_performance": state.out_of_sample_performance,
407 | },
408 | "recommendation": {
409 | "recommended_strategy": state.recommended_strategy,
410 | "recommended_parameters": state.recommended_parameters,
411 | "recommendation_confidence": state.recommendation_confidence,
412 | "final_strategy_ranking": state.final_strategy_ranking,
413 | "risk_assessment": state.risk_assessment,
414 | },
415 | "performance_analysis": {
416 | "comparative_metrics": state.comparative_metrics,
417 | "benchmark_comparison": state.benchmark_comparison,
418 | "risk_adjusted_performance": state.risk_adjusted_performance,
419 | },
420 | }
421 |
422 | async def run_quick_analysis(
423 | self,
424 | symbol: str,
425 | start_date: str | None = None,
426 | end_date: str | None = None,
427 | initial_capital: float = 10000.0,
428 | ) -> dict[str, Any]:
429 | """Run quick analysis with market regime detection and basic strategy recommendations.
430 |
431 | This is a faster alternative that skips parameter optimization and validation
432 | for rapid insights.
433 |
434 | Args:
435 | symbol: Stock symbol to analyze
436 | start_date: Start date (YYYY-MM-DD)
437 | end_date: End date (YYYY-MM-DD)
438 | initial_capital: Starting capital
439 |
440 | Returns:
441 | Quick analysis results with strategy recommendations
442 | """
443 | start_time = datetime.now()
444 |
445 | try:
446 | logger.info(f"Running quick analysis for {symbol}")
447 |
448 | # Set default dates
449 | if not end_date:
450 | end_date = datetime.now().strftime("%Y-%m-%d")
451 | if not start_date:
452 | start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
453 |
454 | # Create initial state
455 | state = self._create_initial_state(
456 | symbol=symbol,
457 | start_date=start_date,
458 | end_date=end_date,
459 | initial_capital=initial_capital,
460 | requested_strategy=None,
461 | )
462 |
463 | # Run market analysis
464 | state = await self.market_analyzer.analyze_market_regime(state)
465 |
466 | # Run strategy selection
467 | if state["market_regime"] != "unknown" or state["regime_confidence"] > 0.3:
468 | state = await self.strategy_selector.select_strategies(state)
469 |
470 | # Format quick results
471 | execution_time = (datetime.now() - start_time).total_seconds() * 1000
472 |
473 | return {
474 | "symbol": symbol,
475 | "analysis_type": "quick_analysis",
476 | "market_regime": {
477 | "regime": state["market_regime"],
478 | "confidence": state["regime_confidence"],
479 | "trend_strength": state["trend_strength"],
480 | "volatility_percentile": state["volatility_percentile"],
481 | },
482 | "recommended_strategies": state["selected_strategies"][:3], # Top 3
483 | "strategy_fitness": {
484 | strategy: state["strategy_rankings"].get(strategy, 0)
485 | for strategy in state["selected_strategies"][:3]
486 | },
487 | "market_conditions": state["market_conditions"],
488 | "selection_reasoning": state["strategy_selection_reasoning"],
489 | "execution_time_ms": execution_time,
490 | "data_quality": {
491 | "errors": len(state["errors_encountered"]),
492 | "warnings": state["data_quality_issues"],
493 | },
494 | }
495 |
496 | except Exception as e:
497 | logger.error(f"Quick analysis failed for {symbol}: {e}")
498 | return {
499 | "symbol": symbol,
500 | "analysis_type": "quick_analysis",
501 | "error": str(e),
502 | "execution_time_ms": (datetime.now() - start_time).total_seconds()
503 | * 1000,
504 | }
505 |
506 | def get_workflow_status(self, state: BacktestingWorkflowState) -> dict[str, Any]:
507 | """Get current workflow status and progress."""
508 | total_steps = 5 # initialize, analyze, select, optimize, validate
509 | completed_steps = len(state.steps_completed)
510 |
511 | return {
512 | "workflow_status": state.workflow_status,
513 | "current_step": state.current_step,
514 | "progress_percentage": (completed_steps / total_steps) * 100,
515 | "steps_completed": state.steps_completed,
516 | "errors_count": len(state.errors_encountered),
517 | "warnings_count": len(state.validation_warnings),
518 | "execution_time_ms": state.total_execution_time_ms,
519 | "recommended_strategy": state.recommended_strategy or "TBD",
520 | "recommendation_confidence": state.recommendation_confidence,
521 | }
522 |
```