#
tokens: 45773/50000 6/435 files (page 22/39)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 22 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .dockerignore
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── FUNDING.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.md
│   │   ├── config.yml
│   │   ├── feature_request.md
│   │   ├── question.md
│   │   └── security_report.md
│   ├── pull_request_template.md
│   └── workflows
│       ├── claude-code-review.yml
│       └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│   ├── launch.json
│   └── settings.json
├── alembic
│   ├── env.py
│   ├── script.py.mako
│   └── versions
│       ├── 001_initial_schema.py
│       ├── 003_add_performance_indexes.py
│       ├── 006_rename_metadata_columns.py
│       ├── 008_performance_optimization_indexes.py
│       ├── 009_rename_to_supply_demand.py
│       ├── 010_self_contained_schema.py
│       ├── 011_remove_proprietary_terms.py
│       ├── 013_add_backtest_persistence_models.py
│       ├── 014_add_portfolio_models.py
│       ├── 08e3945a0c93_merge_heads.py
│       ├── 9374a5c9b679_merge_heads_for_testing.py
│       ├── abf9b9afb134_merge_multiple_heads.py
│       ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│       ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│       ├── f0696e2cac15_add_essential_performance_indexes.py
│       └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│   ├── api
│   │   └── backtesting.md
│   ├── BACKTESTING.md
│   ├── COST_BASIS_SPECIFICATION.md
│   ├── deep_research_agent.md
│   ├── exa_research_testing_strategy.md
│   ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│   ├── PORTFOLIO.md
│   ├── SETUP_SELF_CONTAINED.md
│   └── speed_testing_framework.md
├── examples
│   ├── complete_speed_validation.py
│   ├── deep_research_integration.py
│   ├── llm_optimization_example.py
│   ├── llm_speed_demo.py
│   ├── monitoring_example.py
│   ├── parallel_research_example.py
│   ├── speed_optimization_demo.py
│   └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│   ├── __init__.py
│   ├── agents
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── circuit_breaker.py
│   │   ├── deep_research.py
│   │   ├── market_analysis.py
│   │   ├── optimized_research.py
│   │   ├── supervisor.py
│   │   └── technical_analysis.py
│   ├── api
│   │   ├── __init__.py
│   │   ├── api_server.py
│   │   ├── connection_manager.py
│   │   ├── dependencies
│   │   │   ├── __init__.py
│   │   │   ├── stock_analysis.py
│   │   │   └── technical_analysis.py
│   │   ├── error_handling.py
│   │   ├── inspector_compatible_sse.py
│   │   ├── inspector_sse.py
│   │   ├── middleware
│   │   │   ├── error_handling.py
│   │   │   ├── mcp_logging.py
│   │   │   ├── rate_limiting_enhanced.py
│   │   │   └── security.py
│   │   ├── openapi_config.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── agents.py
│   │   │   ├── backtesting.py
│   │   │   ├── data_enhanced.py
│   │   │   ├── data.py
│   │   │   ├── health_enhanced.py
│   │   │   ├── health_tools.py
│   │   │   ├── health.py
│   │   │   ├── intelligent_backtesting.py
│   │   │   ├── introspection.py
│   │   │   ├── mcp_prompts.py
│   │   │   ├── monitoring.py
│   │   │   ├── news_sentiment_enhanced.py
│   │   │   ├── performance.py
│   │   │   ├── portfolio.py
│   │   │   ├── research.py
│   │   │   ├── screening_ddd.py
│   │   │   ├── screening_parallel.py
│   │   │   ├── screening.py
│   │   │   ├── technical_ddd.py
│   │   │   ├── technical_enhanced.py
│   │   │   ├── technical.py
│   │   │   └── tool_registry.py
│   │   ├── server.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── base_service.py
│   │   │   ├── market_service.py
│   │   │   ├── portfolio_service.py
│   │   │   ├── prompt_service.py
│   │   │   └── resource_service.py
│   │   ├── simple_sse.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── insomnia_export.py
│   │       └── postman_export.py
│   ├── application
│   │   ├── __init__.py
│   │   ├── commands
│   │   │   └── __init__.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   └── technical_analysis_dto.py
│   │   ├── queries
│   │   │   ├── __init__.py
│   │   │   └── get_technical_analysis.py
│   │   └── screening
│   │       ├── __init__.py
│   │       ├── dtos.py
│   │       └── queries.py
│   ├── backtesting
│   │   ├── __init__.py
│   │   ├── ab_testing.py
│   │   ├── analysis.py
│   │   ├── batch_processing_stub.py
│   │   ├── batch_processing.py
│   │   ├── model_manager.py
│   │   ├── optimization.py
│   │   ├── persistence.py
│   │   ├── retraining_pipeline.py
│   │   ├── strategies
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── ml
│   │   │   │   ├── __init__.py
│   │   │   │   ├── adaptive.py
│   │   │   │   ├── ensemble.py
│   │   │   │   ├── feature_engineering.py
│   │   │   │   └── regime_aware.py
│   │   │   ├── ml_strategies.py
│   │   │   ├── parser.py
│   │   │   └── templates.py
│   │   ├── strategy_executor.py
│   │   ├── vectorbt_engine.py
│   │   └── visualization.py
│   ├── config
│   │   ├── __init__.py
│   │   ├── constants.py
│   │   ├── database_self_contained.py
│   │   ├── database.py
│   │   ├── llm_optimization_config.py
│   │   ├── logging_settings.py
│   │   ├── plotly_config.py
│   │   ├── security_utils.py
│   │   ├── security.py
│   │   ├── settings.py
│   │   ├── technical_constants.py
│   │   ├── tool_estimation.py
│   │   └── validation.py
│   ├── core
│   │   ├── __init__.py
│   │   ├── technical_analysis.py
│   │   └── visualization.py
│   ├── data
│   │   ├── __init__.py
│   │   ├── cache_manager.py
│   │   ├── cache.py
│   │   ├── django_adapter.py
│   │   ├── health.py
│   │   ├── models.py
│   │   ├── performance.py
│   │   ├── session_management.py
│   │   └── validation.py
│   ├── database
│   │   ├── __init__.py
│   │   ├── base.py
│   │   └── optimization.py
│   ├── dependencies.py
│   ├── domain
│   │   ├── __init__.py
│   │   ├── entities
│   │   │   ├── __init__.py
│   │   │   └── stock_analysis.py
│   │   ├── events
│   │   │   └── __init__.py
│   │   ├── portfolio.py
│   │   ├── screening
│   │   │   ├── __init__.py
│   │   │   ├── entities.py
│   │   │   ├── services.py
│   │   │   └── value_objects.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   └── technical_analysis_service.py
│   │   ├── stock_analysis
│   │   │   ├── __init__.py
│   │   │   └── stock_analysis_service.py
│   │   └── value_objects
│   │       ├── __init__.py
│   │       └── technical_indicators.py
│   ├── exceptions.py
│   ├── infrastructure
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   └── __init__.py
│   │   ├── caching
│   │   │   ├── __init__.py
│   │   │   └── cache_management_service.py
│   │   ├── connection_manager.py
│   │   ├── data_fetching
│   │   │   ├── __init__.py
│   │   │   └── stock_data_service.py
│   │   ├── health
│   │   │   ├── __init__.py
│   │   │   └── health_checker.py
│   │   ├── persistence
│   │   │   ├── __init__.py
│   │   │   └── stock_repository.py
│   │   ├── providers
│   │   │   └── __init__.py
│   │   ├── screening
│   │   │   ├── __init__.py
│   │   │   └── repositories.py
│   │   └── sse_optimizer.py
│   ├── langchain_tools
│   │   ├── __init__.py
│   │   ├── adapters.py
│   │   └── registry.py
│   ├── logging_config.py
│   ├── memory
│   │   ├── __init__.py
│   │   └── stores.py
│   ├── monitoring
│   │   ├── __init__.py
│   │   ├── health_check.py
│   │   ├── health_monitor.py
│   │   ├── integration_example.py
│   │   ├── metrics.py
│   │   ├── middleware.py
│   │   └── status_dashboard.py
│   ├── providers
│   │   ├── __init__.py
│   │   ├── dependencies.py
│   │   ├── factories
│   │   │   ├── __init__.py
│   │   │   ├── config_factory.py
│   │   │   └── provider_factory.py
│   │   ├── implementations
│   │   │   ├── __init__.py
│   │   │   ├── cache_adapter.py
│   │   │   ├── macro_data_adapter.py
│   │   │   ├── market_data_adapter.py
│   │   │   ├── persistence_adapter.py
│   │   │   └── stock_data_adapter.py
│   │   ├── interfaces
│   │   │   ├── __init__.py
│   │   │   ├── cache.py
│   │   │   ├── config.py
│   │   │   ├── macro_data.py
│   │   │   ├── market_data.py
│   │   │   ├── persistence.py
│   │   │   └── stock_data.py
│   │   ├── llm_factory.py
│   │   ├── macro_data.py
│   │   ├── market_data.py
│   │   ├── mocks
│   │   │   ├── __init__.py
│   │   │   ├── mock_cache.py
│   │   │   ├── mock_config.py
│   │   │   ├── mock_macro_data.py
│   │   │   ├── mock_market_data.py
│   │   │   ├── mock_persistence.py
│   │   │   └── mock_stock_data.py
│   │   ├── openrouter_provider.py
│   │   ├── optimized_screening.py
│   │   ├── optimized_stock_data.py
│   │   └── stock_data.py
│   ├── README.md
│   ├── tests
│   │   ├── __init__.py
│   │   ├── README_INMEMORY_TESTS.md
│   │   ├── test_cache_debug.py
│   │   ├── test_fixes_validation.py
│   │   ├── test_in_memory_routers.py
│   │   ├── test_in_memory_server.py
│   │   ├── test_macro_data_provider.py
│   │   ├── test_mailgun_email.py
│   │   ├── test_market_calendar_caching.py
│   │   ├── test_mcp_tool_fixes_pytest.py
│   │   ├── test_mcp_tool_fixes.py
│   │   ├── test_mcp_tools.py
│   │   ├── test_models_functional.py
│   │   ├── test_server.py
│   │   ├── test_stock_data_enhanced.py
│   │   ├── test_stock_data_provider.py
│   │   └── test_technical_analysis.py
│   ├── tools
│   │   ├── __init__.py
│   │   ├── performance_monitoring.py
│   │   ├── portfolio_manager.py
│   │   ├── risk_management.py
│   │   └── sentiment_analysis.py
│   ├── utils
│   │   ├── __init__.py
│   │   ├── agent_errors.py
│   │   ├── batch_processing.py
│   │   ├── cache_warmer.py
│   │   ├── circuit_breaker_decorators.py
│   │   ├── circuit_breaker_services.py
│   │   ├── circuit_breaker.py
│   │   ├── data_chunking.py
│   │   ├── database_monitoring.py
│   │   ├── debug_utils.py
│   │   ├── fallback_strategies.py
│   │   ├── llm_optimization.py
│   │   ├── logging_example.py
│   │   ├── logging_init.py
│   │   ├── logging.py
│   │   ├── mcp_logging.py
│   │   ├── memory_profiler.py
│   │   ├── monitoring_middleware.py
│   │   ├── monitoring.py
│   │   ├── orchestration_logging.py
│   │   ├── parallel_research.py
│   │   ├── parallel_screening.py
│   │   ├── quick_cache.py
│   │   ├── resource_manager.py
│   │   ├── shutdown.py
│   │   ├── stock_helpers.py
│   │   ├── structured_logger.py
│   │   ├── tool_monitoring.py
│   │   ├── tracing.py
│   │   └── yfinance_pool.py
│   ├── validation
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── data.py
│   │   ├── middleware.py
│   │   ├── portfolio.py
│   │   ├── responses.py
│   │   ├── screening.py
│   │   └── technical.py
│   └── workflows
│       ├── __init__.py
│       ├── agents
│       │   ├── __init__.py
│       │   ├── market_analyzer.py
│       │   ├── optimizer_agent.py
│       │   ├── strategy_selector.py
│       │   └── validator_agent.py
│       ├── backtesting_workflow.py
│       └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│   ├── dev.sh
│   ├── INSTALLATION_GUIDE.md
│   ├── load_example.py
│   ├── load_market_data.py
│   ├── load_tiingo_data.py
│   ├── migrate_db.py
│   ├── README_TIINGO_LOADER.md
│   ├── requirements_tiingo.txt
│   ├── run_stock_screening.py
│   ├── run-migrations.sh
│   ├── seed_db.py
│   ├── seed_sp500.py
│   ├── setup_database.sh
│   ├── setup_self_contained.py
│   ├── setup_sp500_database.sh
│   ├── test_seeded_data.py
│   ├── test_tiingo_loader.py
│   ├── tiingo_config.py
│   └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│   ├── __init__.py
│   ├── conftest.py
│   ├── core
│   │   └── test_technical_analysis.py
│   ├── data
│   │   └── test_portfolio_models.py
│   ├── domain
│   │   ├── conftest.py
│   │   ├── test_portfolio_entities.py
│   │   └── test_technical_analysis_service.py
│   ├── fixtures
│   │   └── orchestration_fixtures.py
│   ├── integration
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── README.md
│   │   ├── run_integration_tests.sh
│   │   ├── test_api_technical.py
│   │   ├── test_chaos_engineering.py
│   │   ├── test_config_management.py
│   │   ├── test_full_backtest_workflow_advanced.py
│   │   ├── test_full_backtest_workflow.py
│   │   ├── test_high_volume.py
│   │   ├── test_mcp_tools.py
│   │   ├── test_orchestration_complete.py
│   │   ├── test_portfolio_persistence.py
│   │   ├── test_redis_cache.py
│   │   ├── test_security_integration.py.disabled
│   │   └── vcr_setup.py
│   ├── performance
│   │   ├── __init__.py
│   │   ├── test_benchmarks.py
│   │   ├── test_load.py
│   │   ├── test_profiling.py
│   │   └── test_stress.py
│   ├── providers
│   │   └── test_stock_data_simple.py
│   ├── README.md
│   ├── test_agents_router_mcp.py
│   ├── test_backtest_persistence.py
│   ├── test_cache_management_service.py
│   ├── test_cache_serialization.py
│   ├── test_circuit_breaker.py
│   ├── test_database_pool_config_simple.py
│   ├── test_database_pool_config.py
│   ├── test_deep_research_functional.py
│   ├── test_deep_research_integration.py
│   ├── test_deep_research_parallel_execution.py
│   ├── test_error_handling.py
│   ├── test_event_loop_integrity.py
│   ├── test_exa_research_integration.py
│   ├── test_exception_hierarchy.py
│   ├── test_financial_search.py
│   ├── test_graceful_shutdown.py
│   ├── test_integration_simple.py
│   ├── test_langgraph_workflow.py
│   ├── test_market_data_async.py
│   ├── test_market_data_simple.py
│   ├── test_mcp_orchestration_functional.py
│   ├── test_ml_strategies.py
│   ├── test_optimized_research_agent.py
│   ├── test_orchestration_integration.py
│   ├── test_orchestration_logging.py
│   ├── test_orchestration_tools_simple.py
│   ├── test_parallel_research_integration.py
│   ├── test_parallel_research_orchestrator.py
│   ├── test_parallel_research_performance.py
│   ├── test_performance_optimizations.py
│   ├── test_production_validation.py
│   ├── test_provider_architecture.py
│   ├── test_rate_limiting_enhanced.py
│   ├── test_runner_validation.py
│   ├── test_security_comprehensive.py.disabled
│   ├── test_security_cors.py
│   ├── test_security_enhancements.py.disabled
│   ├── test_security_headers.py
│   ├── test_security_penetration.py
│   ├── test_session_management.py
│   ├── test_speed_optimization_validation.py
│   ├── test_stock_analysis_dependencies.py
│   ├── test_stock_analysis_service.py
│   ├── test_stock_data_fetching_service.py
│   ├── test_supervisor_agent.py
│   ├── test_supervisor_functional.py
│   ├── test_tool_estimation_config.py
│   ├── test_visualization.py
│   └── utils
│       ├── test_agent_errors.py
│       ├── test_logging.py
│       ├── test_parallel_screening.py
│       └── test_quick_cache.py
├── tools
│   ├── check_orchestration_config.py
│   ├── experiments
│   │   ├── validation_examples.py
│   │   └── validation_fixed.py
│   ├── fast_dev.sh
│   ├── hot_reload.py
│   ├── quick_test.py
│   └── templates
│       ├── new_router_template.py
│       ├── new_tool_template.py
│       ├── screening_strategy_template.py
│       └── test_template.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/tests/test_security_cors.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Comprehensive CORS Security Tests for Maverick MCP.
  3 | 
  4 | Tests CORS configuration, validation, origin blocking, wildcard security,
  5 | and environment-specific behaviors.
  6 | """
  7 | 
  8 | import os
  9 | from unittest.mock import MagicMock, patch
 10 | 
 11 | import pytest
 12 | from fastapi import FastAPI
 13 | from fastapi.middleware.cors import CORSMiddleware
 14 | from fastapi.testclient import TestClient
 15 | 
 16 | from maverick_mcp.config.security import (
 17 |     CORSConfig,
 18 |     SecurityConfig,
 19 |     validate_security_config,
 20 | )
 21 | from maverick_mcp.config.security_utils import (
 22 |     apply_cors_to_fastapi,
 23 |     check_security_config,
 24 |     get_safe_cors_config,
 25 | )
 26 | 
 27 | 
 28 | class TestCORSConfiguration:
 29 |     """Test CORS configuration validation and creation."""
 30 | 
 31 |     def test_cors_config_valid_origins(self):
 32 |         """Test CORS config creation with valid origins."""
 33 |         config = CORSConfig(
 34 |             allowed_origins=["https://example.com", "https://app.example.com"],
 35 |             allow_credentials=True,
 36 |         )
 37 | 
 38 |         assert config.allowed_origins == [
 39 |             "https://example.com",
 40 |             "https://app.example.com",
 41 |         ]
 42 |         assert config.allow_credentials is True
 43 | 
 44 |     def test_cors_config_wildcard_with_credentials_raises_error(self):
 45 |         """Test that wildcard origins with credentials raises validation error."""
 46 |         with pytest.raises(
 47 |             ValueError,
 48 |             match="CORS Security Error.*wildcard origin.*serious security vulnerability",
 49 |         ):
 50 |             CORSConfig(allowed_origins=["*"], allow_credentials=True)
 51 | 
 52 |     def test_cors_config_wildcard_without_credentials_warns(self):
 53 |         """Test that wildcard origins without credentials logs warning."""
 54 |         with patch("logging.getLogger") as mock_logger:
 55 |             mock_logger_instance = MagicMock()
 56 |             mock_logger.return_value = mock_logger_instance
 57 | 
 58 |             config = CORSConfig(allowed_origins=["*"], allow_credentials=False)
 59 | 
 60 |             assert config.allowed_origins == ["*"]
 61 |             assert config.allow_credentials is False
 62 |             mock_logger_instance.warning.assert_called_once()
 63 | 
 64 |     def test_cors_config_multiple_origins_with_wildcard_fails(self):
 65 |         """Test that mixed origins including wildcard with credentials fails."""
 66 |         with pytest.raises(ValueError, match="CORS Security Error"):
 67 |             CORSConfig(
 68 |                 allowed_origins=["https://example.com", "*"], allow_credentials=True
 69 |             )
 70 | 
 71 |     def test_cors_config_default_values(self):
 72 |         """Test CORS config default values are secure."""
 73 |         with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=False):
 74 |             with patch(
 75 |                 "maverick_mcp.config.security._get_cors_origins"
 76 |             ) as mock_origins:
 77 |                 mock_origins.return_value = ["http://localhost:3000"]
 78 | 
 79 |                 config = CORSConfig()
 80 | 
 81 |                 assert config.allow_credentials is True
 82 |                 assert "GET" in config.allowed_methods
 83 |                 assert "POST" in config.allowed_methods
 84 |                 assert "Authorization" in config.allowed_headers
 85 |                 assert "Content-Type" in config.allowed_headers
 86 |                 assert config.max_age == 86400
 87 | 
 88 |     def test_cors_config_expose_headers(self):
 89 |         """Test that proper headers are exposed to clients."""
 90 |         config = CORSConfig()
 91 | 
 92 |         expected_exposed = [
 93 |             "X-Process-Time",
 94 |             "X-RateLimit-Limit",
 95 |             "X-RateLimit-Remaining",
 96 |             "X-RateLimit-Reset",
 97 |             "X-Request-ID",
 98 |         ]
 99 | 
100 |         for header in expected_exposed:
101 |             assert header in config.exposed_headers
102 | 
103 | 
104 | class TestCORSEnvironmentConfiguration:
105 |     """Test environment-specific CORS configuration."""
106 | 
107 |     def test_production_cors_origins(self):
108 |         """Test production CORS origins are restrictive."""
109 |         with patch.dict(os.environ, {"ENVIRONMENT": "production"}, clear=True):
110 |             with patch(
111 |                 "maverick_mcp.config.security._get_cors_origins"
112 |             ) as mock_origins:
113 |                 mock_origins.return_value = [
114 |                     "https://app.maverick-mcp.com",
115 |                     "https://maverick-mcp.com",
116 |                 ]
117 | 
118 |                 config = SecurityConfig()
119 | 
120 |                 assert "localhost" not in str(config.cors.allowed_origins).lower()
121 |                 assert "127.0.0.1" not in str(config.cors.allowed_origins).lower()
122 |                 assert all(
123 |                     origin.startswith("https://")
124 |                     for origin in config.cors.allowed_origins
125 |                 )
126 | 
127 |     def test_development_cors_origins(self):
128 |         """Test development CORS origins include localhost."""
129 |         with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=True):
130 |             with patch(
131 |                 "maverick_mcp.config.security._get_cors_origins"
132 |             ) as mock_origins:
133 |                 mock_origins.return_value = [
134 |                     "http://localhost:3000",
135 |                     "http://127.0.0.1:3000",
136 |                 ]
137 | 
138 |                 config = SecurityConfig()
139 | 
140 |                 localhost_found = any(
141 |                     "localhost" in origin for origin in config.cors.allowed_origins
142 |                 )
143 |                 assert localhost_found
144 | 
145 |     def test_staging_cors_origins(self):
146 |         """Test staging CORS origins are appropriate."""
147 |         with patch.dict(os.environ, {"ENVIRONMENT": "staging"}, clear=True):
148 |             with patch(
149 |                 "maverick_mcp.config.security._get_cors_origins"
150 |             ) as mock_origins:
151 |                 mock_origins.return_value = [
152 |                     "https://staging.maverick-mcp.com",
153 |                     "http://localhost:3000",
154 |                 ]
155 | 
156 |                 config = SecurityConfig()
157 | 
158 |                 staging_found = any(
159 |                     "staging" in origin for origin in config.cors.allowed_origins
160 |                 )
161 |                 assert staging_found
162 | 
163 |     def test_custom_cors_origins_from_env(self):
164 |         """Test custom CORS origins from environment variable."""
165 |         custom_origins = "https://custom1.com,https://custom2.com"
166 | 
167 |         with patch.dict(os.environ, {"CORS_ORIGINS": custom_origins}, clear=False):
168 |             with patch(
169 |                 "maverick_mcp.config.security._get_cors_origins"
170 |             ) as mock_origins:
171 |                 mock_origins.return_value = [
172 |                     "https://custom1.com",
173 |                     "https://custom2.com",
174 |                 ]
175 | 
176 |                 config = SecurityConfig()
177 | 
178 |                 assert "https://custom1.com" in config.cors.allowed_origins
179 |                 assert "https://custom2.com" in config.cors.allowed_origins
180 | 
181 | 
182 | class TestCORSValidation:
183 |     """Test CORS security validation."""
184 | 
185 |     def test_validate_security_config_valid_cors(self):
186 |         """Test security validation passes with valid CORS config."""
187 |         with patch("maverick_mcp.config.security.get_security_config") as mock_config:
188 |             mock_security_config = MagicMock()
189 |             mock_security_config.cors.allowed_origins = ["https://example.com"]
190 |             mock_security_config.cors.allow_credentials = True
191 |             mock_security_config.is_production.return_value = False
192 |             mock_security_config.force_https = True
193 |             mock_security_config.headers.x_frame_options = "DENY"
194 |             mock_config.return_value = mock_security_config
195 | 
196 |             result = validate_security_config()
197 | 
198 |             assert result["valid"] is True
199 |             assert len(result["issues"]) == 0
200 | 
201 |     def test_validate_security_config_wildcard_with_credentials(self):
202 |         """Test security validation fails with wildcard + credentials."""
203 |         with patch("maverick_mcp.config.security.get_security_config") as mock_config:
204 |             mock_security_config = MagicMock()
205 |             mock_security_config.cors.allowed_origins = ["*"]
206 |             mock_security_config.cors.allow_credentials = True
207 |             mock_security_config.is_production.return_value = False
208 |             mock_security_config.force_https = True
209 |             mock_security_config.headers.x_frame_options = "DENY"
210 |             mock_config.return_value = mock_security_config
211 | 
212 |             result = validate_security_config()
213 | 
214 |             assert result["valid"] is False
215 |             assert any(
216 |                 "Wildcard CORS origins with credentials enabled" in issue
217 |                 for issue in result["issues"]
218 |             )
219 | 
220 |     def test_validate_security_config_production_wildcards(self):
221 |         """Test security validation fails with wildcards in production."""
222 |         with patch("maverick_mcp.config.security.get_security_config") as mock_config:
223 |             mock_security_config = MagicMock()
224 |             mock_security_config.cors.allowed_origins = ["*"]
225 |             mock_security_config.cors.allow_credentials = False
226 |             mock_security_config.is_production.return_value = True
227 |             mock_security_config.force_https = True
228 |             mock_security_config.headers.x_frame_options = "DENY"
229 |             mock_config.return_value = mock_security_config
230 | 
231 |             result = validate_security_config()
232 | 
233 |             assert result["valid"] is False
234 |             assert any(
235 |                 "Wildcard CORS origins in production" in issue
236 |                 for issue in result["issues"]
237 |             )
238 | 
239 |     def test_validate_security_config_production_localhost_warning(self):
240 |         """Test security validation warns about localhost in production."""
241 |         with patch("maverick_mcp.config.security.get_security_config") as mock_config:
242 |             mock_security_config = MagicMock()
243 |             mock_security_config.cors.allowed_origins = [
244 |                 "https://app.com",
245 |                 "http://localhost:3000",
246 |             ]
247 |             mock_security_config.cors.allow_credentials = True
248 |             mock_security_config.is_production.return_value = True
249 |             mock_security_config.force_https = True
250 |             mock_security_config.headers.x_frame_options = "DENY"
251 |             mock_config.return_value = mock_security_config
252 | 
253 |             result = validate_security_config()
254 | 
255 |             assert result["valid"] is True  # Warning, not error
256 |             assert any("localhost" in warning.lower() for warning in result["warnings"])
257 | 
258 | 
259 | class TestCORSMiddlewareIntegration:
260 |     """Test CORS middleware integration with FastAPI."""
261 | 
262 |     def create_test_app(self, security_config=None):
263 |         """Create a test FastAPI app with CORS applied."""
264 |         app = FastAPI()
265 | 
266 |         if security_config:
267 |             with patch(
268 |                 "maverick_mcp.config.security_utils.get_security_config",
269 |                 return_value=security_config,
270 |             ):
271 |                 apply_cors_to_fastapi(app)
272 |         else:
273 |             apply_cors_to_fastapi(app)
274 | 
275 |         @app.get("/test")
276 |         async def test_endpoint():
277 |             return {"message": "test"}
278 | 
279 |         @app.post("/test")
280 |         async def test_post_endpoint():
281 |             return {"message": "post test"}
282 | 
283 |         return app
284 | 
285 |     def test_cors_middleware_allows_configured_origins(self):
286 |         """Test that CORS middleware allows configured origins."""
287 |         # Create mock security config
288 |         mock_config = MagicMock()
289 |         mock_config.get_cors_middleware_config.return_value = {
290 |             "allow_origins": ["https://allowed.com"],
291 |             "allow_credentials": True,
292 |             "allow_methods": ["GET", "POST"],
293 |             "allow_headers": ["Content-Type", "Authorization"],
294 |             "expose_headers": [],
295 |             "max_age": 86400,
296 |         }
297 | 
298 |         # Mock validation to pass
299 |         with patch(
300 |             "maverick_mcp.config.security_utils.validate_security_config"
301 |         ) as mock_validate:
302 |             mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
303 | 
304 |             app = self.create_test_app(mock_config)
305 |             client = TestClient(app)
306 | 
307 |             # Test preflight request
308 |             response = client.options(
309 |                 "/test",
310 |                 headers={
311 |                     "Origin": "https://allowed.com",
312 |                     "Access-Control-Request-Method": "POST",
313 |                     "Access-Control-Request-Headers": "Content-Type",
314 |                 },
315 |             )
316 | 
317 |             assert response.status_code == 200
318 |             assert (
319 |                 response.headers.get("Access-Control-Allow-Origin")
320 |                 == "https://allowed.com"
321 |             )
322 |             assert "POST" in response.headers.get("Access-Control-Allow-Methods", "")
323 | 
324 |     def test_cors_middleware_blocks_unauthorized_origins(self):
325 |         """Test that CORS middleware blocks unauthorized origins."""
326 |         mock_config = MagicMock()
327 |         mock_config.get_cors_middleware_config.return_value = {
328 |             "allow_origins": ["https://allowed.com"],
329 |             "allow_credentials": True,
330 |             "allow_methods": ["GET", "POST"],
331 |             "allow_headers": ["Content-Type"],
332 |             "expose_headers": [],
333 |             "max_age": 86400,
334 |         }
335 | 
336 |         with patch(
337 |             "maverick_mcp.config.security_utils.validate_security_config"
338 |         ) as mock_validate:
339 |             mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
340 | 
341 |             app = self.create_test_app(mock_config)
342 |             client = TestClient(app)
343 | 
344 |             # Test request from unauthorized origin
345 |             response = client.get(
346 |                 "/test", headers={"Origin": "https://unauthorized.com"}
347 |             )
348 | 
349 |             # The request should succeed (CORS is browser-enforced)
350 |             # but the CORS headers should not allow the unauthorized origin
351 |             assert response.status_code == 200
352 |             cors_origin = response.headers.get("Access-Control-Allow-Origin")
353 |             assert cors_origin != "https://unauthorized.com"
354 | 
355 |     def test_cors_middleware_credentials_handling(self):
356 |         """Test CORS middleware credentials handling."""
357 |         mock_config = MagicMock()
358 |         mock_config.get_cors_middleware_config.return_value = {
359 |             "allow_origins": ["https://allowed.com"],
360 |             "allow_credentials": True,
361 |             "allow_methods": ["GET", "POST"],
362 |             "allow_headers": ["Content-Type"],
363 |             "expose_headers": [],
364 |             "max_age": 86400,
365 |         }
366 | 
367 |         with patch(
368 |             "maverick_mcp.config.security_utils.validate_security_config"
369 |         ) as mock_validate:
370 |             mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
371 | 
372 |             app = self.create_test_app(mock_config)
373 |             client = TestClient(app)
374 | 
375 |             response = client.options(
376 |                 "/test",
377 |                 headers={
378 |                     "Origin": "https://allowed.com",
379 |                     "Access-Control-Request-Method": "POST",
380 |                 },
381 |             )
382 | 
383 |             assert response.headers.get("Access-Control-Allow-Credentials") == "true"
384 | 
385 |     def test_cors_middleware_exposed_headers(self):
386 |         """Test that CORS middleware exposes configured headers."""
387 |         mock_config = MagicMock()
388 |         mock_config.get_cors_middleware_config.return_value = {
389 |             "allow_origins": ["https://allowed.com"],
390 |             "allow_credentials": True,
391 |             "allow_methods": ["GET"],
392 |             "allow_headers": ["Content-Type"],
393 |             "expose_headers": ["X-Custom-Header", "X-Rate-Limit"],
394 |             "max_age": 86400,
395 |         }
396 | 
397 |         with patch(
398 |             "maverick_mcp.config.security_utils.validate_security_config"
399 |         ) as mock_validate:
400 |             mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
401 | 
402 |             app = self.create_test_app(mock_config)
403 |             client = TestClient(app)
404 | 
405 |             response = client.get("/test", headers={"Origin": "https://allowed.com"})
406 | 
407 |             exposed_headers = response.headers.get("Access-Control-Expose-Headers", "")
408 |             assert "X-Custom-Header" in exposed_headers
409 |             assert "X-Rate-Limit" in exposed_headers
410 | 
411 | 
412 | class TestCORSSecurityValidation:
413 |     """Test CORS security validation and safety measures."""
414 | 
415 |     def test_apply_cors_fails_with_invalid_config(self):
416 |         """Test that applying CORS fails with invalid configuration."""
417 |         app = FastAPI()
418 | 
419 |         # Mock invalid configuration
420 |         with patch(
421 |             "maverick_mcp.config.security_utils.validate_security_config"
422 |         ) as mock_validate:
423 |             mock_validate.return_value = {
424 |                 "valid": False,
425 |                 "issues": ["Wildcard CORS origins with credentials"],
426 |                 "warnings": [],
427 |             }
428 | 
429 |             with pytest.raises(ValueError, match="Security configuration is invalid"):
430 |                 apply_cors_to_fastapi(app)
431 | 
432 |     def test_get_safe_cors_config_production_fallback(self):
433 |         """Test safe CORS config fallback for production."""
434 |         with patch(
435 |             "maverick_mcp.config.security_utils.validate_security_config"
436 |         ) as mock_validate:
437 |             mock_validate.return_value = {
438 |                 "valid": False,
439 |                 "issues": ["Invalid config"],
440 |                 "warnings": [],
441 |             }
442 | 
443 |             with patch(
444 |                 "maverick_mcp.config.security_utils.get_security_config"
445 |             ) as mock_config:
446 |                 mock_security_config = MagicMock()
447 |                 mock_security_config.is_production.return_value = True
448 |                 mock_config.return_value = mock_security_config
449 | 
450 |                 safe_config = get_safe_cors_config()
451 | 
452 |                 assert safe_config["allow_origins"] == ["https://maverick-mcp.com"]
453 |                 assert safe_config["allow_credentials"] is True
454 |                 assert "localhost" not in str(safe_config["allow_origins"])
455 | 
456 |     def test_get_safe_cors_config_development_fallback(self):
457 |         """Test safe CORS config fallback for development."""
458 |         with patch(
459 |             "maverick_mcp.config.security_utils.validate_security_config"
460 |         ) as mock_validate:
461 |             mock_validate.return_value = {
462 |                 "valid": False,
463 |                 "issues": ["Invalid config"],
464 |                 "warnings": [],
465 |             }
466 | 
467 |             with patch(
468 |                 "maverick_mcp.config.security_utils.get_security_config"
469 |             ) as mock_config:
470 |                 mock_security_config = MagicMock()
471 |                 mock_security_config.is_production.return_value = False
472 |                 mock_config.return_value = mock_security_config
473 | 
474 |                 safe_config = get_safe_cors_config()
475 | 
476 |                 assert safe_config["allow_origins"] == ["http://localhost:3000"]
477 |                 assert safe_config["allow_credentials"] is True
478 | 
479 |     def test_check_security_config_function(self):
480 |         """Test security config check function."""
481 |         with patch(
482 |             "maverick_mcp.config.security_utils.validate_security_config"
483 |         ) as mock_validate:
484 |             # Test valid config
485 |             mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
486 |             assert check_security_config() is True
487 | 
488 |             # Test invalid config
489 |             mock_validate.return_value = {
490 |                 "valid": False,
491 |                 "issues": ["Error"],
492 |                 "warnings": [],
493 |             }
494 |             assert check_security_config() is False
495 | 
496 | 
497 | class TestCORSPreflightRequests:
498 |     """Test CORS preflight request handling."""
499 | 
500 |     def test_preflight_request_max_age(self):
501 |         """Test CORS preflight max-age header."""
502 |         app = FastAPI()
503 |         app.add_middleware(
504 |             CORSMiddleware,
505 |             allow_origins=["https://example.com"],
506 |             allow_methods=["GET", "POST"],
507 |             allow_headers=["Content-Type"],
508 |             max_age=3600,
509 |         )
510 | 
511 |         @app.get("/test")
512 |         async def test_endpoint():
513 |             return {"message": "test"}
514 | 
515 |         client = TestClient(app)
516 | 
517 |         response = client.options(
518 |             "/test",
519 |             headers={
520 |                 "Origin": "https://example.com",
521 |                 "Access-Control-Request-Method": "GET",
522 |             },
523 |         )
524 | 
525 |         assert response.headers.get("Access-Control-Max-Age") == "3600"
526 | 
527 |     def test_preflight_request_methods(self):
528 |         """Test CORS preflight allowed methods."""
529 |         app = FastAPI()
530 |         app.add_middleware(
531 |             CORSMiddleware,
532 |             allow_origins=["https://example.com"],
533 |             allow_methods=["GET", "POST", "PUT"],
534 |             allow_headers=["Content-Type"],
535 |         )
536 | 
537 |         @app.get("/test")
538 |         async def test_endpoint():
539 |             return {"message": "test"}
540 | 
541 |         client = TestClient(app)
542 | 
543 |         response = client.options(
544 |             "/test",
545 |             headers={
546 |                 "Origin": "https://example.com",
547 |                 "Access-Control-Request-Method": "PUT",
548 |             },
549 |         )
550 | 
551 |         allowed_methods = response.headers.get("Access-Control-Allow-Methods", "")
552 |         assert "PUT" in allowed_methods
553 |         assert "GET" in allowed_methods
554 |         assert "POST" in allowed_methods
555 | 
556 |     def test_preflight_request_headers(self):
557 |         """Test CORS preflight allowed headers."""
558 |         app = FastAPI()
559 |         app.add_middleware(
560 |             CORSMiddleware,
561 |             allow_origins=["https://example.com"],
562 |             allow_methods=["POST"],
563 |             allow_headers=["Content-Type", "Authorization", "X-Custom"],
564 |         )
565 | 
566 |         @app.post("/test")
567 |         async def test_endpoint():
568 |             return {"message": "test"}
569 | 
570 |         client = TestClient(app)
571 | 
572 |         response = client.options(
573 |             "/test",
574 |             headers={
575 |                 "Origin": "https://example.com",
576 |                 "Access-Control-Request-Method": "POST",
577 |                 "Access-Control-Request-Headers": "Content-Type, Authorization",
578 |             },
579 |         )
580 | 
581 |         allowed_headers = response.headers.get("Access-Control-Allow-Headers", "")
582 |         assert "Content-Type" in allowed_headers
583 |         assert "Authorization" in allowed_headers
584 | 
585 | 
586 | class TestCORSEdgeCases:
587 |     """Test CORS edge cases and security scenarios."""
588 | 
589 |     def test_cors_with_vary_header(self):
590 |         """Test that CORS responses include Vary header."""
591 |         app = FastAPI()
592 |         app.add_middleware(
593 |             CORSMiddleware,
594 |             allow_origins=["https://example.com"],
595 |             allow_methods=["GET"],
596 |             allow_headers=["Content-Type"],
597 |         )
598 | 
599 |         @app.get("/test")
600 |         async def test_endpoint():
601 |             return {"message": "test"}
602 | 
603 |         client = TestClient(app)
604 | 
605 |         response = client.get("/test", headers={"Origin": "https://example.com"})
606 | 
607 |         vary_header = response.headers.get("Vary", "")
608 |         assert "Origin" in vary_header
609 | 
610 |     def test_cors_null_origin_handling(self):
611 |         """Test CORS handling of null origin (file:// protocol)."""
612 |         app = FastAPI()
613 |         app.add_middleware(
614 |             CORSMiddleware,
615 |             allow_origins=["null"],  # Sometimes needed for file:// protocol
616 |             allow_methods=["GET"],
617 |             allow_headers=["Content-Type"],
618 |         )
619 | 
620 |         @app.get("/test")
621 |         async def test_endpoint():
622 |             return {"message": "test"}
623 | 
624 |         client = TestClient(app)
625 | 
626 |         response = client.get("/test", headers={"Origin": "null"})
627 | 
628 |         # Should handle null origin appropriately
629 |         assert response.status_code == 200
630 | 
631 |     def test_cors_case_insensitive_origin(self):
632 |         """Test CORS origin matching is case-sensitive (as it should be)."""
633 |         app = FastAPI()
634 |         app.add_middleware(
635 |             CORSMiddleware,
636 |             allow_origins=["https://Example.com"],  # Capital E
637 |             allow_methods=["GET"],
638 |             allow_headers=["Content-Type"],
639 |         )
640 | 
641 |         @app.get("/test")
642 |         async def test_endpoint():
643 |             return {"message": "test"}
644 | 
645 |         client = TestClient(app)
646 | 
647 |         # Test with different case
648 |         response = client.get(
649 |             "/test",
650 |             headers={"Origin": "https://example.com"},  # lowercase e
651 |         )
652 | 
653 |         # Should not match due to case sensitivity
654 |         cors_origin = response.headers.get("Access-Control-Allow-Origin")
655 |         assert cors_origin != "https://example.com"
656 | 
657 | 
658 | if __name__ == "__main__":
659 |     pytest.main([__file__, "-v"])
660 | 
```

--------------------------------------------------------------------------------
/tests/core/test_technical_analysis.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Unit tests for maverick_mcp.core.technical_analysis module.
  3 | 
  4 | This module contains comprehensive tests for all technical analysis functions
  5 | to ensure accurate financial calculations and proper error handling.
  6 | """
  7 | 
  8 | import numpy as np
  9 | import pandas as pd
 10 | import pytest
 11 | 
 12 | from maverick_mcp.core.technical_analysis import (
 13 |     add_technical_indicators,
 14 |     analyze_bollinger_bands,
 15 |     analyze_macd,
 16 |     analyze_rsi,
 17 |     analyze_stochastic,
 18 |     analyze_trend,
 19 |     analyze_volume,
 20 |     calculate_atr,
 21 |     generate_outlook,
 22 |     identify_chart_patterns,
 23 |     identify_resistance_levels,
 24 |     identify_support_levels,
 25 | )
 26 | 
 27 | 
 28 | class TestTechnicalIndicators:
 29 |     """Test the add_technical_indicators function."""
 30 | 
 31 |     def test_add_technical_indicators_basic(self):
 32 |         """Test basic technical indicators calculation."""
 33 |         # Create sample data with enough data points for all indicators
 34 |         dates = pd.date_range("2024-01-01", periods=100, freq="D")
 35 |         data = {
 36 |             "Date": dates,
 37 |             "Open": np.random.uniform(100, 200, 100),
 38 |             "High": np.random.uniform(150, 250, 100),
 39 |             "Low": np.random.uniform(50, 150, 100),
 40 |             "Close": np.random.uniform(100, 200, 100),
 41 |             "Volume": np.random.randint(1000000, 10000000, 100),
 42 |         }
 43 |         df = pd.DataFrame(data)
 44 |         df = df.set_index("Date")
 45 | 
 46 |         # Add some realistic price movement
 47 |         for i in range(1, len(df)):
 48 |             df.loc[df.index[i], "Close"] = df.iloc[i - 1]["Close"] * np.random.uniform(
 49 |                 0.98, 1.02
 50 |             )
 51 |             df.loc[df.index[i], "High"] = max(
 52 |                 df.iloc[i]["Open"], df.iloc[i]["Close"]
 53 |             ) * np.random.uniform(1.0, 1.02)
 54 |             df.loc[df.index[i], "Low"] = min(
 55 |                 df.iloc[i]["Open"], df.iloc[i]["Close"]
 56 |             ) * np.random.uniform(0.98, 1.0)
 57 | 
 58 |         result = add_technical_indicators(df)
 59 | 
 60 |         # Check that all expected indicators are added
 61 |         expected_indicators = [
 62 |             "ema_21",
 63 |             "sma_50",
 64 |             "sma_200",
 65 |             "rsi",
 66 |             "macd_12_26_9",
 67 |             "macds_12_26_9",
 68 |             "macdh_12_26_9",
 69 |             "sma_20",
 70 |             "bbu_20_2.0",
 71 |             "bbl_20_2.0",
 72 |             "stdev",
 73 |             "atr",
 74 |             "stochk_14_3_3",
 75 |             "stochd_14_3_3",
 76 |             "adx_14",
 77 |         ]
 78 | 
 79 |         for indicator in expected_indicators:
 80 |             assert indicator in result.columns
 81 | 
 82 |         # Check that indicators have reasonable values (not all NaN)
 83 |         assert not result["rsi"].iloc[-10:].isna().all()
 84 |         assert not result["ema_21"].iloc[-10:].isna().all()
 85 |         assert not result["sma_50"].iloc[-10:].isna().all()
 86 | 
 87 |     def test_add_technical_indicators_column_case_insensitive(self):
 88 |         """Test that the function handles different column case properly."""
 89 |         data = {
 90 |             "OPEN": [100, 101, 102],
 91 |             "HIGH": [105, 106, 107],
 92 |             "LOW": [95, 96, 97],
 93 |             "CLOSE": [103, 104, 105],
 94 |             "VOLUME": [1000000, 1100000, 1200000],
 95 |         }
 96 |         df = pd.DataFrame(data)
 97 | 
 98 |         result = add_technical_indicators(df)
 99 | 
100 |         # Check that columns are normalized to lowercase
101 |         assert "close" in result.columns
102 |         assert "high" in result.columns
103 |         assert "low" in result.columns
104 | 
105 |     def test_add_technical_indicators_insufficient_data(self):
106 |         """Test behavior with insufficient data."""
107 |         data = {
108 |             "Open": [100],
109 |             "High": [105],
110 |             "Low": [95],
111 |             "Close": [103],
112 |             "Volume": [1000000],
113 |         }
114 |         df = pd.DataFrame(data)
115 | 
116 |         result = add_technical_indicators(df)
117 | 
118 |         # Should handle insufficient data gracefully
119 |         assert "rsi" in result.columns
120 |         assert pd.isna(result["rsi"].iloc[0])  # Should be NaN for insufficient data
121 | 
122 |     def test_add_technical_indicators_empty_dataframe(self):
123 |         """Test behavior with empty dataframe."""
124 |         df = pd.DataFrame()
125 | 
126 |         with pytest.raises(KeyError):
127 |             add_technical_indicators(df)
128 | 
129 |     @pytest.mark.parametrize(
130 |         "bb_columns",
131 |         [
132 |             ("BBM_20_2.0", "BBU_20_2.0", "BBL_20_2.0"),
133 |             ("BBM_20_2", "BBU_20_2", "BBL_20_2"),
134 |         ],
135 |     )
136 |     def test_add_technical_indicators_supports_bbands_column_aliases(
137 |         self, monkeypatch, bb_columns
138 |     ):
139 |         """Ensure Bollinger Band column name variations are handled."""
140 | 
141 |         index = pd.date_range("2024-01-01", periods=40, freq="D")
142 |         base_series = np.linspace(100, 140, len(index))
143 |         data = {
144 |             "open": base_series,
145 |             "high": base_series + 1,
146 |             "low": base_series - 1,
147 |             "close": base_series,
148 |             "volume": np.full(len(index), 1_000_000),
149 |         }
150 |         df = pd.DataFrame(data, index=index)
151 | 
152 |         mid_column, upper_column, lower_column = bb_columns
153 | 
154 |         def fake_bbands(close, *args, **kwargs):
155 |             band_values = pd.Series(base_series, index=close.index)
156 |             return pd.DataFrame(
157 |                 {
158 |                     mid_column: band_values,
159 |                     upper_column: band_values + 2,
160 |                     lower_column: band_values - 2,
161 |                 }
162 |             )
163 | 
164 |         monkeypatch.setattr(
165 |             "maverick_mcp.core.technical_analysis.ta.bbands",
166 |             fake_bbands,
167 |         )
168 | 
169 |         result = add_technical_indicators(df)
170 | 
171 |         np.testing.assert_allclose(result["sma_20"], base_series)
172 |         np.testing.assert_allclose(result["bbu_20_2.0"], base_series + 2)
173 |         np.testing.assert_allclose(result["bbl_20_2.0"], base_series - 2)
174 | 
175 | 
176 | class TestSupportResistanceLevels:
177 |     """Test support and resistance level identification."""
178 | 
179 |     @pytest.fixture
180 |     def sample_data(self):
181 |         """Create sample price data for testing."""
182 |         data = {
183 |             "high": [105, 110, 108, 115, 112, 120, 118, 125, 122, 130] * 5,
184 |             "low": [95, 100, 98, 105, 102, 110, 108, 115, 112, 120] * 5,
185 |             "close": [100, 105, 103, 110, 107, 115, 113, 120, 117, 125] * 5,
186 |         }
187 |         return pd.DataFrame(data)
188 | 
189 |     def test_identify_support_levels(self, sample_data):
190 |         """Test support level identification."""
191 |         support_levels = identify_support_levels(sample_data)
192 | 
193 |         assert isinstance(support_levels, list)
194 |         assert len(support_levels) > 0
195 |         assert all(
196 |             isinstance(level, float | int | np.number) for level in support_levels
197 |         )
198 |         assert support_levels == sorted(support_levels)  # Should be sorted
199 | 
200 |     def test_identify_resistance_levels(self, sample_data):
201 |         """Test resistance level identification."""
202 |         resistance_levels = identify_resistance_levels(sample_data)
203 | 
204 |         assert isinstance(resistance_levels, list)
205 |         assert len(resistance_levels) > 0
206 |         assert all(
207 |             isinstance(level, float | int | np.number) for level in resistance_levels
208 |         )
209 |         assert resistance_levels == sorted(resistance_levels)  # Should be sorted
210 | 
211 |     def test_support_resistance_with_small_dataset(self):
212 |         """Test with dataset smaller than 30 days."""
213 |         data = {
214 |             "high": [105, 110, 108],
215 |             "low": [95, 100, 98],
216 |             "close": [100, 105, 103],
217 |         }
218 |         df = pd.DataFrame(data)
219 | 
220 |         support_levels = identify_support_levels(df)
221 |         resistance_levels = identify_resistance_levels(df)
222 | 
223 |         assert len(support_levels) > 0
224 |         assert len(resistance_levels) > 0
225 | 
226 | 
227 | class TestTrendAnalysis:
228 |     """Test trend analysis functionality."""
229 | 
230 |     @pytest.fixture
231 |     def trending_data(self):
232 |         """Create data with clear upward trend."""
233 |         dates = pd.date_range("2024-01-01", periods=60, freq="D")
234 |         close_prices = np.linspace(100, 150, 60)  # Clear upward trend
235 | 
236 |         data = {
237 |             "close": close_prices,
238 |             "high": close_prices * 1.02,
239 |             "low": close_prices * 0.98,
240 |             "volume": np.random.randint(1000000, 2000000, 60),
241 |         }
242 |         df = pd.DataFrame(data, index=dates)
243 |         return add_technical_indicators(df)
244 | 
245 |     def test_analyze_trend_uptrend(self, trending_data):
246 |         """Test trend analysis with upward trending data."""
247 |         trend_strength = analyze_trend(trending_data)
248 | 
249 |         assert isinstance(trend_strength, int)
250 |         assert 0 <= trend_strength <= 7
251 |         assert trend_strength > 3  # Should detect strong uptrend
252 | 
253 |     def test_analyze_trend_empty_dataframe(self):
254 |         """Test trend analysis with empty dataframe."""
255 |         df = pd.DataFrame({"close": []})
256 | 
257 |         trend_strength = analyze_trend(df)
258 | 
259 |         assert trend_strength == 0
260 | 
261 |     def test_analyze_trend_missing_indicators(self):
262 |         """Test trend analysis with missing indicators."""
263 |         data = {
264 |             "close": [100, 101, 102, 103, 104],
265 |         }
266 |         df = pd.DataFrame(data)
267 | 
268 |         trend_strength = analyze_trend(df)
269 | 
270 |         assert trend_strength == 0  # Should handle missing indicators gracefully
271 | 
272 | 
273 | class TestRSIAnalysis:
274 |     """Test RSI analysis functionality."""
275 | 
276 |     @pytest.fixture
277 |     def rsi_data(self):
278 |         """Create data with RSI indicator."""
279 |         data = {
280 |             "close": [100, 105, 103, 110, 107, 115, 113, 120, 117, 125],
281 |             "rsi": [50, 55, 52, 65, 60, 70, 68, 75, 72, 80],
282 |         }
283 |         return pd.DataFrame(data)
284 | 
285 |     def test_analyze_rsi_overbought(self, rsi_data):
286 |         """Test RSI analysis with overbought conditions."""
287 |         result = analyze_rsi(rsi_data)
288 | 
289 |         assert result["current"] == 80.0
290 |         assert result["signal"] == "overbought"
291 |         assert "overbought" in result["description"]
292 | 
293 |     def test_analyze_rsi_oversold(self):
294 |         """Test RSI analysis with oversold conditions."""
295 |         data = {
296 |             "close": [100, 95, 90, 85, 80],
297 |             "rsi": [50, 40, 30, 25, 20],
298 |         }
299 |         df = pd.DataFrame(data)
300 | 
301 |         result = analyze_rsi(df)
302 | 
303 |         assert result["current"] == 20.0
304 |         assert result["signal"] == "oversold"
305 | 
306 |     def test_analyze_rsi_bullish(self):
307 |         """Test RSI analysis with bullish conditions."""
308 |         data = {
309 |             "close": [100, 105, 110],
310 |             "rsi": [50, 55, 60],
311 |         }
312 |         df = pd.DataFrame(data)
313 | 
314 |         result = analyze_rsi(df)
315 | 
316 |         assert result["current"] == 60.0
317 |         assert result["signal"] == "bullish"
318 | 
319 |     def test_analyze_rsi_bearish(self):
320 |         """Test RSI analysis with bearish conditions."""
321 |         data = {
322 |             "close": [100, 95, 90],
323 |             "rsi": [50, 45, 40],
324 |         }
325 |         df = pd.DataFrame(data)
326 | 
327 |         result = analyze_rsi(df)
328 | 
329 |         assert result["current"] == 40.0
330 |         assert result["signal"] == "bearish"
331 | 
332 |     def test_analyze_rsi_empty_dataframe(self):
333 |         """Test RSI analysis with empty dataframe."""
334 |         df = pd.DataFrame()
335 | 
336 |         result = analyze_rsi(df)
337 | 
338 |         assert result["current"] is None
339 |         assert result["signal"] == "unavailable"
340 | 
341 |     def test_analyze_rsi_missing_column(self):
342 |         """Test RSI analysis without RSI column."""
343 |         data = {"close": [100, 105, 110]}
344 |         df = pd.DataFrame(data)
345 | 
346 |         result = analyze_rsi(df)
347 | 
348 |         assert result["current"] is None
349 |         assert result["signal"] == "unavailable"
350 | 
351 |     def test_analyze_rsi_nan_values(self):
352 |         """Test RSI analysis with NaN values."""
353 |         data = {
354 |             "close": [100, 105, 110],
355 |             "rsi": [50, 55, np.nan],
356 |         }
357 |         df = pd.DataFrame(data)
358 | 
359 |         result = analyze_rsi(df)
360 | 
361 |         assert result["current"] is None
362 |         assert result["signal"] == "unavailable"
363 | 
364 | 
365 | class TestMACDAnalysis:
366 |     """Test MACD analysis functionality."""
367 | 
368 |     @pytest.fixture
369 |     def macd_data(self):
370 |         """Create data with MACD indicators."""
371 |         data = {
372 |             "macd_12_26_9": [1.5, 2.0, 2.5, 3.0, 2.8],
373 |             "macds_12_26_9": [1.0, 1.8, 2.2, 2.7, 3.2],
374 |             "macdh_12_26_9": [0.5, 0.2, 0.3, 0.3, -0.4],
375 |         }
376 |         return pd.DataFrame(data)
377 | 
378 |     def test_analyze_macd_bullish(self, macd_data):
379 |         """Test MACD analysis with bullish signals."""
380 |         result = analyze_macd(macd_data)
381 | 
382 |         assert result["macd"] == 2.8
383 |         assert result["signal"] == 3.2
384 |         assert result["histogram"] == -0.4
385 |         assert result["indicator"] == "bearish"  # macd < signal and histogram < 0
386 | 
387 |     def test_analyze_macd_crossover_detection(self):
388 |         """Test MACD crossover detection."""
389 |         data = {
390 |             "macd_12_26_9": [1.0, 2.0, 3.0],
391 |             "macds_12_26_9": [2.0, 1.8, 2.5],
392 |             "macdh_12_26_9": [-1.0, 0.2, 0.5],
393 |         }
394 |         df = pd.DataFrame(data)
395 | 
396 |         result = analyze_macd(df)
397 | 
398 |         # Check that crossover detection works (test the logic rather than specific result)
399 |         assert "crossover" in result
400 |         assert result["crossover"] in [
401 |             "bullish crossover detected",
402 |             "bearish crossover detected",
403 |             "no recent crossover",
404 |         ]
405 | 
406 |     def test_analyze_macd_missing_data(self):
407 |         """Test MACD analysis with missing data."""
408 |         data = {
409 |             "macd_12_26_9": [np.nan],
410 |             "macds_12_26_9": [np.nan],
411 |             "macdh_12_26_9": [np.nan],
412 |         }
413 |         df = pd.DataFrame(data)
414 | 
415 |         result = analyze_macd(df)
416 | 
417 |         assert result["macd"] is None
418 |         assert result["indicator"] == "unavailable"
419 | 
420 | 
421 | class TestStochasticAnalysis:
422 |     """Test Stochastic Oscillator analysis."""
423 | 
424 |     @pytest.fixture
425 |     def stoch_data(self):
426 |         """Create data with Stochastic indicators."""
427 |         data = {
428 |             "stochk_14_3_3": [20, 30, 40, 50, 60],
429 |             "stochd_14_3_3": [25, 35, 45, 55, 65],
430 |         }
431 |         return pd.DataFrame(data)
432 | 
433 |     def test_analyze_stochastic_bearish(self, stoch_data):
434 |         """Test Stochastic analysis with bearish signal."""
435 |         result = analyze_stochastic(stoch_data)
436 | 
437 |         assert result["k"] == 60.0
438 |         assert result["d"] == 65.0
439 |         assert result["signal"] == "bearish"  # k < d
440 | 
441 |     def test_analyze_stochastic_overbought(self):
442 |         """Test Stochastic analysis with overbought conditions."""
443 |         data = {
444 |             "stochk_14_3_3": [85],
445 |             "stochd_14_3_3": [83],
446 |         }
447 |         df = pd.DataFrame(data)
448 | 
449 |         result = analyze_stochastic(df)
450 | 
451 |         assert result["signal"] == "overbought"
452 | 
453 |     def test_analyze_stochastic_oversold(self):
454 |         """Test Stochastic analysis with oversold conditions."""
455 |         data = {
456 |             "stochk_14_3_3": [15],
457 |             "stochd_14_3_3": [18],
458 |         }
459 |         df = pd.DataFrame(data)
460 | 
461 |         result = analyze_stochastic(df)
462 | 
463 |         assert result["signal"] == "oversold"
464 | 
465 |     def test_analyze_stochastic_crossover(self):
466 |         """Test Stochastic crossover detection."""
467 |         data = {
468 |             "stochk_14_3_3": [30, 45],
469 |             "stochd_14_3_3": [40, 35],
470 |         }
471 |         df = pd.DataFrame(data)
472 | 
473 |         result = analyze_stochastic(df)
474 | 
475 |         assert result["crossover"] == "bullish crossover detected"
476 | 
477 | 
478 | class TestBollingerBands:
479 |     """Test Bollinger Bands analysis."""
480 | 
481 |     @pytest.fixture
482 |     def bb_data(self):
483 |         """Create data with Bollinger Bands."""
484 |         data = {
485 |             "close": [100, 105, 110, 108, 112],
486 |             "bbu_20_2.0": [115, 116, 117, 116, 118],
487 |             "bbl_20_2.0": [85, 86, 87, 86, 88],
488 |             "sma_20": [100, 101, 102, 101, 103],
489 |         }
490 |         return pd.DataFrame(data)
491 | 
492 |     def test_analyze_bollinger_bands_above_middle(self, bb_data):
493 |         """Test Bollinger Bands with price above middle band."""
494 |         result = analyze_bollinger_bands(bb_data)
495 | 
496 |         assert result["upper_band"] == 118.0
497 |         assert result["middle_band"] == 103.0
498 |         assert result["lower_band"] == 88.0
499 |         assert result["position"] == "above middle band"
500 |         assert result["signal"] == "bullish"
501 | 
502 |     def test_analyze_bollinger_bands_above_upper(self):
503 |         """Test Bollinger Bands with price above upper band."""
504 |         data = {
505 |             "close": [120],
506 |             "bbu_20_2.0": [115],
507 |             "bbl_20_2.0": [85],
508 |             "sma_20": [100],
509 |         }
510 |         df = pd.DataFrame(data)
511 | 
512 |         result = analyze_bollinger_bands(df)
513 | 
514 |         assert result["position"] == "above upper band"
515 |         assert result["signal"] == "overbought"
516 | 
517 |     def test_analyze_bollinger_bands_below_lower(self):
518 |         """Test Bollinger Bands with price below lower band."""
519 |         data = {
520 |             "close": [80],
521 |             "bbu_20_2.0": [115],
522 |             "bbl_20_2.0": [85],
523 |             "sma_20": [100],
524 |         }
525 |         df = pd.DataFrame(data)
526 | 
527 |         result = analyze_bollinger_bands(df)
528 | 
529 |         assert result["position"] == "below lower band"
530 |         assert result["signal"] == "oversold"
531 | 
532 |     def test_analyze_bollinger_bands_volatility_calculation(self):
533 |         """Test Bollinger Bands volatility calculation."""
534 |         # Create data with contracting bands
535 |         data = {
536 |             "close": [100, 100, 100, 100, 100],
537 |             "bbu_20_2.0": [110, 108, 106, 104, 102],
538 |             "bbl_20_2.0": [90, 92, 94, 96, 98],
539 |             "sma_20": [100, 100, 100, 100, 100],
540 |         }
541 |         df = pd.DataFrame(data)
542 | 
543 |         result = analyze_bollinger_bands(df)
544 | 
545 |         assert "contracting" in result["volatility"]
546 | 
547 | 
548 | class TestVolumeAnalysis:
549 |     """Test volume analysis functionality."""
550 | 
551 |     @pytest.fixture
552 |     def volume_data(self):
553 |         """Create data with volume information."""
554 |         data = {
555 |             "volume": [1000000, 1100000, 1200000, 1500000, 2000000],
556 |             "close": [100, 101, 102, 105, 108],
557 |         }
558 |         return pd.DataFrame(data)
559 | 
560 |     def test_analyze_volume_high_volume_up_move(self, volume_data):
561 |         """Test volume analysis with high volume on up move."""
562 |         result = analyze_volume(volume_data)
563 | 
564 |         assert result["current"] == 2000000
565 |         assert result["ratio"] >= 1.4  # More lenient threshold
566 |         # Check that volume analysis is working, signal may vary based on exact ratio
567 |         assert result["description"] in ["above average", "average"]
568 |         assert result["signal"] in ["bullish (high volume on up move)", "neutral"]
569 | 
570 |     def test_analyze_volume_low_volume(self):
571 |         """Test volume analysis with low volume."""
572 |         data = {
573 |             "volume": [1000000, 1100000, 1200000, 1300000, 600000],
574 |             "close": [100, 101, 102, 103, 104],
575 |         }
576 |         df = pd.DataFrame(data)
577 | 
578 |         result = analyze_volume(df)
579 | 
580 |         assert result["ratio"] < 0.7
581 |         assert result["description"] == "below average"
582 |         assert result["signal"] == "weak conviction"
583 | 
584 |     def test_analyze_volume_insufficient_data(self):
585 |         """Test volume analysis with insufficient data."""
586 |         data = {
587 |             "volume": [1000000],
588 |             "close": [100],
589 |         }
590 |         df = pd.DataFrame(data)
591 | 
592 |         result = analyze_volume(df)
593 | 
594 |         # Should still work with single data point
595 |         assert result["current"] == 1000000
596 |         assert result["average"] == 1000000
597 |         assert result["ratio"] == 1.0
598 | 
599 |     def test_analyze_volume_invalid_data(self):
600 |         """Test volume analysis with invalid data."""
601 |         data = {
602 |             "volume": [np.nan],
603 |             "close": [100],
604 |         }
605 |         df = pd.DataFrame(data)
606 | 
607 |         result = analyze_volume(df)
608 | 
609 |         assert result["current"] is None
610 |         assert result["signal"] == "unavailable"
611 | 
612 | 
613 | class TestChartPatterns:
614 |     """Test chart pattern identification."""
615 | 
616 |     def test_identify_chart_patterns_double_bottom(self):
617 |         """Test double bottom pattern identification."""
618 |         # Create price data with double bottom pattern
619 |         prices = [100] * 10 + [90] * 5 + [100] * 10 + [90] * 5 + [100] * 10
620 |         data = {
621 |             "low": prices,
622 |             "high": [p + 10 for p in prices],
623 |             "close": [p + 5 for p in prices],
624 |         }
625 |         df = pd.DataFrame(data)
626 | 
627 |         patterns = identify_chart_patterns(df)
628 | 
629 |         # Note: The pattern detection is quite strict, so we just test it runs
630 |         assert isinstance(patterns, list)
631 | 
632 |     def test_identify_chart_patterns_insufficient_data(self):
633 |         """Test chart pattern identification with insufficient data."""
634 |         data = {
635 |             "low": [90, 95, 92],
636 |             "high": [100, 105, 102],
637 |             "close": [95, 100, 97],
638 |         }
639 |         df = pd.DataFrame(data)
640 | 
641 |         patterns = identify_chart_patterns(df)
642 | 
643 |         assert isinstance(patterns, list)
644 |         assert len(patterns) == 0  # Not enough data for patterns
645 | 
646 | 
647 | class TestATRCalculation:
648 |     """Test Average True Range calculation."""
649 | 
650 |     @pytest.fixture
651 |     def atr_data(self):
652 |         """Create data for ATR calculation."""
653 |         data = {
654 |             "High": [105, 110, 108, 115, 112],
655 |             "Low": [95, 100, 98, 105, 102],
656 |             "Close": [100, 105, 103, 110, 107],
657 |         }
658 |         return pd.DataFrame(data)
659 | 
660 |     def test_calculate_atr_basic(self, atr_data):
661 |         """Test basic ATR calculation."""
662 |         result = calculate_atr(atr_data, period=3)
663 | 
664 |         assert isinstance(result, pd.Series)
665 |         assert len(result) == len(atr_data)
666 |         # ATR values should be positive where calculated
667 |         assert (result.dropna() >= 0).all()
668 | 
669 |     def test_calculate_atr_custom_period(self, atr_data):
670 |         """Test ATR calculation with custom period."""
671 |         result = calculate_atr(atr_data, period=2)
672 | 
673 |         assert isinstance(result, pd.Series)
674 |         assert len(result) == len(atr_data)
675 | 
676 |     def test_calculate_atr_insufficient_data(self):
677 |         """Test ATR calculation with insufficient data."""
678 |         data = {
679 |             "High": [105],
680 |             "Low": [95],
681 |             "Close": [100],
682 |         }
683 |         df = pd.DataFrame(data)
684 | 
685 |         result = calculate_atr(df)
686 | 
687 |         assert isinstance(result, pd.Series)
688 |         # Should handle insufficient data gracefully
689 | 
690 | 
691 | class TestOutlookGeneration:
692 |     """Test overall outlook generation."""
693 | 
694 |     def test_generate_outlook_bullish(self):
695 |         """Test outlook generation with bullish signals."""
696 |         df = pd.DataFrame({"close": [100, 105, 110]})
697 |         trend = "uptrend"
698 |         rsi_analysis = {"signal": "bullish"}
699 |         macd_analysis = {
700 |             "indicator": "bullish",
701 |             "crossover": "bullish crossover detected",
702 |         }
703 |         stoch_analysis = {"signal": "bullish"}
704 | 
705 |         outlook = generate_outlook(
706 |             df, trend, rsi_analysis, macd_analysis, stoch_analysis
707 |         )
708 | 
709 |         assert "bullish" in outlook
710 | 
711 |     def test_generate_outlook_bearish(self):
712 |         """Test outlook generation with bearish signals."""
713 |         df = pd.DataFrame({"close": [100, 95, 90]})
714 |         trend = "downtrend"
715 |         rsi_analysis = {"signal": "bearish"}
716 |         macd_analysis = {
717 |             "indicator": "bearish",
718 |             "crossover": "bearish crossover detected",
719 |         }
720 |         stoch_analysis = {"signal": "bearish"}
721 | 
722 |         outlook = generate_outlook(
723 |             df, trend, rsi_analysis, macd_analysis, stoch_analysis
724 |         )
725 | 
726 |         assert "bearish" in outlook
727 | 
728 |     def test_generate_outlook_neutral(self):
729 |         """Test outlook generation with mixed signals."""
730 |         df = pd.DataFrame({"close": [100, 100, 100]})
731 |         trend = "sideways"
732 |         rsi_analysis = {"signal": "neutral"}
733 |         macd_analysis = {"indicator": "neutral", "crossover": "no recent crossover"}
734 |         stoch_analysis = {"signal": "neutral"}
735 | 
736 |         outlook = generate_outlook(
737 |             df, trend, rsi_analysis, macd_analysis, stoch_analysis
738 |         )
739 | 
740 |         assert outlook == "neutral"
741 | 
742 |     def test_generate_outlook_strongly_bullish(self):
743 |         """Test outlook generation with very bullish signals."""
744 |         df = pd.DataFrame({"close": [100, 105, 110]})
745 |         trend = "uptrend"
746 |         rsi_analysis = {"signal": "oversold"}  # Bullish signal
747 |         macd_analysis = {
748 |             "indicator": "bullish",
749 |             "crossover": "bullish crossover detected",
750 |         }
751 |         stoch_analysis = {"signal": "oversold"}  # Bullish signal
752 | 
753 |         outlook = generate_outlook(
754 |             df, trend, rsi_analysis, macd_analysis, stoch_analysis
755 |         )
756 | 
757 |         assert "strongly bullish" in outlook
758 | 
759 | 
760 | if __name__ == "__main__":
761 |     pytest.main([__file__])
762 | 
```

--------------------------------------------------------------------------------
/tests/performance/test_load.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Load Testing for Concurrent Users and Backtest Operations.
  3 | 
  4 | This test suite covers:
  5 | - Concurrent user load testing (10, 50, 100 users)
  6 | - Response time and throughput measurement
  7 | - Memory usage under concurrent load
  8 | - Database performance with multiple connections
  9 | - API rate limiting behavior
 10 | - Queue management and task distribution
 11 | - System stability under sustained load
 12 | """
 13 | 
 14 | import asyncio
 15 | import logging
 16 | import os
 17 | import random
 18 | import statistics
 19 | import time
 20 | from dataclasses import dataclass
 21 | from typing import Any
 22 | from unittest.mock import Mock
 23 | 
 24 | import numpy as np
 25 | import pandas as pd
 26 | import psutil
 27 | import pytest
 28 | 
 29 | from maverick_mcp.backtesting import VectorBTEngine
 30 | from maverick_mcp.backtesting.persistence import BacktestPersistenceManager
 31 | from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
 32 | 
 33 | logger = logging.getLogger(__name__)
 34 | 
 35 | 
 36 | @dataclass
 37 | class LoadTestResult:
 38 |     """Data class for load test results."""
 39 | 
 40 |     concurrent_users: int
 41 |     total_requests: int
 42 |     successful_requests: int
 43 |     failed_requests: int
 44 |     total_duration: float
 45 |     avg_response_time: float
 46 |     min_response_time: float
 47 |     max_response_time: float
 48 |     p50_response_time: float
 49 |     p95_response_time: float
 50 |     p99_response_time: float
 51 |     requests_per_second: float
 52 |     errors_per_second: float
 53 |     memory_usage_mb: float
 54 |     cpu_usage_percent: float
 55 | 
 56 | 
 57 | class LoadTestRunner:
 58 |     """Load test runner with realistic user simulation."""
 59 | 
 60 |     def __init__(self, data_provider):
 61 |         self.data_provider = data_provider
 62 |         self.results = []
 63 |         self.active_requests = 0
 64 | 
 65 |     async def simulate_user_session(
 66 |         self, user_id: int, session_config: dict[str, Any]
 67 |     ) -> dict[str, Any]:
 68 |         """Simulate a realistic user session with multiple backtests."""
 69 |         session_start = time.time()
 70 |         user_results = []
 71 |         response_times = []
 72 | 
 73 |         symbols = session_config.get("symbols", ["AAPL"])
 74 |         strategies = session_config.get("strategies", ["sma_cross"])
 75 |         think_time_range = session_config.get("think_time", (0.5, 2.0))
 76 | 
 77 |         engine = VectorBTEngine(data_provider=self.data_provider)
 78 | 
 79 |         for symbol in symbols:
 80 |             for strategy in strategies:
 81 |                 self.active_requests += 1
 82 |                 request_start = time.time()
 83 | 
 84 |                 try:
 85 |                     parameters = STRATEGY_TEMPLATES.get(strategy, {}).get(
 86 |                         "parameters", {}
 87 |                     )
 88 | 
 89 |                     result = await engine.run_backtest(
 90 |                         symbol=symbol,
 91 |                         strategy_type=strategy,
 92 |                         parameters=parameters,
 93 |                         start_date="2023-01-01",
 94 |                         end_date="2023-12-31",
 95 |                     )
 96 | 
 97 |                     request_time = time.time() - request_start
 98 |                     response_times.append(request_time)
 99 | 
100 |                     user_results.append(
101 |                         {
102 |                             "symbol": symbol,
103 |                             "strategy": strategy,
104 |                             "success": True,
105 |                             "response_time": request_time,
106 |                             "result_size": len(str(result)),
107 |                         }
108 |                     )
109 | 
110 |                 except Exception as e:
111 |                     request_time = time.time() - request_start
112 |                     response_times.append(request_time)
113 | 
114 |                     user_results.append(
115 |                         {
116 |                             "symbol": symbol,
117 |                             "strategy": strategy,
118 |                             "success": False,
119 |                             "response_time": request_time,
120 |                             "error": str(e),
121 |                         }
122 |                     )
123 | 
124 |                 finally:
125 |                     self.active_requests -= 1
126 | 
127 |                 # Simulate think time between requests
128 |                 think_time = random.uniform(*think_time_range)
129 |                 await asyncio.sleep(think_time)
130 | 
131 |         session_time = time.time() - session_start
132 | 
133 |         return {
134 |             "user_id": user_id,
135 |             "session_time": session_time,
136 |             "results": user_results,
137 |             "response_times": response_times,
138 |             "success_count": sum(1 for r in user_results if r["success"]),
139 |             "failure_count": sum(1 for r in user_results if not r["success"]),
140 |         }
141 | 
142 |     def calculate_percentiles(self, response_times: list[float]) -> dict[str, float]:
143 |         """Calculate response time percentiles."""
144 |         if not response_times:
145 |             return {"p50": 0, "p95": 0, "p99": 0}
146 | 
147 |         sorted_times = sorted(response_times)
148 |         return {
149 |             "p50": np.percentile(sorted_times, 50),
150 |             "p95": np.percentile(sorted_times, 95),
151 |             "p99": np.percentile(sorted_times, 99),
152 |         }
153 | 
154 |     async def run_load_test(
155 |         self,
156 |         concurrent_users: int,
157 |         session_config: dict[str, Any],
158 |         duration_seconds: int = 60,
159 |     ) -> LoadTestResult:
160 |         """Run load test with specified concurrent users."""
161 |         logger.info(
162 |             f"Starting load test: {concurrent_users} concurrent users for {duration_seconds}s"
163 |         )
164 | 
165 |         process = psutil.Process(os.getpid())
166 |         initial_memory = process.memory_info().rss / 1024 / 1024  # MB
167 | 
168 |         start_time = time.time()
169 |         all_response_times = []
170 |         all_user_results = []
171 | 
172 |         # Create semaphore to control concurrency
173 |         semaphore = asyncio.Semaphore(concurrent_users)
174 | 
175 |         async def run_user_with_semaphore(user_id: int):
176 |             async with semaphore:
177 |                 return await self.simulate_user_session(user_id, session_config)
178 | 
179 |         # Generate user sessions
180 |         user_tasks = []
181 |         for user_id in range(concurrent_users):
182 |             task = run_user_with_semaphore(user_id)
183 |             user_tasks.append(task)
184 | 
185 |         # Execute all user sessions concurrently
186 |         try:
187 |             user_results = await asyncio.wait_for(
188 |                 asyncio.gather(*user_tasks, return_exceptions=True),
189 |                 timeout=duration_seconds + 30,  # Add buffer to test timeout
190 |             )
191 |         except TimeoutError:
192 |             logger.warning(f"Load test timed out after {duration_seconds + 30}s")
193 |             user_results = []
194 | 
195 |         end_time = time.time()
196 |         actual_duration = end_time - start_time
197 | 
198 |         # Process results
199 |         successful_sessions = []
200 |         failed_sessions = []
201 | 
202 |         for result in user_results:
203 |             if isinstance(result, Exception):
204 |                 failed_sessions.append(str(result))
205 |             elif isinstance(result, dict):
206 |                 successful_sessions.append(result)
207 |                 all_response_times.extend(result.get("response_times", []))
208 |                 all_user_results.extend(result.get("results", []))
209 | 
210 |         # Calculate metrics
211 |         total_requests = len(all_user_results)
212 |         successful_requests = sum(
213 |             1 for r in all_user_results if r.get("success", False)
214 |         )
215 |         failed_requests = total_requests - successful_requests
216 | 
217 |         # Response time statistics
218 |         percentiles = self.calculate_percentiles(all_response_times)
219 |         avg_response_time = (
220 |             statistics.mean(all_response_times) if all_response_times else 0
221 |         )
222 |         min_response_time = min(all_response_times) if all_response_times else 0
223 |         max_response_time = max(all_response_times) if all_response_times else 0
224 | 
225 |         # Throughput metrics
226 |         requests_per_second = (
227 |             total_requests / actual_duration if actual_duration > 0 else 0
228 |         )
229 |         errors_per_second = (
230 |             failed_requests / actual_duration if actual_duration > 0 else 0
231 |         )
232 | 
233 |         # Resource usage
234 |         final_memory = process.memory_info().rss / 1024 / 1024
235 |         memory_usage = final_memory - initial_memory
236 |         cpu_usage = process.cpu_percent()
237 | 
238 |         result = LoadTestResult(
239 |             concurrent_users=concurrent_users,
240 |             total_requests=total_requests,
241 |             successful_requests=successful_requests,
242 |             failed_requests=failed_requests,
243 |             total_duration=actual_duration,
244 |             avg_response_time=avg_response_time,
245 |             min_response_time=min_response_time,
246 |             max_response_time=max_response_time,
247 |             p50_response_time=percentiles["p50"],
248 |             p95_response_time=percentiles["p95"],
249 |             p99_response_time=percentiles["p99"],
250 |             requests_per_second=requests_per_second,
251 |             errors_per_second=errors_per_second,
252 |             memory_usage_mb=memory_usage,
253 |             cpu_usage_percent=cpu_usage,
254 |         )
255 | 
256 |         logger.info(
257 |             f"Load Test Results ({concurrent_users} users):\n"
258 |             f"  • Total Requests: {total_requests}\n"
259 |             f"  • Success Rate: {successful_requests / total_requests * 100:.1f}%\n"
260 |             f"  • Avg Response Time: {avg_response_time:.3f}s\n"
261 |             f"  • 95th Percentile: {percentiles['p95']:.3f}s\n"
262 |             f"  • Throughput: {requests_per_second:.1f} req/s\n"
263 |             f"  • Memory Usage: {memory_usage:.1f}MB\n"
264 |             f"  • Duration: {actual_duration:.1f}s"
265 |         )
266 | 
267 |         return result
268 | 
269 | 
270 | class TestLoadTesting:
271 |     """Load testing suite for concurrent users."""
272 | 
273 |     @pytest.fixture
274 |     async def optimized_data_provider(self):
275 |         """Create optimized data provider for load testing."""
276 |         provider = Mock()
277 | 
278 |         # Pre-generate data for common symbols to reduce computation
279 |         symbol_data_cache = {}
280 | 
281 |         def get_cached_data(symbol: str) -> pd.DataFrame:
282 |             """Get or generate cached data for symbol."""
283 |             if symbol not in symbol_data_cache:
284 |                 # Generate deterministic data based on symbol hash
285 |                 seed = hash(symbol) % 1000
286 |                 np.random.seed(seed)
287 | 
288 |                 dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
289 |                 returns = np.random.normal(0.001, 0.02, len(dates))
290 |                 prices = 100 * np.cumprod(1 + returns)
291 | 
292 |                 symbol_data_cache[symbol] = pd.DataFrame(
293 |                     {
294 |                         "Open": prices * np.random.uniform(0.99, 1.01, len(dates)),
295 |                         "High": prices * np.random.uniform(1.01, 1.03, len(dates)),
296 |                         "Low": prices * np.random.uniform(0.97, 0.99, len(dates)),
297 |                         "Close": prices,
298 |                         "Volume": np.random.randint(1000000, 10000000, len(dates)),
299 |                         "Adj Close": prices,
300 |                     },
301 |                     index=dates,
302 |                 )
303 | 
304 |                 # Ensure OHLC constraints
305 |                 data = symbol_data_cache[symbol]
306 |                 data["High"] = np.maximum(
307 |                     data["High"], np.maximum(data["Open"], data["Close"])
308 |                 )
309 |                 data["Low"] = np.minimum(
310 |                     data["Low"], np.minimum(data["Open"], data["Close"])
311 |                 )
312 | 
313 |             return symbol_data_cache[symbol].copy()
314 | 
315 |         provider.get_stock_data.side_effect = get_cached_data
316 |         return provider
317 | 
318 |     async def test_concurrent_users_10(self, optimized_data_provider, benchmark_timer):
319 |         """Test load with 10 concurrent users."""
320 |         load_runner = LoadTestRunner(optimized_data_provider)
321 | 
322 |         session_config = {
323 |             "symbols": ["AAPL", "GOOGL"],
324 |             "strategies": ["sma_cross", "rsi"],
325 |             "think_time": (0.1, 0.5),  # Faster think time for testing
326 |         }
327 | 
328 |         with benchmark_timer():
329 |             result = await load_runner.run_load_test(
330 |                 concurrent_users=10, session_config=session_config, duration_seconds=30
331 |             )
332 | 
333 |         # Performance assertions for 10 users
334 |         assert result.requests_per_second >= 2.0, (
335 |             f"Throughput too low: {result.requests_per_second:.1f} req/s"
336 |         )
337 |         assert result.avg_response_time <= 5.0, (
338 |             f"Response time too high: {result.avg_response_time:.2f}s"
339 |         )
340 |         assert result.p95_response_time <= 10.0, (
341 |             f"95th percentile too high: {result.p95_response_time:.2f}s"
342 |         )
343 |         assert result.successful_requests / result.total_requests >= 0.9, (
344 |             "Success rate too low"
345 |         )
346 |         assert result.memory_usage_mb <= 500, (
347 |             f"Memory usage too high: {result.memory_usage_mb:.1f}MB"
348 |         )
349 | 
350 |         return result
351 | 
352 |     async def test_concurrent_users_50(self, optimized_data_provider, benchmark_timer):
353 |         """Test load with 50 concurrent users."""
354 |         load_runner = LoadTestRunner(optimized_data_provider)
355 | 
356 |         session_config = {
357 |             "symbols": ["AAPL", "MSFT", "GOOGL"],
358 |             "strategies": ["sma_cross", "rsi", "macd"],
359 |             "think_time": (0.2, 1.0),
360 |         }
361 | 
362 |         with benchmark_timer():
363 |             result = await load_runner.run_load_test(
364 |                 concurrent_users=50, session_config=session_config, duration_seconds=60
365 |             )
366 | 
367 |         # Performance assertions for 50 users
368 |         assert result.requests_per_second >= 5.0, (
369 |             f"Throughput too low: {result.requests_per_second:.1f} req/s"
370 |         )
371 |         assert result.avg_response_time <= 8.0, (
372 |             f"Response time too high: {result.avg_response_time:.2f}s"
373 |         )
374 |         assert result.p95_response_time <= 15.0, (
375 |             f"95th percentile too high: {result.p95_response_time:.2f}s"
376 |         )
377 |         assert result.successful_requests / result.total_requests >= 0.85, (
378 |             "Success rate too low"
379 |         )
380 |         assert result.memory_usage_mb <= 1000, (
381 |             f"Memory usage too high: {result.memory_usage_mb:.1f}MB"
382 |         )
383 | 
384 |         return result
385 | 
386 |     async def test_concurrent_users_100(self, optimized_data_provider, benchmark_timer):
387 |         """Test load with 100 concurrent users."""
388 |         load_runner = LoadTestRunner(optimized_data_provider)
389 | 
390 |         session_config = {
391 |             "symbols": ["AAPL", "MSFT", "GOOGL", "AMZN"],
392 |             "strategies": ["sma_cross", "rsi"],  # Reduced strategies for higher load
393 |             "think_time": (0.5, 1.5),
394 |         }
395 | 
396 |         with benchmark_timer():
397 |             result = await load_runner.run_load_test(
398 |                 concurrent_users=100, session_config=session_config, duration_seconds=90
399 |             )
400 | 
401 |         # More relaxed performance assertions for 100 users
402 |         assert result.requests_per_second >= 3.0, (
403 |             f"Throughput too low: {result.requests_per_second:.1f} req/s"
404 |         )
405 |         assert result.avg_response_time <= 15.0, (
406 |             f"Response time too high: {result.avg_response_time:.2f}s"
407 |         )
408 |         assert result.p95_response_time <= 30.0, (
409 |             f"95th percentile too high: {result.p95_response_time:.2f}s"
410 |         )
411 |         assert result.successful_requests / result.total_requests >= 0.8, (
412 |             "Success rate too low"
413 |         )
414 |         assert result.memory_usage_mb <= 2000, (
415 |             f"Memory usage too high: {result.memory_usage_mb:.1f}MB"
416 |         )
417 | 
418 |         return result
419 | 
420 |     async def test_load_scalability_analysis(self, optimized_data_provider):
421 |         """Analyze how performance scales with user load."""
422 |         load_runner = LoadTestRunner(optimized_data_provider)
423 | 
424 |         session_config = {
425 |             "symbols": ["AAPL", "GOOGL"],
426 |             "strategies": ["sma_cross"],
427 |             "think_time": (0.3, 0.7),
428 |         }
429 | 
430 |         user_loads = [5, 10, 20, 40]
431 |         scalability_results = []
432 | 
433 |         for user_count in user_loads:
434 |             logger.info(f"Testing scalability with {user_count} users")
435 | 
436 |             result = await load_runner.run_load_test(
437 |                 concurrent_users=user_count,
438 |                 session_config=session_config,
439 |                 duration_seconds=30,
440 |             )
441 | 
442 |             scalability_results.append(result)
443 | 
444 |         # Analyze scalability metrics
445 |         throughput_efficiency = []
446 |         response_time_degradation = []
447 | 
448 |         baseline_rps = scalability_results[0].requests_per_second
449 |         baseline_response_time = scalability_results[0].avg_response_time
450 | 
451 |         for i, result in enumerate(scalability_results):
452 |             expected_rps = baseline_rps * user_loads[i] / user_loads[0]
453 |             actual_efficiency = (
454 |                 result.requests_per_second / expected_rps if expected_rps > 0 else 0
455 |             )
456 |             throughput_efficiency.append(actual_efficiency)
457 | 
458 |             response_degradation = (
459 |                 result.avg_response_time / baseline_response_time
460 |                 if baseline_response_time > 0
461 |                 else 1
462 |             )
463 |             response_time_degradation.append(response_degradation)
464 | 
465 |             logger.info(
466 |                 f"Scalability Analysis ({user_loads[i]} users):\n"
467 |                 f"  • RPS: {result.requests_per_second:.2f}\n"
468 |                 f"  • RPS Efficiency: {actual_efficiency:.2%}\n"
469 |                 f"  • Response Time: {result.avg_response_time:.3f}s\n"
470 |                 f"  • Response Degradation: {response_degradation:.2f}x\n"
471 |                 f"  • Memory: {result.memory_usage_mb:.1f}MB"
472 |             )
473 | 
474 |         # Scalability assertions
475 |         avg_efficiency = statistics.mean(throughput_efficiency)
476 |         max_response_degradation = max(response_time_degradation)
477 | 
478 |         assert avg_efficiency >= 0.5, (
479 |             f"Average throughput efficiency too low: {avg_efficiency:.2%}"
480 |         )
481 |         assert max_response_degradation <= 5.0, (
482 |             f"Response time degradation too high: {max_response_degradation:.1f}x"
483 |         )
484 | 
485 |         return {
486 |             "user_loads": user_loads,
487 |             "results": scalability_results,
488 |             "throughput_efficiency": throughput_efficiency,
489 |             "response_time_degradation": response_time_degradation,
490 |             "avg_efficiency": avg_efficiency,
491 |         }
492 | 
493 |     async def test_sustained_load_stability(self, optimized_data_provider):
494 |         """Test stability under sustained load."""
495 |         load_runner = LoadTestRunner(optimized_data_provider)
496 | 
497 |         session_config = {
498 |             "symbols": ["AAPL", "MSFT"],
499 |             "strategies": ["sma_cross", "rsi"],
500 |             "think_time": (0.5, 1.0),
501 |         }
502 | 
503 |         # Run sustained load for longer duration
504 |         result = await load_runner.run_load_test(
505 |             concurrent_users=25,
506 |             session_config=session_config,
507 |             duration_seconds=300,  # 5 minutes
508 |         )
509 | 
510 |         # Stability assertions
511 |         assert result.errors_per_second <= 0.1, (
512 |             f"Error rate too high: {result.errors_per_second:.3f} err/s"
513 |         )
514 |         assert result.successful_requests / result.total_requests >= 0.95, (
515 |             "Success rate degraded over time"
516 |         )
517 |         assert result.memory_usage_mb <= 800, (
518 |             f"Memory usage grew too much: {result.memory_usage_mb:.1f}MB"
519 |         )
520 | 
521 |         # Check for performance consistency (no significant degradation)
522 |         assert result.p99_response_time / result.p50_response_time <= 5.0, (
523 |             "Response time variance too high"
524 |         )
525 | 
526 |         logger.info(
527 |             f"Sustained Load Results (25 users, 5 minutes):\n"
528 |             f"  • Total Requests: {result.total_requests}\n"
529 |             f"  • Success Rate: {result.successful_requests / result.total_requests * 100:.2f}%\n"
530 |             f"  • Avg Throughput: {result.requests_per_second:.2f} req/s\n"
531 |             f"  • Response Time (50/95/99): {result.p50_response_time:.2f}s/"
532 |             f"{result.p95_response_time:.2f}s/{result.p99_response_time:.2f}s\n"
533 |             f"  • Memory Growth: {result.memory_usage_mb:.1f}MB\n"
534 |             f"  • Error Rate: {result.errors_per_second:.4f} err/s"
535 |         )
536 | 
537 |         return result
538 | 
539 |     async def test_database_connection_pooling_under_load(
540 |         self, optimized_data_provider, db_session
541 |     ):
542 |         """Test database connection pooling under concurrent load."""
543 |         # Generate backtest results to save to database
544 |         engine = VectorBTEngine(data_provider=optimized_data_provider)
545 |         test_symbols = ["DB_LOAD_1", "DB_LOAD_2", "DB_LOAD_3"]
546 | 
547 |         # Pre-generate results for database testing
548 |         backtest_results = []
549 |         for symbol in test_symbols:
550 |             result = await engine.run_backtest(
551 |                 symbol=symbol,
552 |                 strategy_type="sma_cross",
553 |                 parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
554 |                 start_date="2023-01-01",
555 |                 end_date="2023-12-31",
556 |             )
557 |             backtest_results.append(result)
558 | 
559 |         # Test concurrent database operations
560 |         async def concurrent_database_operations(operation_id: int) -> dict[str, Any]:
561 |             """Simulate concurrent database save/retrieve operations."""
562 |             start_time = time.time()
563 |             operations_completed = 0
564 |             errors = []
565 | 
566 |             try:
567 |                 with BacktestPersistenceManager(session=db_session) as persistence:
568 |                     # Save operations
569 |                     for result in backtest_results:
570 |                         try:
571 |                             backtest_id = persistence.save_backtest_result(
572 |                                 vectorbt_results=result,
573 |                                 execution_time=2.0,
574 |                                 notes=f"Load test operation {operation_id}",
575 |                             )
576 |                             operations_completed += 1
577 | 
578 |                             # Retrieve operation
579 |                             retrieved = persistence.get_backtest_by_id(backtest_id)
580 |                             if retrieved:
581 |                                 operations_completed += 1
582 | 
583 |                         except Exception as e:
584 |                             errors.append(str(e))
585 | 
586 |             except Exception as e:
587 |                 errors.append(f"Session error: {str(e)}")
588 | 
589 |             operation_time = time.time() - start_time
590 | 
591 |             return {
592 |                 "operation_id": operation_id,
593 |                 "operations_completed": operations_completed,
594 |                 "errors": errors,
595 |                 "operation_time": operation_time,
596 |             }
597 | 
598 |         # Run concurrent database operations
599 |         concurrent_operations = 20
600 |         db_tasks = [
601 |             concurrent_database_operations(i) for i in range(concurrent_operations)
602 |         ]
603 | 
604 |         start_time = time.time()
605 |         db_results = await asyncio.gather(*db_tasks, return_exceptions=True)
606 |         total_time = time.time() - start_time
607 | 
608 |         # Analyze database performance under load
609 |         successful_operations = [r for r in db_results if isinstance(r, dict)]
610 |         failed_operations = len(db_results) - len(successful_operations)
611 | 
612 |         total_operations = sum(r["operations_completed"] for r in successful_operations)
613 |         total_errors = sum(len(r["errors"]) for r in successful_operations)
614 |         avg_operation_time = statistics.mean(
615 |             [r["operation_time"] for r in successful_operations]
616 |         )
617 | 
618 |         db_throughput = total_operations / total_time if total_time > 0 else 0
619 |         error_rate = total_errors / total_operations if total_operations > 0 else 0
620 | 
621 |         logger.info(
622 |             f"Database Load Test Results:\n"
623 |             f"  • Concurrent Operations: {concurrent_operations}\n"
624 |             f"  • Successful Sessions: {len(successful_operations)}\n"
625 |             f"  • Failed Sessions: {failed_operations}\n"
626 |             f"  • Total DB Operations: {total_operations}\n"
627 |             f"  • DB Throughput: {db_throughput:.2f} ops/s\n"
628 |             f"  • Error Rate: {error_rate:.3%}\n"
629 |             f"  • Avg Operation Time: {avg_operation_time:.3f}s"
630 |         )
631 | 
632 |         # Database performance assertions
633 |         assert len(successful_operations) / len(db_results) >= 0.9, (
634 |             "DB session success rate too low"
635 |         )
636 |         assert error_rate <= 0.05, f"DB error rate too high: {error_rate:.3%}"
637 |         assert db_throughput >= 5.0, f"DB throughput too low: {db_throughput:.2f} ops/s"
638 | 
639 |         return {
640 |             "concurrent_operations": concurrent_operations,
641 |             "db_throughput": db_throughput,
642 |             "error_rate": error_rate,
643 |             "avg_operation_time": avg_operation_time,
644 |         }
645 | 
646 | 
647 | if __name__ == "__main__":
648 |     # Run load testing suite
649 |     pytest.main(
650 |         [
651 |             __file__,
652 |             "-v",
653 |             "--tb=short",
654 |             "--asyncio-mode=auto",
655 |             "--timeout=600",  # 10 minute timeout for load tests
656 |             "--durations=10",
657 |         ]
658 |     )
659 | 
```

--------------------------------------------------------------------------------
/maverick_mcp/backtesting/strategies/ml/ensemble.py:
--------------------------------------------------------------------------------

```python
  1 | """Strategy ensemble methods for combining multiple trading strategies."""
  2 | 
  3 | import logging
  4 | from typing import Any
  5 | 
  6 | import numpy as np
  7 | import pandas as pd
  8 | from pandas import DataFrame, Series
  9 | 
 10 | from maverick_mcp.backtesting.strategies.base import Strategy
 11 | 
 12 | logger = logging.getLogger(__name__)
 13 | 
 14 | 
 15 | class StrategyEnsemble(Strategy):
 16 |     """Ensemble strategy that combines multiple strategies with dynamic weighting."""
 17 | 
 18 |     def __init__(
 19 |         self,
 20 |         strategies: list[Strategy],
 21 |         weighting_method: str = "performance",
 22 |         lookback_period: int = 50,
 23 |         rebalance_frequency: int = 20,
 24 |         parameters: dict[str, Any] = None,
 25 |     ):
 26 |         """Initialize strategy ensemble.
 27 | 
 28 |         Args:
 29 |             strategies: List of base strategies to combine
 30 |             weighting_method: Method for calculating weights ('performance', 'equal', 'volatility')
 31 |             lookback_period: Period for calculating performance metrics
 32 |             rebalance_frequency: How often to update weights
 33 |             parameters: Additional parameters
 34 |         """
 35 |         super().__init__(parameters)
 36 |         self.strategies = strategies
 37 |         self.weighting_method = weighting_method
 38 |         self.lookback_period = lookback_period
 39 |         self.rebalance_frequency = rebalance_frequency
 40 | 
 41 |         # Initialize strategy weights
 42 |         self.weights = np.ones(len(strategies)) / len(strategies)
 43 |         self.strategy_returns = {}
 44 |         self.strategy_signals = {}
 45 |         self.last_rebalance = 0
 46 | 
 47 |     @property
 48 |     def name(self) -> str:
 49 |         """Get strategy name."""
 50 |         strategy_names = [s.name for s in self.strategies]
 51 |         return f"Ensemble({','.join(strategy_names)})"
 52 | 
 53 |     @property
 54 |     def description(self) -> str:
 55 |         """Get strategy description."""
 56 |         return f"Dynamic ensemble combining {len(self.strategies)} strategies using {self.weighting_method} weighting"
 57 | 
 58 |     def calculate_performance_weights(self, data: DataFrame) -> np.ndarray:
 59 |         """Calculate performance-based weights for strategies.
 60 | 
 61 |         Args:
 62 |             data: Price data for performance calculation
 63 | 
 64 |         Returns:
 65 |             Array of strategy weights
 66 |         """
 67 |         if len(self.strategy_returns) < 2:
 68 |             return self.weights
 69 | 
 70 |         # Calculate Sharpe ratios for each strategy
 71 |         sharpe_ratios = []
 72 |         for i, _strategy in enumerate(self.strategies):
 73 |             if (
 74 |                 i in self.strategy_returns
 75 |                 and len(self.strategy_returns[i]) >= self.lookback_period
 76 |             ):
 77 |                 returns = pd.Series(self.strategy_returns[i][-self.lookback_period :])
 78 |                 sharpe = returns.mean() / (returns.std() + 1e-8) * np.sqrt(252)
 79 |                 sharpe_ratios.append(max(0, sharpe))  # Ensure non-negative
 80 |             else:
 81 |                 sharpe_ratios.append(0.1)  # Small positive weight for new strategies
 82 | 
 83 |         # Convert to weights (softmax-like normalization)
 84 |         sharpe_array = np.array(sharpe_ratios)
 85 |         # Fix: Properly check for empty array and zero sum conditions
 86 |         if sharpe_array.size == 0 or np.sum(sharpe_array) == 0:
 87 |             weights = np.ones(len(self.strategies)) / len(self.strategies)
 88 |         else:
 89 |             # Exponential weighting to emphasize better performers
 90 |             exp_sharpe = np.exp(sharpe_array * 2)
 91 |             weights = exp_sharpe / exp_sharpe.sum()
 92 | 
 93 |         return weights
 94 | 
 95 |     def calculate_volatility_weights(self, data: DataFrame) -> np.ndarray:
 96 |         """Calculate inverse volatility weights for strategies.
 97 | 
 98 |         Args:
 99 |             data: Price data for volatility calculation
100 | 
101 |         Returns:
102 |             Array of strategy weights
103 |         """
104 |         if len(self.strategy_returns) < 2:
105 |             return self.weights
106 | 
107 |         # Calculate volatilities for each strategy
108 |         volatilities = []
109 |         for i, _strategy in enumerate(self.strategies):
110 |             if (
111 |                 i in self.strategy_returns
112 |                 and len(self.strategy_returns[i]) >= self.lookback_period
113 |             ):
114 |                 returns = pd.Series(self.strategy_returns[i][-self.lookback_period :])
115 |                 vol = returns.std() * np.sqrt(252)
116 |                 volatilities.append(max(0.01, vol))  # Minimum volatility
117 |             else:
118 |                 volatilities.append(0.2)  # Default volatility assumption
119 | 
120 |         # Inverse volatility weighting
121 |         vol_array = np.array(volatilities)
122 |         inv_vol = 1.0 / vol_array
123 |         weights = inv_vol / inv_vol.sum()
124 | 
125 |         return weights
126 | 
127 |     def update_weights(self, data: DataFrame, current_index: int) -> None:
128 |         """Update strategy weights based on recent performance.
129 | 
130 |         Args:
131 |             data: Price data
132 |             current_index: Current position in data
133 |         """
134 |         # Check if it's time to rebalance
135 |         if current_index - self.last_rebalance < self.rebalance_frequency:
136 |             return
137 | 
138 |         try:
139 |             if self.weighting_method == "performance":
140 |                 self.weights = self.calculate_performance_weights(data)
141 |             elif self.weighting_method == "volatility":
142 |                 self.weights = self.calculate_volatility_weights(data)
143 |             elif self.weighting_method == "equal":
144 |                 self.weights = np.ones(len(self.strategies)) / len(self.strategies)
145 |             else:
146 |                 logger.warning(f"Unknown weighting method: {self.weighting_method}")
147 | 
148 |             self.last_rebalance = current_index
149 | 
150 |             logger.debug(
151 |                 f"Updated ensemble weights: {dict(zip([s.name for s in self.strategies], self.weights, strict=False))}"
152 |             )
153 | 
154 |         except Exception as e:
155 |             logger.error(f"Error updating weights: {e}")
156 | 
157 |     def generate_individual_signals(
158 |         self, data: DataFrame
159 |     ) -> dict[int, tuple[Series, Series]]:
160 |         """Generate signals from all individual strategies with enhanced error handling.
161 | 
162 |         Args:
163 |             data: Price data
164 | 
165 |         Returns:
166 |             Dictionary mapping strategy index to (entry_signals, exit_signals)
167 |         """
168 |         signals = {}
169 |         failed_strategies = []
170 | 
171 |         for i, strategy in enumerate(self.strategies):
172 |             try:
173 |                 # Generate signals with timeout protection
174 |                 entry_signals, exit_signals = strategy.generate_signals(data)
175 | 
176 |                 # Validate signals
177 |                 if not isinstance(entry_signals, pd.Series) or not isinstance(
178 |                     exit_signals, pd.Series
179 |                 ):
180 |                     raise ValueError(
181 |                         f"Strategy {strategy.name} returned invalid signal types"
182 |                     )
183 | 
184 |                 if len(entry_signals) != len(data) or len(exit_signals) != len(data):
185 |                     raise ValueError(
186 |                         f"Strategy {strategy.name} returned signals with wrong length"
187 |                     )
188 | 
189 |                 if not entry_signals.dtype == bool or not exit_signals.dtype == bool:
190 |                     # Convert to boolean if necessary
191 |                     entry_signals = entry_signals.astype(bool)
192 |                     exit_signals = exit_signals.astype(bool)
193 | 
194 |                 signals[i] = (entry_signals, exit_signals)
195 | 
196 |                 # Calculate strategy returns for weight updates (with error handling)
197 |                 try:
198 |                     positions = entry_signals.astype(int) - exit_signals.astype(int)
199 |                     price_returns = data["close"].pct_change()
200 |                     returns = positions.shift(1) * price_returns
201 | 
202 |                     # Remove invalid returns
203 |                     valid_returns = returns.dropna()
204 |                     valid_returns = valid_returns[np.isfinite(valid_returns)]
205 | 
206 |                     if i not in self.strategy_returns:
207 |                         self.strategy_returns[i] = []
208 | 
209 |                     if len(valid_returns) > 0:
210 |                         self.strategy_returns[i].extend(valid_returns.tolist())
211 | 
212 |                         # Keep only recent returns for performance calculation
213 |                         if len(self.strategy_returns[i]) > self.lookback_period * 2:
214 |                             self.strategy_returns[i] = self.strategy_returns[i][
215 |                                 -self.lookback_period * 2 :
216 |                             ]
217 | 
218 |                 except Exception as return_error:
219 |                     logger.debug(
220 |                         f"Error calculating returns for strategy {strategy.name}: {return_error}"
221 |                     )
222 | 
223 |                 logger.debug(
224 |                     f"Strategy {strategy.name}: {entry_signals.sum()} entries, {exit_signals.sum()} exits"
225 |                 )
226 | 
227 |             except Exception as e:
228 |                 logger.error(
229 |                     f"Error generating signals for strategy {strategy.name}: {e}"
230 |                 )
231 |                 failed_strategies.append(i)
232 | 
233 |                 # Create safe fallback signals
234 |                 try:
235 |                     signals[i] = (
236 |                         pd.Series(False, index=data.index),
237 |                         pd.Series(False, index=data.index),
238 |                     )
239 |                 except Exception:
240 |                     # If even creating empty signals fails, skip this strategy
241 |                     logger.error(f"Cannot create fallback signals for strategy {i}")
242 |                     continue
243 | 
244 |         # Log summary of strategy performance
245 |         if failed_strategies:
246 |             failed_names = [self.strategies[i].name for i in failed_strategies]
247 |             logger.warning(f"Failed strategies: {failed_names}")
248 | 
249 |         successful_strategies = len(signals) - len(failed_strategies)
250 |         logger.info(
251 |             f"Successfully generated signals from {successful_strategies}/{len(self.strategies)} strategies"
252 |         )
253 | 
254 |         return signals
255 | 
256 |     def combine_signals(
257 |         self, individual_signals: dict[int, tuple[Series, Series]]
258 |     ) -> tuple[Series, Series]:
259 |         """Combine individual strategy signals using enhanced weighted voting.
260 | 
261 |         Args:
262 |             individual_signals: Dictionary of individual strategy signals
263 | 
264 |         Returns:
265 |             Tuple of combined (entry_signals, exit_signals)
266 |         """
267 |         if not individual_signals:
268 |             # Return empty series with minimal index when no individual signals available
269 |             empty_index = pd.Index([])
270 |             return pd.Series(False, index=empty_index), pd.Series(
271 |                 False, index=empty_index
272 |             )
273 | 
274 |         # Get data index from first strategy
275 |         first_signals = next(iter(individual_signals.values()))
276 |         data_index = first_signals[0].index
277 | 
278 |         # Initialize voting arrays
279 |         entry_votes = np.zeros(len(data_index))
280 |         exit_votes = np.zeros(len(data_index))
281 |         total_weights = 0
282 | 
283 |         # Collect votes with weights and confidence scores
284 |         valid_strategies = 0
285 | 
286 |         for i, (entry_signals, exit_signals) in individual_signals.items():
287 |             weight = self.weights[i] if i < len(self.weights) else 0
288 | 
289 |             if weight > 0:
290 |                 # Add weighted votes
291 |                 entry_votes += weight * entry_signals.astype(float)
292 |                 exit_votes += weight * exit_signals.astype(float)
293 |                 total_weights += weight
294 |                 valid_strategies += 1
295 | 
296 |         if total_weights == 0 or valid_strategies == 0:
297 |             logger.warning("No valid strategies with positive weights")
298 |             return pd.Series(False, index=data_index), pd.Series(
299 |                 False, index=data_index
300 |             )
301 | 
302 |         # Normalize votes by total weights
303 |         entry_votes = entry_votes / total_weights
304 |         exit_votes = exit_votes / total_weights
305 | 
306 |         # Enhanced voting mechanisms
307 |         voting_method = self.parameters.get("voting_method", "weighted")
308 | 
309 |         if voting_method == "majority":
310 |             # Simple majority vote (more than half of strategies agree)
311 |             entry_threshold = 0.5
312 |             exit_threshold = 0.5
313 |         elif voting_method == "supermajority":
314 |             # Require 2/3 agreement
315 |             entry_threshold = 0.67
316 |             exit_threshold = 0.67
317 |         elif voting_method == "consensus":
318 |             # Require near-unanimous agreement
319 |             entry_threshold = 0.8
320 |             exit_threshold = 0.8
321 |         else:  # weighted (default)
322 |             entry_threshold = self.parameters.get("entry_threshold", 0.5)
323 |             exit_threshold = self.parameters.get("exit_threshold", 0.5)
324 | 
325 |         # Anti-conflict mechanism: don't signal entry and exit simultaneously
326 |         combined_entry = entry_votes > entry_threshold
327 |         combined_exit = exit_votes > exit_threshold
328 | 
329 |         # Resolve conflicts (simultaneous entry and exit signals)
330 |         conflicts = combined_entry & combined_exit
331 |         # Fix: Check array size and ensure it's not empty before evaluating boolean truth
332 |         if conflicts.size > 0 and np.any(conflicts):
333 |             logger.debug(f"Resolving {conflicts.sum()} signal conflicts")
334 |             # In case of conflict, use the stronger signal
335 |             entry_strength = entry_votes[conflicts]
336 |             exit_strength = exit_votes[conflicts]
337 | 
338 |             # Keep only the stronger signal
339 |             stronger_entry = entry_strength > exit_strength
340 |             combined_entry[conflicts] = stronger_entry
341 |             combined_exit[conflicts] = ~stronger_entry
342 | 
343 |         # Quality filter: require minimum signal strength
344 |         min_signal_strength = self.parameters.get("min_signal_strength", 0.1)
345 |         weak_entry_signals = (combined_entry) & (entry_votes < min_signal_strength)
346 |         weak_exit_signals = (combined_exit) & (exit_votes < min_signal_strength)
347 | 
348 |         # Fix: Ensure arrays are not empty before boolean indexing
349 |         if weak_entry_signals.size > 0:
350 |             combined_entry[weak_entry_signals] = False
351 |         if weak_exit_signals.size > 0:
352 |             combined_exit[weak_exit_signals] = False
353 | 
354 |         # Convert to pandas Series
355 |         combined_entry = pd.Series(combined_entry, index=data_index)
356 |         combined_exit = pd.Series(combined_exit, index=data_index)
357 | 
358 |         return combined_entry, combined_exit
359 | 
360 |     def generate_signals(self, data: DataFrame) -> tuple[Series, Series]:
361 |         """Generate ensemble trading signals.
362 | 
363 |         Args:
364 |             data: Price data with OHLCV columns
365 | 
366 |         Returns:
367 |             Tuple of (entry_signals, exit_signals) as boolean Series
368 |         """
369 |         try:
370 |             # Generate signals from all individual strategies
371 |             individual_signals = self.generate_individual_signals(data)
372 | 
373 |             if not individual_signals:
374 |                 return pd.Series(False, index=data.index), pd.Series(
375 |                     False, index=data.index
376 |                 )
377 | 
378 |             # Update weights periodically
379 |             for idx in range(
380 |                 self.rebalance_frequency, len(data), self.rebalance_frequency
381 |             ):
382 |                 self.update_weights(data.iloc[:idx], idx)
383 | 
384 |             # Combine signals
385 |             entry_signals, exit_signals = self.combine_signals(individual_signals)
386 | 
387 |             logger.info(
388 |                 f"Generated ensemble signals: {entry_signals.sum()} entries, {exit_signals.sum()} exits"
389 |             )
390 | 
391 |             return entry_signals, exit_signals
392 | 
393 |         except Exception as e:
394 |             logger.error(f"Error generating ensemble signals: {e}")
395 |             return pd.Series(False, index=data.index), pd.Series(
396 |                 False, index=data.index
397 |             )
398 | 
399 |     def get_strategy_weights(self) -> dict[str, float]:
400 |         """Get current strategy weights.
401 | 
402 |         Returns:
403 |             Dictionary mapping strategy names to weights
404 |         """
405 |         return dict(zip([s.name for s in self.strategies], self.weights, strict=False))
406 | 
407 |     def get_strategy_performance(self) -> dict[str, dict[str, float]]:
408 |         """Get performance metrics for individual strategies.
409 | 
410 |         Returns:
411 |             Dictionary mapping strategy names to performance metrics
412 |         """
413 |         performance = {}
414 | 
415 |         for i, strategy in enumerate(self.strategies):
416 |             if i in self.strategy_returns and len(self.strategy_returns[i]) > 0:
417 |                 returns = pd.Series(self.strategy_returns[i])
418 | 
419 |                 performance[strategy.name] = {
420 |                     "total_return": returns.sum(),
421 |                     "annual_return": returns.mean() * 252,
422 |                     "volatility": returns.std() * np.sqrt(252),
423 |                     "sharpe_ratio": returns.mean()
424 |                     / (returns.std() + 1e-8)
425 |                     * np.sqrt(252),
426 |                     "max_drawdown": (
427 |                         returns.cumsum() - returns.cumsum().expanding().max()
428 |                     ).min(),
429 |                     "win_rate": (returns > 0).mean(),
430 |                     "current_weight": self.weights[i],
431 |                 }
432 |             else:
433 |                 performance[strategy.name] = {
434 |                     "total_return": 0.0,
435 |                     "annual_return": 0.0,
436 |                     "volatility": 0.0,
437 |                     "sharpe_ratio": 0.0,
438 |                     "max_drawdown": 0.0,
439 |                     "win_rate": 0.0,
440 |                     "current_weight": self.weights[i] if i < len(self.weights) else 0.0,
441 |                 }
442 | 
443 |         return performance
444 | 
445 |     def validate_parameters(self) -> bool:
446 |         """Validate ensemble parameters.
447 | 
448 |         Returns:
449 |             True if parameters are valid
450 |         """
451 |         if not self.strategies:
452 |             return False
453 | 
454 |         if self.weighting_method not in ["performance", "equal", "volatility"]:
455 |             return False
456 | 
457 |         if self.lookback_period <= 0 or self.rebalance_frequency <= 0:
458 |             return False
459 | 
460 |         # Validate individual strategies
461 |         for strategy in self.strategies:
462 |             if not strategy.validate_parameters():
463 |                 return False
464 | 
465 |         return True
466 | 
467 |     def get_default_parameters(self) -> dict[str, Any]:
468 |         """Get default ensemble parameters.
469 | 
470 |         Returns:
471 |             Dictionary of default parameters
472 |         """
473 |         return {
474 |             "weighting_method": "performance",
475 |             "lookback_period": 50,
476 |             "rebalance_frequency": 20,
477 |             "entry_threshold": 0.5,
478 |             "exit_threshold": 0.5,
479 |             "voting_method": "weighted",  # weighted, majority, supermajority, consensus
480 |             "min_signal_strength": 0.1,  # Minimum signal strength to avoid weak signals
481 |             "conflict_resolution": "stronger",  # How to resolve entry/exit conflicts
482 |         }
483 | 
484 |     def to_dict(self) -> dict[str, Any]:
485 |         """Convert ensemble to dictionary representation.
486 | 
487 |         Returns:
488 |             Dictionary with ensemble details
489 |         """
490 |         base_dict = super().to_dict()
491 |         base_dict.update(
492 |             {
493 |                 "strategies": [s.to_dict() for s in self.strategies],
494 |                 "current_weights": self.get_strategy_weights(),
495 |                 "weighting_method": self.weighting_method,
496 |                 "lookback_period": self.lookback_period,
497 |                 "rebalance_frequency": self.rebalance_frequency,
498 |             }
499 |         )
500 | 
501 |         return base_dict
502 | 
503 | 
504 | class RiskAdjustedEnsemble(StrategyEnsemble):
505 |     """Risk-adjusted ensemble with position sizing and risk management."""
506 | 
507 |     def __init__(
508 |         self,
509 |         strategies: list[Strategy],
510 |         max_position_size: float = 0.1,
511 |         max_correlation: float = 0.7,
512 |         risk_target: float = 0.15,
513 |         **kwargs,
514 |     ):
515 |         """Initialize risk-adjusted ensemble.
516 | 
517 |         Args:
518 |             strategies: List of base strategies
519 |             max_position_size: Maximum position size per strategy
520 |             max_correlation: Maximum correlation between strategies
521 |             risk_target: Target portfolio volatility
522 |             **kwargs: Additional parameters for base ensemble
523 |         """
524 |         super().__init__(strategies, **kwargs)
525 |         self.max_position_size = max_position_size
526 |         self.max_correlation = max_correlation
527 |         self.risk_target = risk_target
528 | 
529 |     def calculate_correlation_matrix(self) -> pd.DataFrame:
530 |         """Calculate correlation matrix between strategy returns.
531 | 
532 |         Returns:
533 |             Correlation matrix as DataFrame
534 |         """
535 |         if len(self.strategy_returns) < 2:
536 |             return pd.DataFrame()
537 | 
538 |         # Create returns DataFrame
539 |         min_length = min(
540 |             len(returns)
541 |             for returns in self.strategy_returns.values()
542 |             if len(returns) > 0
543 |         )
544 |         if min_length == 0:
545 |             return pd.DataFrame()
546 | 
547 |         returns_data = {}
548 |         for i, strategy in enumerate(self.strategies):
549 |             if (
550 |                 i in self.strategy_returns
551 |                 and len(self.strategy_returns[i]) >= min_length
552 |             ):
553 |                 returns_data[strategy.name] = self.strategy_returns[i][-min_length:]
554 | 
555 |         if not returns_data:
556 |             return pd.DataFrame()
557 | 
558 |         returns_df = pd.DataFrame(returns_data)
559 |         return returns_df.corr()
560 | 
561 |     def adjust_weights_for_correlation(self, weights: np.ndarray) -> np.ndarray:
562 |         """Adjust weights to account for strategy correlation.
563 | 
564 |         Args:
565 |             weights: Original weights
566 | 
567 |         Returns:
568 |             Correlation-adjusted weights
569 |         """
570 |         corr_matrix = self.calculate_correlation_matrix()
571 | 
572 |         if corr_matrix.empty:
573 |             return weights
574 | 
575 |         try:
576 |             # Penalize highly correlated strategies
577 |             adjusted_weights = weights.copy()
578 | 
579 |             for i, strategy_i in enumerate(self.strategies):
580 |                 for j, strategy_j in enumerate(self.strategies):
581 |                     if (
582 |                         i != j
583 |                         and strategy_i.name in corr_matrix.index
584 |                         and strategy_j.name in corr_matrix.columns
585 |                     ):
586 |                         correlation = abs(
587 |                             corr_matrix.loc[strategy_i.name, strategy_j.name]
588 |                         )
589 | 
590 |                         if correlation > self.max_correlation:
591 |                             # Reduce weight of both strategies
592 |                             penalty = (correlation - self.max_correlation) * 0.5
593 |                             adjusted_weights[i] *= 1 - penalty
594 |                             adjusted_weights[j] *= 1 - penalty
595 | 
596 |             # Renormalize weights
597 |             # Fix: Check array size and sum properly before normalization
598 |             if adjusted_weights.size > 0 and np.sum(adjusted_weights) > 0:
599 |                 adjusted_weights /= adjusted_weights.sum()
600 |             else:
601 |                 adjusted_weights = np.ones(len(self.strategies)) / len(self.strategies)
602 | 
603 |             return adjusted_weights
604 | 
605 |         except Exception as e:
606 |             logger.error(f"Error adjusting weights for correlation: {e}")
607 |             return weights
608 | 
609 |     def calculate_risk_adjusted_weights(self, data: DataFrame) -> np.ndarray:
610 |         """Calculate risk-adjusted weights based on target volatility.
611 | 
612 |         Args:
613 |             data: Price data
614 | 
615 |         Returns:
616 |             Risk-adjusted weights
617 |         """
618 |         # Start with performance-based weights
619 |         base_weights = self.calculate_performance_weights(data)
620 | 
621 |         # Adjust for correlation
622 |         corr_adjusted_weights = self.adjust_weights_for_correlation(base_weights)
623 | 
624 |         # Apply position size limits
625 |         position_adjusted_weights = np.minimum(
626 |             corr_adjusted_weights, self.max_position_size
627 |         )
628 | 
629 |         # Renormalize
630 |         # Fix: Check array size and sum properly before normalization
631 |         if position_adjusted_weights.size > 0 and np.sum(position_adjusted_weights) > 0:
632 |             position_adjusted_weights /= position_adjusted_weights.sum()
633 |         else:
634 |             position_adjusted_weights = np.ones(len(self.strategies)) / len(
635 |                 self.strategies
636 |             )
637 | 
638 |         return position_adjusted_weights
639 | 
640 |     def update_weights(self, data: DataFrame, current_index: int) -> None:
641 |         """Update risk-adjusted weights.
642 | 
643 |         Args:
644 |             data: Price data
645 |             current_index: Current position in data
646 |         """
647 |         if current_index - self.last_rebalance < self.rebalance_frequency:
648 |             return
649 | 
650 |         try:
651 |             self.weights = self.calculate_risk_adjusted_weights(data)
652 |             self.last_rebalance = current_index
653 | 
654 |             logger.debug(
655 |                 f"Updated risk-adjusted weights: {dict(zip([s.name for s in self.strategies], self.weights, strict=False))}"
656 |             )
657 | 
658 |         except Exception as e:
659 |             logger.error(f"Error updating risk-adjusted weights: {e}")
660 | 
661 |     @property
662 |     def name(self) -> str:
663 |         """Get strategy name."""
664 |         return f"RiskAdjusted{super().name}"
665 | 
666 |     @property
667 |     def description(self) -> str:
668 |         """Get strategy description."""
669 |         return "Risk-adjusted ensemble with correlation control and position sizing"
670 | 
```

--------------------------------------------------------------------------------
/maverick_mcp/agents/technical_analysis.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Technical Analysis Agent with pattern recognition and multi-timeframe analysis.
  3 | """
  4 | 
  5 | import logging
  6 | from datetime import datetime
  7 | from typing import Any
  8 | 
  9 | from langchain_core.messages import HumanMessage
 10 | from langchain_core.tools import BaseTool
 11 | from langgraph.checkpoint.memory import MemorySaver
 12 | from langgraph.graph import END, START, StateGraph
 13 | 
 14 | from maverick_mcp.agents.circuit_breaker import circuit_manager
 15 | from maverick_mcp.langchain_tools import get_tool_registry
 16 | from maverick_mcp.memory import ConversationStore
 17 | from maverick_mcp.tools.risk_management import TechnicalStopsTool
 18 | from maverick_mcp.workflows.state import TechnicalAnalysisState
 19 | 
 20 | from .base import PersonaAwareAgent
 21 | 
 22 | logger = logging.getLogger(__name__)
 23 | 
 24 | 
 25 | class TechnicalAnalysisAgent(PersonaAwareAgent):
 26 |     """
 27 |     Professional technical analysis agent with pattern recognition.
 28 | 
 29 |     Features:
 30 |     - Chart pattern detection (head & shoulders, triangles, flags)
 31 |     - Multi-timeframe analysis
 32 |     - Indicator confluence scoring
 33 |     - Support/resistance clustering
 34 |     - Volume profile analysis
 35 |     - LLM-powered technical narratives
 36 |     """
 37 | 
 38 |     def __init__(
 39 |         self,
 40 |         llm,
 41 |         persona: str = "moderate",
 42 |         ttl_hours: int = 1,
 43 |     ):
 44 |         """
 45 |         Initialize technical analysis agent.
 46 | 
 47 |         Args:
 48 |             llm: Language model
 49 |             persona: Investor persona
 50 |             ttl_hours: Cache TTL in hours
 51 |             postgres_url: Optional PostgreSQL URL for checkpointing
 52 |         """
 53 |         # Store persona temporarily for tool configuration
 54 |         self._temp_persona = persona
 55 | 
 56 |         # Get technical analysis tools
 57 |         tools = self._get_technical_tools()
 58 | 
 59 |         # Initialize with MemorySaver
 60 |         super().__init__(
 61 |             llm=llm,
 62 |             tools=tools,
 63 |             persona=persona,
 64 |             checkpointer=MemorySaver(),
 65 |             ttl_hours=ttl_hours,
 66 |         )
 67 | 
 68 |         # Initialize conversation store
 69 |         self.conversation_store = ConversationStore(ttl_hours=ttl_hours)
 70 | 
 71 |     def _get_technical_tools(self) -> list[BaseTool]:
 72 |         """Get comprehensive technical analysis tools."""
 73 |         registry = get_tool_registry()
 74 | 
 75 |         # Core technical tools
 76 |         technical_tools = [
 77 |             registry.get_tool("get_technical_indicators"),
 78 |             registry.get_tool("calculate_support_resistance"),
 79 |             registry.get_tool("detect_chart_patterns"),
 80 |             registry.get_tool("calculate_moving_averages"),
 81 |             registry.get_tool("calculate_oscillators"),
 82 |         ]
 83 | 
 84 |         # Price action tools
 85 |         price_tools = [
 86 |             registry.get_tool("get_stock_price"),
 87 |             registry.get_tool("get_stock_history"),
 88 |             registry.get_tool("get_intraday_data"),
 89 |         ]
 90 | 
 91 |         # Volume analysis tools
 92 |         volume_tools = [
 93 |             registry.get_tool("analyze_volume_profile"),
 94 |             registry.get_tool("detect_volume_patterns"),
 95 |         ]
 96 | 
 97 |         # Risk tools
 98 |         risk_tools = [
 99 |             TechnicalStopsTool(),
100 |         ]
101 | 
102 |         # Combine and filter
103 |         all_tools = technical_tools + price_tools + volume_tools + risk_tools
104 |         tools = [t for t in all_tools if t is not None]
105 | 
106 |         # Configure persona for PersonaAwareTools
107 |         for tool in tools:
108 |             if hasattr(tool, "set_persona"):
109 |                 tool.set_persona(self._temp_persona)
110 | 
111 |         if not tools:
112 |             logger.warning("No technical tools available, using mock tools")
113 |             tools = self._create_mock_tools()
114 | 
115 |         return tools
116 | 
117 |     def get_state_schema(self) -> type:
118 |         """Return enhanced state schema for technical analysis."""
119 |         return TechnicalAnalysisState
120 | 
121 |     def _build_system_prompt(self) -> str:
122 |         """Build comprehensive system prompt for technical analysis."""
123 |         base_prompt = super()._build_system_prompt()
124 | 
125 |         technical_prompt = f"""
126 | 
127 | You are a professional technical analyst specializing in pattern recognition and multi-timeframe analysis.
128 | Current date: {datetime.now().strftime("%Y-%m-%d")}
129 | 
130 | ## Core Responsibilities:
131 | 
132 | 1. **Pattern Recognition**:
133 |    - Chart patterns: Head & Shoulders, Triangles, Flags, Wedges
134 |    - Candlestick patterns: Doji, Hammer, Engulfing, etc.
135 |    - Support/Resistance: Dynamic and static levels
136 |    - Trend lines and channels
137 | 
138 | 2. **Multi-Timeframe Analysis**:
139 |    - Align signals across daily, hourly, and 5-minute charts
140 |    - Identify confluences between timeframes
141 |    - Spot divergences early
142 |    - Time entries based on lower timeframe setups
143 | 
144 | 3. **Indicator Analysis**:
145 |    - Trend: Moving averages, ADX, MACD
146 |    - Momentum: RSI, Stochastic, CCI
147 |    - Volume: OBV, Volume Profile, VWAP
148 |    - Volatility: Bollinger Bands, ATR, Keltner Channels
149 | 
150 | 4. **Trade Setup Construction**:
151 |    - Entry points with specific triggers
152 |    - Stop loss placement using ATR or structure
153 |    - Profit targets based on measured moves
154 |    - Risk/Reward ratio calculation
155 | 
156 | ## Analysis Framework by Persona:
157 | 
158 | **Conservative ({self.persona.name if self.persona.name == "Conservative" else "N/A"})**:
159 | - Wait for confirmed patterns only
160 | - Use wider stops above/below structure
161 | - Target 1.5:1 risk/reward minimum
162 | - Focus on daily/weekly timeframes
163 | 
164 | **Moderate ({self.persona.name if self.persona.name == "Moderate" else "N/A"})**:
165 | - Balance pattern quality with opportunity
166 | - Standard ATR-based stops
167 | - Target 2:1 risk/reward
168 | - Use daily/4H timeframes
169 | 
170 | **Aggressive ({self.persona.name if self.persona.name == "Aggressive" else "N/A"})**:
171 | - Trade emerging patterns
172 | - Tighter stops for larger positions
173 | - Target 3:1+ risk/reward
174 | - Include intraday timeframes
175 | 
176 | **Day Trader ({self.persona.name if self.persona.name == "Day Trader" else "N/A"})**:
177 | - Focus on intraday patterns
178 | - Use tick/volume charts
179 | - Quick scalps with tight stops
180 | - Multiple entries/exits
181 | 
182 | ## Technical Analysis Process:
183 | 
184 | 1. **Market Structure**: Identify trend direction and strength
185 | 2. **Key Levels**: Map support/resistance zones
186 | 3. **Pattern Search**: Scan for actionable patterns
187 | 4. **Indicator Confluence**: Check for agreement
188 | 5. **Volume Confirmation**: Validate with volume
189 | 6. **Risk Definition**: Calculate stops and targets
190 | 7. **Setup Quality**: Rate A+ to C based on confluence
191 | 
192 | Remember to:
193 | - Be specific with price levels
194 | - Explain pattern psychology
195 | - Highlight invalidation levels
196 | - Consider market context
197 | - Provide clear action plans
198 | """
199 | 
200 |         return base_prompt + technical_prompt
201 | 
202 |     def _build_graph(self):
203 |         """Build enhanced graph with technical analysis nodes."""
204 |         workflow = StateGraph(TechnicalAnalysisState)
205 | 
206 |         # Add specialized nodes with unique names
207 |         workflow.add_node("analyze_structure", self._analyze_structure)
208 |         workflow.add_node("detect_patterns", self._detect_patterns)
209 |         workflow.add_node("analyze_indicators", self._analyze_indicators)
210 |         workflow.add_node("construct_trade_setup", self._construct_trade_setup)
211 |         workflow.add_node("agent", self._agent_node)
212 | 
213 |         # Create tool node if tools available
214 |         if self.tools:
215 |             from langgraph.prebuilt import ToolNode
216 | 
217 |             tool_node = ToolNode(self.tools)
218 |             workflow.add_node("tools", tool_node)
219 | 
220 |         # Define flow
221 |         workflow.add_edge(START, "analyze_structure")
222 |         workflow.add_edge("analyze_structure", "detect_patterns")
223 |         workflow.add_edge("detect_patterns", "analyze_indicators")
224 |         workflow.add_edge("analyze_indicators", "construct_trade_setup")
225 |         workflow.add_edge("construct_trade_setup", "agent")
226 | 
227 |         if self.tools:
228 |             workflow.add_conditional_edges(
229 |                 "agent",
230 |                 self._should_continue,
231 |                 {
232 |                     "continue": "tools",
233 |                     "end": END,
234 |                 },
235 |             )
236 |             workflow.add_edge("tools", "agent")
237 |         else:
238 |             workflow.add_edge("agent", END)
239 | 
240 |         return workflow.compile(checkpointer=self.checkpointer)
241 | 
242 |     async def _analyze_structure(self, state: TechnicalAnalysisState) -> dict[str, Any]:
243 |         """Analyze market structure and identify key levels."""
244 |         try:
245 |             # Get support/resistance tool
246 |             sr_tool = next(
247 |                 (t for t in self.tools if "support_resistance" in t.name), None
248 |             )
249 | 
250 |             if sr_tool and state.get("symbol"):
251 |                 circuit_breaker = await circuit_manager.get_or_create("technical")
252 | 
253 |                 async def get_levels():
254 |                     return await sr_tool.ainvoke(
255 |                         {
256 |                             "symbol": state["symbol"],
257 |                             "lookback_days": state.get("lookback_days", 20),
258 |                         }
259 |                     )
260 | 
261 |                 levels_data = await circuit_breaker.call(get_levels)
262 | 
263 |                 # Extract support/resistance levels
264 |                 if isinstance(levels_data, dict):
265 |                     state["support_levels"] = levels_data.get("support_levels", [])
266 |                     state["resistance_levels"] = levels_data.get(
267 |                         "resistance_levels", []
268 |                     )
269 | 
270 |                     # Determine trend based on structure
271 |                     if levels_data.get("trend"):
272 |                         state["trend_direction"] = levels_data["trend"]
273 |                     else:
274 |                         # Simple trend determination
275 |                         current = state.get("current_price", 0)
276 |                         ma_50 = levels_data.get("ma_50", current)
277 |                         state["trend_direction"] = (
278 |                             "bullish" if current > ma_50 else "bearish"
279 |                         )
280 | 
281 |         except Exception as e:
282 |             logger.error(f"Error analyzing structure: {e}")
283 | 
284 |         state["api_calls_made"] = state.get("api_calls_made", 0) + 1
285 |         return {
286 |             "support_levels": state.get("support_levels", []),
287 |             "resistance_levels": state.get("resistance_levels", []),
288 |             "trend_direction": state.get("trend_direction", "neutral"),
289 |         }
290 | 
291 |     async def _detect_patterns(self, state: TechnicalAnalysisState) -> dict[str, Any]:
292 |         """Detect chart patterns."""
293 |         try:
294 |             # Get pattern detection tool
295 |             pattern_tool = next((t for t in self.tools if "pattern" in t.name), None)
296 | 
297 |             if pattern_tool and state.get("symbol"):
298 |                 circuit_breaker = await circuit_manager.get_or_create("technical")
299 | 
300 |                 async def detect():
301 |                     return await pattern_tool.ainvoke(
302 |                         {
303 |                             "symbol": state["symbol"],
304 |                             "timeframe": state.get("timeframe", "1d"),
305 |                         }
306 |                     )
307 | 
308 |                 pattern_data = await circuit_breaker.call(detect)
309 | 
310 |                 if isinstance(pattern_data, dict) and "patterns" in pattern_data:
311 |                     patterns = pattern_data["patterns"]
312 |                     state["patterns"] = patterns
313 | 
314 |                     # Calculate pattern confidence scores
315 |                     pattern_confidence = {}
316 |                     for pattern in patterns:
317 |                         name = pattern.get("name", "Unknown")
318 |                         confidence = pattern.get("confidence", 50)
319 |                         pattern_confidence[name] = confidence
320 | 
321 |                     state["pattern_confidence"] = pattern_confidence
322 | 
323 |         except Exception as e:
324 |             logger.error(f"Error detecting patterns: {e}")
325 | 
326 |         state["api_calls_made"] = state.get("api_calls_made", 0) + 1
327 |         return {
328 |             "patterns": state.get("patterns", []),
329 |             "pattern_confidence": state.get("pattern_confidence", {}),
330 |         }
331 | 
332 |     async def _analyze_indicators(
333 |         self, state: TechnicalAnalysisState
334 |     ) -> dict[str, Any]:
335 |         """Analyze technical indicators."""
336 |         try:
337 |             # Get indicators tool
338 |             indicators_tool = next(
339 |                 (t for t in self.tools if "technical_indicators" in t.name), None
340 |             )
341 | 
342 |             if indicators_tool and state.get("symbol"):
343 |                 circuit_breaker = await circuit_manager.get_or_create("technical")
344 | 
345 |                 indicators = state.get("indicators", ["RSI", "MACD", "BB"])
346 | 
347 |                 async def get_indicators():
348 |                     return await indicators_tool.ainvoke(
349 |                         {
350 |                             "symbol": state["symbol"],
351 |                             "indicators": indicators,
352 |                             "period": state.get("lookback_days", 20),
353 |                         }
354 |                     )
355 | 
356 |                 indicator_data = await circuit_breaker.call(get_indicators)
357 | 
358 |                 if isinstance(indicator_data, dict):
359 |                     # Store indicator values
360 |                     state["indicator_values"] = indicator_data.get("values", {})
361 | 
362 |                     # Generate indicator signals
363 |                     signals = self._generate_indicator_signals(indicator_data)
364 |                     state["indicator_signals"] = signals
365 | 
366 |                     # Check for divergences
367 |                     divergences = self._check_divergences(
368 |                         state.get("price_history", {}), indicator_data
369 |                     )
370 |                     state["divergences"] = divergences
371 | 
372 |         except Exception as e:
373 |             logger.error(f"Error analyzing indicators: {e}")
374 | 
375 |         state["api_calls_made"] = state.get("api_calls_made", 0) + 1
376 |         return {
377 |             "indicator_values": state.get("indicator_values", {}),
378 |             "indicator_signals": state.get("indicator_signals", {}),
379 |             "divergences": state.get("divergences", []),
380 |         }
381 | 
382 |     async def _construct_trade_setup(
383 |         self, state: TechnicalAnalysisState
384 |     ) -> dict[str, Any]:
385 |         """Construct complete trade setup."""
386 |         try:
387 |             current_price = state.get("current_price", 0)
388 | 
389 |             if current_price > 0:
390 |                 # Calculate entry points based on patterns and levels
391 |                 entry_points = self._calculate_entry_points(state)
392 |                 state["entry_points"] = entry_points
393 | 
394 |                 # Get stop loss recommendation
395 |                 stops_tool = next(
396 |                     (t for t in self.tools if isinstance(t, TechnicalStopsTool)), None
397 |                 )
398 | 
399 |                 if stops_tool:
400 |                     stops_data = await stops_tool.ainvoke(
401 |                         {
402 |                             "symbol": state["symbol"],
403 |                             "lookback_days": 20,
404 |                         }
405 |                     )
406 | 
407 |                     if isinstance(stops_data, dict):
408 |                         stop_loss = stops_data.get(
409 |                             "recommended_stop", current_price * 0.95
410 |                         )
411 |                     else:
412 |                         stop_loss = current_price * 0.95
413 |                 else:
414 |                     stop_loss = current_price * 0.95
415 | 
416 |                 state["stop_loss"] = stop_loss
417 | 
418 |                 # Calculate profit targets
419 |                 risk = current_price - stop_loss
420 |                 targets = [
421 |                     current_price + (risk * 1.5),  # 1.5R
422 |                     current_price + (risk * 2.0),  # 2R
423 |                     current_price + (risk * 3.0),  # 3R
424 |                 ]
425 |                 state["profit_targets"] = targets
426 | 
427 |                 # Calculate risk/reward
428 |                 state["risk_reward_ratio"] = 2.0  # Default target
429 | 
430 |                 # Rate setup quality
431 |                 quality = self._rate_setup_quality(state)
432 |                 state["setup_quality"] = quality
433 | 
434 |                 # Calculate confidence score
435 |                 confidence = self._calculate_confidence_score(state)
436 |                 state["confidence_score"] = confidence
437 | 
438 |         except Exception as e:
439 |             logger.error(f"Error constructing trade setup: {e}")
440 | 
441 |         return {
442 |             "entry_points": state.get("entry_points", []),
443 |             "stop_loss": state.get("stop_loss", 0),
444 |             "profit_targets": state.get("profit_targets", []),
445 |             "risk_reward_ratio": state.get("risk_reward_ratio", 0),
446 |             "setup_quality": state.get("setup_quality", "C"),
447 |             "confidence_score": state.get("confidence_score", 0),
448 |         }
449 | 
450 |     def _generate_indicator_signals(self, indicator_data: dict) -> dict[str, str]:
451 |         """Generate buy/sell/hold signals from indicators."""
452 |         signals = {}
453 | 
454 |         # RSI signals
455 |         rsi = indicator_data.get("RSI", {}).get("value", 50)
456 |         if rsi < 30:
457 |             signals["RSI"] = "buy"
458 |         elif rsi > 70:
459 |             signals["RSI"] = "sell"
460 |         else:
461 |             signals["RSI"] = "hold"
462 | 
463 |         # MACD signals
464 |         macd = indicator_data.get("MACD", {})
465 |         if macd.get("histogram", 0) > 0 and macd.get("signal_cross", "") == "bullish":
466 |             signals["MACD"] = "buy"
467 |         elif macd.get("histogram", 0) < 0 and macd.get("signal_cross", "") == "bearish":
468 |             signals["MACD"] = "sell"
469 |         else:
470 |             signals["MACD"] = "hold"
471 | 
472 |         return signals
473 | 
474 |     def _check_divergences(
475 |         self, price_history: dict, indicator_data: dict
476 |     ) -> list[dict[str, Any]]:
477 |         """Check for price/indicator divergences."""
478 |         divergences: list[dict[str, Any]] = []
479 | 
480 |         # Simplified divergence detection
481 |         # In production, would use more sophisticated analysis
482 | 
483 |         return divergences
484 | 
485 |     def _calculate_entry_points(self, state: TechnicalAnalysisState) -> list[float]:
486 |         """Calculate optimal entry points."""
487 |         current_price = state.get("current_price", 0)
488 |         support_levels = state.get("support_levels", [])
489 |         patterns = state.get("patterns", [])
490 | 
491 |         entries = []
492 | 
493 |         # Pattern-based entries
494 |         for pattern in patterns:
495 |             if pattern.get("entry_price"):
496 |                 entries.append(pattern["entry_price"])
497 | 
498 |         # Support-based entries
499 |         for support in support_levels:
500 |             if support < current_price:
501 |                 # Entry just above support
502 |                 entries.append(support * 1.01)
503 | 
504 |         # Current price entry if momentum
505 |         if state.get("trend_direction") == "bullish":
506 |             entries.append(current_price)
507 | 
508 |         return sorted(set(entries))[:3]  # Top 3 unique entries
509 | 
510 |     def _rate_setup_quality(self, state: TechnicalAnalysisState) -> str:
511 |         """Rate the quality of the trade setup."""
512 |         score = 0
513 | 
514 |         # Pattern quality
515 |         if state.get("patterns"):
516 |             max_confidence = max(p.get("confidence", 0) for p in state["patterns"])
517 |             if max_confidence > 80:
518 |                 score += 30
519 |             elif max_confidence > 60:
520 |                 score += 20
521 |             else:
522 |                 score += 10
523 | 
524 |         # Indicator confluence
525 |         signals = state.get("indicator_signals", {})
526 |         buy_signals = sum(1 for s in signals.values() if s == "buy")
527 |         if buy_signals >= 3:
528 |             score += 30
529 |         elif buy_signals >= 2:
530 |             score += 20
531 |         else:
532 |             score += 10
533 | 
534 |         # Risk/Reward
535 |         rr = state.get("risk_reward_ratio", 0)
536 |         if rr >= 3:
537 |             score += 20
538 |         elif rr >= 2:
539 |             score += 15
540 |         else:
541 |             score += 5
542 | 
543 |         # Volume confirmation (would check in real implementation)
544 |         score += 10
545 | 
546 |         # Market alignment (would check in real implementation)
547 |         score += 10
548 | 
549 |         # Convert score to grade
550 |         if score >= 85:
551 |             return "A+"
552 |         elif score >= 75:
553 |             return "A"
554 |         elif score >= 65:
555 |             return "B"
556 |         else:
557 |             return "C"
558 | 
559 |     def _calculate_confidence_score(self, state: TechnicalAnalysisState) -> float:
560 |         """Calculate overall confidence score for the setup."""
561 |         factors = []
562 | 
563 |         # Pattern confidence
564 |         if state.get("pattern_confidence"):
565 |             factors.append(max(state["pattern_confidence"].values()) / 100)
566 | 
567 |         # Indicator agreement
568 |         signals = state.get("indicator_signals", {})
569 |         if signals:
570 |             buy_count = sum(1 for s in signals.values() if s == "buy")
571 |             factors.append(buy_count / len(signals))
572 | 
573 |         # Setup quality
574 |         quality_scores = {"A+": 1.0, "A": 0.85, "B": 0.70, "C": 0.50}
575 |         factors.append(quality_scores.get(state.get("setup_quality", "C"), 0.5))
576 | 
577 |         # Average confidence
578 |         return round(sum(factors) / len(factors) * 100, 1) if factors else 50.0
579 | 
580 |     async def analyze_stock(
581 |         self,
582 |         symbol: str,
583 |         timeframe: str = "1d",
584 |         indicators: list[str] | None = None,
585 |         **kwargs,
586 |     ) -> dict[str, Any]:
587 |         """
588 |         Perform comprehensive technical analysis on a stock.
589 | 
590 |         Args:
591 |             symbol: Stock symbol
592 |             timeframe: Chart timeframe
593 |             indicators: List of indicators to analyze
594 |             **kwargs: Additional parameters
595 | 
596 |         Returns:
597 |             Complete technical analysis with trade setup
598 |         """
599 |         start_time = datetime.now()
600 | 
601 |         # Default indicators
602 |         if indicators is None:
603 |             indicators = ["RSI", "MACD", "BB", "EMA", "VWAP"]
604 | 
605 |         # Prepare query
606 |         query = f"Analyze {symbol} on {timeframe} timeframe with focus on patterns and trade setup"
607 | 
608 |         # Initial state
609 |         initial_state = {
610 |             "messages": [HumanMessage(content=query)],
611 |             "symbol": symbol,
612 |             "timeframe": timeframe,
613 |             "indicators": indicators,
614 |             "lookback_days": kwargs.get("lookback_days", 20),
615 |             "pattern_detection": True,
616 |             "multi_timeframe": kwargs.get("multi_timeframe", False),
617 |             "persona": self.persona.name,
618 |             "session_id": kwargs.get(
619 |                 "session_id", f"{symbol}_{datetime.now().timestamp()}"
620 |             ),
621 |             "timestamp": datetime.now(),
622 |             "api_calls_made": 0,
623 |         }
624 | 
625 |         # Run analysis
626 |         result = await self.ainvoke(
627 |             query, initial_state["session_id"], initial_state=initial_state
628 |         )
629 | 
630 |         # Calculate execution time
631 |         execution_time = (datetime.now() - start_time).total_seconds() * 1000
632 | 
633 |         # Extract results
634 |         return self._format_analysis_results(result, execution_time)
635 | 
636 |     def _format_analysis_results(
637 |         self, result: dict[str, Any], execution_time: float
638 |     ) -> dict[str, Any]:
639 |         """Format technical analysis results."""
640 |         state = result.get("state", {})
641 |         messages = result.get("messages", [])
642 | 
643 |         return {
644 |             "status": "success",
645 |             "timestamp": datetime.now().isoformat(),
646 |             "execution_time_ms": execution_time,
647 |             "symbol": state.get("symbol", ""),
648 |             "analysis": {
649 |                 "market_structure": {
650 |                     "trend": state.get("trend_direction", "neutral"),
651 |                     "support_levels": state.get("support_levels", []),
652 |                     "resistance_levels": state.get("resistance_levels", []),
653 |                 },
654 |                 "patterns": {
655 |                     "detected": state.get("patterns", []),
656 |                     "confidence": state.get("pattern_confidence", {}),
657 |                 },
658 |                 "indicators": {
659 |                     "values": state.get("indicator_values", {}),
660 |                     "signals": state.get("indicator_signals", {}),
661 |                     "divergences": state.get("divergences", []),
662 |                 },
663 |                 "trade_setup": {
664 |                     "entries": state.get("entry_points", []),
665 |                     "stop_loss": state.get("stop_loss", 0),
666 |                     "targets": state.get("profit_targets", []),
667 |                     "risk_reward": state.get("risk_reward_ratio", 0),
668 |                     "quality": state.get("setup_quality", "C"),
669 |                     "confidence": state.get("confidence_score", 0),
670 |                 },
671 |             },
672 |             "recommendation": messages[-1].content if messages else "",
673 |             "persona_adjusted": True,
674 |             "risk_profile": self.persona.name,
675 |         }
676 | 
677 |     def _create_mock_tools(self) -> list:
678 |         """Create mock tools for testing."""
679 |         from langchain_core.tools import tool
680 | 
681 |         @tool
682 |         def mock_technical_indicators(symbol: str, indicators: list[str]) -> dict:
683 |             """Mock technical indicators tool."""
684 |             return {
685 |                 "RSI": {"value": 45, "trend": "neutral"},
686 |                 "MACD": {"histogram": 0.5, "signal_cross": "bullish"},
687 |                 "BB": {"upper": 150, "middle": 145, "lower": 140},
688 |             }
689 | 
690 |         @tool
691 |         def mock_support_resistance(symbol: str) -> dict:
692 |             """Mock support/resistance tool."""
693 |             return {
694 |                 "support_levels": [140, 135, 130],
695 |                 "resistance_levels": [150, 155, 160],
696 |                 "trend": "bullish",
697 |             }
698 | 
699 |         return [mock_technical_indicators, mock_support_resistance]
700 | 
```

--------------------------------------------------------------------------------
/maverick_mcp/backtesting/ab_testing.py:
--------------------------------------------------------------------------------

```python
  1 | """A/B testing framework for comparing ML model performance."""
  2 | 
  3 | import logging
  4 | import random
  5 | from datetime import datetime
  6 | from typing import Any
  7 | 
  8 | import numpy as np
  9 | import pandas as pd
 10 | from scipy import stats
 11 | from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
 12 | 
 13 | from .model_manager import ModelManager
 14 | 
 15 | logger = logging.getLogger(__name__)
 16 | 
 17 | 
 18 | class ABTestGroup:
 19 |     """Represents a group in an A/B test."""
 20 | 
 21 |     def __init__(
 22 |         self,
 23 |         group_id: str,
 24 |         model_id: str,
 25 |         model_version: str,
 26 |         traffic_allocation: float,
 27 |         description: str = "",
 28 |     ):
 29 |         """Initialize A/B test group.
 30 | 
 31 |         Args:
 32 |             group_id: Unique identifier for the group
 33 |             model_id: Model identifier
 34 |             model_version: Model version
 35 |             traffic_allocation: Fraction of traffic allocated to this group (0-1)
 36 |             description: Description of the group
 37 |         """
 38 |         self.group_id = group_id
 39 |         self.model_id = model_id
 40 |         self.model_version = model_version
 41 |         self.traffic_allocation = traffic_allocation
 42 |         self.description = description
 43 |         self.created_at = datetime.now()
 44 | 
 45 |         # Performance tracking
 46 |         self.predictions: list[Any] = []
 47 |         self.actual_values: list[Any] = []
 48 |         self.prediction_timestamps: list[datetime] = []
 49 |         self.prediction_confidence: list[float] = []
 50 | 
 51 |     def add_prediction(
 52 |         self,
 53 |         prediction: Any,
 54 |         actual: Any,
 55 |         confidence: float = 1.0,
 56 |         timestamp: datetime | None = None,
 57 |     ):
 58 |         """Add a prediction result to the group.
 59 | 
 60 |         Args:
 61 |             prediction: Model prediction
 62 |             actual: Actual value
 63 |             confidence: Prediction confidence score
 64 |             timestamp: Prediction timestamp
 65 |         """
 66 |         self.predictions.append(prediction)
 67 |         self.actual_values.append(actual)
 68 |         self.prediction_confidence.append(confidence)
 69 |         self.prediction_timestamps.append(timestamp or datetime.now())
 70 | 
 71 |     def get_metrics(self) -> dict[str, float]:
 72 |         """Calculate performance metrics for the group.
 73 | 
 74 |         Returns:
 75 |             Dictionary of performance metrics
 76 |         """
 77 |         if not self.predictions or not self.actual_values:
 78 |             return {}
 79 | 
 80 |         try:
 81 |             predictions = np.array(self.predictions)
 82 |             actuals = np.array(self.actual_values)
 83 | 
 84 |             metrics = {
 85 |                 "sample_count": len(predictions),
 86 |                 "accuracy": accuracy_score(actuals, predictions),
 87 |                 "precision": precision_score(
 88 |                     actuals, predictions, average="weighted", zero_division=0
 89 |                 ),
 90 |                 "recall": recall_score(
 91 |                     actuals, predictions, average="weighted", zero_division=0
 92 |                 ),
 93 |                 "f1_score": f1_score(
 94 |                     actuals, predictions, average="weighted", zero_division=0
 95 |                 ),
 96 |                 "avg_confidence": np.mean(self.prediction_confidence),
 97 |             }
 98 | 
 99 |             # Add confusion matrix for binary/multiclass
100 |             unique_labels = np.unique(np.concatenate([predictions, actuals]))
101 |             if len(unique_labels) <= 10:  # Reasonable number of classes
102 |                 from sklearn.metrics import confusion_matrix
103 | 
104 |                 cm = confusion_matrix(actuals, predictions, labels=unique_labels)
105 |                 metrics["confusion_matrix"] = cm.tolist()
106 |                 metrics["unique_labels"] = unique_labels.tolist()
107 | 
108 |             return metrics
109 | 
110 |         except Exception as e:
111 |             logger.error(f"Error calculating metrics for group {self.group_id}: {e}")
112 |             return {"error": str(e)}
113 | 
114 |     def to_dict(self) -> dict[str, Any]:
115 |         """Convert group to dictionary representation."""
116 |         return {
117 |             "group_id": self.group_id,
118 |             "model_id": self.model_id,
119 |             "model_version": self.model_version,
120 |             "traffic_allocation": self.traffic_allocation,
121 |             "description": self.description,
122 |             "created_at": self.created_at.isoformat(),
123 |             "metrics": self.get_metrics(),
124 |         }
125 | 
126 | 
127 | class ABTest:
128 |     """Manages an A/B test between different model versions."""
129 | 
130 |     def __init__(
131 |         self,
132 |         test_id: str,
133 |         name: str,
134 |         description: str = "",
135 |         random_seed: int | None = None,
136 |     ):
137 |         """Initialize A/B test.
138 | 
139 |         Args:
140 |             test_id: Unique identifier for the test
141 |             name: Human-readable name for the test
142 |             description: Description of the test
143 |             random_seed: Random seed for reproducible traffic splitting
144 |         """
145 |         self.test_id = test_id
146 |         self.name = name
147 |         self.description = description
148 |         self.created_at = datetime.now()
149 |         self.started_at: datetime | None = None
150 |         self.ended_at: datetime | None = None
151 |         self.status = "created"  # created, running, completed, cancelled
152 | 
153 |         # Groups in the test
154 |         self.groups: dict[str, ABTestGroup] = {}
155 | 
156 |         # Traffic allocation
157 |         self.traffic_splitter = TrafficSplitter(random_seed)
158 | 
159 |         # Test configuration
160 |         self.min_samples_per_group = 100
161 |         self.confidence_level = 0.95
162 |         self.minimum_detectable_effect = 0.05
163 | 
164 |     def add_group(
165 |         self,
166 |         group_id: str,
167 |         model_id: str,
168 |         model_version: str,
169 |         traffic_allocation: float,
170 |         description: str = "",
171 |     ) -> bool:
172 |         """Add a group to the A/B test.
173 | 
174 |         Args:
175 |             group_id: Unique identifier for the group
176 |             model_id: Model identifier
177 |             model_version: Model version
178 |             traffic_allocation: Fraction of traffic (0-1)
179 |             description: Description of the group
180 | 
181 |         Returns:
182 |             True if successful
183 |         """
184 |         if self.status != "created":
185 |             logger.error(
186 |                 f"Cannot add group to test {self.test_id} - test already started"
187 |             )
188 |             return False
189 | 
190 |         if group_id in self.groups:
191 |             logger.error(f"Group {group_id} already exists in test {self.test_id}")
192 |             return False
193 | 
194 |         # Validate traffic allocation
195 |         current_total = sum(g.traffic_allocation for g in self.groups.values())
196 |         if (
197 |             current_total + traffic_allocation > 1.0001
198 |         ):  # Small tolerance for floating point
199 |             logger.error(
200 |                 f"Traffic allocation would exceed 100%: {current_total + traffic_allocation}"
201 |             )
202 |             return False
203 | 
204 |         group = ABTestGroup(
205 |             group_id=group_id,
206 |             model_id=model_id,
207 |             model_version=model_version,
208 |             traffic_allocation=traffic_allocation,
209 |             description=description,
210 |         )
211 | 
212 |         self.groups[group_id] = group
213 |         self.traffic_splitter.update_allocation(
214 |             {gid: g.traffic_allocation for gid, g in self.groups.items()}
215 |         )
216 | 
217 |         logger.info(f"Added group {group_id} to test {self.test_id}")
218 |         return True
219 | 
220 |     def start_test(self) -> bool:
221 |         """Start the A/B test.
222 | 
223 |         Returns:
224 |             True if successful
225 |         """
226 |         if self.status != "created":
227 |             logger.error(
228 |                 f"Cannot start test {self.test_id} - invalid status: {self.status}"
229 |             )
230 |             return False
231 | 
232 |         if len(self.groups) < 2:
233 |             logger.error(f"Cannot start test {self.test_id} - need at least 2 groups")
234 |             return False
235 | 
236 |         # Validate traffic allocation sums to approximately 1.0
237 |         total_allocation = sum(g.traffic_allocation for g in self.groups.values())
238 |         if abs(total_allocation - 1.0) > 0.01:
239 |             logger.error(f"Traffic allocation does not sum to 1.0: {total_allocation}")
240 |             return False
241 | 
242 |         self.status = "running"
243 |         self.started_at = datetime.now()
244 |         logger.info(f"Started A/B test {self.test_id} with {len(self.groups)} groups")
245 |         return True
246 | 
247 |     def assign_traffic(self, user_id: str | None = None) -> str | None:
248 |         """Assign traffic to a group.
249 | 
250 |         Args:
251 |             user_id: User identifier for consistent assignment
252 | 
253 |         Returns:
254 |             Group ID or None if test not running
255 |         """
256 |         if self.status != "running":
257 |             return None
258 | 
259 |         return self.traffic_splitter.assign_group(user_id)
260 | 
261 |     def record_prediction(
262 |         self,
263 |         group_id: str,
264 |         prediction: Any,
265 |         actual: Any,
266 |         confidence: float = 1.0,
267 |         timestamp: datetime | None = None,
268 |     ) -> bool:
269 |         """Record a prediction result for a group.
270 | 
271 |         Args:
272 |             group_id: Group identifier
273 |             prediction: Model prediction
274 |             actual: Actual value
275 |             confidence: Prediction confidence
276 |             timestamp: Prediction timestamp
277 | 
278 |         Returns:
279 |             True if successful
280 |         """
281 |         if group_id not in self.groups:
282 |             logger.error(f"Group {group_id} not found in test {self.test_id}")
283 |             return False
284 | 
285 |         self.groups[group_id].add_prediction(prediction, actual, confidence, timestamp)
286 |         return True
287 | 
288 |     def get_results(self) -> dict[str, Any]:
289 |         """Get current A/B test results.
290 | 
291 |         Returns:
292 |             Dictionary with test results
293 |         """
294 |         results = {
295 |             "test_id": self.test_id,
296 |             "name": self.name,
297 |             "description": self.description,
298 |             "status": self.status,
299 |             "created_at": self.created_at.isoformat(),
300 |             "started_at": self.started_at.isoformat() if self.started_at else None,
301 |             "ended_at": self.ended_at.isoformat() if self.ended_at else None,
302 |             "groups": {},
303 |             "statistical_analysis": {},
304 |         }
305 | 
306 |         # Group results
307 |         for group_id, group in self.groups.items():
308 |             results["groups"][group_id] = group.to_dict()
309 | 
310 |         # Statistical analysis
311 |         if len(self.groups) >= 2:
312 |             results["statistical_analysis"] = self._perform_statistical_analysis()
313 | 
314 |         return results
315 | 
316 |     def _perform_statistical_analysis(self) -> dict[str, Any]:
317 |         """Perform statistical analysis of A/B test results.
318 | 
319 |         Returns:
320 |             Statistical analysis results
321 |         """
322 |         analysis = {
323 |             "ready_for_analysis": True,
324 |             "sample_size_adequate": True,
325 |             "statistical_significance": {},
326 |             "effect_sizes": {},
327 |             "recommendations": [],
328 |         }
329 | 
330 |         # Check sample sizes
331 |         sample_sizes = {
332 |             group_id: len(group.predictions) for group_id, group in self.groups.items()
333 |         }
334 | 
335 |         min_samples = min(sample_sizes.values()) if sample_sizes else 0
336 |         if min_samples < self.min_samples_per_group:
337 |             analysis["ready_for_analysis"] = False
338 |             analysis["sample_size_adequate"] = False
339 |             analysis["recommendations"].append(
340 |                 f"Need at least {self.min_samples_per_group} samples per group (current min: {min_samples})"
341 |             )
342 | 
343 |         if not analysis["ready_for_analysis"]:
344 |             return analysis
345 | 
346 |         # Pairwise comparisons
347 |         group_ids = list(self.groups.keys())
348 |         for i, group_a_id in enumerate(group_ids):
349 |             for group_b_id in group_ids[i + 1 :]:
350 |                 comparison_key = f"{group_a_id}_vs_{group_b_id}"
351 | 
352 |                 try:
353 |                     group_a = self.groups[group_a_id]
354 |                     group_b = self.groups[group_b_id]
355 | 
356 |                     # Compare accuracy scores
357 |                     accuracy_a = accuracy_score(
358 |                         group_a.actual_values, group_a.predictions
359 |                     )
360 |                     accuracy_b = accuracy_score(
361 |                         group_b.actual_values, group_b.predictions
362 |                     )
363 | 
364 |                     # Perform statistical test
365 |                     # For classification accuracy, we can use a proportion test
366 |                     n_correct_a = sum(
367 |                         np.array(group_a.predictions) == np.array(group_a.actual_values)
368 |                     )
369 |                     n_correct_b = sum(
370 |                         np.array(group_b.predictions) == np.array(group_b.actual_values)
371 |                     )
372 |                     n_total_a = len(group_a.predictions)
373 |                     n_total_b = len(group_b.predictions)
374 | 
375 |                     # Two-proportion z-test
376 |                     p_combined = (n_correct_a + n_correct_b) / (n_total_a + n_total_b)
377 |                     se = np.sqrt(
378 |                         p_combined * (1 - p_combined) * (1 / n_total_a + 1 / n_total_b)
379 |                     )
380 | 
381 |                     if se > 0:
382 |                         z_score = (accuracy_a - accuracy_b) / se
383 |                         p_value = 2 * (1 - stats.norm.cdf(abs(z_score)))
384 | 
385 |                         # Effect size (Cohen's h for proportions)
386 |                         h = 2 * (
387 |                             np.arcsin(np.sqrt(accuracy_a))
388 |                             - np.arcsin(np.sqrt(accuracy_b))
389 |                         )
390 | 
391 |                         analysis["statistical_significance"][comparison_key] = {
392 |                             "accuracy_a": accuracy_a,
393 |                             "accuracy_b": accuracy_b,
394 |                             "difference": accuracy_a - accuracy_b,
395 |                             "z_score": z_score,
396 |                             "p_value": p_value,
397 |                             "significant": p_value < (1 - self.confidence_level),
398 |                             "effect_size_h": h,
399 |                         }
400 | 
401 |                         # Recommendations based on results
402 |                         if p_value < (1 - self.confidence_level):
403 |                             if accuracy_a > accuracy_b:
404 |                                 analysis["recommendations"].append(
405 |                                     f"Group {group_a_id} significantly outperforms {group_b_id} "
406 |                                     f"(p={p_value:.4f}, effect_size={h:.4f})"
407 |                                 )
408 |                             else:
409 |                                 analysis["recommendations"].append(
410 |                                     f"Group {group_b_id} significantly outperforms {group_a_id} "
411 |                                     f"(p={p_value:.4f}, effect_size={h:.4f})"
412 |                                 )
413 |                         else:
414 |                             analysis["recommendations"].append(
415 |                                 f"No significant difference between {group_a_id} and {group_b_id} "
416 |                                 f"(p={p_value:.4f})"
417 |                             )
418 | 
419 |                 except Exception as e:
420 |                     logger.error(
421 |                         f"Error in statistical analysis for {comparison_key}: {e}"
422 |                     )
423 |                     analysis["statistical_significance"][comparison_key] = {
424 |                         "error": str(e)
425 |                     }
426 | 
427 |         return analysis
428 | 
429 |     def stop_test(self, reason: str = "completed") -> bool:
430 |         """Stop the A/B test.
431 | 
432 |         Args:
433 |             reason: Reason for stopping
434 | 
435 |         Returns:
436 |             True if successful
437 |         """
438 |         if self.status != "running":
439 |             logger.error(f"Cannot stop test {self.test_id} - not running")
440 |             return False
441 | 
442 |         self.status = "completed" if reason == "completed" else "cancelled"
443 |         self.ended_at = datetime.now()
444 |         logger.info(f"Stopped A/B test {self.test_id}: {reason}")
445 |         return True
446 | 
447 |     def to_dict(self) -> dict[str, Any]:
448 |         """Convert test to dictionary representation."""
449 |         return {
450 |             "test_id": self.test_id,
451 |             "name": self.name,
452 |             "description": self.description,
453 |             "status": self.status,
454 |             "created_at": self.created_at.isoformat(),
455 |             "started_at": self.started_at.isoformat() if self.started_at else None,
456 |             "ended_at": self.ended_at.isoformat() if self.ended_at else None,
457 |             "groups": {gid: g.to_dict() for gid, g in self.groups.items()},
458 |             "configuration": {
459 |                 "min_samples_per_group": self.min_samples_per_group,
460 |                 "confidence_level": self.confidence_level,
461 |                 "minimum_detectable_effect": self.minimum_detectable_effect,
462 |             },
463 |         }
464 | 
465 | 
466 | class TrafficSplitter:
467 |     """Handles traffic splitting for A/B tests."""
468 | 
469 |     def __init__(self, random_seed: int | None = None):
470 |         """Initialize traffic splitter.
471 | 
472 |         Args:
473 |             random_seed: Random seed for reproducible splitting
474 |         """
475 |         self.random_seed = random_seed
476 |         self.group_allocation: dict[str, float] = {}
477 |         self.cumulative_allocation: list[tuple[str, float]] = []
478 | 
479 |     def update_allocation(self, allocation: dict[str, float]):
480 |         """Update group traffic allocation.
481 | 
482 |         Args:
483 |             allocation: Dictionary mapping group_id to allocation fraction
484 |         """
485 |         self.group_allocation = allocation.copy()
486 | 
487 |         # Create cumulative distribution for sampling
488 |         self.cumulative_allocation = []
489 |         cumulative = 0.0
490 | 
491 |         for group_id, fraction in allocation.items():
492 |             cumulative += fraction
493 |             self.cumulative_allocation.append((group_id, cumulative))
494 | 
495 |     def assign_group(self, user_id: str | None = None) -> str | None:
496 |         """Assign a user to a group.
497 | 
498 |         Args:
499 |             user_id: User identifier for consistent assignment
500 | 
501 |         Returns:
502 |             Group ID or None if no groups configured
503 |         """
504 |         if not self.cumulative_allocation:
505 |             return None
506 | 
507 |         # Generate random value
508 |         if user_id is not None:
509 |             # Hash user_id for consistent assignment
510 |             import hashlib
511 | 
512 |             hash_object = hashlib.md5(user_id.encode())
513 |             hash_int = int(hash_object.hexdigest(), 16)
514 |             rand_value = (hash_int % 10000) / 10000.0  # Normalize to [0, 1)
515 |         else:
516 |             if self.random_seed is not None:
517 |                 random.seed(self.random_seed)
518 |             rand_value = random.random()
519 | 
520 |         # Find group based on cumulative allocation
521 |         for group_id, cumulative_threshold in self.cumulative_allocation:
522 |             if rand_value <= cumulative_threshold:
523 |                 return group_id
524 | 
525 |         # Fallback to last group
526 |         return self.cumulative_allocation[-1][0] if self.cumulative_allocation else None
527 | 
528 | 
529 | class ABTestManager:
530 |     """Manages multiple A/B tests."""
531 | 
532 |     def __init__(self, model_manager: ModelManager):
533 |         """Initialize A/B test manager.
534 | 
535 |         Args:
536 |             model_manager: Model manager instance
537 |         """
538 |         self.model_manager = model_manager
539 |         self.tests: dict[str, ABTest] = {}
540 | 
541 |     def create_test(
542 |         self,
543 |         test_id: str,
544 |         name: str,
545 |         description: str = "",
546 |         random_seed: int | None = None,
547 |     ) -> ABTest:
548 |         """Create a new A/B test.
549 | 
550 |         Args:
551 |             test_id: Unique identifier for the test
552 |             name: Human-readable name
553 |             description: Description
554 |             random_seed: Random seed for reproducible splitting
555 | 
556 |         Returns:
557 |             ABTest instance
558 |         """
559 |         if test_id in self.tests:
560 |             raise ValueError(f"Test {test_id} already exists")
561 | 
562 |         test = ABTest(test_id, name, description, random_seed)
563 |         self.tests[test_id] = test
564 |         logger.info(f"Created A/B test {test_id}: {name}")
565 |         return test
566 | 
567 |     def get_test(self, test_id: str) -> ABTest | None:
568 |         """Get an A/B test by ID.
569 | 
570 |         Args:
571 |             test_id: Test identifier
572 | 
573 |         Returns:
574 |             ABTest instance or None
575 |         """
576 |         return self.tests.get(test_id)
577 | 
578 |     def list_tests(self, status_filter: str | None = None) -> list[dict[str, Any]]:
579 |         """List all A/B tests.
580 | 
581 |         Args:
582 |             status_filter: Filter by status (created, running, completed, cancelled)
583 | 
584 |         Returns:
585 |             List of test summaries
586 |         """
587 |         tests = []
588 |         for test in self.tests.values():
589 |             if status_filter is None or test.status == status_filter:
590 |                 tests.append(
591 |                     {
592 |                         "test_id": test.test_id,
593 |                         "name": test.name,
594 |                         "status": test.status,
595 |                         "groups_count": len(test.groups),
596 |                         "created_at": test.created_at.isoformat(),
597 |                         "started_at": test.started_at.isoformat()
598 |                         if test.started_at
599 |                         else None,
600 |                     }
601 |                 )
602 | 
603 |         # Sort by creation time (newest first)
604 |         tests.sort(key=lambda x: x["created_at"], reverse=True)
605 |         return tests
606 | 
607 |     def run_model_comparison(
608 |         self,
609 |         test_name: str,
610 |         model_versions: list[tuple[str, str]],  # (model_id, version)
611 |         test_data: pd.DataFrame,
612 |         feature_extractor: Any,
613 |         target_extractor: Any,
614 |         traffic_allocation: list[float] | None = None,
615 |         test_duration_hours: int = 24,
616 |     ) -> str:
617 |         """Run a model comparison A/B test.
618 | 
619 |         Args:
620 |             test_name: Name for the test
621 |             model_versions: List of (model_id, version) tuples to compare
622 |             test_data: Test data for evaluation
623 |             feature_extractor: Function to extract features
624 |             target_extractor: Function to extract targets
625 |             traffic_allocation: Custom traffic allocation (defaults to equal split)
626 |             test_duration_hours: Duration to run the test
627 | 
628 |         Returns:
629 |             Test ID
630 |         """
631 |         # Generate unique test ID
632 |         test_id = f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
633 | 
634 |         # Create test
635 |         test = self.create_test(
636 |             test_id=test_id,
637 |             name=test_name,
638 |             description=f"Comparing {len(model_versions)} model versions",
639 |         )
640 | 
641 |         # Default equal traffic allocation
642 |         if traffic_allocation is None:
643 |             allocation_per_group = 1.0 / len(model_versions)
644 |             traffic_allocation = [allocation_per_group] * len(model_versions)
645 | 
646 |         # Add groups
647 |         for i, (model_id, version) in enumerate(model_versions):
648 |             group_id = f"group_{i}_{model_id}_{version}"
649 |             test.add_group(
650 |                 group_id=group_id,
651 |                 model_id=model_id,
652 |                 model_version=version,
653 |                 traffic_allocation=traffic_allocation[i],
654 |                 description=f"Model {model_id} version {version}",
655 |             )
656 | 
657 |         # Start test
658 |         test.start_test()
659 | 
660 |         # Extract features and targets
661 |         features = feature_extractor(test_data)
662 |         targets = target_extractor(test_data)
663 | 
664 |         # Simulate predictions for each group
665 |         for _, row in features.iterrows():
666 |             # Assign traffic
667 |             group_id = test.assign_traffic(str(row.name))  # Use row index as user_id
668 |             if group_id is None:
669 |                 continue
670 | 
671 |             # Get corresponding group's model
672 |             group = test.groups[group_id]
673 |             model_version = self.model_manager.load_model(
674 |                 group.model_id, group.model_version
675 |             )
676 | 
677 |             if model_version is None or model_version.model is None:
678 |                 logger.warning(f"Could not load model for group {group_id}")
679 |                 continue
680 | 
681 |             try:
682 |                 # Make prediction
683 |                 X = row.values.reshape(1, -1)
684 |                 if model_version.scaler is not None:
685 |                     X = model_version.scaler.transform(X)
686 | 
687 |                 prediction = model_version.model.predict(X)[0]
688 | 
689 |                 # Get confidence if available
690 |                 confidence = 1.0
691 |                 if hasattr(model_version.model, "predict_proba"):
692 |                     proba = model_version.model.predict_proba(X)[0]
693 |                     confidence = max(proba)
694 | 
695 |                 # Get actual value
696 |                 actual = targets.loc[row.name]
697 | 
698 |                 # Record prediction
699 |                 test.record_prediction(group_id, prediction, actual, confidence)
700 | 
701 |             except Exception as e:
702 |                 logger.warning(f"Error making prediction for group {group_id}: {e}")
703 | 
704 |         logger.info(f"Completed model comparison test {test_id}")
705 |         return test_id
706 | 
707 |     def get_test_summary(self) -> dict[str, Any]:
708 |         """Get summary of all A/B tests.
709 | 
710 |         Returns:
711 |             Summary dictionary
712 |         """
713 |         total_tests = len(self.tests)
714 |         status_counts = {}
715 | 
716 |         for test in self.tests.values():
717 |             status_counts[test.status] = status_counts.get(test.status, 0) + 1
718 | 
719 |         recent_tests = sorted(
720 |             [
721 |                 {
722 |                     "test_id": test.test_id,
723 |                     "name": test.name,
724 |                     "status": test.status,
725 |                     "created_at": test.created_at.isoformat(),
726 |                 }
727 |                 for test in self.tests.values()
728 |             ],
729 |             key=lambda x: x["created_at"],
730 |             reverse=True,
731 |         )[:10]
732 | 
733 |         return {
734 |             "total_tests": total_tests,
735 |             "status_distribution": status_counts,
736 |             "recent_tests": recent_tests,
737 |         }
738 | 
```
Page 22/39FirstPrevNextLast