This is page 22 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_security_cors.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive CORS Security Tests for Maverick MCP.
3 |
4 | Tests CORS configuration, validation, origin blocking, wildcard security,
5 | and environment-specific behaviors.
6 | """
7 |
8 | import os
9 | from unittest.mock import MagicMock, patch
10 |
11 | import pytest
12 | from fastapi import FastAPI
13 | from fastapi.middleware.cors import CORSMiddleware
14 | from fastapi.testclient import TestClient
15 |
16 | from maverick_mcp.config.security import (
17 | CORSConfig,
18 | SecurityConfig,
19 | validate_security_config,
20 | )
21 | from maverick_mcp.config.security_utils import (
22 | apply_cors_to_fastapi,
23 | check_security_config,
24 | get_safe_cors_config,
25 | )
26 |
27 |
28 | class TestCORSConfiguration:
29 | """Test CORS configuration validation and creation."""
30 |
31 | def test_cors_config_valid_origins(self):
32 | """Test CORS config creation with valid origins."""
33 | config = CORSConfig(
34 | allowed_origins=["https://example.com", "https://app.example.com"],
35 | allow_credentials=True,
36 | )
37 |
38 | assert config.allowed_origins == [
39 | "https://example.com",
40 | "https://app.example.com",
41 | ]
42 | assert config.allow_credentials is True
43 |
44 | def test_cors_config_wildcard_with_credentials_raises_error(self):
45 | """Test that wildcard origins with credentials raises validation error."""
46 | with pytest.raises(
47 | ValueError,
48 | match="CORS Security Error.*wildcard origin.*serious security vulnerability",
49 | ):
50 | CORSConfig(allowed_origins=["*"], allow_credentials=True)
51 |
52 | def test_cors_config_wildcard_without_credentials_warns(self):
53 | """Test that wildcard origins without credentials logs warning."""
54 | with patch("logging.getLogger") as mock_logger:
55 | mock_logger_instance = MagicMock()
56 | mock_logger.return_value = mock_logger_instance
57 |
58 | config = CORSConfig(allowed_origins=["*"], allow_credentials=False)
59 |
60 | assert config.allowed_origins == ["*"]
61 | assert config.allow_credentials is False
62 | mock_logger_instance.warning.assert_called_once()
63 |
64 | def test_cors_config_multiple_origins_with_wildcard_fails(self):
65 | """Test that mixed origins including wildcard with credentials fails."""
66 | with pytest.raises(ValueError, match="CORS Security Error"):
67 | CORSConfig(
68 | allowed_origins=["https://example.com", "*"], allow_credentials=True
69 | )
70 |
71 | def test_cors_config_default_values(self):
72 | """Test CORS config default values are secure."""
73 | with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=False):
74 | with patch(
75 | "maverick_mcp.config.security._get_cors_origins"
76 | ) as mock_origins:
77 | mock_origins.return_value = ["http://localhost:3000"]
78 |
79 | config = CORSConfig()
80 |
81 | assert config.allow_credentials is True
82 | assert "GET" in config.allowed_methods
83 | assert "POST" in config.allowed_methods
84 | assert "Authorization" in config.allowed_headers
85 | assert "Content-Type" in config.allowed_headers
86 | assert config.max_age == 86400
87 |
88 | def test_cors_config_expose_headers(self):
89 | """Test that proper headers are exposed to clients."""
90 | config = CORSConfig()
91 |
92 | expected_exposed = [
93 | "X-Process-Time",
94 | "X-RateLimit-Limit",
95 | "X-RateLimit-Remaining",
96 | "X-RateLimit-Reset",
97 | "X-Request-ID",
98 | ]
99 |
100 | for header in expected_exposed:
101 | assert header in config.exposed_headers
102 |
103 |
104 | class TestCORSEnvironmentConfiguration:
105 | """Test environment-specific CORS configuration."""
106 |
107 | def test_production_cors_origins(self):
108 | """Test production CORS origins are restrictive."""
109 | with patch.dict(os.environ, {"ENVIRONMENT": "production"}, clear=True):
110 | with patch(
111 | "maverick_mcp.config.security._get_cors_origins"
112 | ) as mock_origins:
113 | mock_origins.return_value = [
114 | "https://app.maverick-mcp.com",
115 | "https://maverick-mcp.com",
116 | ]
117 |
118 | config = SecurityConfig()
119 |
120 | assert "localhost" not in str(config.cors.allowed_origins).lower()
121 | assert "127.0.0.1" not in str(config.cors.allowed_origins).lower()
122 | assert all(
123 | origin.startswith("https://")
124 | for origin in config.cors.allowed_origins
125 | )
126 |
127 | def test_development_cors_origins(self):
128 | """Test development CORS origins include localhost."""
129 | with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=True):
130 | with patch(
131 | "maverick_mcp.config.security._get_cors_origins"
132 | ) as mock_origins:
133 | mock_origins.return_value = [
134 | "http://localhost:3000",
135 | "http://127.0.0.1:3000",
136 | ]
137 |
138 | config = SecurityConfig()
139 |
140 | localhost_found = any(
141 | "localhost" in origin for origin in config.cors.allowed_origins
142 | )
143 | assert localhost_found
144 |
145 | def test_staging_cors_origins(self):
146 | """Test staging CORS origins are appropriate."""
147 | with patch.dict(os.environ, {"ENVIRONMENT": "staging"}, clear=True):
148 | with patch(
149 | "maverick_mcp.config.security._get_cors_origins"
150 | ) as mock_origins:
151 | mock_origins.return_value = [
152 | "https://staging.maverick-mcp.com",
153 | "http://localhost:3000",
154 | ]
155 |
156 | config = SecurityConfig()
157 |
158 | staging_found = any(
159 | "staging" in origin for origin in config.cors.allowed_origins
160 | )
161 | assert staging_found
162 |
163 | def test_custom_cors_origins_from_env(self):
164 | """Test custom CORS origins from environment variable."""
165 | custom_origins = "https://custom1.com,https://custom2.com"
166 |
167 | with patch.dict(os.environ, {"CORS_ORIGINS": custom_origins}, clear=False):
168 | with patch(
169 | "maverick_mcp.config.security._get_cors_origins"
170 | ) as mock_origins:
171 | mock_origins.return_value = [
172 | "https://custom1.com",
173 | "https://custom2.com",
174 | ]
175 |
176 | config = SecurityConfig()
177 |
178 | assert "https://custom1.com" in config.cors.allowed_origins
179 | assert "https://custom2.com" in config.cors.allowed_origins
180 |
181 |
182 | class TestCORSValidation:
183 | """Test CORS security validation."""
184 |
185 | def test_validate_security_config_valid_cors(self):
186 | """Test security validation passes with valid CORS config."""
187 | with patch("maverick_mcp.config.security.get_security_config") as mock_config:
188 | mock_security_config = MagicMock()
189 | mock_security_config.cors.allowed_origins = ["https://example.com"]
190 | mock_security_config.cors.allow_credentials = True
191 | mock_security_config.is_production.return_value = False
192 | mock_security_config.force_https = True
193 | mock_security_config.headers.x_frame_options = "DENY"
194 | mock_config.return_value = mock_security_config
195 |
196 | result = validate_security_config()
197 |
198 | assert result["valid"] is True
199 | assert len(result["issues"]) == 0
200 |
201 | def test_validate_security_config_wildcard_with_credentials(self):
202 | """Test security validation fails with wildcard + credentials."""
203 | with patch("maverick_mcp.config.security.get_security_config") as mock_config:
204 | mock_security_config = MagicMock()
205 | mock_security_config.cors.allowed_origins = ["*"]
206 | mock_security_config.cors.allow_credentials = True
207 | mock_security_config.is_production.return_value = False
208 | mock_security_config.force_https = True
209 | mock_security_config.headers.x_frame_options = "DENY"
210 | mock_config.return_value = mock_security_config
211 |
212 | result = validate_security_config()
213 |
214 | assert result["valid"] is False
215 | assert any(
216 | "Wildcard CORS origins with credentials enabled" in issue
217 | for issue in result["issues"]
218 | )
219 |
220 | def test_validate_security_config_production_wildcards(self):
221 | """Test security validation fails with wildcards in production."""
222 | with patch("maverick_mcp.config.security.get_security_config") as mock_config:
223 | mock_security_config = MagicMock()
224 | mock_security_config.cors.allowed_origins = ["*"]
225 | mock_security_config.cors.allow_credentials = False
226 | mock_security_config.is_production.return_value = True
227 | mock_security_config.force_https = True
228 | mock_security_config.headers.x_frame_options = "DENY"
229 | mock_config.return_value = mock_security_config
230 |
231 | result = validate_security_config()
232 |
233 | assert result["valid"] is False
234 | assert any(
235 | "Wildcard CORS origins in production" in issue
236 | for issue in result["issues"]
237 | )
238 |
239 | def test_validate_security_config_production_localhost_warning(self):
240 | """Test security validation warns about localhost in production."""
241 | with patch("maverick_mcp.config.security.get_security_config") as mock_config:
242 | mock_security_config = MagicMock()
243 | mock_security_config.cors.allowed_origins = [
244 | "https://app.com",
245 | "http://localhost:3000",
246 | ]
247 | mock_security_config.cors.allow_credentials = True
248 | mock_security_config.is_production.return_value = True
249 | mock_security_config.force_https = True
250 | mock_security_config.headers.x_frame_options = "DENY"
251 | mock_config.return_value = mock_security_config
252 |
253 | result = validate_security_config()
254 |
255 | assert result["valid"] is True # Warning, not error
256 | assert any("localhost" in warning.lower() for warning in result["warnings"])
257 |
258 |
259 | class TestCORSMiddlewareIntegration:
260 | """Test CORS middleware integration with FastAPI."""
261 |
262 | def create_test_app(self, security_config=None):
263 | """Create a test FastAPI app with CORS applied."""
264 | app = FastAPI()
265 |
266 | if security_config:
267 | with patch(
268 | "maverick_mcp.config.security_utils.get_security_config",
269 | return_value=security_config,
270 | ):
271 | apply_cors_to_fastapi(app)
272 | else:
273 | apply_cors_to_fastapi(app)
274 |
275 | @app.get("/test")
276 | async def test_endpoint():
277 | return {"message": "test"}
278 |
279 | @app.post("/test")
280 | async def test_post_endpoint():
281 | return {"message": "post test"}
282 |
283 | return app
284 |
285 | def test_cors_middleware_allows_configured_origins(self):
286 | """Test that CORS middleware allows configured origins."""
287 | # Create mock security config
288 | mock_config = MagicMock()
289 | mock_config.get_cors_middleware_config.return_value = {
290 | "allow_origins": ["https://allowed.com"],
291 | "allow_credentials": True,
292 | "allow_methods": ["GET", "POST"],
293 | "allow_headers": ["Content-Type", "Authorization"],
294 | "expose_headers": [],
295 | "max_age": 86400,
296 | }
297 |
298 | # Mock validation to pass
299 | with patch(
300 | "maverick_mcp.config.security_utils.validate_security_config"
301 | ) as mock_validate:
302 | mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
303 |
304 | app = self.create_test_app(mock_config)
305 | client = TestClient(app)
306 |
307 | # Test preflight request
308 | response = client.options(
309 | "/test",
310 | headers={
311 | "Origin": "https://allowed.com",
312 | "Access-Control-Request-Method": "POST",
313 | "Access-Control-Request-Headers": "Content-Type",
314 | },
315 | )
316 |
317 | assert response.status_code == 200
318 | assert (
319 | response.headers.get("Access-Control-Allow-Origin")
320 | == "https://allowed.com"
321 | )
322 | assert "POST" in response.headers.get("Access-Control-Allow-Methods", "")
323 |
324 | def test_cors_middleware_blocks_unauthorized_origins(self):
325 | """Test that CORS middleware blocks unauthorized origins."""
326 | mock_config = MagicMock()
327 | mock_config.get_cors_middleware_config.return_value = {
328 | "allow_origins": ["https://allowed.com"],
329 | "allow_credentials": True,
330 | "allow_methods": ["GET", "POST"],
331 | "allow_headers": ["Content-Type"],
332 | "expose_headers": [],
333 | "max_age": 86400,
334 | }
335 |
336 | with patch(
337 | "maverick_mcp.config.security_utils.validate_security_config"
338 | ) as mock_validate:
339 | mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
340 |
341 | app = self.create_test_app(mock_config)
342 | client = TestClient(app)
343 |
344 | # Test request from unauthorized origin
345 | response = client.get(
346 | "/test", headers={"Origin": "https://unauthorized.com"}
347 | )
348 |
349 | # The request should succeed (CORS is browser-enforced)
350 | # but the CORS headers should not allow the unauthorized origin
351 | assert response.status_code == 200
352 | cors_origin = response.headers.get("Access-Control-Allow-Origin")
353 | assert cors_origin != "https://unauthorized.com"
354 |
355 | def test_cors_middleware_credentials_handling(self):
356 | """Test CORS middleware credentials handling."""
357 | mock_config = MagicMock()
358 | mock_config.get_cors_middleware_config.return_value = {
359 | "allow_origins": ["https://allowed.com"],
360 | "allow_credentials": True,
361 | "allow_methods": ["GET", "POST"],
362 | "allow_headers": ["Content-Type"],
363 | "expose_headers": [],
364 | "max_age": 86400,
365 | }
366 |
367 | with patch(
368 | "maverick_mcp.config.security_utils.validate_security_config"
369 | ) as mock_validate:
370 | mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
371 |
372 | app = self.create_test_app(mock_config)
373 | client = TestClient(app)
374 |
375 | response = client.options(
376 | "/test",
377 | headers={
378 | "Origin": "https://allowed.com",
379 | "Access-Control-Request-Method": "POST",
380 | },
381 | )
382 |
383 | assert response.headers.get("Access-Control-Allow-Credentials") == "true"
384 |
385 | def test_cors_middleware_exposed_headers(self):
386 | """Test that CORS middleware exposes configured headers."""
387 | mock_config = MagicMock()
388 | mock_config.get_cors_middleware_config.return_value = {
389 | "allow_origins": ["https://allowed.com"],
390 | "allow_credentials": True,
391 | "allow_methods": ["GET"],
392 | "allow_headers": ["Content-Type"],
393 | "expose_headers": ["X-Custom-Header", "X-Rate-Limit"],
394 | "max_age": 86400,
395 | }
396 |
397 | with patch(
398 | "maverick_mcp.config.security_utils.validate_security_config"
399 | ) as mock_validate:
400 | mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
401 |
402 | app = self.create_test_app(mock_config)
403 | client = TestClient(app)
404 |
405 | response = client.get("/test", headers={"Origin": "https://allowed.com"})
406 |
407 | exposed_headers = response.headers.get("Access-Control-Expose-Headers", "")
408 | assert "X-Custom-Header" in exposed_headers
409 | assert "X-Rate-Limit" in exposed_headers
410 |
411 |
412 | class TestCORSSecurityValidation:
413 | """Test CORS security validation and safety measures."""
414 |
415 | def test_apply_cors_fails_with_invalid_config(self):
416 | """Test that applying CORS fails with invalid configuration."""
417 | app = FastAPI()
418 |
419 | # Mock invalid configuration
420 | with patch(
421 | "maverick_mcp.config.security_utils.validate_security_config"
422 | ) as mock_validate:
423 | mock_validate.return_value = {
424 | "valid": False,
425 | "issues": ["Wildcard CORS origins with credentials"],
426 | "warnings": [],
427 | }
428 |
429 | with pytest.raises(ValueError, match="Security configuration is invalid"):
430 | apply_cors_to_fastapi(app)
431 |
432 | def test_get_safe_cors_config_production_fallback(self):
433 | """Test safe CORS config fallback for production."""
434 | with patch(
435 | "maverick_mcp.config.security_utils.validate_security_config"
436 | ) as mock_validate:
437 | mock_validate.return_value = {
438 | "valid": False,
439 | "issues": ["Invalid config"],
440 | "warnings": [],
441 | }
442 |
443 | with patch(
444 | "maverick_mcp.config.security_utils.get_security_config"
445 | ) as mock_config:
446 | mock_security_config = MagicMock()
447 | mock_security_config.is_production.return_value = True
448 | mock_config.return_value = mock_security_config
449 |
450 | safe_config = get_safe_cors_config()
451 |
452 | assert safe_config["allow_origins"] == ["https://maverick-mcp.com"]
453 | assert safe_config["allow_credentials"] is True
454 | assert "localhost" not in str(safe_config["allow_origins"])
455 |
456 | def test_get_safe_cors_config_development_fallback(self):
457 | """Test safe CORS config fallback for development."""
458 | with patch(
459 | "maverick_mcp.config.security_utils.validate_security_config"
460 | ) as mock_validate:
461 | mock_validate.return_value = {
462 | "valid": False,
463 | "issues": ["Invalid config"],
464 | "warnings": [],
465 | }
466 |
467 | with patch(
468 | "maverick_mcp.config.security_utils.get_security_config"
469 | ) as mock_config:
470 | mock_security_config = MagicMock()
471 | mock_security_config.is_production.return_value = False
472 | mock_config.return_value = mock_security_config
473 |
474 | safe_config = get_safe_cors_config()
475 |
476 | assert safe_config["allow_origins"] == ["http://localhost:3000"]
477 | assert safe_config["allow_credentials"] is True
478 |
479 | def test_check_security_config_function(self):
480 | """Test security config check function."""
481 | with patch(
482 | "maverick_mcp.config.security_utils.validate_security_config"
483 | ) as mock_validate:
484 | # Test valid config
485 | mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
486 | assert check_security_config() is True
487 |
488 | # Test invalid config
489 | mock_validate.return_value = {
490 | "valid": False,
491 | "issues": ["Error"],
492 | "warnings": [],
493 | }
494 | assert check_security_config() is False
495 |
496 |
497 | class TestCORSPreflightRequests:
498 | """Test CORS preflight request handling."""
499 |
500 | def test_preflight_request_max_age(self):
501 | """Test CORS preflight max-age header."""
502 | app = FastAPI()
503 | app.add_middleware(
504 | CORSMiddleware,
505 | allow_origins=["https://example.com"],
506 | allow_methods=["GET", "POST"],
507 | allow_headers=["Content-Type"],
508 | max_age=3600,
509 | )
510 |
511 | @app.get("/test")
512 | async def test_endpoint():
513 | return {"message": "test"}
514 |
515 | client = TestClient(app)
516 |
517 | response = client.options(
518 | "/test",
519 | headers={
520 | "Origin": "https://example.com",
521 | "Access-Control-Request-Method": "GET",
522 | },
523 | )
524 |
525 | assert response.headers.get("Access-Control-Max-Age") == "3600"
526 |
527 | def test_preflight_request_methods(self):
528 | """Test CORS preflight allowed methods."""
529 | app = FastAPI()
530 | app.add_middleware(
531 | CORSMiddleware,
532 | allow_origins=["https://example.com"],
533 | allow_methods=["GET", "POST", "PUT"],
534 | allow_headers=["Content-Type"],
535 | )
536 |
537 | @app.get("/test")
538 | async def test_endpoint():
539 | return {"message": "test"}
540 |
541 | client = TestClient(app)
542 |
543 | response = client.options(
544 | "/test",
545 | headers={
546 | "Origin": "https://example.com",
547 | "Access-Control-Request-Method": "PUT",
548 | },
549 | )
550 |
551 | allowed_methods = response.headers.get("Access-Control-Allow-Methods", "")
552 | assert "PUT" in allowed_methods
553 | assert "GET" in allowed_methods
554 | assert "POST" in allowed_methods
555 |
556 | def test_preflight_request_headers(self):
557 | """Test CORS preflight allowed headers."""
558 | app = FastAPI()
559 | app.add_middleware(
560 | CORSMiddleware,
561 | allow_origins=["https://example.com"],
562 | allow_methods=["POST"],
563 | allow_headers=["Content-Type", "Authorization", "X-Custom"],
564 | )
565 |
566 | @app.post("/test")
567 | async def test_endpoint():
568 | return {"message": "test"}
569 |
570 | client = TestClient(app)
571 |
572 | response = client.options(
573 | "/test",
574 | headers={
575 | "Origin": "https://example.com",
576 | "Access-Control-Request-Method": "POST",
577 | "Access-Control-Request-Headers": "Content-Type, Authorization",
578 | },
579 | )
580 |
581 | allowed_headers = response.headers.get("Access-Control-Allow-Headers", "")
582 | assert "Content-Type" in allowed_headers
583 | assert "Authorization" in allowed_headers
584 |
585 |
586 | class TestCORSEdgeCases:
587 | """Test CORS edge cases and security scenarios."""
588 |
589 | def test_cors_with_vary_header(self):
590 | """Test that CORS responses include Vary header."""
591 | app = FastAPI()
592 | app.add_middleware(
593 | CORSMiddleware,
594 | allow_origins=["https://example.com"],
595 | allow_methods=["GET"],
596 | allow_headers=["Content-Type"],
597 | )
598 |
599 | @app.get("/test")
600 | async def test_endpoint():
601 | return {"message": "test"}
602 |
603 | client = TestClient(app)
604 |
605 | response = client.get("/test", headers={"Origin": "https://example.com"})
606 |
607 | vary_header = response.headers.get("Vary", "")
608 | assert "Origin" in vary_header
609 |
610 | def test_cors_null_origin_handling(self):
611 | """Test CORS handling of null origin (file:// protocol)."""
612 | app = FastAPI()
613 | app.add_middleware(
614 | CORSMiddleware,
615 | allow_origins=["null"], # Sometimes needed for file:// protocol
616 | allow_methods=["GET"],
617 | allow_headers=["Content-Type"],
618 | )
619 |
620 | @app.get("/test")
621 | async def test_endpoint():
622 | return {"message": "test"}
623 |
624 | client = TestClient(app)
625 |
626 | response = client.get("/test", headers={"Origin": "null"})
627 |
628 | # Should handle null origin appropriately
629 | assert response.status_code == 200
630 |
631 | def test_cors_case_insensitive_origin(self):
632 | """Test CORS origin matching is case-sensitive (as it should be)."""
633 | app = FastAPI()
634 | app.add_middleware(
635 | CORSMiddleware,
636 | allow_origins=["https://Example.com"], # Capital E
637 | allow_methods=["GET"],
638 | allow_headers=["Content-Type"],
639 | )
640 |
641 | @app.get("/test")
642 | async def test_endpoint():
643 | return {"message": "test"}
644 |
645 | client = TestClient(app)
646 |
647 | # Test with different case
648 | response = client.get(
649 | "/test",
650 | headers={"Origin": "https://example.com"}, # lowercase e
651 | )
652 |
653 | # Should not match due to case sensitivity
654 | cors_origin = response.headers.get("Access-Control-Allow-Origin")
655 | assert cors_origin != "https://example.com"
656 |
657 |
658 | if __name__ == "__main__":
659 | pytest.main([__file__, "-v"])
660 |
```
--------------------------------------------------------------------------------
/tests/core/test_technical_analysis.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for maverick_mcp.core.technical_analysis module.
3 |
4 | This module contains comprehensive tests for all technical analysis functions
5 | to ensure accurate financial calculations and proper error handling.
6 | """
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import pytest
11 |
12 | from maverick_mcp.core.technical_analysis import (
13 | add_technical_indicators,
14 | analyze_bollinger_bands,
15 | analyze_macd,
16 | analyze_rsi,
17 | analyze_stochastic,
18 | analyze_trend,
19 | analyze_volume,
20 | calculate_atr,
21 | generate_outlook,
22 | identify_chart_patterns,
23 | identify_resistance_levels,
24 | identify_support_levels,
25 | )
26 |
27 |
28 | class TestTechnicalIndicators:
29 | """Test the add_technical_indicators function."""
30 |
31 | def test_add_technical_indicators_basic(self):
32 | """Test basic technical indicators calculation."""
33 | # Create sample data with enough data points for all indicators
34 | dates = pd.date_range("2024-01-01", periods=100, freq="D")
35 | data = {
36 | "Date": dates,
37 | "Open": np.random.uniform(100, 200, 100),
38 | "High": np.random.uniform(150, 250, 100),
39 | "Low": np.random.uniform(50, 150, 100),
40 | "Close": np.random.uniform(100, 200, 100),
41 | "Volume": np.random.randint(1000000, 10000000, 100),
42 | }
43 | df = pd.DataFrame(data)
44 | df = df.set_index("Date")
45 |
46 | # Add some realistic price movement
47 | for i in range(1, len(df)):
48 | df.loc[df.index[i], "Close"] = df.iloc[i - 1]["Close"] * np.random.uniform(
49 | 0.98, 1.02
50 | )
51 | df.loc[df.index[i], "High"] = max(
52 | df.iloc[i]["Open"], df.iloc[i]["Close"]
53 | ) * np.random.uniform(1.0, 1.02)
54 | df.loc[df.index[i], "Low"] = min(
55 | df.iloc[i]["Open"], df.iloc[i]["Close"]
56 | ) * np.random.uniform(0.98, 1.0)
57 |
58 | result = add_technical_indicators(df)
59 |
60 | # Check that all expected indicators are added
61 | expected_indicators = [
62 | "ema_21",
63 | "sma_50",
64 | "sma_200",
65 | "rsi",
66 | "macd_12_26_9",
67 | "macds_12_26_9",
68 | "macdh_12_26_9",
69 | "sma_20",
70 | "bbu_20_2.0",
71 | "bbl_20_2.0",
72 | "stdev",
73 | "atr",
74 | "stochk_14_3_3",
75 | "stochd_14_3_3",
76 | "adx_14",
77 | ]
78 |
79 | for indicator in expected_indicators:
80 | assert indicator in result.columns
81 |
82 | # Check that indicators have reasonable values (not all NaN)
83 | assert not result["rsi"].iloc[-10:].isna().all()
84 | assert not result["ema_21"].iloc[-10:].isna().all()
85 | assert not result["sma_50"].iloc[-10:].isna().all()
86 |
87 | def test_add_technical_indicators_column_case_insensitive(self):
88 | """Test that the function handles different column case properly."""
89 | data = {
90 | "OPEN": [100, 101, 102],
91 | "HIGH": [105, 106, 107],
92 | "LOW": [95, 96, 97],
93 | "CLOSE": [103, 104, 105],
94 | "VOLUME": [1000000, 1100000, 1200000],
95 | }
96 | df = pd.DataFrame(data)
97 |
98 | result = add_technical_indicators(df)
99 |
100 | # Check that columns are normalized to lowercase
101 | assert "close" in result.columns
102 | assert "high" in result.columns
103 | assert "low" in result.columns
104 |
105 | def test_add_technical_indicators_insufficient_data(self):
106 | """Test behavior with insufficient data."""
107 | data = {
108 | "Open": [100],
109 | "High": [105],
110 | "Low": [95],
111 | "Close": [103],
112 | "Volume": [1000000],
113 | }
114 | df = pd.DataFrame(data)
115 |
116 | result = add_technical_indicators(df)
117 |
118 | # Should handle insufficient data gracefully
119 | assert "rsi" in result.columns
120 | assert pd.isna(result["rsi"].iloc[0]) # Should be NaN for insufficient data
121 |
122 | def test_add_technical_indicators_empty_dataframe(self):
123 | """Test behavior with empty dataframe."""
124 | df = pd.DataFrame()
125 |
126 | with pytest.raises(KeyError):
127 | add_technical_indicators(df)
128 |
129 | @pytest.mark.parametrize(
130 | "bb_columns",
131 | [
132 | ("BBM_20_2.0", "BBU_20_2.0", "BBL_20_2.0"),
133 | ("BBM_20_2", "BBU_20_2", "BBL_20_2"),
134 | ],
135 | )
136 | def test_add_technical_indicators_supports_bbands_column_aliases(
137 | self, monkeypatch, bb_columns
138 | ):
139 | """Ensure Bollinger Band column name variations are handled."""
140 |
141 | index = pd.date_range("2024-01-01", periods=40, freq="D")
142 | base_series = np.linspace(100, 140, len(index))
143 | data = {
144 | "open": base_series,
145 | "high": base_series + 1,
146 | "low": base_series - 1,
147 | "close": base_series,
148 | "volume": np.full(len(index), 1_000_000),
149 | }
150 | df = pd.DataFrame(data, index=index)
151 |
152 | mid_column, upper_column, lower_column = bb_columns
153 |
154 | def fake_bbands(close, *args, **kwargs):
155 | band_values = pd.Series(base_series, index=close.index)
156 | return pd.DataFrame(
157 | {
158 | mid_column: band_values,
159 | upper_column: band_values + 2,
160 | lower_column: band_values - 2,
161 | }
162 | )
163 |
164 | monkeypatch.setattr(
165 | "maverick_mcp.core.technical_analysis.ta.bbands",
166 | fake_bbands,
167 | )
168 |
169 | result = add_technical_indicators(df)
170 |
171 | np.testing.assert_allclose(result["sma_20"], base_series)
172 | np.testing.assert_allclose(result["bbu_20_2.0"], base_series + 2)
173 | np.testing.assert_allclose(result["bbl_20_2.0"], base_series - 2)
174 |
175 |
176 | class TestSupportResistanceLevels:
177 | """Test support and resistance level identification."""
178 |
179 | @pytest.fixture
180 | def sample_data(self):
181 | """Create sample price data for testing."""
182 | data = {
183 | "high": [105, 110, 108, 115, 112, 120, 118, 125, 122, 130] * 5,
184 | "low": [95, 100, 98, 105, 102, 110, 108, 115, 112, 120] * 5,
185 | "close": [100, 105, 103, 110, 107, 115, 113, 120, 117, 125] * 5,
186 | }
187 | return pd.DataFrame(data)
188 |
189 | def test_identify_support_levels(self, sample_data):
190 | """Test support level identification."""
191 | support_levels = identify_support_levels(sample_data)
192 |
193 | assert isinstance(support_levels, list)
194 | assert len(support_levels) > 0
195 | assert all(
196 | isinstance(level, float | int | np.number) for level in support_levels
197 | )
198 | assert support_levels == sorted(support_levels) # Should be sorted
199 |
200 | def test_identify_resistance_levels(self, sample_data):
201 | """Test resistance level identification."""
202 | resistance_levels = identify_resistance_levels(sample_data)
203 |
204 | assert isinstance(resistance_levels, list)
205 | assert len(resistance_levels) > 0
206 | assert all(
207 | isinstance(level, float | int | np.number) for level in resistance_levels
208 | )
209 | assert resistance_levels == sorted(resistance_levels) # Should be sorted
210 |
211 | def test_support_resistance_with_small_dataset(self):
212 | """Test with dataset smaller than 30 days."""
213 | data = {
214 | "high": [105, 110, 108],
215 | "low": [95, 100, 98],
216 | "close": [100, 105, 103],
217 | }
218 | df = pd.DataFrame(data)
219 |
220 | support_levels = identify_support_levels(df)
221 | resistance_levels = identify_resistance_levels(df)
222 |
223 | assert len(support_levels) > 0
224 | assert len(resistance_levels) > 0
225 |
226 |
227 | class TestTrendAnalysis:
228 | """Test trend analysis functionality."""
229 |
230 | @pytest.fixture
231 | def trending_data(self):
232 | """Create data with clear upward trend."""
233 | dates = pd.date_range("2024-01-01", periods=60, freq="D")
234 | close_prices = np.linspace(100, 150, 60) # Clear upward trend
235 |
236 | data = {
237 | "close": close_prices,
238 | "high": close_prices * 1.02,
239 | "low": close_prices * 0.98,
240 | "volume": np.random.randint(1000000, 2000000, 60),
241 | }
242 | df = pd.DataFrame(data, index=dates)
243 | return add_technical_indicators(df)
244 |
245 | def test_analyze_trend_uptrend(self, trending_data):
246 | """Test trend analysis with upward trending data."""
247 | trend_strength = analyze_trend(trending_data)
248 |
249 | assert isinstance(trend_strength, int)
250 | assert 0 <= trend_strength <= 7
251 | assert trend_strength > 3 # Should detect strong uptrend
252 |
253 | def test_analyze_trend_empty_dataframe(self):
254 | """Test trend analysis with empty dataframe."""
255 | df = pd.DataFrame({"close": []})
256 |
257 | trend_strength = analyze_trend(df)
258 |
259 | assert trend_strength == 0
260 |
261 | def test_analyze_trend_missing_indicators(self):
262 | """Test trend analysis with missing indicators."""
263 | data = {
264 | "close": [100, 101, 102, 103, 104],
265 | }
266 | df = pd.DataFrame(data)
267 |
268 | trend_strength = analyze_trend(df)
269 |
270 | assert trend_strength == 0 # Should handle missing indicators gracefully
271 |
272 |
273 | class TestRSIAnalysis:
274 | """Test RSI analysis functionality."""
275 |
276 | @pytest.fixture
277 | def rsi_data(self):
278 | """Create data with RSI indicator."""
279 | data = {
280 | "close": [100, 105, 103, 110, 107, 115, 113, 120, 117, 125],
281 | "rsi": [50, 55, 52, 65, 60, 70, 68, 75, 72, 80],
282 | }
283 | return pd.DataFrame(data)
284 |
285 | def test_analyze_rsi_overbought(self, rsi_data):
286 | """Test RSI analysis with overbought conditions."""
287 | result = analyze_rsi(rsi_data)
288 |
289 | assert result["current"] == 80.0
290 | assert result["signal"] == "overbought"
291 | assert "overbought" in result["description"]
292 |
293 | def test_analyze_rsi_oversold(self):
294 | """Test RSI analysis with oversold conditions."""
295 | data = {
296 | "close": [100, 95, 90, 85, 80],
297 | "rsi": [50, 40, 30, 25, 20],
298 | }
299 | df = pd.DataFrame(data)
300 |
301 | result = analyze_rsi(df)
302 |
303 | assert result["current"] == 20.0
304 | assert result["signal"] == "oversold"
305 |
306 | def test_analyze_rsi_bullish(self):
307 | """Test RSI analysis with bullish conditions."""
308 | data = {
309 | "close": [100, 105, 110],
310 | "rsi": [50, 55, 60],
311 | }
312 | df = pd.DataFrame(data)
313 |
314 | result = analyze_rsi(df)
315 |
316 | assert result["current"] == 60.0
317 | assert result["signal"] == "bullish"
318 |
319 | def test_analyze_rsi_bearish(self):
320 | """Test RSI analysis with bearish conditions."""
321 | data = {
322 | "close": [100, 95, 90],
323 | "rsi": [50, 45, 40],
324 | }
325 | df = pd.DataFrame(data)
326 |
327 | result = analyze_rsi(df)
328 |
329 | assert result["current"] == 40.0
330 | assert result["signal"] == "bearish"
331 |
332 | def test_analyze_rsi_empty_dataframe(self):
333 | """Test RSI analysis with empty dataframe."""
334 | df = pd.DataFrame()
335 |
336 | result = analyze_rsi(df)
337 |
338 | assert result["current"] is None
339 | assert result["signal"] == "unavailable"
340 |
341 | def test_analyze_rsi_missing_column(self):
342 | """Test RSI analysis without RSI column."""
343 | data = {"close": [100, 105, 110]}
344 | df = pd.DataFrame(data)
345 |
346 | result = analyze_rsi(df)
347 |
348 | assert result["current"] is None
349 | assert result["signal"] == "unavailable"
350 |
351 | def test_analyze_rsi_nan_values(self):
352 | """Test RSI analysis with NaN values."""
353 | data = {
354 | "close": [100, 105, 110],
355 | "rsi": [50, 55, np.nan],
356 | }
357 | df = pd.DataFrame(data)
358 |
359 | result = analyze_rsi(df)
360 |
361 | assert result["current"] is None
362 | assert result["signal"] == "unavailable"
363 |
364 |
365 | class TestMACDAnalysis:
366 | """Test MACD analysis functionality."""
367 |
368 | @pytest.fixture
369 | def macd_data(self):
370 | """Create data with MACD indicators."""
371 | data = {
372 | "macd_12_26_9": [1.5, 2.0, 2.5, 3.0, 2.8],
373 | "macds_12_26_9": [1.0, 1.8, 2.2, 2.7, 3.2],
374 | "macdh_12_26_9": [0.5, 0.2, 0.3, 0.3, -0.4],
375 | }
376 | return pd.DataFrame(data)
377 |
378 | def test_analyze_macd_bullish(self, macd_data):
379 | """Test MACD analysis with bullish signals."""
380 | result = analyze_macd(macd_data)
381 |
382 | assert result["macd"] == 2.8
383 | assert result["signal"] == 3.2
384 | assert result["histogram"] == -0.4
385 | assert result["indicator"] == "bearish" # macd < signal and histogram < 0
386 |
387 | def test_analyze_macd_crossover_detection(self):
388 | """Test MACD crossover detection."""
389 | data = {
390 | "macd_12_26_9": [1.0, 2.0, 3.0],
391 | "macds_12_26_9": [2.0, 1.8, 2.5],
392 | "macdh_12_26_9": [-1.0, 0.2, 0.5],
393 | }
394 | df = pd.DataFrame(data)
395 |
396 | result = analyze_macd(df)
397 |
398 | # Check that crossover detection works (test the logic rather than specific result)
399 | assert "crossover" in result
400 | assert result["crossover"] in [
401 | "bullish crossover detected",
402 | "bearish crossover detected",
403 | "no recent crossover",
404 | ]
405 |
406 | def test_analyze_macd_missing_data(self):
407 | """Test MACD analysis with missing data."""
408 | data = {
409 | "macd_12_26_9": [np.nan],
410 | "macds_12_26_9": [np.nan],
411 | "macdh_12_26_9": [np.nan],
412 | }
413 | df = pd.DataFrame(data)
414 |
415 | result = analyze_macd(df)
416 |
417 | assert result["macd"] is None
418 | assert result["indicator"] == "unavailable"
419 |
420 |
421 | class TestStochasticAnalysis:
422 | """Test Stochastic Oscillator analysis."""
423 |
424 | @pytest.fixture
425 | def stoch_data(self):
426 | """Create data with Stochastic indicators."""
427 | data = {
428 | "stochk_14_3_3": [20, 30, 40, 50, 60],
429 | "stochd_14_3_3": [25, 35, 45, 55, 65],
430 | }
431 | return pd.DataFrame(data)
432 |
433 | def test_analyze_stochastic_bearish(self, stoch_data):
434 | """Test Stochastic analysis with bearish signal."""
435 | result = analyze_stochastic(stoch_data)
436 |
437 | assert result["k"] == 60.0
438 | assert result["d"] == 65.0
439 | assert result["signal"] == "bearish" # k < d
440 |
441 | def test_analyze_stochastic_overbought(self):
442 | """Test Stochastic analysis with overbought conditions."""
443 | data = {
444 | "stochk_14_3_3": [85],
445 | "stochd_14_3_3": [83],
446 | }
447 | df = pd.DataFrame(data)
448 |
449 | result = analyze_stochastic(df)
450 |
451 | assert result["signal"] == "overbought"
452 |
453 | def test_analyze_stochastic_oversold(self):
454 | """Test Stochastic analysis with oversold conditions."""
455 | data = {
456 | "stochk_14_3_3": [15],
457 | "stochd_14_3_3": [18],
458 | }
459 | df = pd.DataFrame(data)
460 |
461 | result = analyze_stochastic(df)
462 |
463 | assert result["signal"] == "oversold"
464 |
465 | def test_analyze_stochastic_crossover(self):
466 | """Test Stochastic crossover detection."""
467 | data = {
468 | "stochk_14_3_3": [30, 45],
469 | "stochd_14_3_3": [40, 35],
470 | }
471 | df = pd.DataFrame(data)
472 |
473 | result = analyze_stochastic(df)
474 |
475 | assert result["crossover"] == "bullish crossover detected"
476 |
477 |
478 | class TestBollingerBands:
479 | """Test Bollinger Bands analysis."""
480 |
481 | @pytest.fixture
482 | def bb_data(self):
483 | """Create data with Bollinger Bands."""
484 | data = {
485 | "close": [100, 105, 110, 108, 112],
486 | "bbu_20_2.0": [115, 116, 117, 116, 118],
487 | "bbl_20_2.0": [85, 86, 87, 86, 88],
488 | "sma_20": [100, 101, 102, 101, 103],
489 | }
490 | return pd.DataFrame(data)
491 |
492 | def test_analyze_bollinger_bands_above_middle(self, bb_data):
493 | """Test Bollinger Bands with price above middle band."""
494 | result = analyze_bollinger_bands(bb_data)
495 |
496 | assert result["upper_band"] == 118.0
497 | assert result["middle_band"] == 103.0
498 | assert result["lower_band"] == 88.0
499 | assert result["position"] == "above middle band"
500 | assert result["signal"] == "bullish"
501 |
502 | def test_analyze_bollinger_bands_above_upper(self):
503 | """Test Bollinger Bands with price above upper band."""
504 | data = {
505 | "close": [120],
506 | "bbu_20_2.0": [115],
507 | "bbl_20_2.0": [85],
508 | "sma_20": [100],
509 | }
510 | df = pd.DataFrame(data)
511 |
512 | result = analyze_bollinger_bands(df)
513 |
514 | assert result["position"] == "above upper band"
515 | assert result["signal"] == "overbought"
516 |
517 | def test_analyze_bollinger_bands_below_lower(self):
518 | """Test Bollinger Bands with price below lower band."""
519 | data = {
520 | "close": [80],
521 | "bbu_20_2.0": [115],
522 | "bbl_20_2.0": [85],
523 | "sma_20": [100],
524 | }
525 | df = pd.DataFrame(data)
526 |
527 | result = analyze_bollinger_bands(df)
528 |
529 | assert result["position"] == "below lower band"
530 | assert result["signal"] == "oversold"
531 |
532 | def test_analyze_bollinger_bands_volatility_calculation(self):
533 | """Test Bollinger Bands volatility calculation."""
534 | # Create data with contracting bands
535 | data = {
536 | "close": [100, 100, 100, 100, 100],
537 | "bbu_20_2.0": [110, 108, 106, 104, 102],
538 | "bbl_20_2.0": [90, 92, 94, 96, 98],
539 | "sma_20": [100, 100, 100, 100, 100],
540 | }
541 | df = pd.DataFrame(data)
542 |
543 | result = analyze_bollinger_bands(df)
544 |
545 | assert "contracting" in result["volatility"]
546 |
547 |
548 | class TestVolumeAnalysis:
549 | """Test volume analysis functionality."""
550 |
551 | @pytest.fixture
552 | def volume_data(self):
553 | """Create data with volume information."""
554 | data = {
555 | "volume": [1000000, 1100000, 1200000, 1500000, 2000000],
556 | "close": [100, 101, 102, 105, 108],
557 | }
558 | return pd.DataFrame(data)
559 |
560 | def test_analyze_volume_high_volume_up_move(self, volume_data):
561 | """Test volume analysis with high volume on up move."""
562 | result = analyze_volume(volume_data)
563 |
564 | assert result["current"] == 2000000
565 | assert result["ratio"] >= 1.4 # More lenient threshold
566 | # Check that volume analysis is working, signal may vary based on exact ratio
567 | assert result["description"] in ["above average", "average"]
568 | assert result["signal"] in ["bullish (high volume on up move)", "neutral"]
569 |
570 | def test_analyze_volume_low_volume(self):
571 | """Test volume analysis with low volume."""
572 | data = {
573 | "volume": [1000000, 1100000, 1200000, 1300000, 600000],
574 | "close": [100, 101, 102, 103, 104],
575 | }
576 | df = pd.DataFrame(data)
577 |
578 | result = analyze_volume(df)
579 |
580 | assert result["ratio"] < 0.7
581 | assert result["description"] == "below average"
582 | assert result["signal"] == "weak conviction"
583 |
584 | def test_analyze_volume_insufficient_data(self):
585 | """Test volume analysis with insufficient data."""
586 | data = {
587 | "volume": [1000000],
588 | "close": [100],
589 | }
590 | df = pd.DataFrame(data)
591 |
592 | result = analyze_volume(df)
593 |
594 | # Should still work with single data point
595 | assert result["current"] == 1000000
596 | assert result["average"] == 1000000
597 | assert result["ratio"] == 1.0
598 |
599 | def test_analyze_volume_invalid_data(self):
600 | """Test volume analysis with invalid data."""
601 | data = {
602 | "volume": [np.nan],
603 | "close": [100],
604 | }
605 | df = pd.DataFrame(data)
606 |
607 | result = analyze_volume(df)
608 |
609 | assert result["current"] is None
610 | assert result["signal"] == "unavailable"
611 |
612 |
613 | class TestChartPatterns:
614 | """Test chart pattern identification."""
615 |
616 | def test_identify_chart_patterns_double_bottom(self):
617 | """Test double bottom pattern identification."""
618 | # Create price data with double bottom pattern
619 | prices = [100] * 10 + [90] * 5 + [100] * 10 + [90] * 5 + [100] * 10
620 | data = {
621 | "low": prices,
622 | "high": [p + 10 for p in prices],
623 | "close": [p + 5 for p in prices],
624 | }
625 | df = pd.DataFrame(data)
626 |
627 | patterns = identify_chart_patterns(df)
628 |
629 | # Note: The pattern detection is quite strict, so we just test it runs
630 | assert isinstance(patterns, list)
631 |
632 | def test_identify_chart_patterns_insufficient_data(self):
633 | """Test chart pattern identification with insufficient data."""
634 | data = {
635 | "low": [90, 95, 92],
636 | "high": [100, 105, 102],
637 | "close": [95, 100, 97],
638 | }
639 | df = pd.DataFrame(data)
640 |
641 | patterns = identify_chart_patterns(df)
642 |
643 | assert isinstance(patterns, list)
644 | assert len(patterns) == 0 # Not enough data for patterns
645 |
646 |
647 | class TestATRCalculation:
648 | """Test Average True Range calculation."""
649 |
650 | @pytest.fixture
651 | def atr_data(self):
652 | """Create data for ATR calculation."""
653 | data = {
654 | "High": [105, 110, 108, 115, 112],
655 | "Low": [95, 100, 98, 105, 102],
656 | "Close": [100, 105, 103, 110, 107],
657 | }
658 | return pd.DataFrame(data)
659 |
660 | def test_calculate_atr_basic(self, atr_data):
661 | """Test basic ATR calculation."""
662 | result = calculate_atr(atr_data, period=3)
663 |
664 | assert isinstance(result, pd.Series)
665 | assert len(result) == len(atr_data)
666 | # ATR values should be positive where calculated
667 | assert (result.dropna() >= 0).all()
668 |
669 | def test_calculate_atr_custom_period(self, atr_data):
670 | """Test ATR calculation with custom period."""
671 | result = calculate_atr(atr_data, period=2)
672 |
673 | assert isinstance(result, pd.Series)
674 | assert len(result) == len(atr_data)
675 |
676 | def test_calculate_atr_insufficient_data(self):
677 | """Test ATR calculation with insufficient data."""
678 | data = {
679 | "High": [105],
680 | "Low": [95],
681 | "Close": [100],
682 | }
683 | df = pd.DataFrame(data)
684 |
685 | result = calculate_atr(df)
686 |
687 | assert isinstance(result, pd.Series)
688 | # Should handle insufficient data gracefully
689 |
690 |
691 | class TestOutlookGeneration:
692 | """Test overall outlook generation."""
693 |
694 | def test_generate_outlook_bullish(self):
695 | """Test outlook generation with bullish signals."""
696 | df = pd.DataFrame({"close": [100, 105, 110]})
697 | trend = "uptrend"
698 | rsi_analysis = {"signal": "bullish"}
699 | macd_analysis = {
700 | "indicator": "bullish",
701 | "crossover": "bullish crossover detected",
702 | }
703 | stoch_analysis = {"signal": "bullish"}
704 |
705 | outlook = generate_outlook(
706 | df, trend, rsi_analysis, macd_analysis, stoch_analysis
707 | )
708 |
709 | assert "bullish" in outlook
710 |
711 | def test_generate_outlook_bearish(self):
712 | """Test outlook generation with bearish signals."""
713 | df = pd.DataFrame({"close": [100, 95, 90]})
714 | trend = "downtrend"
715 | rsi_analysis = {"signal": "bearish"}
716 | macd_analysis = {
717 | "indicator": "bearish",
718 | "crossover": "bearish crossover detected",
719 | }
720 | stoch_analysis = {"signal": "bearish"}
721 |
722 | outlook = generate_outlook(
723 | df, trend, rsi_analysis, macd_analysis, stoch_analysis
724 | )
725 |
726 | assert "bearish" in outlook
727 |
728 | def test_generate_outlook_neutral(self):
729 | """Test outlook generation with mixed signals."""
730 | df = pd.DataFrame({"close": [100, 100, 100]})
731 | trend = "sideways"
732 | rsi_analysis = {"signal": "neutral"}
733 | macd_analysis = {"indicator": "neutral", "crossover": "no recent crossover"}
734 | stoch_analysis = {"signal": "neutral"}
735 |
736 | outlook = generate_outlook(
737 | df, trend, rsi_analysis, macd_analysis, stoch_analysis
738 | )
739 |
740 | assert outlook == "neutral"
741 |
742 | def test_generate_outlook_strongly_bullish(self):
743 | """Test outlook generation with very bullish signals."""
744 | df = pd.DataFrame({"close": [100, 105, 110]})
745 | trend = "uptrend"
746 | rsi_analysis = {"signal": "oversold"} # Bullish signal
747 | macd_analysis = {
748 | "indicator": "bullish",
749 | "crossover": "bullish crossover detected",
750 | }
751 | stoch_analysis = {"signal": "oversold"} # Bullish signal
752 |
753 | outlook = generate_outlook(
754 | df, trend, rsi_analysis, macd_analysis, stoch_analysis
755 | )
756 |
757 | assert "strongly bullish" in outlook
758 |
759 |
760 | if __name__ == "__main__":
761 | pytest.main([__file__])
762 |
```
--------------------------------------------------------------------------------
/tests/performance/test_load.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Load Testing for Concurrent Users and Backtest Operations.
3 |
4 | This test suite covers:
5 | - Concurrent user load testing (10, 50, 100 users)
6 | - Response time and throughput measurement
7 | - Memory usage under concurrent load
8 | - Database performance with multiple connections
9 | - API rate limiting behavior
10 | - Queue management and task distribution
11 | - System stability under sustained load
12 | """
13 |
14 | import asyncio
15 | import logging
16 | import os
17 | import random
18 | import statistics
19 | import time
20 | from dataclasses import dataclass
21 | from typing import Any
22 | from unittest.mock import Mock
23 |
24 | import numpy as np
25 | import pandas as pd
26 | import psutil
27 | import pytest
28 |
29 | from maverick_mcp.backtesting import VectorBTEngine
30 | from maverick_mcp.backtesting.persistence import BacktestPersistenceManager
31 | from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | @dataclass
37 | class LoadTestResult:
38 | """Data class for load test results."""
39 |
40 | concurrent_users: int
41 | total_requests: int
42 | successful_requests: int
43 | failed_requests: int
44 | total_duration: float
45 | avg_response_time: float
46 | min_response_time: float
47 | max_response_time: float
48 | p50_response_time: float
49 | p95_response_time: float
50 | p99_response_time: float
51 | requests_per_second: float
52 | errors_per_second: float
53 | memory_usage_mb: float
54 | cpu_usage_percent: float
55 |
56 |
57 | class LoadTestRunner:
58 | """Load test runner with realistic user simulation."""
59 |
60 | def __init__(self, data_provider):
61 | self.data_provider = data_provider
62 | self.results = []
63 | self.active_requests = 0
64 |
65 | async def simulate_user_session(
66 | self, user_id: int, session_config: dict[str, Any]
67 | ) -> dict[str, Any]:
68 | """Simulate a realistic user session with multiple backtests."""
69 | session_start = time.time()
70 | user_results = []
71 | response_times = []
72 |
73 | symbols = session_config.get("symbols", ["AAPL"])
74 | strategies = session_config.get("strategies", ["sma_cross"])
75 | think_time_range = session_config.get("think_time", (0.5, 2.0))
76 |
77 | engine = VectorBTEngine(data_provider=self.data_provider)
78 |
79 | for symbol in symbols:
80 | for strategy in strategies:
81 | self.active_requests += 1
82 | request_start = time.time()
83 |
84 | try:
85 | parameters = STRATEGY_TEMPLATES.get(strategy, {}).get(
86 | "parameters", {}
87 | )
88 |
89 | result = await engine.run_backtest(
90 | symbol=symbol,
91 | strategy_type=strategy,
92 | parameters=parameters,
93 | start_date="2023-01-01",
94 | end_date="2023-12-31",
95 | )
96 |
97 | request_time = time.time() - request_start
98 | response_times.append(request_time)
99 |
100 | user_results.append(
101 | {
102 | "symbol": symbol,
103 | "strategy": strategy,
104 | "success": True,
105 | "response_time": request_time,
106 | "result_size": len(str(result)),
107 | }
108 | )
109 |
110 | except Exception as e:
111 | request_time = time.time() - request_start
112 | response_times.append(request_time)
113 |
114 | user_results.append(
115 | {
116 | "symbol": symbol,
117 | "strategy": strategy,
118 | "success": False,
119 | "response_time": request_time,
120 | "error": str(e),
121 | }
122 | )
123 |
124 | finally:
125 | self.active_requests -= 1
126 |
127 | # Simulate think time between requests
128 | think_time = random.uniform(*think_time_range)
129 | await asyncio.sleep(think_time)
130 |
131 | session_time = time.time() - session_start
132 |
133 | return {
134 | "user_id": user_id,
135 | "session_time": session_time,
136 | "results": user_results,
137 | "response_times": response_times,
138 | "success_count": sum(1 for r in user_results if r["success"]),
139 | "failure_count": sum(1 for r in user_results if not r["success"]),
140 | }
141 |
142 | def calculate_percentiles(self, response_times: list[float]) -> dict[str, float]:
143 | """Calculate response time percentiles."""
144 | if not response_times:
145 | return {"p50": 0, "p95": 0, "p99": 0}
146 |
147 | sorted_times = sorted(response_times)
148 | return {
149 | "p50": np.percentile(sorted_times, 50),
150 | "p95": np.percentile(sorted_times, 95),
151 | "p99": np.percentile(sorted_times, 99),
152 | }
153 |
154 | async def run_load_test(
155 | self,
156 | concurrent_users: int,
157 | session_config: dict[str, Any],
158 | duration_seconds: int = 60,
159 | ) -> LoadTestResult:
160 | """Run load test with specified concurrent users."""
161 | logger.info(
162 | f"Starting load test: {concurrent_users} concurrent users for {duration_seconds}s"
163 | )
164 |
165 | process = psutil.Process(os.getpid())
166 | initial_memory = process.memory_info().rss / 1024 / 1024 # MB
167 |
168 | start_time = time.time()
169 | all_response_times = []
170 | all_user_results = []
171 |
172 | # Create semaphore to control concurrency
173 | semaphore = asyncio.Semaphore(concurrent_users)
174 |
175 | async def run_user_with_semaphore(user_id: int):
176 | async with semaphore:
177 | return await self.simulate_user_session(user_id, session_config)
178 |
179 | # Generate user sessions
180 | user_tasks = []
181 | for user_id in range(concurrent_users):
182 | task = run_user_with_semaphore(user_id)
183 | user_tasks.append(task)
184 |
185 | # Execute all user sessions concurrently
186 | try:
187 | user_results = await asyncio.wait_for(
188 | asyncio.gather(*user_tasks, return_exceptions=True),
189 | timeout=duration_seconds + 30, # Add buffer to test timeout
190 | )
191 | except TimeoutError:
192 | logger.warning(f"Load test timed out after {duration_seconds + 30}s")
193 | user_results = []
194 |
195 | end_time = time.time()
196 | actual_duration = end_time - start_time
197 |
198 | # Process results
199 | successful_sessions = []
200 | failed_sessions = []
201 |
202 | for result in user_results:
203 | if isinstance(result, Exception):
204 | failed_sessions.append(str(result))
205 | elif isinstance(result, dict):
206 | successful_sessions.append(result)
207 | all_response_times.extend(result.get("response_times", []))
208 | all_user_results.extend(result.get("results", []))
209 |
210 | # Calculate metrics
211 | total_requests = len(all_user_results)
212 | successful_requests = sum(
213 | 1 for r in all_user_results if r.get("success", False)
214 | )
215 | failed_requests = total_requests - successful_requests
216 |
217 | # Response time statistics
218 | percentiles = self.calculate_percentiles(all_response_times)
219 | avg_response_time = (
220 | statistics.mean(all_response_times) if all_response_times else 0
221 | )
222 | min_response_time = min(all_response_times) if all_response_times else 0
223 | max_response_time = max(all_response_times) if all_response_times else 0
224 |
225 | # Throughput metrics
226 | requests_per_second = (
227 | total_requests / actual_duration if actual_duration > 0 else 0
228 | )
229 | errors_per_second = (
230 | failed_requests / actual_duration if actual_duration > 0 else 0
231 | )
232 |
233 | # Resource usage
234 | final_memory = process.memory_info().rss / 1024 / 1024
235 | memory_usage = final_memory - initial_memory
236 | cpu_usage = process.cpu_percent()
237 |
238 | result = LoadTestResult(
239 | concurrent_users=concurrent_users,
240 | total_requests=total_requests,
241 | successful_requests=successful_requests,
242 | failed_requests=failed_requests,
243 | total_duration=actual_duration,
244 | avg_response_time=avg_response_time,
245 | min_response_time=min_response_time,
246 | max_response_time=max_response_time,
247 | p50_response_time=percentiles["p50"],
248 | p95_response_time=percentiles["p95"],
249 | p99_response_time=percentiles["p99"],
250 | requests_per_second=requests_per_second,
251 | errors_per_second=errors_per_second,
252 | memory_usage_mb=memory_usage,
253 | cpu_usage_percent=cpu_usage,
254 | )
255 |
256 | logger.info(
257 | f"Load Test Results ({concurrent_users} users):\n"
258 | f" • Total Requests: {total_requests}\n"
259 | f" • Success Rate: {successful_requests / total_requests * 100:.1f}%\n"
260 | f" • Avg Response Time: {avg_response_time:.3f}s\n"
261 | f" • 95th Percentile: {percentiles['p95']:.3f}s\n"
262 | f" • Throughput: {requests_per_second:.1f} req/s\n"
263 | f" • Memory Usage: {memory_usage:.1f}MB\n"
264 | f" • Duration: {actual_duration:.1f}s"
265 | )
266 |
267 | return result
268 |
269 |
270 | class TestLoadTesting:
271 | """Load testing suite for concurrent users."""
272 |
273 | @pytest.fixture
274 | async def optimized_data_provider(self):
275 | """Create optimized data provider for load testing."""
276 | provider = Mock()
277 |
278 | # Pre-generate data for common symbols to reduce computation
279 | symbol_data_cache = {}
280 |
281 | def get_cached_data(symbol: str) -> pd.DataFrame:
282 | """Get or generate cached data for symbol."""
283 | if symbol not in symbol_data_cache:
284 | # Generate deterministic data based on symbol hash
285 | seed = hash(symbol) % 1000
286 | np.random.seed(seed)
287 |
288 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
289 | returns = np.random.normal(0.001, 0.02, len(dates))
290 | prices = 100 * np.cumprod(1 + returns)
291 |
292 | symbol_data_cache[symbol] = pd.DataFrame(
293 | {
294 | "Open": prices * np.random.uniform(0.99, 1.01, len(dates)),
295 | "High": prices * np.random.uniform(1.01, 1.03, len(dates)),
296 | "Low": prices * np.random.uniform(0.97, 0.99, len(dates)),
297 | "Close": prices,
298 | "Volume": np.random.randint(1000000, 10000000, len(dates)),
299 | "Adj Close": prices,
300 | },
301 | index=dates,
302 | )
303 |
304 | # Ensure OHLC constraints
305 | data = symbol_data_cache[symbol]
306 | data["High"] = np.maximum(
307 | data["High"], np.maximum(data["Open"], data["Close"])
308 | )
309 | data["Low"] = np.minimum(
310 | data["Low"], np.minimum(data["Open"], data["Close"])
311 | )
312 |
313 | return symbol_data_cache[symbol].copy()
314 |
315 | provider.get_stock_data.side_effect = get_cached_data
316 | return provider
317 |
318 | async def test_concurrent_users_10(self, optimized_data_provider, benchmark_timer):
319 | """Test load with 10 concurrent users."""
320 | load_runner = LoadTestRunner(optimized_data_provider)
321 |
322 | session_config = {
323 | "symbols": ["AAPL", "GOOGL"],
324 | "strategies": ["sma_cross", "rsi"],
325 | "think_time": (0.1, 0.5), # Faster think time for testing
326 | }
327 |
328 | with benchmark_timer():
329 | result = await load_runner.run_load_test(
330 | concurrent_users=10, session_config=session_config, duration_seconds=30
331 | )
332 |
333 | # Performance assertions for 10 users
334 | assert result.requests_per_second >= 2.0, (
335 | f"Throughput too low: {result.requests_per_second:.1f} req/s"
336 | )
337 | assert result.avg_response_time <= 5.0, (
338 | f"Response time too high: {result.avg_response_time:.2f}s"
339 | )
340 | assert result.p95_response_time <= 10.0, (
341 | f"95th percentile too high: {result.p95_response_time:.2f}s"
342 | )
343 | assert result.successful_requests / result.total_requests >= 0.9, (
344 | "Success rate too low"
345 | )
346 | assert result.memory_usage_mb <= 500, (
347 | f"Memory usage too high: {result.memory_usage_mb:.1f}MB"
348 | )
349 |
350 | return result
351 |
352 | async def test_concurrent_users_50(self, optimized_data_provider, benchmark_timer):
353 | """Test load with 50 concurrent users."""
354 | load_runner = LoadTestRunner(optimized_data_provider)
355 |
356 | session_config = {
357 | "symbols": ["AAPL", "MSFT", "GOOGL"],
358 | "strategies": ["sma_cross", "rsi", "macd"],
359 | "think_time": (0.2, 1.0),
360 | }
361 |
362 | with benchmark_timer():
363 | result = await load_runner.run_load_test(
364 | concurrent_users=50, session_config=session_config, duration_seconds=60
365 | )
366 |
367 | # Performance assertions for 50 users
368 | assert result.requests_per_second >= 5.0, (
369 | f"Throughput too low: {result.requests_per_second:.1f} req/s"
370 | )
371 | assert result.avg_response_time <= 8.0, (
372 | f"Response time too high: {result.avg_response_time:.2f}s"
373 | )
374 | assert result.p95_response_time <= 15.0, (
375 | f"95th percentile too high: {result.p95_response_time:.2f}s"
376 | )
377 | assert result.successful_requests / result.total_requests >= 0.85, (
378 | "Success rate too low"
379 | )
380 | assert result.memory_usage_mb <= 1000, (
381 | f"Memory usage too high: {result.memory_usage_mb:.1f}MB"
382 | )
383 |
384 | return result
385 |
386 | async def test_concurrent_users_100(self, optimized_data_provider, benchmark_timer):
387 | """Test load with 100 concurrent users."""
388 | load_runner = LoadTestRunner(optimized_data_provider)
389 |
390 | session_config = {
391 | "symbols": ["AAPL", "MSFT", "GOOGL", "AMZN"],
392 | "strategies": ["sma_cross", "rsi"], # Reduced strategies for higher load
393 | "think_time": (0.5, 1.5),
394 | }
395 |
396 | with benchmark_timer():
397 | result = await load_runner.run_load_test(
398 | concurrent_users=100, session_config=session_config, duration_seconds=90
399 | )
400 |
401 | # More relaxed performance assertions for 100 users
402 | assert result.requests_per_second >= 3.0, (
403 | f"Throughput too low: {result.requests_per_second:.1f} req/s"
404 | )
405 | assert result.avg_response_time <= 15.0, (
406 | f"Response time too high: {result.avg_response_time:.2f}s"
407 | )
408 | assert result.p95_response_time <= 30.0, (
409 | f"95th percentile too high: {result.p95_response_time:.2f}s"
410 | )
411 | assert result.successful_requests / result.total_requests >= 0.8, (
412 | "Success rate too low"
413 | )
414 | assert result.memory_usage_mb <= 2000, (
415 | f"Memory usage too high: {result.memory_usage_mb:.1f}MB"
416 | )
417 |
418 | return result
419 |
420 | async def test_load_scalability_analysis(self, optimized_data_provider):
421 | """Analyze how performance scales with user load."""
422 | load_runner = LoadTestRunner(optimized_data_provider)
423 |
424 | session_config = {
425 | "symbols": ["AAPL", "GOOGL"],
426 | "strategies": ["sma_cross"],
427 | "think_time": (0.3, 0.7),
428 | }
429 |
430 | user_loads = [5, 10, 20, 40]
431 | scalability_results = []
432 |
433 | for user_count in user_loads:
434 | logger.info(f"Testing scalability with {user_count} users")
435 |
436 | result = await load_runner.run_load_test(
437 | concurrent_users=user_count,
438 | session_config=session_config,
439 | duration_seconds=30,
440 | )
441 |
442 | scalability_results.append(result)
443 |
444 | # Analyze scalability metrics
445 | throughput_efficiency = []
446 | response_time_degradation = []
447 |
448 | baseline_rps = scalability_results[0].requests_per_second
449 | baseline_response_time = scalability_results[0].avg_response_time
450 |
451 | for i, result in enumerate(scalability_results):
452 | expected_rps = baseline_rps * user_loads[i] / user_loads[0]
453 | actual_efficiency = (
454 | result.requests_per_second / expected_rps if expected_rps > 0 else 0
455 | )
456 | throughput_efficiency.append(actual_efficiency)
457 |
458 | response_degradation = (
459 | result.avg_response_time / baseline_response_time
460 | if baseline_response_time > 0
461 | else 1
462 | )
463 | response_time_degradation.append(response_degradation)
464 |
465 | logger.info(
466 | f"Scalability Analysis ({user_loads[i]} users):\n"
467 | f" • RPS: {result.requests_per_second:.2f}\n"
468 | f" • RPS Efficiency: {actual_efficiency:.2%}\n"
469 | f" • Response Time: {result.avg_response_time:.3f}s\n"
470 | f" • Response Degradation: {response_degradation:.2f}x\n"
471 | f" • Memory: {result.memory_usage_mb:.1f}MB"
472 | )
473 |
474 | # Scalability assertions
475 | avg_efficiency = statistics.mean(throughput_efficiency)
476 | max_response_degradation = max(response_time_degradation)
477 |
478 | assert avg_efficiency >= 0.5, (
479 | f"Average throughput efficiency too low: {avg_efficiency:.2%}"
480 | )
481 | assert max_response_degradation <= 5.0, (
482 | f"Response time degradation too high: {max_response_degradation:.1f}x"
483 | )
484 |
485 | return {
486 | "user_loads": user_loads,
487 | "results": scalability_results,
488 | "throughput_efficiency": throughput_efficiency,
489 | "response_time_degradation": response_time_degradation,
490 | "avg_efficiency": avg_efficiency,
491 | }
492 |
493 | async def test_sustained_load_stability(self, optimized_data_provider):
494 | """Test stability under sustained load."""
495 | load_runner = LoadTestRunner(optimized_data_provider)
496 |
497 | session_config = {
498 | "symbols": ["AAPL", "MSFT"],
499 | "strategies": ["sma_cross", "rsi"],
500 | "think_time": (0.5, 1.0),
501 | }
502 |
503 | # Run sustained load for longer duration
504 | result = await load_runner.run_load_test(
505 | concurrent_users=25,
506 | session_config=session_config,
507 | duration_seconds=300, # 5 minutes
508 | )
509 |
510 | # Stability assertions
511 | assert result.errors_per_second <= 0.1, (
512 | f"Error rate too high: {result.errors_per_second:.3f} err/s"
513 | )
514 | assert result.successful_requests / result.total_requests >= 0.95, (
515 | "Success rate degraded over time"
516 | )
517 | assert result.memory_usage_mb <= 800, (
518 | f"Memory usage grew too much: {result.memory_usage_mb:.1f}MB"
519 | )
520 |
521 | # Check for performance consistency (no significant degradation)
522 | assert result.p99_response_time / result.p50_response_time <= 5.0, (
523 | "Response time variance too high"
524 | )
525 |
526 | logger.info(
527 | f"Sustained Load Results (25 users, 5 minutes):\n"
528 | f" • Total Requests: {result.total_requests}\n"
529 | f" • Success Rate: {result.successful_requests / result.total_requests * 100:.2f}%\n"
530 | f" • Avg Throughput: {result.requests_per_second:.2f} req/s\n"
531 | f" • Response Time (50/95/99): {result.p50_response_time:.2f}s/"
532 | f"{result.p95_response_time:.2f}s/{result.p99_response_time:.2f}s\n"
533 | f" • Memory Growth: {result.memory_usage_mb:.1f}MB\n"
534 | f" • Error Rate: {result.errors_per_second:.4f} err/s"
535 | )
536 |
537 | return result
538 |
539 | async def test_database_connection_pooling_under_load(
540 | self, optimized_data_provider, db_session
541 | ):
542 | """Test database connection pooling under concurrent load."""
543 | # Generate backtest results to save to database
544 | engine = VectorBTEngine(data_provider=optimized_data_provider)
545 | test_symbols = ["DB_LOAD_1", "DB_LOAD_2", "DB_LOAD_3"]
546 |
547 | # Pre-generate results for database testing
548 | backtest_results = []
549 | for symbol in test_symbols:
550 | result = await engine.run_backtest(
551 | symbol=symbol,
552 | strategy_type="sma_cross",
553 | parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
554 | start_date="2023-01-01",
555 | end_date="2023-12-31",
556 | )
557 | backtest_results.append(result)
558 |
559 | # Test concurrent database operations
560 | async def concurrent_database_operations(operation_id: int) -> dict[str, Any]:
561 | """Simulate concurrent database save/retrieve operations."""
562 | start_time = time.time()
563 | operations_completed = 0
564 | errors = []
565 |
566 | try:
567 | with BacktestPersistenceManager(session=db_session) as persistence:
568 | # Save operations
569 | for result in backtest_results:
570 | try:
571 | backtest_id = persistence.save_backtest_result(
572 | vectorbt_results=result,
573 | execution_time=2.0,
574 | notes=f"Load test operation {operation_id}",
575 | )
576 | operations_completed += 1
577 |
578 | # Retrieve operation
579 | retrieved = persistence.get_backtest_by_id(backtest_id)
580 | if retrieved:
581 | operations_completed += 1
582 |
583 | except Exception as e:
584 | errors.append(str(e))
585 |
586 | except Exception as e:
587 | errors.append(f"Session error: {str(e)}")
588 |
589 | operation_time = time.time() - start_time
590 |
591 | return {
592 | "operation_id": operation_id,
593 | "operations_completed": operations_completed,
594 | "errors": errors,
595 | "operation_time": operation_time,
596 | }
597 |
598 | # Run concurrent database operations
599 | concurrent_operations = 20
600 | db_tasks = [
601 | concurrent_database_operations(i) for i in range(concurrent_operations)
602 | ]
603 |
604 | start_time = time.time()
605 | db_results = await asyncio.gather(*db_tasks, return_exceptions=True)
606 | total_time = time.time() - start_time
607 |
608 | # Analyze database performance under load
609 | successful_operations = [r for r in db_results if isinstance(r, dict)]
610 | failed_operations = len(db_results) - len(successful_operations)
611 |
612 | total_operations = sum(r["operations_completed"] for r in successful_operations)
613 | total_errors = sum(len(r["errors"]) for r in successful_operations)
614 | avg_operation_time = statistics.mean(
615 | [r["operation_time"] for r in successful_operations]
616 | )
617 |
618 | db_throughput = total_operations / total_time if total_time > 0 else 0
619 | error_rate = total_errors / total_operations if total_operations > 0 else 0
620 |
621 | logger.info(
622 | f"Database Load Test Results:\n"
623 | f" • Concurrent Operations: {concurrent_operations}\n"
624 | f" • Successful Sessions: {len(successful_operations)}\n"
625 | f" • Failed Sessions: {failed_operations}\n"
626 | f" • Total DB Operations: {total_operations}\n"
627 | f" • DB Throughput: {db_throughput:.2f} ops/s\n"
628 | f" • Error Rate: {error_rate:.3%}\n"
629 | f" • Avg Operation Time: {avg_operation_time:.3f}s"
630 | )
631 |
632 | # Database performance assertions
633 | assert len(successful_operations) / len(db_results) >= 0.9, (
634 | "DB session success rate too low"
635 | )
636 | assert error_rate <= 0.05, f"DB error rate too high: {error_rate:.3%}"
637 | assert db_throughput >= 5.0, f"DB throughput too low: {db_throughput:.2f} ops/s"
638 |
639 | return {
640 | "concurrent_operations": concurrent_operations,
641 | "db_throughput": db_throughput,
642 | "error_rate": error_rate,
643 | "avg_operation_time": avg_operation_time,
644 | }
645 |
646 |
647 | if __name__ == "__main__":
648 | # Run load testing suite
649 | pytest.main(
650 | [
651 | __file__,
652 | "-v",
653 | "--tb=short",
654 | "--asyncio-mode=auto",
655 | "--timeout=600", # 10 minute timeout for load tests
656 | "--durations=10",
657 | ]
658 | )
659 |
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/strategies/ml/ensemble.py:
--------------------------------------------------------------------------------
```python
1 | """Strategy ensemble methods for combining multiple trading strategies."""
2 |
3 | import logging
4 | from typing import Any
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from pandas import DataFrame, Series
9 |
10 | from maverick_mcp.backtesting.strategies.base import Strategy
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class StrategyEnsemble(Strategy):
16 | """Ensemble strategy that combines multiple strategies with dynamic weighting."""
17 |
18 | def __init__(
19 | self,
20 | strategies: list[Strategy],
21 | weighting_method: str = "performance",
22 | lookback_period: int = 50,
23 | rebalance_frequency: int = 20,
24 | parameters: dict[str, Any] = None,
25 | ):
26 | """Initialize strategy ensemble.
27 |
28 | Args:
29 | strategies: List of base strategies to combine
30 | weighting_method: Method for calculating weights ('performance', 'equal', 'volatility')
31 | lookback_period: Period for calculating performance metrics
32 | rebalance_frequency: How often to update weights
33 | parameters: Additional parameters
34 | """
35 | super().__init__(parameters)
36 | self.strategies = strategies
37 | self.weighting_method = weighting_method
38 | self.lookback_period = lookback_period
39 | self.rebalance_frequency = rebalance_frequency
40 |
41 | # Initialize strategy weights
42 | self.weights = np.ones(len(strategies)) / len(strategies)
43 | self.strategy_returns = {}
44 | self.strategy_signals = {}
45 | self.last_rebalance = 0
46 |
47 | @property
48 | def name(self) -> str:
49 | """Get strategy name."""
50 | strategy_names = [s.name for s in self.strategies]
51 | return f"Ensemble({','.join(strategy_names)})"
52 |
53 | @property
54 | def description(self) -> str:
55 | """Get strategy description."""
56 | return f"Dynamic ensemble combining {len(self.strategies)} strategies using {self.weighting_method} weighting"
57 |
58 | def calculate_performance_weights(self, data: DataFrame) -> np.ndarray:
59 | """Calculate performance-based weights for strategies.
60 |
61 | Args:
62 | data: Price data for performance calculation
63 |
64 | Returns:
65 | Array of strategy weights
66 | """
67 | if len(self.strategy_returns) < 2:
68 | return self.weights
69 |
70 | # Calculate Sharpe ratios for each strategy
71 | sharpe_ratios = []
72 | for i, _strategy in enumerate(self.strategies):
73 | if (
74 | i in self.strategy_returns
75 | and len(self.strategy_returns[i]) >= self.lookback_period
76 | ):
77 | returns = pd.Series(self.strategy_returns[i][-self.lookback_period :])
78 | sharpe = returns.mean() / (returns.std() + 1e-8) * np.sqrt(252)
79 | sharpe_ratios.append(max(0, sharpe)) # Ensure non-negative
80 | else:
81 | sharpe_ratios.append(0.1) # Small positive weight for new strategies
82 |
83 | # Convert to weights (softmax-like normalization)
84 | sharpe_array = np.array(sharpe_ratios)
85 | # Fix: Properly check for empty array and zero sum conditions
86 | if sharpe_array.size == 0 or np.sum(sharpe_array) == 0:
87 | weights = np.ones(len(self.strategies)) / len(self.strategies)
88 | else:
89 | # Exponential weighting to emphasize better performers
90 | exp_sharpe = np.exp(sharpe_array * 2)
91 | weights = exp_sharpe / exp_sharpe.sum()
92 |
93 | return weights
94 |
95 | def calculate_volatility_weights(self, data: DataFrame) -> np.ndarray:
96 | """Calculate inverse volatility weights for strategies.
97 |
98 | Args:
99 | data: Price data for volatility calculation
100 |
101 | Returns:
102 | Array of strategy weights
103 | """
104 | if len(self.strategy_returns) < 2:
105 | return self.weights
106 |
107 | # Calculate volatilities for each strategy
108 | volatilities = []
109 | for i, _strategy in enumerate(self.strategies):
110 | if (
111 | i in self.strategy_returns
112 | and len(self.strategy_returns[i]) >= self.lookback_period
113 | ):
114 | returns = pd.Series(self.strategy_returns[i][-self.lookback_period :])
115 | vol = returns.std() * np.sqrt(252)
116 | volatilities.append(max(0.01, vol)) # Minimum volatility
117 | else:
118 | volatilities.append(0.2) # Default volatility assumption
119 |
120 | # Inverse volatility weighting
121 | vol_array = np.array(volatilities)
122 | inv_vol = 1.0 / vol_array
123 | weights = inv_vol / inv_vol.sum()
124 |
125 | return weights
126 |
127 | def update_weights(self, data: DataFrame, current_index: int) -> None:
128 | """Update strategy weights based on recent performance.
129 |
130 | Args:
131 | data: Price data
132 | current_index: Current position in data
133 | """
134 | # Check if it's time to rebalance
135 | if current_index - self.last_rebalance < self.rebalance_frequency:
136 | return
137 |
138 | try:
139 | if self.weighting_method == "performance":
140 | self.weights = self.calculate_performance_weights(data)
141 | elif self.weighting_method == "volatility":
142 | self.weights = self.calculate_volatility_weights(data)
143 | elif self.weighting_method == "equal":
144 | self.weights = np.ones(len(self.strategies)) / len(self.strategies)
145 | else:
146 | logger.warning(f"Unknown weighting method: {self.weighting_method}")
147 |
148 | self.last_rebalance = current_index
149 |
150 | logger.debug(
151 | f"Updated ensemble weights: {dict(zip([s.name for s in self.strategies], self.weights, strict=False))}"
152 | )
153 |
154 | except Exception as e:
155 | logger.error(f"Error updating weights: {e}")
156 |
157 | def generate_individual_signals(
158 | self, data: DataFrame
159 | ) -> dict[int, tuple[Series, Series]]:
160 | """Generate signals from all individual strategies with enhanced error handling.
161 |
162 | Args:
163 | data: Price data
164 |
165 | Returns:
166 | Dictionary mapping strategy index to (entry_signals, exit_signals)
167 | """
168 | signals = {}
169 | failed_strategies = []
170 |
171 | for i, strategy in enumerate(self.strategies):
172 | try:
173 | # Generate signals with timeout protection
174 | entry_signals, exit_signals = strategy.generate_signals(data)
175 |
176 | # Validate signals
177 | if not isinstance(entry_signals, pd.Series) or not isinstance(
178 | exit_signals, pd.Series
179 | ):
180 | raise ValueError(
181 | f"Strategy {strategy.name} returned invalid signal types"
182 | )
183 |
184 | if len(entry_signals) != len(data) or len(exit_signals) != len(data):
185 | raise ValueError(
186 | f"Strategy {strategy.name} returned signals with wrong length"
187 | )
188 |
189 | if not entry_signals.dtype == bool or not exit_signals.dtype == bool:
190 | # Convert to boolean if necessary
191 | entry_signals = entry_signals.astype(bool)
192 | exit_signals = exit_signals.astype(bool)
193 |
194 | signals[i] = (entry_signals, exit_signals)
195 |
196 | # Calculate strategy returns for weight updates (with error handling)
197 | try:
198 | positions = entry_signals.astype(int) - exit_signals.astype(int)
199 | price_returns = data["close"].pct_change()
200 | returns = positions.shift(1) * price_returns
201 |
202 | # Remove invalid returns
203 | valid_returns = returns.dropna()
204 | valid_returns = valid_returns[np.isfinite(valid_returns)]
205 |
206 | if i not in self.strategy_returns:
207 | self.strategy_returns[i] = []
208 |
209 | if len(valid_returns) > 0:
210 | self.strategy_returns[i].extend(valid_returns.tolist())
211 |
212 | # Keep only recent returns for performance calculation
213 | if len(self.strategy_returns[i]) > self.lookback_period * 2:
214 | self.strategy_returns[i] = self.strategy_returns[i][
215 | -self.lookback_period * 2 :
216 | ]
217 |
218 | except Exception as return_error:
219 | logger.debug(
220 | f"Error calculating returns for strategy {strategy.name}: {return_error}"
221 | )
222 |
223 | logger.debug(
224 | f"Strategy {strategy.name}: {entry_signals.sum()} entries, {exit_signals.sum()} exits"
225 | )
226 |
227 | except Exception as e:
228 | logger.error(
229 | f"Error generating signals for strategy {strategy.name}: {e}"
230 | )
231 | failed_strategies.append(i)
232 |
233 | # Create safe fallback signals
234 | try:
235 | signals[i] = (
236 | pd.Series(False, index=data.index),
237 | pd.Series(False, index=data.index),
238 | )
239 | except Exception:
240 | # If even creating empty signals fails, skip this strategy
241 | logger.error(f"Cannot create fallback signals for strategy {i}")
242 | continue
243 |
244 | # Log summary of strategy performance
245 | if failed_strategies:
246 | failed_names = [self.strategies[i].name for i in failed_strategies]
247 | logger.warning(f"Failed strategies: {failed_names}")
248 |
249 | successful_strategies = len(signals) - len(failed_strategies)
250 | logger.info(
251 | f"Successfully generated signals from {successful_strategies}/{len(self.strategies)} strategies"
252 | )
253 |
254 | return signals
255 |
256 | def combine_signals(
257 | self, individual_signals: dict[int, tuple[Series, Series]]
258 | ) -> tuple[Series, Series]:
259 | """Combine individual strategy signals using enhanced weighted voting.
260 |
261 | Args:
262 | individual_signals: Dictionary of individual strategy signals
263 |
264 | Returns:
265 | Tuple of combined (entry_signals, exit_signals)
266 | """
267 | if not individual_signals:
268 | # Return empty series with minimal index when no individual signals available
269 | empty_index = pd.Index([])
270 | return pd.Series(False, index=empty_index), pd.Series(
271 | False, index=empty_index
272 | )
273 |
274 | # Get data index from first strategy
275 | first_signals = next(iter(individual_signals.values()))
276 | data_index = first_signals[0].index
277 |
278 | # Initialize voting arrays
279 | entry_votes = np.zeros(len(data_index))
280 | exit_votes = np.zeros(len(data_index))
281 | total_weights = 0
282 |
283 | # Collect votes with weights and confidence scores
284 | valid_strategies = 0
285 |
286 | for i, (entry_signals, exit_signals) in individual_signals.items():
287 | weight = self.weights[i] if i < len(self.weights) else 0
288 |
289 | if weight > 0:
290 | # Add weighted votes
291 | entry_votes += weight * entry_signals.astype(float)
292 | exit_votes += weight * exit_signals.astype(float)
293 | total_weights += weight
294 | valid_strategies += 1
295 |
296 | if total_weights == 0 or valid_strategies == 0:
297 | logger.warning("No valid strategies with positive weights")
298 | return pd.Series(False, index=data_index), pd.Series(
299 | False, index=data_index
300 | )
301 |
302 | # Normalize votes by total weights
303 | entry_votes = entry_votes / total_weights
304 | exit_votes = exit_votes / total_weights
305 |
306 | # Enhanced voting mechanisms
307 | voting_method = self.parameters.get("voting_method", "weighted")
308 |
309 | if voting_method == "majority":
310 | # Simple majority vote (more than half of strategies agree)
311 | entry_threshold = 0.5
312 | exit_threshold = 0.5
313 | elif voting_method == "supermajority":
314 | # Require 2/3 agreement
315 | entry_threshold = 0.67
316 | exit_threshold = 0.67
317 | elif voting_method == "consensus":
318 | # Require near-unanimous agreement
319 | entry_threshold = 0.8
320 | exit_threshold = 0.8
321 | else: # weighted (default)
322 | entry_threshold = self.parameters.get("entry_threshold", 0.5)
323 | exit_threshold = self.parameters.get("exit_threshold", 0.5)
324 |
325 | # Anti-conflict mechanism: don't signal entry and exit simultaneously
326 | combined_entry = entry_votes > entry_threshold
327 | combined_exit = exit_votes > exit_threshold
328 |
329 | # Resolve conflicts (simultaneous entry and exit signals)
330 | conflicts = combined_entry & combined_exit
331 | # Fix: Check array size and ensure it's not empty before evaluating boolean truth
332 | if conflicts.size > 0 and np.any(conflicts):
333 | logger.debug(f"Resolving {conflicts.sum()} signal conflicts")
334 | # In case of conflict, use the stronger signal
335 | entry_strength = entry_votes[conflicts]
336 | exit_strength = exit_votes[conflicts]
337 |
338 | # Keep only the stronger signal
339 | stronger_entry = entry_strength > exit_strength
340 | combined_entry[conflicts] = stronger_entry
341 | combined_exit[conflicts] = ~stronger_entry
342 |
343 | # Quality filter: require minimum signal strength
344 | min_signal_strength = self.parameters.get("min_signal_strength", 0.1)
345 | weak_entry_signals = (combined_entry) & (entry_votes < min_signal_strength)
346 | weak_exit_signals = (combined_exit) & (exit_votes < min_signal_strength)
347 |
348 | # Fix: Ensure arrays are not empty before boolean indexing
349 | if weak_entry_signals.size > 0:
350 | combined_entry[weak_entry_signals] = False
351 | if weak_exit_signals.size > 0:
352 | combined_exit[weak_exit_signals] = False
353 |
354 | # Convert to pandas Series
355 | combined_entry = pd.Series(combined_entry, index=data_index)
356 | combined_exit = pd.Series(combined_exit, index=data_index)
357 |
358 | return combined_entry, combined_exit
359 |
360 | def generate_signals(self, data: DataFrame) -> tuple[Series, Series]:
361 | """Generate ensemble trading signals.
362 |
363 | Args:
364 | data: Price data with OHLCV columns
365 |
366 | Returns:
367 | Tuple of (entry_signals, exit_signals) as boolean Series
368 | """
369 | try:
370 | # Generate signals from all individual strategies
371 | individual_signals = self.generate_individual_signals(data)
372 |
373 | if not individual_signals:
374 | return pd.Series(False, index=data.index), pd.Series(
375 | False, index=data.index
376 | )
377 |
378 | # Update weights periodically
379 | for idx in range(
380 | self.rebalance_frequency, len(data), self.rebalance_frequency
381 | ):
382 | self.update_weights(data.iloc[:idx], idx)
383 |
384 | # Combine signals
385 | entry_signals, exit_signals = self.combine_signals(individual_signals)
386 |
387 | logger.info(
388 | f"Generated ensemble signals: {entry_signals.sum()} entries, {exit_signals.sum()} exits"
389 | )
390 |
391 | return entry_signals, exit_signals
392 |
393 | except Exception as e:
394 | logger.error(f"Error generating ensemble signals: {e}")
395 | return pd.Series(False, index=data.index), pd.Series(
396 | False, index=data.index
397 | )
398 |
399 | def get_strategy_weights(self) -> dict[str, float]:
400 | """Get current strategy weights.
401 |
402 | Returns:
403 | Dictionary mapping strategy names to weights
404 | """
405 | return dict(zip([s.name for s in self.strategies], self.weights, strict=False))
406 |
407 | def get_strategy_performance(self) -> dict[str, dict[str, float]]:
408 | """Get performance metrics for individual strategies.
409 |
410 | Returns:
411 | Dictionary mapping strategy names to performance metrics
412 | """
413 | performance = {}
414 |
415 | for i, strategy in enumerate(self.strategies):
416 | if i in self.strategy_returns and len(self.strategy_returns[i]) > 0:
417 | returns = pd.Series(self.strategy_returns[i])
418 |
419 | performance[strategy.name] = {
420 | "total_return": returns.sum(),
421 | "annual_return": returns.mean() * 252,
422 | "volatility": returns.std() * np.sqrt(252),
423 | "sharpe_ratio": returns.mean()
424 | / (returns.std() + 1e-8)
425 | * np.sqrt(252),
426 | "max_drawdown": (
427 | returns.cumsum() - returns.cumsum().expanding().max()
428 | ).min(),
429 | "win_rate": (returns > 0).mean(),
430 | "current_weight": self.weights[i],
431 | }
432 | else:
433 | performance[strategy.name] = {
434 | "total_return": 0.0,
435 | "annual_return": 0.0,
436 | "volatility": 0.0,
437 | "sharpe_ratio": 0.0,
438 | "max_drawdown": 0.0,
439 | "win_rate": 0.0,
440 | "current_weight": self.weights[i] if i < len(self.weights) else 0.0,
441 | }
442 |
443 | return performance
444 |
445 | def validate_parameters(self) -> bool:
446 | """Validate ensemble parameters.
447 |
448 | Returns:
449 | True if parameters are valid
450 | """
451 | if not self.strategies:
452 | return False
453 |
454 | if self.weighting_method not in ["performance", "equal", "volatility"]:
455 | return False
456 |
457 | if self.lookback_period <= 0 or self.rebalance_frequency <= 0:
458 | return False
459 |
460 | # Validate individual strategies
461 | for strategy in self.strategies:
462 | if not strategy.validate_parameters():
463 | return False
464 |
465 | return True
466 |
467 | def get_default_parameters(self) -> dict[str, Any]:
468 | """Get default ensemble parameters.
469 |
470 | Returns:
471 | Dictionary of default parameters
472 | """
473 | return {
474 | "weighting_method": "performance",
475 | "lookback_period": 50,
476 | "rebalance_frequency": 20,
477 | "entry_threshold": 0.5,
478 | "exit_threshold": 0.5,
479 | "voting_method": "weighted", # weighted, majority, supermajority, consensus
480 | "min_signal_strength": 0.1, # Minimum signal strength to avoid weak signals
481 | "conflict_resolution": "stronger", # How to resolve entry/exit conflicts
482 | }
483 |
484 | def to_dict(self) -> dict[str, Any]:
485 | """Convert ensemble to dictionary representation.
486 |
487 | Returns:
488 | Dictionary with ensemble details
489 | """
490 | base_dict = super().to_dict()
491 | base_dict.update(
492 | {
493 | "strategies": [s.to_dict() for s in self.strategies],
494 | "current_weights": self.get_strategy_weights(),
495 | "weighting_method": self.weighting_method,
496 | "lookback_period": self.lookback_period,
497 | "rebalance_frequency": self.rebalance_frequency,
498 | }
499 | )
500 |
501 | return base_dict
502 |
503 |
504 | class RiskAdjustedEnsemble(StrategyEnsemble):
505 | """Risk-adjusted ensemble with position sizing and risk management."""
506 |
507 | def __init__(
508 | self,
509 | strategies: list[Strategy],
510 | max_position_size: float = 0.1,
511 | max_correlation: float = 0.7,
512 | risk_target: float = 0.15,
513 | **kwargs,
514 | ):
515 | """Initialize risk-adjusted ensemble.
516 |
517 | Args:
518 | strategies: List of base strategies
519 | max_position_size: Maximum position size per strategy
520 | max_correlation: Maximum correlation between strategies
521 | risk_target: Target portfolio volatility
522 | **kwargs: Additional parameters for base ensemble
523 | """
524 | super().__init__(strategies, **kwargs)
525 | self.max_position_size = max_position_size
526 | self.max_correlation = max_correlation
527 | self.risk_target = risk_target
528 |
529 | def calculate_correlation_matrix(self) -> pd.DataFrame:
530 | """Calculate correlation matrix between strategy returns.
531 |
532 | Returns:
533 | Correlation matrix as DataFrame
534 | """
535 | if len(self.strategy_returns) < 2:
536 | return pd.DataFrame()
537 |
538 | # Create returns DataFrame
539 | min_length = min(
540 | len(returns)
541 | for returns in self.strategy_returns.values()
542 | if len(returns) > 0
543 | )
544 | if min_length == 0:
545 | return pd.DataFrame()
546 |
547 | returns_data = {}
548 | for i, strategy in enumerate(self.strategies):
549 | if (
550 | i in self.strategy_returns
551 | and len(self.strategy_returns[i]) >= min_length
552 | ):
553 | returns_data[strategy.name] = self.strategy_returns[i][-min_length:]
554 |
555 | if not returns_data:
556 | return pd.DataFrame()
557 |
558 | returns_df = pd.DataFrame(returns_data)
559 | return returns_df.corr()
560 |
561 | def adjust_weights_for_correlation(self, weights: np.ndarray) -> np.ndarray:
562 | """Adjust weights to account for strategy correlation.
563 |
564 | Args:
565 | weights: Original weights
566 |
567 | Returns:
568 | Correlation-adjusted weights
569 | """
570 | corr_matrix = self.calculate_correlation_matrix()
571 |
572 | if corr_matrix.empty:
573 | return weights
574 |
575 | try:
576 | # Penalize highly correlated strategies
577 | adjusted_weights = weights.copy()
578 |
579 | for i, strategy_i in enumerate(self.strategies):
580 | for j, strategy_j in enumerate(self.strategies):
581 | if (
582 | i != j
583 | and strategy_i.name in corr_matrix.index
584 | and strategy_j.name in corr_matrix.columns
585 | ):
586 | correlation = abs(
587 | corr_matrix.loc[strategy_i.name, strategy_j.name]
588 | )
589 |
590 | if correlation > self.max_correlation:
591 | # Reduce weight of both strategies
592 | penalty = (correlation - self.max_correlation) * 0.5
593 | adjusted_weights[i] *= 1 - penalty
594 | adjusted_weights[j] *= 1 - penalty
595 |
596 | # Renormalize weights
597 | # Fix: Check array size and sum properly before normalization
598 | if adjusted_weights.size > 0 and np.sum(adjusted_weights) > 0:
599 | adjusted_weights /= adjusted_weights.sum()
600 | else:
601 | adjusted_weights = np.ones(len(self.strategies)) / len(self.strategies)
602 |
603 | return adjusted_weights
604 |
605 | except Exception as e:
606 | logger.error(f"Error adjusting weights for correlation: {e}")
607 | return weights
608 |
609 | def calculate_risk_adjusted_weights(self, data: DataFrame) -> np.ndarray:
610 | """Calculate risk-adjusted weights based on target volatility.
611 |
612 | Args:
613 | data: Price data
614 |
615 | Returns:
616 | Risk-adjusted weights
617 | """
618 | # Start with performance-based weights
619 | base_weights = self.calculate_performance_weights(data)
620 |
621 | # Adjust for correlation
622 | corr_adjusted_weights = self.adjust_weights_for_correlation(base_weights)
623 |
624 | # Apply position size limits
625 | position_adjusted_weights = np.minimum(
626 | corr_adjusted_weights, self.max_position_size
627 | )
628 |
629 | # Renormalize
630 | # Fix: Check array size and sum properly before normalization
631 | if position_adjusted_weights.size > 0 and np.sum(position_adjusted_weights) > 0:
632 | position_adjusted_weights /= position_adjusted_weights.sum()
633 | else:
634 | position_adjusted_weights = np.ones(len(self.strategies)) / len(
635 | self.strategies
636 | )
637 |
638 | return position_adjusted_weights
639 |
640 | def update_weights(self, data: DataFrame, current_index: int) -> None:
641 | """Update risk-adjusted weights.
642 |
643 | Args:
644 | data: Price data
645 | current_index: Current position in data
646 | """
647 | if current_index - self.last_rebalance < self.rebalance_frequency:
648 | return
649 |
650 | try:
651 | self.weights = self.calculate_risk_adjusted_weights(data)
652 | self.last_rebalance = current_index
653 |
654 | logger.debug(
655 | f"Updated risk-adjusted weights: {dict(zip([s.name for s in self.strategies], self.weights, strict=False))}"
656 | )
657 |
658 | except Exception as e:
659 | logger.error(f"Error updating risk-adjusted weights: {e}")
660 |
661 | @property
662 | def name(self) -> str:
663 | """Get strategy name."""
664 | return f"RiskAdjusted{super().name}"
665 |
666 | @property
667 | def description(self) -> str:
668 | """Get strategy description."""
669 | return "Risk-adjusted ensemble with correlation control and position sizing"
670 |
```
--------------------------------------------------------------------------------
/maverick_mcp/agents/technical_analysis.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Technical Analysis Agent with pattern recognition and multi-timeframe analysis.
3 | """
4 |
5 | import logging
6 | from datetime import datetime
7 | from typing import Any
8 |
9 | from langchain_core.messages import HumanMessage
10 | from langchain_core.tools import BaseTool
11 | from langgraph.checkpoint.memory import MemorySaver
12 | from langgraph.graph import END, START, StateGraph
13 |
14 | from maverick_mcp.agents.circuit_breaker import circuit_manager
15 | from maverick_mcp.langchain_tools import get_tool_registry
16 | from maverick_mcp.memory import ConversationStore
17 | from maverick_mcp.tools.risk_management import TechnicalStopsTool
18 | from maverick_mcp.workflows.state import TechnicalAnalysisState
19 |
20 | from .base import PersonaAwareAgent
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | class TechnicalAnalysisAgent(PersonaAwareAgent):
26 | """
27 | Professional technical analysis agent with pattern recognition.
28 |
29 | Features:
30 | - Chart pattern detection (head & shoulders, triangles, flags)
31 | - Multi-timeframe analysis
32 | - Indicator confluence scoring
33 | - Support/resistance clustering
34 | - Volume profile analysis
35 | - LLM-powered technical narratives
36 | """
37 |
38 | def __init__(
39 | self,
40 | llm,
41 | persona: str = "moderate",
42 | ttl_hours: int = 1,
43 | ):
44 | """
45 | Initialize technical analysis agent.
46 |
47 | Args:
48 | llm: Language model
49 | persona: Investor persona
50 | ttl_hours: Cache TTL in hours
51 | postgres_url: Optional PostgreSQL URL for checkpointing
52 | """
53 | # Store persona temporarily for tool configuration
54 | self._temp_persona = persona
55 |
56 | # Get technical analysis tools
57 | tools = self._get_technical_tools()
58 |
59 | # Initialize with MemorySaver
60 | super().__init__(
61 | llm=llm,
62 | tools=tools,
63 | persona=persona,
64 | checkpointer=MemorySaver(),
65 | ttl_hours=ttl_hours,
66 | )
67 |
68 | # Initialize conversation store
69 | self.conversation_store = ConversationStore(ttl_hours=ttl_hours)
70 |
71 | def _get_technical_tools(self) -> list[BaseTool]:
72 | """Get comprehensive technical analysis tools."""
73 | registry = get_tool_registry()
74 |
75 | # Core technical tools
76 | technical_tools = [
77 | registry.get_tool("get_technical_indicators"),
78 | registry.get_tool("calculate_support_resistance"),
79 | registry.get_tool("detect_chart_patterns"),
80 | registry.get_tool("calculate_moving_averages"),
81 | registry.get_tool("calculate_oscillators"),
82 | ]
83 |
84 | # Price action tools
85 | price_tools = [
86 | registry.get_tool("get_stock_price"),
87 | registry.get_tool("get_stock_history"),
88 | registry.get_tool("get_intraday_data"),
89 | ]
90 |
91 | # Volume analysis tools
92 | volume_tools = [
93 | registry.get_tool("analyze_volume_profile"),
94 | registry.get_tool("detect_volume_patterns"),
95 | ]
96 |
97 | # Risk tools
98 | risk_tools = [
99 | TechnicalStopsTool(),
100 | ]
101 |
102 | # Combine and filter
103 | all_tools = technical_tools + price_tools + volume_tools + risk_tools
104 | tools = [t for t in all_tools if t is not None]
105 |
106 | # Configure persona for PersonaAwareTools
107 | for tool in tools:
108 | if hasattr(tool, "set_persona"):
109 | tool.set_persona(self._temp_persona)
110 |
111 | if not tools:
112 | logger.warning("No technical tools available, using mock tools")
113 | tools = self._create_mock_tools()
114 |
115 | return tools
116 |
117 | def get_state_schema(self) -> type:
118 | """Return enhanced state schema for technical analysis."""
119 | return TechnicalAnalysisState
120 |
121 | def _build_system_prompt(self) -> str:
122 | """Build comprehensive system prompt for technical analysis."""
123 | base_prompt = super()._build_system_prompt()
124 |
125 | technical_prompt = f"""
126 |
127 | You are a professional technical analyst specializing in pattern recognition and multi-timeframe analysis.
128 | Current date: {datetime.now().strftime("%Y-%m-%d")}
129 |
130 | ## Core Responsibilities:
131 |
132 | 1. **Pattern Recognition**:
133 | - Chart patterns: Head & Shoulders, Triangles, Flags, Wedges
134 | - Candlestick patterns: Doji, Hammer, Engulfing, etc.
135 | - Support/Resistance: Dynamic and static levels
136 | - Trend lines and channels
137 |
138 | 2. **Multi-Timeframe Analysis**:
139 | - Align signals across daily, hourly, and 5-minute charts
140 | - Identify confluences between timeframes
141 | - Spot divergences early
142 | - Time entries based on lower timeframe setups
143 |
144 | 3. **Indicator Analysis**:
145 | - Trend: Moving averages, ADX, MACD
146 | - Momentum: RSI, Stochastic, CCI
147 | - Volume: OBV, Volume Profile, VWAP
148 | - Volatility: Bollinger Bands, ATR, Keltner Channels
149 |
150 | 4. **Trade Setup Construction**:
151 | - Entry points with specific triggers
152 | - Stop loss placement using ATR or structure
153 | - Profit targets based on measured moves
154 | - Risk/Reward ratio calculation
155 |
156 | ## Analysis Framework by Persona:
157 |
158 | **Conservative ({self.persona.name if self.persona.name == "Conservative" else "N/A"})**:
159 | - Wait for confirmed patterns only
160 | - Use wider stops above/below structure
161 | - Target 1.5:1 risk/reward minimum
162 | - Focus on daily/weekly timeframes
163 |
164 | **Moderate ({self.persona.name if self.persona.name == "Moderate" else "N/A"})**:
165 | - Balance pattern quality with opportunity
166 | - Standard ATR-based stops
167 | - Target 2:1 risk/reward
168 | - Use daily/4H timeframes
169 |
170 | **Aggressive ({self.persona.name if self.persona.name == "Aggressive" else "N/A"})**:
171 | - Trade emerging patterns
172 | - Tighter stops for larger positions
173 | - Target 3:1+ risk/reward
174 | - Include intraday timeframes
175 |
176 | **Day Trader ({self.persona.name if self.persona.name == "Day Trader" else "N/A"})**:
177 | - Focus on intraday patterns
178 | - Use tick/volume charts
179 | - Quick scalps with tight stops
180 | - Multiple entries/exits
181 |
182 | ## Technical Analysis Process:
183 |
184 | 1. **Market Structure**: Identify trend direction and strength
185 | 2. **Key Levels**: Map support/resistance zones
186 | 3. **Pattern Search**: Scan for actionable patterns
187 | 4. **Indicator Confluence**: Check for agreement
188 | 5. **Volume Confirmation**: Validate with volume
189 | 6. **Risk Definition**: Calculate stops and targets
190 | 7. **Setup Quality**: Rate A+ to C based on confluence
191 |
192 | Remember to:
193 | - Be specific with price levels
194 | - Explain pattern psychology
195 | - Highlight invalidation levels
196 | - Consider market context
197 | - Provide clear action plans
198 | """
199 |
200 | return base_prompt + technical_prompt
201 |
202 | def _build_graph(self):
203 | """Build enhanced graph with technical analysis nodes."""
204 | workflow = StateGraph(TechnicalAnalysisState)
205 |
206 | # Add specialized nodes with unique names
207 | workflow.add_node("analyze_structure", self._analyze_structure)
208 | workflow.add_node("detect_patterns", self._detect_patterns)
209 | workflow.add_node("analyze_indicators", self._analyze_indicators)
210 | workflow.add_node("construct_trade_setup", self._construct_trade_setup)
211 | workflow.add_node("agent", self._agent_node)
212 |
213 | # Create tool node if tools available
214 | if self.tools:
215 | from langgraph.prebuilt import ToolNode
216 |
217 | tool_node = ToolNode(self.tools)
218 | workflow.add_node("tools", tool_node)
219 |
220 | # Define flow
221 | workflow.add_edge(START, "analyze_structure")
222 | workflow.add_edge("analyze_structure", "detect_patterns")
223 | workflow.add_edge("detect_patterns", "analyze_indicators")
224 | workflow.add_edge("analyze_indicators", "construct_trade_setup")
225 | workflow.add_edge("construct_trade_setup", "agent")
226 |
227 | if self.tools:
228 | workflow.add_conditional_edges(
229 | "agent",
230 | self._should_continue,
231 | {
232 | "continue": "tools",
233 | "end": END,
234 | },
235 | )
236 | workflow.add_edge("tools", "agent")
237 | else:
238 | workflow.add_edge("agent", END)
239 |
240 | return workflow.compile(checkpointer=self.checkpointer)
241 |
242 | async def _analyze_structure(self, state: TechnicalAnalysisState) -> dict[str, Any]:
243 | """Analyze market structure and identify key levels."""
244 | try:
245 | # Get support/resistance tool
246 | sr_tool = next(
247 | (t for t in self.tools if "support_resistance" in t.name), None
248 | )
249 |
250 | if sr_tool and state.get("symbol"):
251 | circuit_breaker = await circuit_manager.get_or_create("technical")
252 |
253 | async def get_levels():
254 | return await sr_tool.ainvoke(
255 | {
256 | "symbol": state["symbol"],
257 | "lookback_days": state.get("lookback_days", 20),
258 | }
259 | )
260 |
261 | levels_data = await circuit_breaker.call(get_levels)
262 |
263 | # Extract support/resistance levels
264 | if isinstance(levels_data, dict):
265 | state["support_levels"] = levels_data.get("support_levels", [])
266 | state["resistance_levels"] = levels_data.get(
267 | "resistance_levels", []
268 | )
269 |
270 | # Determine trend based on structure
271 | if levels_data.get("trend"):
272 | state["trend_direction"] = levels_data["trend"]
273 | else:
274 | # Simple trend determination
275 | current = state.get("current_price", 0)
276 | ma_50 = levels_data.get("ma_50", current)
277 | state["trend_direction"] = (
278 | "bullish" if current > ma_50 else "bearish"
279 | )
280 |
281 | except Exception as e:
282 | logger.error(f"Error analyzing structure: {e}")
283 |
284 | state["api_calls_made"] = state.get("api_calls_made", 0) + 1
285 | return {
286 | "support_levels": state.get("support_levels", []),
287 | "resistance_levels": state.get("resistance_levels", []),
288 | "trend_direction": state.get("trend_direction", "neutral"),
289 | }
290 |
291 | async def _detect_patterns(self, state: TechnicalAnalysisState) -> dict[str, Any]:
292 | """Detect chart patterns."""
293 | try:
294 | # Get pattern detection tool
295 | pattern_tool = next((t for t in self.tools if "pattern" in t.name), None)
296 |
297 | if pattern_tool and state.get("symbol"):
298 | circuit_breaker = await circuit_manager.get_or_create("technical")
299 |
300 | async def detect():
301 | return await pattern_tool.ainvoke(
302 | {
303 | "symbol": state["symbol"],
304 | "timeframe": state.get("timeframe", "1d"),
305 | }
306 | )
307 |
308 | pattern_data = await circuit_breaker.call(detect)
309 |
310 | if isinstance(pattern_data, dict) and "patterns" in pattern_data:
311 | patterns = pattern_data["patterns"]
312 | state["patterns"] = patterns
313 |
314 | # Calculate pattern confidence scores
315 | pattern_confidence = {}
316 | for pattern in patterns:
317 | name = pattern.get("name", "Unknown")
318 | confidence = pattern.get("confidence", 50)
319 | pattern_confidence[name] = confidence
320 |
321 | state["pattern_confidence"] = pattern_confidence
322 |
323 | except Exception as e:
324 | logger.error(f"Error detecting patterns: {e}")
325 |
326 | state["api_calls_made"] = state.get("api_calls_made", 0) + 1
327 | return {
328 | "patterns": state.get("patterns", []),
329 | "pattern_confidence": state.get("pattern_confidence", {}),
330 | }
331 |
332 | async def _analyze_indicators(
333 | self, state: TechnicalAnalysisState
334 | ) -> dict[str, Any]:
335 | """Analyze technical indicators."""
336 | try:
337 | # Get indicators tool
338 | indicators_tool = next(
339 | (t for t in self.tools if "technical_indicators" in t.name), None
340 | )
341 |
342 | if indicators_tool and state.get("symbol"):
343 | circuit_breaker = await circuit_manager.get_or_create("technical")
344 |
345 | indicators = state.get("indicators", ["RSI", "MACD", "BB"])
346 |
347 | async def get_indicators():
348 | return await indicators_tool.ainvoke(
349 | {
350 | "symbol": state["symbol"],
351 | "indicators": indicators,
352 | "period": state.get("lookback_days", 20),
353 | }
354 | )
355 |
356 | indicator_data = await circuit_breaker.call(get_indicators)
357 |
358 | if isinstance(indicator_data, dict):
359 | # Store indicator values
360 | state["indicator_values"] = indicator_data.get("values", {})
361 |
362 | # Generate indicator signals
363 | signals = self._generate_indicator_signals(indicator_data)
364 | state["indicator_signals"] = signals
365 |
366 | # Check for divergences
367 | divergences = self._check_divergences(
368 | state.get("price_history", {}), indicator_data
369 | )
370 | state["divergences"] = divergences
371 |
372 | except Exception as e:
373 | logger.error(f"Error analyzing indicators: {e}")
374 |
375 | state["api_calls_made"] = state.get("api_calls_made", 0) + 1
376 | return {
377 | "indicator_values": state.get("indicator_values", {}),
378 | "indicator_signals": state.get("indicator_signals", {}),
379 | "divergences": state.get("divergences", []),
380 | }
381 |
382 | async def _construct_trade_setup(
383 | self, state: TechnicalAnalysisState
384 | ) -> dict[str, Any]:
385 | """Construct complete trade setup."""
386 | try:
387 | current_price = state.get("current_price", 0)
388 |
389 | if current_price > 0:
390 | # Calculate entry points based on patterns and levels
391 | entry_points = self._calculate_entry_points(state)
392 | state["entry_points"] = entry_points
393 |
394 | # Get stop loss recommendation
395 | stops_tool = next(
396 | (t for t in self.tools if isinstance(t, TechnicalStopsTool)), None
397 | )
398 |
399 | if stops_tool:
400 | stops_data = await stops_tool.ainvoke(
401 | {
402 | "symbol": state["symbol"],
403 | "lookback_days": 20,
404 | }
405 | )
406 |
407 | if isinstance(stops_data, dict):
408 | stop_loss = stops_data.get(
409 | "recommended_stop", current_price * 0.95
410 | )
411 | else:
412 | stop_loss = current_price * 0.95
413 | else:
414 | stop_loss = current_price * 0.95
415 |
416 | state["stop_loss"] = stop_loss
417 |
418 | # Calculate profit targets
419 | risk = current_price - stop_loss
420 | targets = [
421 | current_price + (risk * 1.5), # 1.5R
422 | current_price + (risk * 2.0), # 2R
423 | current_price + (risk * 3.0), # 3R
424 | ]
425 | state["profit_targets"] = targets
426 |
427 | # Calculate risk/reward
428 | state["risk_reward_ratio"] = 2.0 # Default target
429 |
430 | # Rate setup quality
431 | quality = self._rate_setup_quality(state)
432 | state["setup_quality"] = quality
433 |
434 | # Calculate confidence score
435 | confidence = self._calculate_confidence_score(state)
436 | state["confidence_score"] = confidence
437 |
438 | except Exception as e:
439 | logger.error(f"Error constructing trade setup: {e}")
440 |
441 | return {
442 | "entry_points": state.get("entry_points", []),
443 | "stop_loss": state.get("stop_loss", 0),
444 | "profit_targets": state.get("profit_targets", []),
445 | "risk_reward_ratio": state.get("risk_reward_ratio", 0),
446 | "setup_quality": state.get("setup_quality", "C"),
447 | "confidence_score": state.get("confidence_score", 0),
448 | }
449 |
450 | def _generate_indicator_signals(self, indicator_data: dict) -> dict[str, str]:
451 | """Generate buy/sell/hold signals from indicators."""
452 | signals = {}
453 |
454 | # RSI signals
455 | rsi = indicator_data.get("RSI", {}).get("value", 50)
456 | if rsi < 30:
457 | signals["RSI"] = "buy"
458 | elif rsi > 70:
459 | signals["RSI"] = "sell"
460 | else:
461 | signals["RSI"] = "hold"
462 |
463 | # MACD signals
464 | macd = indicator_data.get("MACD", {})
465 | if macd.get("histogram", 0) > 0 and macd.get("signal_cross", "") == "bullish":
466 | signals["MACD"] = "buy"
467 | elif macd.get("histogram", 0) < 0 and macd.get("signal_cross", "") == "bearish":
468 | signals["MACD"] = "sell"
469 | else:
470 | signals["MACD"] = "hold"
471 |
472 | return signals
473 |
474 | def _check_divergences(
475 | self, price_history: dict, indicator_data: dict
476 | ) -> list[dict[str, Any]]:
477 | """Check for price/indicator divergences."""
478 | divergences: list[dict[str, Any]] = []
479 |
480 | # Simplified divergence detection
481 | # In production, would use more sophisticated analysis
482 |
483 | return divergences
484 |
485 | def _calculate_entry_points(self, state: TechnicalAnalysisState) -> list[float]:
486 | """Calculate optimal entry points."""
487 | current_price = state.get("current_price", 0)
488 | support_levels = state.get("support_levels", [])
489 | patterns = state.get("patterns", [])
490 |
491 | entries = []
492 |
493 | # Pattern-based entries
494 | for pattern in patterns:
495 | if pattern.get("entry_price"):
496 | entries.append(pattern["entry_price"])
497 |
498 | # Support-based entries
499 | for support in support_levels:
500 | if support < current_price:
501 | # Entry just above support
502 | entries.append(support * 1.01)
503 |
504 | # Current price entry if momentum
505 | if state.get("trend_direction") == "bullish":
506 | entries.append(current_price)
507 |
508 | return sorted(set(entries))[:3] # Top 3 unique entries
509 |
510 | def _rate_setup_quality(self, state: TechnicalAnalysisState) -> str:
511 | """Rate the quality of the trade setup."""
512 | score = 0
513 |
514 | # Pattern quality
515 | if state.get("patterns"):
516 | max_confidence = max(p.get("confidence", 0) for p in state["patterns"])
517 | if max_confidence > 80:
518 | score += 30
519 | elif max_confidence > 60:
520 | score += 20
521 | else:
522 | score += 10
523 |
524 | # Indicator confluence
525 | signals = state.get("indicator_signals", {})
526 | buy_signals = sum(1 for s in signals.values() if s == "buy")
527 | if buy_signals >= 3:
528 | score += 30
529 | elif buy_signals >= 2:
530 | score += 20
531 | else:
532 | score += 10
533 |
534 | # Risk/Reward
535 | rr = state.get("risk_reward_ratio", 0)
536 | if rr >= 3:
537 | score += 20
538 | elif rr >= 2:
539 | score += 15
540 | else:
541 | score += 5
542 |
543 | # Volume confirmation (would check in real implementation)
544 | score += 10
545 |
546 | # Market alignment (would check in real implementation)
547 | score += 10
548 |
549 | # Convert score to grade
550 | if score >= 85:
551 | return "A+"
552 | elif score >= 75:
553 | return "A"
554 | elif score >= 65:
555 | return "B"
556 | else:
557 | return "C"
558 |
559 | def _calculate_confidence_score(self, state: TechnicalAnalysisState) -> float:
560 | """Calculate overall confidence score for the setup."""
561 | factors = []
562 |
563 | # Pattern confidence
564 | if state.get("pattern_confidence"):
565 | factors.append(max(state["pattern_confidence"].values()) / 100)
566 |
567 | # Indicator agreement
568 | signals = state.get("indicator_signals", {})
569 | if signals:
570 | buy_count = sum(1 for s in signals.values() if s == "buy")
571 | factors.append(buy_count / len(signals))
572 |
573 | # Setup quality
574 | quality_scores = {"A+": 1.0, "A": 0.85, "B": 0.70, "C": 0.50}
575 | factors.append(quality_scores.get(state.get("setup_quality", "C"), 0.5))
576 |
577 | # Average confidence
578 | return round(sum(factors) / len(factors) * 100, 1) if factors else 50.0
579 |
580 | async def analyze_stock(
581 | self,
582 | symbol: str,
583 | timeframe: str = "1d",
584 | indicators: list[str] | None = None,
585 | **kwargs,
586 | ) -> dict[str, Any]:
587 | """
588 | Perform comprehensive technical analysis on a stock.
589 |
590 | Args:
591 | symbol: Stock symbol
592 | timeframe: Chart timeframe
593 | indicators: List of indicators to analyze
594 | **kwargs: Additional parameters
595 |
596 | Returns:
597 | Complete technical analysis with trade setup
598 | """
599 | start_time = datetime.now()
600 |
601 | # Default indicators
602 | if indicators is None:
603 | indicators = ["RSI", "MACD", "BB", "EMA", "VWAP"]
604 |
605 | # Prepare query
606 | query = f"Analyze {symbol} on {timeframe} timeframe with focus on patterns and trade setup"
607 |
608 | # Initial state
609 | initial_state = {
610 | "messages": [HumanMessage(content=query)],
611 | "symbol": symbol,
612 | "timeframe": timeframe,
613 | "indicators": indicators,
614 | "lookback_days": kwargs.get("lookback_days", 20),
615 | "pattern_detection": True,
616 | "multi_timeframe": kwargs.get("multi_timeframe", False),
617 | "persona": self.persona.name,
618 | "session_id": kwargs.get(
619 | "session_id", f"{symbol}_{datetime.now().timestamp()}"
620 | ),
621 | "timestamp": datetime.now(),
622 | "api_calls_made": 0,
623 | }
624 |
625 | # Run analysis
626 | result = await self.ainvoke(
627 | query, initial_state["session_id"], initial_state=initial_state
628 | )
629 |
630 | # Calculate execution time
631 | execution_time = (datetime.now() - start_time).total_seconds() * 1000
632 |
633 | # Extract results
634 | return self._format_analysis_results(result, execution_time)
635 |
636 | def _format_analysis_results(
637 | self, result: dict[str, Any], execution_time: float
638 | ) -> dict[str, Any]:
639 | """Format technical analysis results."""
640 | state = result.get("state", {})
641 | messages = result.get("messages", [])
642 |
643 | return {
644 | "status": "success",
645 | "timestamp": datetime.now().isoformat(),
646 | "execution_time_ms": execution_time,
647 | "symbol": state.get("symbol", ""),
648 | "analysis": {
649 | "market_structure": {
650 | "trend": state.get("trend_direction", "neutral"),
651 | "support_levels": state.get("support_levels", []),
652 | "resistance_levels": state.get("resistance_levels", []),
653 | },
654 | "patterns": {
655 | "detected": state.get("patterns", []),
656 | "confidence": state.get("pattern_confidence", {}),
657 | },
658 | "indicators": {
659 | "values": state.get("indicator_values", {}),
660 | "signals": state.get("indicator_signals", {}),
661 | "divergences": state.get("divergences", []),
662 | },
663 | "trade_setup": {
664 | "entries": state.get("entry_points", []),
665 | "stop_loss": state.get("stop_loss", 0),
666 | "targets": state.get("profit_targets", []),
667 | "risk_reward": state.get("risk_reward_ratio", 0),
668 | "quality": state.get("setup_quality", "C"),
669 | "confidence": state.get("confidence_score", 0),
670 | },
671 | },
672 | "recommendation": messages[-1].content if messages else "",
673 | "persona_adjusted": True,
674 | "risk_profile": self.persona.name,
675 | }
676 |
677 | def _create_mock_tools(self) -> list:
678 | """Create mock tools for testing."""
679 | from langchain_core.tools import tool
680 |
681 | @tool
682 | def mock_technical_indicators(symbol: str, indicators: list[str]) -> dict:
683 | """Mock technical indicators tool."""
684 | return {
685 | "RSI": {"value": 45, "trend": "neutral"},
686 | "MACD": {"histogram": 0.5, "signal_cross": "bullish"},
687 | "BB": {"upper": 150, "middle": 145, "lower": 140},
688 | }
689 |
690 | @tool
691 | def mock_support_resistance(symbol: str) -> dict:
692 | """Mock support/resistance tool."""
693 | return {
694 | "support_levels": [140, 135, 130],
695 | "resistance_levels": [150, 155, 160],
696 | "trend": "bullish",
697 | }
698 |
699 | return [mock_technical_indicators, mock_support_resistance]
700 |
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/ab_testing.py:
--------------------------------------------------------------------------------
```python
1 | """A/B testing framework for comparing ML model performance."""
2 |
3 | import logging
4 | import random
5 | from datetime import datetime
6 | from typing import Any
7 |
8 | import numpy as np
9 | import pandas as pd
10 | from scipy import stats
11 | from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
12 |
13 | from .model_manager import ModelManager
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class ABTestGroup:
19 | """Represents a group in an A/B test."""
20 |
21 | def __init__(
22 | self,
23 | group_id: str,
24 | model_id: str,
25 | model_version: str,
26 | traffic_allocation: float,
27 | description: str = "",
28 | ):
29 | """Initialize A/B test group.
30 |
31 | Args:
32 | group_id: Unique identifier for the group
33 | model_id: Model identifier
34 | model_version: Model version
35 | traffic_allocation: Fraction of traffic allocated to this group (0-1)
36 | description: Description of the group
37 | """
38 | self.group_id = group_id
39 | self.model_id = model_id
40 | self.model_version = model_version
41 | self.traffic_allocation = traffic_allocation
42 | self.description = description
43 | self.created_at = datetime.now()
44 |
45 | # Performance tracking
46 | self.predictions: list[Any] = []
47 | self.actual_values: list[Any] = []
48 | self.prediction_timestamps: list[datetime] = []
49 | self.prediction_confidence: list[float] = []
50 |
51 | def add_prediction(
52 | self,
53 | prediction: Any,
54 | actual: Any,
55 | confidence: float = 1.0,
56 | timestamp: datetime | None = None,
57 | ):
58 | """Add a prediction result to the group.
59 |
60 | Args:
61 | prediction: Model prediction
62 | actual: Actual value
63 | confidence: Prediction confidence score
64 | timestamp: Prediction timestamp
65 | """
66 | self.predictions.append(prediction)
67 | self.actual_values.append(actual)
68 | self.prediction_confidence.append(confidence)
69 | self.prediction_timestamps.append(timestamp or datetime.now())
70 |
71 | def get_metrics(self) -> dict[str, float]:
72 | """Calculate performance metrics for the group.
73 |
74 | Returns:
75 | Dictionary of performance metrics
76 | """
77 | if not self.predictions or not self.actual_values:
78 | return {}
79 |
80 | try:
81 | predictions = np.array(self.predictions)
82 | actuals = np.array(self.actual_values)
83 |
84 | metrics = {
85 | "sample_count": len(predictions),
86 | "accuracy": accuracy_score(actuals, predictions),
87 | "precision": precision_score(
88 | actuals, predictions, average="weighted", zero_division=0
89 | ),
90 | "recall": recall_score(
91 | actuals, predictions, average="weighted", zero_division=0
92 | ),
93 | "f1_score": f1_score(
94 | actuals, predictions, average="weighted", zero_division=0
95 | ),
96 | "avg_confidence": np.mean(self.prediction_confidence),
97 | }
98 |
99 | # Add confusion matrix for binary/multiclass
100 | unique_labels = np.unique(np.concatenate([predictions, actuals]))
101 | if len(unique_labels) <= 10: # Reasonable number of classes
102 | from sklearn.metrics import confusion_matrix
103 |
104 | cm = confusion_matrix(actuals, predictions, labels=unique_labels)
105 | metrics["confusion_matrix"] = cm.tolist()
106 | metrics["unique_labels"] = unique_labels.tolist()
107 |
108 | return metrics
109 |
110 | except Exception as e:
111 | logger.error(f"Error calculating metrics for group {self.group_id}: {e}")
112 | return {"error": str(e)}
113 |
114 | def to_dict(self) -> dict[str, Any]:
115 | """Convert group to dictionary representation."""
116 | return {
117 | "group_id": self.group_id,
118 | "model_id": self.model_id,
119 | "model_version": self.model_version,
120 | "traffic_allocation": self.traffic_allocation,
121 | "description": self.description,
122 | "created_at": self.created_at.isoformat(),
123 | "metrics": self.get_metrics(),
124 | }
125 |
126 |
127 | class ABTest:
128 | """Manages an A/B test between different model versions."""
129 |
130 | def __init__(
131 | self,
132 | test_id: str,
133 | name: str,
134 | description: str = "",
135 | random_seed: int | None = None,
136 | ):
137 | """Initialize A/B test.
138 |
139 | Args:
140 | test_id: Unique identifier for the test
141 | name: Human-readable name for the test
142 | description: Description of the test
143 | random_seed: Random seed for reproducible traffic splitting
144 | """
145 | self.test_id = test_id
146 | self.name = name
147 | self.description = description
148 | self.created_at = datetime.now()
149 | self.started_at: datetime | None = None
150 | self.ended_at: datetime | None = None
151 | self.status = "created" # created, running, completed, cancelled
152 |
153 | # Groups in the test
154 | self.groups: dict[str, ABTestGroup] = {}
155 |
156 | # Traffic allocation
157 | self.traffic_splitter = TrafficSplitter(random_seed)
158 |
159 | # Test configuration
160 | self.min_samples_per_group = 100
161 | self.confidence_level = 0.95
162 | self.minimum_detectable_effect = 0.05
163 |
164 | def add_group(
165 | self,
166 | group_id: str,
167 | model_id: str,
168 | model_version: str,
169 | traffic_allocation: float,
170 | description: str = "",
171 | ) -> bool:
172 | """Add a group to the A/B test.
173 |
174 | Args:
175 | group_id: Unique identifier for the group
176 | model_id: Model identifier
177 | model_version: Model version
178 | traffic_allocation: Fraction of traffic (0-1)
179 | description: Description of the group
180 |
181 | Returns:
182 | True if successful
183 | """
184 | if self.status != "created":
185 | logger.error(
186 | f"Cannot add group to test {self.test_id} - test already started"
187 | )
188 | return False
189 |
190 | if group_id in self.groups:
191 | logger.error(f"Group {group_id} already exists in test {self.test_id}")
192 | return False
193 |
194 | # Validate traffic allocation
195 | current_total = sum(g.traffic_allocation for g in self.groups.values())
196 | if (
197 | current_total + traffic_allocation > 1.0001
198 | ): # Small tolerance for floating point
199 | logger.error(
200 | f"Traffic allocation would exceed 100%: {current_total + traffic_allocation}"
201 | )
202 | return False
203 |
204 | group = ABTestGroup(
205 | group_id=group_id,
206 | model_id=model_id,
207 | model_version=model_version,
208 | traffic_allocation=traffic_allocation,
209 | description=description,
210 | )
211 |
212 | self.groups[group_id] = group
213 | self.traffic_splitter.update_allocation(
214 | {gid: g.traffic_allocation for gid, g in self.groups.items()}
215 | )
216 |
217 | logger.info(f"Added group {group_id} to test {self.test_id}")
218 | return True
219 |
220 | def start_test(self) -> bool:
221 | """Start the A/B test.
222 |
223 | Returns:
224 | True if successful
225 | """
226 | if self.status != "created":
227 | logger.error(
228 | f"Cannot start test {self.test_id} - invalid status: {self.status}"
229 | )
230 | return False
231 |
232 | if len(self.groups) < 2:
233 | logger.error(f"Cannot start test {self.test_id} - need at least 2 groups")
234 | return False
235 |
236 | # Validate traffic allocation sums to approximately 1.0
237 | total_allocation = sum(g.traffic_allocation for g in self.groups.values())
238 | if abs(total_allocation - 1.0) > 0.01:
239 | logger.error(f"Traffic allocation does not sum to 1.0: {total_allocation}")
240 | return False
241 |
242 | self.status = "running"
243 | self.started_at = datetime.now()
244 | logger.info(f"Started A/B test {self.test_id} with {len(self.groups)} groups")
245 | return True
246 |
247 | def assign_traffic(self, user_id: str | None = None) -> str | None:
248 | """Assign traffic to a group.
249 |
250 | Args:
251 | user_id: User identifier for consistent assignment
252 |
253 | Returns:
254 | Group ID or None if test not running
255 | """
256 | if self.status != "running":
257 | return None
258 |
259 | return self.traffic_splitter.assign_group(user_id)
260 |
261 | def record_prediction(
262 | self,
263 | group_id: str,
264 | prediction: Any,
265 | actual: Any,
266 | confidence: float = 1.0,
267 | timestamp: datetime | None = None,
268 | ) -> bool:
269 | """Record a prediction result for a group.
270 |
271 | Args:
272 | group_id: Group identifier
273 | prediction: Model prediction
274 | actual: Actual value
275 | confidence: Prediction confidence
276 | timestamp: Prediction timestamp
277 |
278 | Returns:
279 | True if successful
280 | """
281 | if group_id not in self.groups:
282 | logger.error(f"Group {group_id} not found in test {self.test_id}")
283 | return False
284 |
285 | self.groups[group_id].add_prediction(prediction, actual, confidence, timestamp)
286 | return True
287 |
288 | def get_results(self) -> dict[str, Any]:
289 | """Get current A/B test results.
290 |
291 | Returns:
292 | Dictionary with test results
293 | """
294 | results = {
295 | "test_id": self.test_id,
296 | "name": self.name,
297 | "description": self.description,
298 | "status": self.status,
299 | "created_at": self.created_at.isoformat(),
300 | "started_at": self.started_at.isoformat() if self.started_at else None,
301 | "ended_at": self.ended_at.isoformat() if self.ended_at else None,
302 | "groups": {},
303 | "statistical_analysis": {},
304 | }
305 |
306 | # Group results
307 | for group_id, group in self.groups.items():
308 | results["groups"][group_id] = group.to_dict()
309 |
310 | # Statistical analysis
311 | if len(self.groups) >= 2:
312 | results["statistical_analysis"] = self._perform_statistical_analysis()
313 |
314 | return results
315 |
316 | def _perform_statistical_analysis(self) -> dict[str, Any]:
317 | """Perform statistical analysis of A/B test results.
318 |
319 | Returns:
320 | Statistical analysis results
321 | """
322 | analysis = {
323 | "ready_for_analysis": True,
324 | "sample_size_adequate": True,
325 | "statistical_significance": {},
326 | "effect_sizes": {},
327 | "recommendations": [],
328 | }
329 |
330 | # Check sample sizes
331 | sample_sizes = {
332 | group_id: len(group.predictions) for group_id, group in self.groups.items()
333 | }
334 |
335 | min_samples = min(sample_sizes.values()) if sample_sizes else 0
336 | if min_samples < self.min_samples_per_group:
337 | analysis["ready_for_analysis"] = False
338 | analysis["sample_size_adequate"] = False
339 | analysis["recommendations"].append(
340 | f"Need at least {self.min_samples_per_group} samples per group (current min: {min_samples})"
341 | )
342 |
343 | if not analysis["ready_for_analysis"]:
344 | return analysis
345 |
346 | # Pairwise comparisons
347 | group_ids = list(self.groups.keys())
348 | for i, group_a_id in enumerate(group_ids):
349 | for group_b_id in group_ids[i + 1 :]:
350 | comparison_key = f"{group_a_id}_vs_{group_b_id}"
351 |
352 | try:
353 | group_a = self.groups[group_a_id]
354 | group_b = self.groups[group_b_id]
355 |
356 | # Compare accuracy scores
357 | accuracy_a = accuracy_score(
358 | group_a.actual_values, group_a.predictions
359 | )
360 | accuracy_b = accuracy_score(
361 | group_b.actual_values, group_b.predictions
362 | )
363 |
364 | # Perform statistical test
365 | # For classification accuracy, we can use a proportion test
366 | n_correct_a = sum(
367 | np.array(group_a.predictions) == np.array(group_a.actual_values)
368 | )
369 | n_correct_b = sum(
370 | np.array(group_b.predictions) == np.array(group_b.actual_values)
371 | )
372 | n_total_a = len(group_a.predictions)
373 | n_total_b = len(group_b.predictions)
374 |
375 | # Two-proportion z-test
376 | p_combined = (n_correct_a + n_correct_b) / (n_total_a + n_total_b)
377 | se = np.sqrt(
378 | p_combined * (1 - p_combined) * (1 / n_total_a + 1 / n_total_b)
379 | )
380 |
381 | if se > 0:
382 | z_score = (accuracy_a - accuracy_b) / se
383 | p_value = 2 * (1 - stats.norm.cdf(abs(z_score)))
384 |
385 | # Effect size (Cohen's h for proportions)
386 | h = 2 * (
387 | np.arcsin(np.sqrt(accuracy_a))
388 | - np.arcsin(np.sqrt(accuracy_b))
389 | )
390 |
391 | analysis["statistical_significance"][comparison_key] = {
392 | "accuracy_a": accuracy_a,
393 | "accuracy_b": accuracy_b,
394 | "difference": accuracy_a - accuracy_b,
395 | "z_score": z_score,
396 | "p_value": p_value,
397 | "significant": p_value < (1 - self.confidence_level),
398 | "effect_size_h": h,
399 | }
400 |
401 | # Recommendations based on results
402 | if p_value < (1 - self.confidence_level):
403 | if accuracy_a > accuracy_b:
404 | analysis["recommendations"].append(
405 | f"Group {group_a_id} significantly outperforms {group_b_id} "
406 | f"(p={p_value:.4f}, effect_size={h:.4f})"
407 | )
408 | else:
409 | analysis["recommendations"].append(
410 | f"Group {group_b_id} significantly outperforms {group_a_id} "
411 | f"(p={p_value:.4f}, effect_size={h:.4f})"
412 | )
413 | else:
414 | analysis["recommendations"].append(
415 | f"No significant difference between {group_a_id} and {group_b_id} "
416 | f"(p={p_value:.4f})"
417 | )
418 |
419 | except Exception as e:
420 | logger.error(
421 | f"Error in statistical analysis for {comparison_key}: {e}"
422 | )
423 | analysis["statistical_significance"][comparison_key] = {
424 | "error": str(e)
425 | }
426 |
427 | return analysis
428 |
429 | def stop_test(self, reason: str = "completed") -> bool:
430 | """Stop the A/B test.
431 |
432 | Args:
433 | reason: Reason for stopping
434 |
435 | Returns:
436 | True if successful
437 | """
438 | if self.status != "running":
439 | logger.error(f"Cannot stop test {self.test_id} - not running")
440 | return False
441 |
442 | self.status = "completed" if reason == "completed" else "cancelled"
443 | self.ended_at = datetime.now()
444 | logger.info(f"Stopped A/B test {self.test_id}: {reason}")
445 | return True
446 |
447 | def to_dict(self) -> dict[str, Any]:
448 | """Convert test to dictionary representation."""
449 | return {
450 | "test_id": self.test_id,
451 | "name": self.name,
452 | "description": self.description,
453 | "status": self.status,
454 | "created_at": self.created_at.isoformat(),
455 | "started_at": self.started_at.isoformat() if self.started_at else None,
456 | "ended_at": self.ended_at.isoformat() if self.ended_at else None,
457 | "groups": {gid: g.to_dict() for gid, g in self.groups.items()},
458 | "configuration": {
459 | "min_samples_per_group": self.min_samples_per_group,
460 | "confidence_level": self.confidence_level,
461 | "minimum_detectable_effect": self.minimum_detectable_effect,
462 | },
463 | }
464 |
465 |
466 | class TrafficSplitter:
467 | """Handles traffic splitting for A/B tests."""
468 |
469 | def __init__(self, random_seed: int | None = None):
470 | """Initialize traffic splitter.
471 |
472 | Args:
473 | random_seed: Random seed for reproducible splitting
474 | """
475 | self.random_seed = random_seed
476 | self.group_allocation: dict[str, float] = {}
477 | self.cumulative_allocation: list[tuple[str, float]] = []
478 |
479 | def update_allocation(self, allocation: dict[str, float]):
480 | """Update group traffic allocation.
481 |
482 | Args:
483 | allocation: Dictionary mapping group_id to allocation fraction
484 | """
485 | self.group_allocation = allocation.copy()
486 |
487 | # Create cumulative distribution for sampling
488 | self.cumulative_allocation = []
489 | cumulative = 0.0
490 |
491 | for group_id, fraction in allocation.items():
492 | cumulative += fraction
493 | self.cumulative_allocation.append((group_id, cumulative))
494 |
495 | def assign_group(self, user_id: str | None = None) -> str | None:
496 | """Assign a user to a group.
497 |
498 | Args:
499 | user_id: User identifier for consistent assignment
500 |
501 | Returns:
502 | Group ID or None if no groups configured
503 | """
504 | if not self.cumulative_allocation:
505 | return None
506 |
507 | # Generate random value
508 | if user_id is not None:
509 | # Hash user_id for consistent assignment
510 | import hashlib
511 |
512 | hash_object = hashlib.md5(user_id.encode())
513 | hash_int = int(hash_object.hexdigest(), 16)
514 | rand_value = (hash_int % 10000) / 10000.0 # Normalize to [0, 1)
515 | else:
516 | if self.random_seed is not None:
517 | random.seed(self.random_seed)
518 | rand_value = random.random()
519 |
520 | # Find group based on cumulative allocation
521 | for group_id, cumulative_threshold in self.cumulative_allocation:
522 | if rand_value <= cumulative_threshold:
523 | return group_id
524 |
525 | # Fallback to last group
526 | return self.cumulative_allocation[-1][0] if self.cumulative_allocation else None
527 |
528 |
529 | class ABTestManager:
530 | """Manages multiple A/B tests."""
531 |
532 | def __init__(self, model_manager: ModelManager):
533 | """Initialize A/B test manager.
534 |
535 | Args:
536 | model_manager: Model manager instance
537 | """
538 | self.model_manager = model_manager
539 | self.tests: dict[str, ABTest] = {}
540 |
541 | def create_test(
542 | self,
543 | test_id: str,
544 | name: str,
545 | description: str = "",
546 | random_seed: int | None = None,
547 | ) -> ABTest:
548 | """Create a new A/B test.
549 |
550 | Args:
551 | test_id: Unique identifier for the test
552 | name: Human-readable name
553 | description: Description
554 | random_seed: Random seed for reproducible splitting
555 |
556 | Returns:
557 | ABTest instance
558 | """
559 | if test_id in self.tests:
560 | raise ValueError(f"Test {test_id} already exists")
561 |
562 | test = ABTest(test_id, name, description, random_seed)
563 | self.tests[test_id] = test
564 | logger.info(f"Created A/B test {test_id}: {name}")
565 | return test
566 |
567 | def get_test(self, test_id: str) -> ABTest | None:
568 | """Get an A/B test by ID.
569 |
570 | Args:
571 | test_id: Test identifier
572 |
573 | Returns:
574 | ABTest instance or None
575 | """
576 | return self.tests.get(test_id)
577 |
578 | def list_tests(self, status_filter: str | None = None) -> list[dict[str, Any]]:
579 | """List all A/B tests.
580 |
581 | Args:
582 | status_filter: Filter by status (created, running, completed, cancelled)
583 |
584 | Returns:
585 | List of test summaries
586 | """
587 | tests = []
588 | for test in self.tests.values():
589 | if status_filter is None or test.status == status_filter:
590 | tests.append(
591 | {
592 | "test_id": test.test_id,
593 | "name": test.name,
594 | "status": test.status,
595 | "groups_count": len(test.groups),
596 | "created_at": test.created_at.isoformat(),
597 | "started_at": test.started_at.isoformat()
598 | if test.started_at
599 | else None,
600 | }
601 | )
602 |
603 | # Sort by creation time (newest first)
604 | tests.sort(key=lambda x: x["created_at"], reverse=True)
605 | return tests
606 |
607 | def run_model_comparison(
608 | self,
609 | test_name: str,
610 | model_versions: list[tuple[str, str]], # (model_id, version)
611 | test_data: pd.DataFrame,
612 | feature_extractor: Any,
613 | target_extractor: Any,
614 | traffic_allocation: list[float] | None = None,
615 | test_duration_hours: int = 24,
616 | ) -> str:
617 | """Run a model comparison A/B test.
618 |
619 | Args:
620 | test_name: Name for the test
621 | model_versions: List of (model_id, version) tuples to compare
622 | test_data: Test data for evaluation
623 | feature_extractor: Function to extract features
624 | target_extractor: Function to extract targets
625 | traffic_allocation: Custom traffic allocation (defaults to equal split)
626 | test_duration_hours: Duration to run the test
627 |
628 | Returns:
629 | Test ID
630 | """
631 | # Generate unique test ID
632 | test_id = f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
633 |
634 | # Create test
635 | test = self.create_test(
636 | test_id=test_id,
637 | name=test_name,
638 | description=f"Comparing {len(model_versions)} model versions",
639 | )
640 |
641 | # Default equal traffic allocation
642 | if traffic_allocation is None:
643 | allocation_per_group = 1.0 / len(model_versions)
644 | traffic_allocation = [allocation_per_group] * len(model_versions)
645 |
646 | # Add groups
647 | for i, (model_id, version) in enumerate(model_versions):
648 | group_id = f"group_{i}_{model_id}_{version}"
649 | test.add_group(
650 | group_id=group_id,
651 | model_id=model_id,
652 | model_version=version,
653 | traffic_allocation=traffic_allocation[i],
654 | description=f"Model {model_id} version {version}",
655 | )
656 |
657 | # Start test
658 | test.start_test()
659 |
660 | # Extract features and targets
661 | features = feature_extractor(test_data)
662 | targets = target_extractor(test_data)
663 |
664 | # Simulate predictions for each group
665 | for _, row in features.iterrows():
666 | # Assign traffic
667 | group_id = test.assign_traffic(str(row.name)) # Use row index as user_id
668 | if group_id is None:
669 | continue
670 |
671 | # Get corresponding group's model
672 | group = test.groups[group_id]
673 | model_version = self.model_manager.load_model(
674 | group.model_id, group.model_version
675 | )
676 |
677 | if model_version is None or model_version.model is None:
678 | logger.warning(f"Could not load model for group {group_id}")
679 | continue
680 |
681 | try:
682 | # Make prediction
683 | X = row.values.reshape(1, -1)
684 | if model_version.scaler is not None:
685 | X = model_version.scaler.transform(X)
686 |
687 | prediction = model_version.model.predict(X)[0]
688 |
689 | # Get confidence if available
690 | confidence = 1.0
691 | if hasattr(model_version.model, "predict_proba"):
692 | proba = model_version.model.predict_proba(X)[0]
693 | confidence = max(proba)
694 |
695 | # Get actual value
696 | actual = targets.loc[row.name]
697 |
698 | # Record prediction
699 | test.record_prediction(group_id, prediction, actual, confidence)
700 |
701 | except Exception as e:
702 | logger.warning(f"Error making prediction for group {group_id}: {e}")
703 |
704 | logger.info(f"Completed model comparison test {test_id}")
705 | return test_id
706 |
707 | def get_test_summary(self) -> dict[str, Any]:
708 | """Get summary of all A/B tests.
709 |
710 | Returns:
711 | Summary dictionary
712 | """
713 | total_tests = len(self.tests)
714 | status_counts = {}
715 |
716 | for test in self.tests.values():
717 | status_counts[test.status] = status_counts.get(test.status, 0) + 1
718 |
719 | recent_tests = sorted(
720 | [
721 | {
722 | "test_id": test.test_id,
723 | "name": test.name,
724 | "status": test.status,
725 | "created_at": test.created_at.isoformat(),
726 | }
727 | for test in self.tests.values()
728 | ],
729 | key=lambda x: x["created_at"],
730 | reverse=True,
731 | )[:10]
732 |
733 | return {
734 | "total_tests": total_tests,
735 | "status_distribution": status_counts,
736 | "recent_tests": recent_tests,
737 | }
738 |
```