This is page 23 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_visualization.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive tests for backtesting visualization module.
3 |
4 | Tests cover:
5 | - Chart generation and base64 encoding with matplotlib
6 | - Equity curve plotting with drawdown subplots
7 | - Trade scatter plots on price charts
8 | - Parameter optimization heatmaps
9 | - Portfolio allocation pie charts
10 | - Strategy comparison line charts
11 | - Performance dashboard table generation
12 | - Theme support (light/dark modes)
13 | - Image resolution and size optimization
14 | - Error handling for malformed data
15 | """
16 |
17 | import base64
18 | from unittest.mock import patch
19 |
20 | import matplotlib.pyplot as plt
21 | import numpy as np
22 | import pandas as pd
23 | import pytest
24 |
25 | from maverick_mcp.backtesting.visualization import (
26 | generate_equity_curve,
27 | generate_optimization_heatmap,
28 | generate_performance_dashboard,
29 | generate_portfolio_allocation,
30 | generate_strategy_comparison,
31 | generate_trade_scatter,
32 | image_to_base64,
33 | set_chart_style,
34 | )
35 |
36 |
37 | class TestVisualizationUtilities:
38 | """Test suite for visualization utility functions."""
39 |
40 | def test_set_chart_style_light_theme(self):
41 | """Test light theme styling configuration."""
42 | set_chart_style("light")
43 |
44 | # Test that matplotlib parameters are set correctly
45 | assert plt.rcParams["axes.facecolor"] == "white"
46 | assert plt.rcParams["figure.facecolor"] == "white"
47 | assert plt.rcParams["font.size"] == 10
48 | assert plt.rcParams["text.color"] == "black"
49 | assert plt.rcParams["axes.labelcolor"] == "black"
50 | assert plt.rcParams["xtick.color"] == "black"
51 | assert plt.rcParams["ytick.color"] == "black"
52 |
53 | def test_set_chart_style_dark_theme(self):
54 | """Test dark theme styling configuration."""
55 | set_chart_style("dark")
56 |
57 | # Test that matplotlib parameters are set correctly
58 | assert plt.rcParams["axes.facecolor"] == "#1E1E1E"
59 | assert plt.rcParams["figure.facecolor"] == "#121212"
60 | assert plt.rcParams["font.size"] == 10
61 | assert plt.rcParams["text.color"] == "white"
62 | assert plt.rcParams["axes.labelcolor"] == "white"
63 | assert plt.rcParams["xtick.color"] == "white"
64 | assert plt.rcParams["ytick.color"] == "white"
65 |
66 | def test_image_to_base64_conversion(self):
67 | """Test image to base64 conversion with proper formatting."""
68 | # Create a simple test figure
69 | fig, ax = plt.subplots(figsize=(6, 4))
70 | ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
71 | ax.set_title("Test Chart")
72 |
73 | # Convert to base64
74 | base64_str = image_to_base64(fig, dpi=100)
75 |
76 | # Test base64 string properties
77 | assert isinstance(base64_str, str)
78 | assert len(base64_str) > 100 # Should contain substantial data
79 |
80 | # Test that it's valid base64
81 | try:
82 | decoded_bytes = base64.b64decode(base64_str)
83 | assert len(decoded_bytes) > 0
84 | except Exception as e:
85 | pytest.fail(f"Invalid base64 encoding: {e}")
86 |
87 | def test_image_to_base64_size_optimization(self):
88 | """Test image size optimization and aspect ratio preservation."""
89 | # Create large figure
90 | fig, ax = plt.subplots(figsize=(20, 15)) # Large size
91 | ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
92 |
93 | original_width, original_height = fig.get_size_inches()
94 | original_aspect = original_height / original_width
95 |
96 | # Convert with size constraint
97 | base64_str = image_to_base64(fig, dpi=100, max_width=800)
98 |
99 | # Test that resizing occurred
100 | final_width, final_height = fig.get_size_inches()
101 | final_aspect = final_height / final_width
102 |
103 | assert final_width <= 8.0 # 800px / 100dpi = 8 inches
104 | assert abs(final_aspect - original_aspect) < 0.01 # Aspect ratio preserved
105 | assert len(base64_str) > 0
106 |
107 | def test_image_to_base64_error_handling(self):
108 | """Test error handling in base64 conversion."""
109 | with patch(
110 | "matplotlib.figure.Figure.savefig", side_effect=Exception("Save error")
111 | ):
112 | fig, ax = plt.subplots()
113 | ax.plot([1, 2, 3])
114 |
115 | result = image_to_base64(fig)
116 | assert result == "" # Should return empty string on error
117 |
118 |
119 | class TestEquityCurveGeneration:
120 | """Test suite for equity curve chart generation."""
121 |
122 | @pytest.fixture
123 | def sample_returns_data(self):
124 | """Create sample returns data for testing."""
125 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
126 | returns = np.random.normal(0.001, 0.02, len(dates))
127 | cumulative_returns = np.cumprod(1 + returns)
128 |
129 | # Create drawdown series
130 | running_max = np.maximum.accumulate(cumulative_returns)
131 | drawdown = (cumulative_returns - running_max) / running_max * 100
132 |
133 | return pd.Series(cumulative_returns, index=dates), pd.Series(
134 | drawdown, index=dates
135 | )
136 |
137 | def test_generate_equity_curve_basic(self, sample_returns_data):
138 | """Test basic equity curve generation."""
139 | returns, drawdown = sample_returns_data
140 |
141 | base64_str = generate_equity_curve(returns, title="Test Equity Curve")
142 |
143 | assert isinstance(base64_str, str)
144 | assert len(base64_str) > 100
145 |
146 | # Test that it's valid base64 image
147 | try:
148 | decoded_bytes = base64.b64decode(base64_str)
149 | assert decoded_bytes.startswith(b"\x89PNG") # PNG header
150 | except Exception as e:
151 | pytest.fail(f"Invalid PNG image: {e}")
152 |
153 | def test_generate_equity_curve_with_drawdown(self, sample_returns_data):
154 | """Test equity curve generation with drawdown subplot."""
155 | returns, drawdown = sample_returns_data
156 |
157 | base64_str = generate_equity_curve(
158 | returns, drawdown=drawdown, title="Equity Curve with Drawdown", theme="dark"
159 | )
160 |
161 | assert isinstance(base64_str, str)
162 | assert len(base64_str) > 100
163 |
164 | # Should be larger image due to subplot
165 | base64_no_dd = generate_equity_curve(returns, title="No Drawdown")
166 | assert len(base64_str) >= len(base64_no_dd)
167 |
168 | def test_generate_equity_curve_themes(self, sample_returns_data):
169 | """Test equity curve generation with different themes."""
170 | returns, _ = sample_returns_data
171 |
172 | light_chart = generate_equity_curve(returns, theme="light")
173 | dark_chart = generate_equity_curve(returns, theme="dark")
174 |
175 | assert len(light_chart) > 100
176 | assert len(dark_chart) > 100
177 | # Different themes should produce different images
178 | assert light_chart != dark_chart
179 |
180 | def test_generate_equity_curve_error_handling(self):
181 | """Test error handling in equity curve generation."""
182 | # Test with invalid data
183 | invalid_returns = pd.Series([]) # Empty series
184 |
185 | result = generate_equity_curve(invalid_returns)
186 | assert result == ""
187 |
188 | # Test with NaN data
189 | nan_returns = pd.Series([np.nan, np.nan, np.nan])
190 | result = generate_equity_curve(nan_returns)
191 | assert result == ""
192 |
193 |
194 | class TestTradeScatterGeneration:
195 | """Test suite for trade scatter plot generation."""
196 |
197 | @pytest.fixture
198 | def sample_trade_data(self):
199 | """Create sample trade data for testing."""
200 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
201 | prices = pd.Series(100 + np.random.walk(len(dates)), index=dates)
202 |
203 | # Create sample trades
204 | trade_dates = dates[::30] # Every 30 days
205 | trades = []
206 |
207 | for i, trade_date in enumerate(trade_dates):
208 | if i % 2 == 0: # Entry
209 | trades.append(
210 | {
211 | "date": trade_date,
212 | "price": prices.loc[trade_date],
213 | "type": "entry",
214 | }
215 | )
216 | else: # Exit
217 | trades.append(
218 | {
219 | "date": trade_date,
220 | "price": prices.loc[trade_date],
221 | "type": "exit",
222 | }
223 | )
224 |
225 | trades_df = pd.DataFrame(trades).set_index("date")
226 | return prices, trades_df
227 |
228 | def test_generate_trade_scatter_basic(self, sample_trade_data):
229 | """Test basic trade scatter plot generation."""
230 | prices, trades = sample_trade_data
231 |
232 | base64_str = generate_trade_scatter(prices, trades, title="Trade Scatter Plot")
233 |
234 | assert isinstance(base64_str, str)
235 | assert len(base64_str) > 100
236 |
237 | # Verify valid PNG
238 | try:
239 | decoded_bytes = base64.b64decode(base64_str)
240 | assert decoded_bytes.startswith(b"\x89PNG")
241 | except Exception as e:
242 | pytest.fail(f"Invalid PNG image: {e}")
243 |
244 | def test_generate_trade_scatter_themes(self, sample_trade_data):
245 | """Test trade scatter plots with different themes."""
246 | prices, trades = sample_trade_data
247 |
248 | light_chart = generate_trade_scatter(prices, trades, theme="light")
249 | dark_chart = generate_trade_scatter(prices, trades, theme="dark")
250 |
251 | assert len(light_chart) > 100
252 | assert len(dark_chart) > 100
253 | assert light_chart != dark_chart
254 |
255 | def test_generate_trade_scatter_empty_trades(self, sample_trade_data):
256 | """Test trade scatter plot with empty trade data."""
257 | prices, _ = sample_trade_data
258 | empty_trades = pd.DataFrame(columns=["price", "type"])
259 |
260 | result = generate_trade_scatter(prices, empty_trades)
261 | assert result == ""
262 |
263 | def test_generate_trade_scatter_error_handling(self):
264 | """Test error handling in trade scatter generation."""
265 | # Test with mismatched data
266 | prices = pd.Series([1, 2, 3])
267 | trades = pd.DataFrame({"price": [10, 20], "type": ["entry", "exit"]})
268 |
269 | # Should handle gracefully
270 | result = generate_trade_scatter(prices, trades)
271 | # Might return empty string or valid chart depending on implementation
272 | assert isinstance(result, str)
273 |
274 |
275 | class TestOptimizationHeatmapGeneration:
276 | """Test suite for parameter optimization heatmap generation."""
277 |
278 | @pytest.fixture
279 | def sample_optimization_data(self):
280 | """Create sample optimization results for testing."""
281 | parameters = ["param1", "param2", "param3"]
282 | results = {}
283 |
284 | for p1 in parameters:
285 | results[p1] = {}
286 | for p2 in parameters:
287 | # Create some performance metric
288 | results[p1][p2] = np.random.uniform(0.5, 2.0)
289 |
290 | return results
291 |
292 | def test_generate_optimization_heatmap_basic(self, sample_optimization_data):
293 | """Test basic optimization heatmap generation."""
294 | base64_str = generate_optimization_heatmap(
295 | sample_optimization_data, title="Parameter Optimization Heatmap"
296 | )
297 |
298 | assert isinstance(base64_str, str)
299 | assert len(base64_str) > 100
300 |
301 | # Verify valid PNG
302 | try:
303 | decoded_bytes = base64.b64decode(base64_str)
304 | assert decoded_bytes.startswith(b"\x89PNG")
305 | except Exception as e:
306 | pytest.fail(f"Invalid PNG image: {e}")
307 |
308 | def test_generate_optimization_heatmap_themes(self, sample_optimization_data):
309 | """Test optimization heatmap with different themes."""
310 | light_chart = generate_optimization_heatmap(
311 | sample_optimization_data, theme="light"
312 | )
313 | dark_chart = generate_optimization_heatmap(
314 | sample_optimization_data, theme="dark"
315 | )
316 |
317 | assert len(light_chart) > 100
318 | assert len(dark_chart) > 100
319 | assert light_chart != dark_chart
320 |
321 | def test_generate_optimization_heatmap_empty_data(self):
322 | """Test heatmap generation with empty data."""
323 | empty_data = {}
324 |
325 | result = generate_optimization_heatmap(empty_data)
326 | assert result == ""
327 |
328 | def test_generate_optimization_heatmap_error_handling(self):
329 | """Test error handling in heatmap generation."""
330 | # Test with malformed data
331 | malformed_data = {"param1": "not_a_dict"}
332 |
333 | result = generate_optimization_heatmap(malformed_data)
334 | assert result == ""
335 |
336 |
337 | class TestPortfolioAllocationGeneration:
338 | """Test suite for portfolio allocation chart generation."""
339 |
340 | @pytest.fixture
341 | def sample_allocation_data(self):
342 | """Create sample allocation data for testing."""
343 | return {
344 | "AAPL": 0.25,
345 | "GOOGL": 0.20,
346 | "MSFT": 0.15,
347 | "TSLA": 0.15,
348 | "AMZN": 0.10,
349 | "Cash": 0.15,
350 | }
351 |
352 | def test_generate_portfolio_allocation_basic(self, sample_allocation_data):
353 | """Test basic portfolio allocation chart generation."""
354 | base64_str = generate_portfolio_allocation(
355 | sample_allocation_data, title="Portfolio Allocation"
356 | )
357 |
358 | assert isinstance(base64_str, str)
359 | assert len(base64_str) > 100
360 |
361 | # Verify valid PNG
362 | try:
363 | decoded_bytes = base64.b64decode(base64_str)
364 | assert decoded_bytes.startswith(b"\x89PNG")
365 | except Exception as e:
366 | pytest.fail(f"Invalid PNG image: {e}")
367 |
368 | def test_generate_portfolio_allocation_themes(self, sample_allocation_data):
369 | """Test portfolio allocation with different themes."""
370 | light_chart = generate_portfolio_allocation(
371 | sample_allocation_data, theme="light"
372 | )
373 | dark_chart = generate_portfolio_allocation(sample_allocation_data, theme="dark")
374 |
375 | assert len(light_chart) > 100
376 | assert len(dark_chart) > 100
377 | assert light_chart != dark_chart
378 |
379 | def test_generate_portfolio_allocation_empty_data(self):
380 | """Test allocation chart with empty data."""
381 | empty_data = {}
382 |
383 | result = generate_portfolio_allocation(empty_data)
384 | assert result == ""
385 |
386 | def test_generate_portfolio_allocation_single_asset(self):
387 | """Test allocation chart with single asset."""
388 | single_asset = {"AAPL": 1.0}
389 |
390 | result = generate_portfolio_allocation(single_asset)
391 | assert isinstance(result, str)
392 | assert len(result) > 100 # Should still generate valid chart
393 |
394 | def test_generate_portfolio_allocation_error_handling(self):
395 | """Test error handling in allocation chart generation."""
396 | # Test with invalid allocation values
397 | invalid_data = {"AAPL": "invalid_value"}
398 |
399 | result = generate_portfolio_allocation(invalid_data)
400 | assert result == ""
401 |
402 |
403 | class TestStrategyComparisonGeneration:
404 | """Test suite for strategy comparison chart generation."""
405 |
406 | @pytest.fixture
407 | def sample_strategy_data(self):
408 | """Create sample strategy comparison data."""
409 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
410 |
411 | strategies = {
412 | "Momentum": pd.Series(
413 | np.cumprod(1 + np.random.normal(0.0008, 0.015, len(dates))), index=dates
414 | ),
415 | "Mean Reversion": pd.Series(
416 | np.cumprod(1 + np.random.normal(0.0005, 0.012, len(dates))), index=dates
417 | ),
418 | "Breakout": pd.Series(
419 | np.cumprod(1 + np.random.normal(0.0012, 0.020, len(dates))), index=dates
420 | ),
421 | }
422 |
423 | return strategies
424 |
425 | def test_generate_strategy_comparison_basic(self, sample_strategy_data):
426 | """Test basic strategy comparison chart generation."""
427 | base64_str = generate_strategy_comparison(
428 | sample_strategy_data, title="Strategy Performance Comparison"
429 | )
430 |
431 | assert isinstance(base64_str, str)
432 | assert len(base64_str) > 100
433 |
434 | # Verify valid PNG
435 | try:
436 | decoded_bytes = base64.b64decode(base64_str)
437 | assert decoded_bytes.startswith(b"\x89PNG")
438 | except Exception as e:
439 | pytest.fail(f"Invalid PNG image: {e}")
440 |
441 | def test_generate_strategy_comparison_themes(self, sample_strategy_data):
442 | """Test strategy comparison with different themes."""
443 | light_chart = generate_strategy_comparison(sample_strategy_data, theme="light")
444 | dark_chart = generate_strategy_comparison(sample_strategy_data, theme="dark")
445 |
446 | assert len(light_chart) > 100
447 | assert len(dark_chart) > 100
448 | assert light_chart != dark_chart
449 |
450 | def test_generate_strategy_comparison_single_strategy(self):
451 | """Test comparison chart with single strategy."""
452 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
453 | single_strategy = {
454 | "Only Strategy": pd.Series(
455 | np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))), index=dates
456 | )
457 | }
458 |
459 | result = generate_strategy_comparison(single_strategy)
460 | assert isinstance(result, str)
461 | assert len(result) > 100
462 |
463 | def test_generate_strategy_comparison_empty_data(self):
464 | """Test comparison chart with empty data."""
465 | empty_data = {}
466 |
467 | result = generate_strategy_comparison(empty_data)
468 | assert result == ""
469 |
470 | def test_generate_strategy_comparison_error_handling(self):
471 | """Test error handling in strategy comparison generation."""
472 | # Test with mismatched data lengths
473 | dates1 = pd.date_range(start="2023-01-01", end="2023-06-30", freq="D")
474 | dates2 = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
475 |
476 | mismatched_data = {
477 | "Strategy1": pd.Series(np.random.normal(0, 1, len(dates1)), index=dates1),
478 | "Strategy2": pd.Series(np.random.normal(0, 1, len(dates2)), index=dates2),
479 | }
480 |
481 | # Should handle gracefully
482 | result = generate_strategy_comparison(mismatched_data)
483 | assert isinstance(result, str) # Might be empty or valid
484 |
485 |
486 | class TestPerformanceDashboardGeneration:
487 | """Test suite for performance dashboard generation."""
488 |
489 | @pytest.fixture
490 | def sample_metrics_data(self):
491 | """Create sample performance metrics for testing."""
492 | return {
493 | "Total Return": 0.156,
494 | "Sharpe Ratio": 1.25,
495 | "Max Drawdown": -0.082,
496 | "Win Rate": 0.583,
497 | "Profit Factor": 1.35,
498 | "Total Trades": 24,
499 | "Annualized Return": 0.18,
500 | "Volatility": 0.16,
501 | "Calmar Ratio": 1.10,
502 | "Best Trade": 0.12,
503 | }
504 |
505 | def test_generate_performance_dashboard_basic(self, sample_metrics_data):
506 | """Test basic performance dashboard generation."""
507 | base64_str = generate_performance_dashboard(
508 | sample_metrics_data, title="Performance Dashboard"
509 | )
510 |
511 | assert isinstance(base64_str, str)
512 | assert len(base64_str) > 100
513 |
514 | # Verify valid PNG
515 | try:
516 | decoded_bytes = base64.b64decode(base64_str)
517 | assert decoded_bytes.startswith(b"\x89PNG")
518 | except Exception as e:
519 | pytest.fail(f"Invalid PNG image: {e}")
520 |
521 | def test_generate_performance_dashboard_themes(self, sample_metrics_data):
522 | """Test performance dashboard with different themes."""
523 | light_chart = generate_performance_dashboard(sample_metrics_data, theme="light")
524 | dark_chart = generate_performance_dashboard(sample_metrics_data, theme="dark")
525 |
526 | assert len(light_chart) > 100
527 | assert len(dark_chart) > 100
528 | assert light_chart != dark_chart
529 |
530 | def test_generate_performance_dashboard_mixed_data_types(self):
531 | """Test dashboard with mixed data types."""
532 | mixed_metrics = {
533 | "Total Return": 0.156,
534 | "Strategy": "Momentum",
535 | "Symbol": "AAPL",
536 | "Duration": "365 days",
537 | "Sharpe Ratio": 1.25,
538 | "Status": "Completed",
539 | }
540 |
541 | result = generate_performance_dashboard(mixed_metrics)
542 | assert isinstance(result, str)
543 | assert len(result) > 100
544 |
545 | def test_generate_performance_dashboard_empty_data(self):
546 | """Test dashboard with empty metrics."""
547 | empty_metrics = {}
548 |
549 | result = generate_performance_dashboard(empty_metrics)
550 | assert result == ""
551 |
552 | def test_generate_performance_dashboard_large_dataset(self):
553 | """Test dashboard with large number of metrics."""
554 | large_metrics = {f"Metric_{i}": np.random.uniform(-1, 2) for i in range(50)}
555 |
556 | result = generate_performance_dashboard(large_metrics)
557 | assert isinstance(result, str)
558 | # Might be empty if table becomes too large, or valid if handled properly
559 |
560 | def test_generate_performance_dashboard_error_handling(self):
561 | """Test error handling in dashboard generation."""
562 | # Test with invalid data that might cause table generation to fail
563 | problematic_metrics = {
564 | "Valid Metric": 1.25,
565 | "Problematic": [1, 2, 3], # List instead of scalar
566 | "Another Valid": 0.85,
567 | }
568 |
569 | result = generate_performance_dashboard(problematic_metrics)
570 | assert isinstance(result, str)
571 |
572 |
573 | class TestVisualizationIntegration:
574 | """Integration tests for visualization functions working together."""
575 |
576 | def test_consistent_theming_across_charts(self):
577 | """Test that theming is consistent across different chart types."""
578 | # Create sample data for different chart types
579 | dates = pd.date_range(start="2023-01-01", end="2023-06-30", freq="D")
580 | returns = pd.Series(
581 | np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))), index=dates
582 | )
583 |
584 | allocation = {"AAPL": 0.4, "GOOGL": 0.3, "MSFT": 0.3}
585 | metrics = {"Return": 0.15, "Sharpe": 1.2, "Drawdown": -0.08}
586 |
587 | # Generate charts with same theme
588 | equity_chart = generate_equity_curve(returns, theme="dark")
589 | allocation_chart = generate_portfolio_allocation(allocation, theme="dark")
590 | dashboard_chart = generate_performance_dashboard(metrics, theme="dark")
591 |
592 | # All should generate valid base64 strings
593 | charts = [equity_chart, allocation_chart, dashboard_chart]
594 | for chart in charts:
595 | assert isinstance(chart, str)
596 | assert len(chart) > 100
597 |
598 | # Verify valid PNG
599 | try:
600 | decoded_bytes = base64.b64decode(chart)
601 | assert decoded_bytes.startswith(b"\x89PNG")
602 | except Exception as e:
603 | pytest.fail(f"Invalid PNG in themed charts: {e}")
604 |
605 | def test_memory_cleanup_after_chart_generation(self):
606 | """Test that matplotlib figures are properly cleaned up."""
607 | import matplotlib.pyplot as plt
608 |
609 | initial_figure_count = len(plt.get_fignums())
610 |
611 | # Generate multiple charts
612 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
613 | returns = pd.Series(
614 | np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))), index=dates
615 | )
616 |
617 | for i in range(10):
618 | chart = generate_equity_curve(returns, title=f"Test Chart {i}")
619 | assert len(chart) > 0
620 |
621 | final_figure_count = len(plt.get_fignums())
622 |
623 | # Figure count should not have increased (figures should be closed)
624 | assert final_figure_count <= initial_figure_count + 1 # Allow for 1 open figure
625 |
626 | def test_chart_generation_performance_benchmark(self, benchmark_timer):
627 | """Test chart generation performance benchmarks."""
628 | # Create substantial dataset
629 | dates = pd.date_range(
630 | start="2023-01-01", end="2023-12-31", freq="H"
631 | ) # Hourly data
632 | returns = pd.Series(
633 | np.cumprod(1 + np.random.normal(0.0001, 0.005, len(dates))), index=dates
634 | )
635 |
636 | with benchmark_timer() as timer:
637 | chart = generate_equity_curve(returns, title="Performance Test")
638 |
639 | # Should complete within reasonable time even with large dataset
640 | assert timer.elapsed < 5.0 # < 5 seconds
641 | assert len(chart) > 100 # Valid chart generated
642 |
643 | def test_concurrent_chart_generation(self):
644 | """Test concurrent chart generation doesn't cause conflicts."""
645 | import queue
646 | import threading
647 |
648 | results_queue = queue.Queue()
649 | error_queue = queue.Queue()
650 |
651 | def generate_chart(thread_id):
652 | try:
653 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
654 | returns = pd.Series(
655 | np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))),
656 | index=dates,
657 | )
658 |
659 | chart = generate_equity_curve(returns, title=f"Thread {thread_id}")
660 | results_queue.put((thread_id, len(chart)))
661 | except Exception as e:
662 | error_queue.put(f"Thread {thread_id}: {e}")
663 |
664 | # Create multiple threads
665 | threads = []
666 | for i in range(5):
667 | thread = threading.Thread(target=generate_chart, args=(i,))
668 | threads.append(thread)
669 | thread.start()
670 |
671 | # Wait for completion
672 | for thread in threads:
673 | thread.join(timeout=10)
674 |
675 | # Check results
676 | assert error_queue.empty(), f"Errors: {list(error_queue.queue)}"
677 | assert results_queue.qsize() == 5
678 |
679 | # All should have generated valid charts
680 | while not results_queue.empty():
681 | thread_id, chart_length = results_queue.get()
682 | assert chart_length > 100
683 |
684 |
685 | if __name__ == "__main__":
686 | # Run tests with detailed output
687 | pytest.main([__file__, "-v", "--tb=short", "--asyncio-mode=auto"])
688 |
```
--------------------------------------------------------------------------------
/tests/integration/test_full_backtest_workflow_advanced.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Advanced End-to-End Integration Tests for VectorBT Backtesting Workflow.
3 |
4 | This comprehensive test suite covers:
5 | - Complete workflow integration from data fetch to results
6 | - All 15 strategies (9 traditional + 6 ML) testing
7 | - Parallel execution capabilities
8 | - Cache behavior and optimization
9 | - Real production-like scenarios
10 | - Error recovery and resilience
11 | - Resource management and cleanup
12 | """
13 |
14 | import asyncio
15 | import logging
16 | import time
17 | from unittest.mock import Mock
18 | from uuid import UUID
19 |
20 | import numpy as np
21 | import pandas as pd
22 | import pytest
23 |
24 | from maverick_mcp.backtesting import (
25 | VectorBTEngine,
26 | )
27 | from maverick_mcp.backtesting.persistence import BacktestPersistenceManager
28 | from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
29 | from maverick_mcp.backtesting.visualization import (
30 | generate_equity_curve,
31 | generate_performance_dashboard,
32 | )
33 |
34 | logger = logging.getLogger(__name__)
35 |
36 | # Strategy definitions for comprehensive testing
37 | TRADITIONAL_STRATEGIES = [
38 | "sma_cross",
39 | "ema_cross",
40 | "rsi",
41 | "macd",
42 | "bollinger",
43 | "momentum",
44 | "breakout",
45 | "mean_reversion",
46 | "volume_momentum",
47 | ]
48 |
49 | ML_STRATEGIES = [
50 | "ml_predictor",
51 | "adaptive",
52 | "ensemble",
53 | "regime_aware",
54 | "online_learning",
55 | "reinforcement_learning",
56 | ]
57 |
58 | ALL_STRATEGIES = TRADITIONAL_STRATEGIES + ML_STRATEGIES
59 |
60 |
61 | class TestAdvancedBacktestWorkflowIntegration:
62 | """Advanced integration tests for complete backtesting workflow."""
63 |
64 | @pytest.fixture
65 | async def enhanced_stock_data_provider(self):
66 | """Create enhanced mock stock data provider with realistic multi-year data."""
67 | provider = Mock()
68 |
69 | # Generate 3 years of realistic stock data with different market conditions
70 | dates = pd.date_range(start="2021-01-01", end="2023-12-31", freq="D")
71 |
72 | # Simulate different market regimes
73 | bull_period = len(dates) // 3 # First third: bull market
74 | sideways_period = len(dates) // 3 # Second third: sideways
75 | bear_period = len(dates) - bull_period - sideways_period # Final: bear market
76 |
77 | # Generate returns for different regimes
78 | bull_returns = np.random.normal(0.0015, 0.015, bull_period) # Positive drift
79 | sideways_returns = np.random.normal(0.0002, 0.02, sideways_period) # Low drift
80 | bear_returns = np.random.normal(-0.001, 0.025, bear_period) # Negative drift
81 |
82 | all_returns = np.concatenate([bull_returns, sideways_returns, bear_returns])
83 | prices = 100 * np.cumprod(1 + all_returns) # Start at $100
84 |
85 | # Add realistic volume patterns
86 | volumes = np.random.randint(500000, 5000000, len(dates)).astype(float)
87 | volumes += np.random.normal(0, volumes * 0.1) # Add volume volatility
88 | volumes = np.maximum(volumes, 100000) # Minimum volume
89 | volumes = volumes.astype(int) # Convert back to integers
90 |
91 | stock_data = pd.DataFrame(
92 | {
93 | "Open": prices * np.random.uniform(0.995, 1.005, len(dates)),
94 | "High": prices * np.random.uniform(1.002, 1.025, len(dates)),
95 | "Low": prices * np.random.uniform(0.975, 0.998, len(dates)),
96 | "Close": prices,
97 | "Volume": volumes.astype(int),
98 | "Adj Close": prices,
99 | },
100 | index=dates,
101 | )
102 |
103 | # Ensure OHLC constraints
104 | stock_data["High"] = np.maximum(
105 | stock_data["High"], np.maximum(stock_data["Open"], stock_data["Close"])
106 | )
107 | stock_data["Low"] = np.minimum(
108 | stock_data["Low"], np.minimum(stock_data["Open"], stock_data["Close"])
109 | )
110 |
111 | provider.get_stock_data.return_value = stock_data
112 | return provider
113 |
114 | @pytest.fixture
115 | async def complete_vectorbt_engine(self, enhanced_stock_data_provider):
116 | """Create complete VectorBT engine with all strategies enabled."""
117 | engine = VectorBTEngine(data_provider=enhanced_stock_data_provider)
118 | return engine
119 |
120 | async def test_all_15_strategies_integration(
121 | self, complete_vectorbt_engine, benchmark_timer
122 | ):
123 | """Test all 15 strategies (9 traditional + 6 ML) in complete workflow."""
124 | results = {}
125 | failed_strategies = []
126 |
127 | with benchmark_timer() as timer:
128 | # Test traditional strategies
129 | for strategy in TRADITIONAL_STRATEGIES:
130 | try:
131 | if strategy in STRATEGY_TEMPLATES:
132 | parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
133 | result = await complete_vectorbt_engine.run_backtest(
134 | symbol="COMPREHENSIVE_TEST",
135 | strategy_type=strategy,
136 | parameters=parameters,
137 | start_date="2022-01-01",
138 | end_date="2023-12-31",
139 | )
140 | results[strategy] = result
141 |
142 | # Validate basic result structure
143 | assert "metrics" in result
144 | assert "trades" in result
145 | assert "equity_curve" in result
146 | assert result["symbol"] == "COMPREHENSIVE_TEST"
147 |
148 | logger.info(f"✓ {strategy} strategy executed successfully")
149 | else:
150 | logger.warning(f"Strategy {strategy} not found in templates")
151 |
152 | except Exception as e:
153 | failed_strategies.append(strategy)
154 | logger.error(f"✗ {strategy} strategy failed: {str(e)}")
155 |
156 | # Test ML strategies (mock implementation for integration test)
157 | for strategy in ML_STRATEGIES:
158 | try:
159 | # Mock ML strategy execution
160 | mock_ml_result = {
161 | "symbol": "COMPREHENSIVE_TEST",
162 | "strategy_type": strategy,
163 | "metrics": {
164 | "total_return": np.random.uniform(-0.2, 0.3),
165 | "sharpe_ratio": np.random.uniform(0.5, 2.0),
166 | "max_drawdown": np.random.uniform(-0.3, -0.05),
167 | "total_trades": np.random.randint(10, 100),
168 | },
169 | "trades": [],
170 | "equity_curve": np.random.cumsum(
171 | np.random.normal(0.001, 0.02, 252)
172 | ).tolist(),
173 | "ml_specific": {
174 | "model_accuracy": np.random.uniform(0.55, 0.85),
175 | "feature_importance": {
176 | "momentum": 0.3,
177 | "volatility": 0.25,
178 | "volume": 0.45,
179 | },
180 | },
181 | }
182 | results[strategy] = mock_ml_result
183 | logger.info(f"✓ {strategy} ML strategy simulated successfully")
184 |
185 | except Exception as e:
186 | failed_strategies.append(strategy)
187 | logger.error(f"✗ {strategy} ML strategy failed: {str(e)}")
188 |
189 | execution_time = timer.elapsed
190 |
191 | # Validate overall results
192 | successful_strategies = len(results)
193 | total_strategies = len(ALL_STRATEGIES)
194 | success_rate = successful_strategies / total_strategies
195 |
196 | # Performance requirements
197 | assert execution_time < 180.0 # Should complete within 3 minutes
198 | assert success_rate >= 0.8 # At least 80% success rate
199 | assert successful_strategies >= 12 # At least 12 strategies should work
200 |
201 | # Log comprehensive results
202 | logger.info(
203 | f"Strategy Integration Test Results:\n"
204 | f" • Total Strategies: {total_strategies}\n"
205 | f" • Successful: {successful_strategies}\n"
206 | f" • Failed: {len(failed_strategies)}\n"
207 | f" • Success Rate: {success_rate:.1%}\n"
208 | f" • Execution Time: {execution_time:.2f}s\n"
209 | f" • Failed Strategies: {failed_strategies}"
210 | )
211 |
212 | return {
213 | "total_strategies": total_strategies,
214 | "successful_strategies": successful_strategies,
215 | "failed_strategies": failed_strategies,
216 | "success_rate": success_rate,
217 | "execution_time": execution_time,
218 | "results": results,
219 | }
220 |
221 | async def test_parallel_execution_capabilities(
222 | self, complete_vectorbt_engine, benchmark_timer
223 | ):
224 | """Test parallel execution of multiple backtests."""
225 | symbols = ["AAPL", "GOOGL", "MSFT", "TSLA", "AMZN", "META", "NFLX", "NVDA"]
226 | strategies = ["sma_cross", "rsi", "macd", "bollinger"]
227 |
228 | async def run_single_backtest(symbol, strategy):
229 | """Run a single backtest."""
230 | try:
231 | parameters = STRATEGY_TEMPLATES.get(strategy, {}).get("parameters", {})
232 | result = await complete_vectorbt_engine.run_backtest(
233 | symbol=symbol,
234 | strategy_type=strategy,
235 | parameters=parameters,
236 | start_date="2023-01-01",
237 | end_date="2023-12-31",
238 | )
239 | return {
240 | "symbol": symbol,
241 | "strategy": strategy,
242 | "result": result,
243 | "success": True,
244 | }
245 | except Exception as e:
246 | return {
247 | "symbol": symbol,
248 | "strategy": strategy,
249 | "error": str(e),
250 | "success": False,
251 | }
252 |
253 | with benchmark_timer() as timer:
254 | # Create all combinations
255 | tasks = []
256 | for symbol in symbols:
257 | for strategy in strategies:
258 | tasks.append(run_single_backtest(symbol, strategy))
259 |
260 | # Execute in parallel with semaphore to control concurrency
261 | semaphore = asyncio.Semaphore(8) # Max 8 concurrent executions
262 |
263 | async def run_with_semaphore(task):
264 | async with semaphore:
265 | return await task
266 |
267 | results = await asyncio.gather(
268 | *[run_with_semaphore(task) for task in tasks], return_exceptions=True
269 | )
270 |
271 | execution_time = timer.elapsed
272 |
273 | # Analyze results
274 | total_executions = len(tasks)
275 | successful_executions = sum(
276 | 1 for r in results if isinstance(r, dict) and r.get("success", False)
277 | )
278 | failed_executions = total_executions - successful_executions
279 |
280 | # Performance assertions
281 | assert execution_time < 300.0 # Should complete within 5 minutes
282 | assert successful_executions >= total_executions * 0.7 # At least 70% success
283 |
284 | # Calculate average execution time per backtest
285 | avg_time_per_backtest = execution_time / total_executions
286 |
287 | logger.info(
288 | f"Parallel Execution Results:\n"
289 | f" • Total Executions: {total_executions}\n"
290 | f" • Successful: {successful_executions}\n"
291 | f" • Failed: {failed_executions}\n"
292 | f" • Success Rate: {successful_executions / total_executions:.1%}\n"
293 | f" • Total Time: {execution_time:.2f}s\n"
294 | f" • Avg Time/Backtest: {avg_time_per_backtest:.2f}s\n"
295 | f" • Parallel Speedup: ~{total_executions * avg_time_per_backtest / execution_time:.1f}x"
296 | )
297 |
298 | return {
299 | "total_executions": total_executions,
300 | "successful_executions": successful_executions,
301 | "execution_time": execution_time,
302 | "avg_time_per_backtest": avg_time_per_backtest,
303 | }
304 |
305 | async def test_cache_behavior_and_optimization(self, complete_vectorbt_engine):
306 | """Test cache behavior and optimization in integrated workflow."""
307 | symbol = "CACHE_TEST_SYMBOL"
308 | strategy = "sma_cross"
309 | parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
310 |
311 | # First run - should populate cache
312 | start_time = time.time()
313 | result1 = await complete_vectorbt_engine.run_backtest(
314 | symbol=symbol,
315 | strategy_type=strategy,
316 | parameters=parameters,
317 | start_date="2023-01-01",
318 | end_date="2023-12-31",
319 | )
320 | first_run_time = time.time() - start_time
321 |
322 | # Second run - should use cache
323 | start_time = time.time()
324 | result2 = await complete_vectorbt_engine.run_backtest(
325 | symbol=symbol,
326 | strategy_type=strategy,
327 | parameters=parameters,
328 | start_date="2023-01-01",
329 | end_date="2023-12-31",
330 | )
331 | second_run_time = time.time() - start_time
332 |
333 | # Third run with different parameters - should not use cache
334 | modified_parameters = {
335 | **parameters,
336 | "fast_period": parameters.get("fast_period", 10) + 5,
337 | }
338 | start_time = time.time()
339 | await complete_vectorbt_engine.run_backtest(
340 | symbol=symbol,
341 | strategy_type=strategy,
342 | parameters=modified_parameters,
343 | start_date="2023-01-01",
344 | end_date="2023-12-31",
345 | )
346 | third_run_time = time.time() - start_time
347 |
348 | # Validate results consistency (for cached runs)
349 | assert result1["symbol"] == result2["symbol"]
350 | assert result1["strategy_type"] == result2["strategy_type"]
351 |
352 | # Cache effectiveness check (second run might be faster, but not guaranteed)
353 | cache_speedup = first_run_time / second_run_time if second_run_time > 0 else 1.0
354 |
355 | logger.info(
356 | f"Cache Behavior Test Results:\n"
357 | f" • First Run: {first_run_time:.3f}s\n"
358 | f" • Second Run (cached): {second_run_time:.3f}s\n"
359 | f" • Third Run (different params): {third_run_time:.3f}s\n"
360 | f" • Cache Speedup: {cache_speedup:.2f}x\n"
361 | )
362 |
363 | return {
364 | "first_run_time": first_run_time,
365 | "second_run_time": second_run_time,
366 | "third_run_time": third_run_time,
367 | "cache_speedup": cache_speedup,
368 | }
369 |
370 | async def test_database_persistence_integration(
371 | self, complete_vectorbt_engine, db_session
372 | ):
373 | """Test complete database persistence integration."""
374 | # Generate test results
375 | result = await complete_vectorbt_engine.run_backtest(
376 | symbol="PERSISTENCE_TEST",
377 | strategy_type="sma_cross",
378 | parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
379 | start_date="2023-01-01",
380 | end_date="2023-12-31",
381 | )
382 |
383 | # Test persistence workflow
384 | with BacktestPersistenceManager(session=db_session) as persistence:
385 | # Save backtest result
386 | backtest_id = persistence.save_backtest_result(
387 | vectorbt_results=result,
388 | execution_time=2.5,
389 | notes="Integration test - complete persistence workflow",
390 | )
391 |
392 | # Validate saved data
393 | assert backtest_id is not None
394 | assert UUID(backtest_id) # Valid UUID
395 |
396 | # Retrieve and validate
397 | saved_result = persistence.get_backtest_by_id(backtest_id)
398 | assert saved_result is not None
399 | assert saved_result.symbol == "PERSISTENCE_TEST"
400 | assert saved_result.strategy_type == "sma_cross"
401 | assert saved_result.execution_time == 2.5
402 |
403 | # Test batch operations
404 | batch_results = []
405 | for i in range(5):
406 | batch_result = await complete_vectorbt_engine.run_backtest(
407 | symbol=f"BATCH_TEST_{i}",
408 | strategy_type="rsi",
409 | parameters=STRATEGY_TEMPLATES["rsi"]["parameters"],
410 | start_date="2023-06-01",
411 | end_date="2023-12-31",
412 | )
413 | batch_results.append(batch_result)
414 |
415 | # Save batch results
416 | batch_ids = []
417 | for i, batch_result in enumerate(batch_results):
418 | batch_id = persistence.save_backtest_result(
419 | vectorbt_results=batch_result,
420 | execution_time=1.8 + i * 0.1,
421 | notes=f"Batch test #{i + 1}",
422 | )
423 | batch_ids.append(batch_id)
424 |
425 | # Query saved batch results
426 | saved_batch = [persistence.get_backtest_by_id(bid) for bid in batch_ids]
427 | assert all(saved is not None for saved in saved_batch)
428 | assert len(saved_batch) == 5
429 |
430 | # Test filtering and querying
431 | rsi_results = persistence.get_backtests_by_strategy("rsi")
432 | assert len(rsi_results) >= 5 # At least our batch results
433 |
434 | logger.info("Database persistence test completed successfully")
435 | return {"batch_ids": batch_ids, "single_id": backtest_id}
436 |
437 | async def test_visualization_integration_complete(self, complete_vectorbt_engine):
438 | """Test complete visualization integration workflow."""
439 | # Run backtest to get data for visualization
440 | result = await complete_vectorbt_engine.run_backtest(
441 | symbol="VIZ_TEST",
442 | strategy_type="macd",
443 | parameters=STRATEGY_TEMPLATES["macd"]["parameters"],
444 | start_date="2023-01-01",
445 | end_date="2023-12-31",
446 | )
447 |
448 | # Test all visualization components
449 | visualizations = {}
450 |
451 | # 1. Equity curve visualization
452 | equity_data = pd.Series(result["equity_curve"])
453 | drawdown_data = pd.Series(result["drawdown_series"])
454 |
455 | equity_chart = generate_equity_curve(
456 | equity_data,
457 | drawdown=drawdown_data,
458 | title="Complete Integration Test - Equity Curve",
459 | )
460 | visualizations["equity_curve"] = equity_chart
461 |
462 | # 2. Performance dashboard
463 | dashboard_chart = generate_performance_dashboard(
464 | result["metrics"], title="Complete Integration Test - Performance Dashboard"
465 | )
466 | visualizations["dashboard"] = dashboard_chart
467 |
468 | # 3. Validate all visualizations
469 | for viz_name, viz_data in visualizations.items():
470 | assert isinstance(viz_data, str), f"{viz_name} should return string"
471 | assert len(viz_data) > 100, f"{viz_name} should have substantial content"
472 |
473 | # Try to decode as base64 (should be valid image)
474 | try:
475 | import base64
476 |
477 | decoded = base64.b64decode(viz_data)
478 | assert len(decoded) > 0, f"{viz_name} should have valid image data"
479 | logger.info(f"✓ {viz_name} visualization generated successfully")
480 | except Exception as e:
481 | logger.error(f"✗ {viz_name} visualization failed: {e}")
482 | raise
483 |
484 | return visualizations
485 |
486 | async def test_error_recovery_comprehensive(self, complete_vectorbt_engine):
487 | """Test comprehensive error recovery across the workflow."""
488 | recovery_results = {}
489 |
490 | # 1. Invalid symbol handling
491 | try:
492 | result = await complete_vectorbt_engine.run_backtest(
493 | symbol="", # Empty symbol
494 | strategy_type="sma_cross",
495 | parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
496 | start_date="2023-01-01",
497 | end_date="2023-12-31",
498 | )
499 | recovery_results["empty_symbol"] = {"recovered": True, "result": result}
500 | except Exception as e:
501 | recovery_results["empty_symbol"] = {"recovered": False, "error": str(e)}
502 |
503 | # 2. Invalid date range handling
504 | try:
505 | result = await complete_vectorbt_engine.run_backtest(
506 | symbol="ERROR_TEST",
507 | strategy_type="sma_cross",
508 | parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
509 | start_date="2025-01-01", # Future date
510 | end_date="2025-12-31",
511 | )
512 | recovery_results["future_dates"] = {"recovered": True, "result": result}
513 | except Exception as e:
514 | recovery_results["future_dates"] = {"recovered": False, "error": str(e)}
515 |
516 | # 3. Invalid strategy parameters
517 | try:
518 | invalid_params = {
519 | "fast_period": -10,
520 | "slow_period": -20,
521 | } # Invalid negative values
522 | result = await complete_vectorbt_engine.run_backtest(
523 | symbol="ERROR_TEST",
524 | strategy_type="sma_cross",
525 | parameters=invalid_params,
526 | start_date="2023-01-01",
527 | end_date="2023-12-31",
528 | )
529 | recovery_results["invalid_params"] = {"recovered": True, "result": result}
530 | except Exception as e:
531 | recovery_results["invalid_params"] = {"recovered": False, "error": str(e)}
532 |
533 | # 4. Unknown strategy handling
534 | try:
535 | result = await complete_vectorbt_engine.run_backtest(
536 | symbol="ERROR_TEST",
537 | strategy_type="nonexistent_strategy",
538 | parameters={},
539 | start_date="2023-01-01",
540 | end_date="2023-12-31",
541 | )
542 | recovery_results["unknown_strategy"] = {"recovered": True, "result": result}
543 | except Exception as e:
544 | recovery_results["unknown_strategy"] = {"recovered": False, "error": str(e)}
545 |
546 | # Analyze recovery effectiveness
547 | total_tests = len(recovery_results)
548 | recovered_tests = sum(
549 | 1 for r in recovery_results.values() if r.get("recovered", False)
550 | )
551 | recovery_rate = recovered_tests / total_tests if total_tests > 0 else 0
552 |
553 | logger.info(
554 | f"Error Recovery Test Results:\n"
555 | f" • Total Error Scenarios: {total_tests}\n"
556 | f" • Successfully Recovered: {recovered_tests}\n"
557 | f" • Recovery Rate: {recovery_rate:.1%}\n"
558 | )
559 |
560 | for scenario, result in recovery_results.items():
561 | status = "✓ RECOVERED" if result.get("recovered") else "✗ FAILED"
562 | logger.info(f" • {scenario}: {status}")
563 |
564 | return recovery_results
565 |
566 | async def test_resource_management_comprehensive(self, complete_vectorbt_engine):
567 | """Test comprehensive resource management across workflow."""
568 | import os
569 |
570 | import psutil
571 |
572 | process = psutil.Process(os.getpid())
573 |
574 | # Baseline measurements
575 | initial_memory = process.memory_info().rss / 1024 / 1024 # MB
576 | initial_threads = process.num_threads()
577 | resource_snapshots = []
578 |
579 | # Run multiple backtests while monitoring resources
580 | for i in range(10):
581 | await complete_vectorbt_engine.run_backtest(
582 | symbol=f"RESOURCE_TEST_{i}",
583 | strategy_type="sma_cross",
584 | parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
585 | start_date="2023-01-01",
586 | end_date="2023-12-31",
587 | )
588 |
589 | # Take resource snapshot
590 | current_memory = process.memory_info().rss / 1024 / 1024 # MB
591 | current_threads = process.num_threads()
592 | current_cpu = process.cpu_percent()
593 |
594 | resource_snapshots.append(
595 | {
596 | "iteration": i + 1,
597 | "memory_mb": current_memory,
598 | "threads": current_threads,
599 | "cpu_percent": current_cpu,
600 | }
601 | )
602 |
603 | # Final measurements
604 | final_memory = process.memory_info().rss / 1024 / 1024 # MB
605 | final_threads = process.num_threads()
606 |
607 | # Calculate resource growth
608 | memory_growth = final_memory - initial_memory
609 | thread_growth = final_threads - initial_threads
610 | peak_memory = max(snapshot["memory_mb"] for snapshot in resource_snapshots)
611 | avg_threads = sum(snapshot["threads"] for snapshot in resource_snapshots) / len(
612 | resource_snapshots
613 | )
614 |
615 | # Resource management assertions
616 | assert memory_growth < 500, (
617 | f"Memory growth too high: {memory_growth:.1f}MB"
618 | ) # Max 500MB growth
619 | assert thread_growth <= 10, (
620 | f"Thread growth too high: {thread_growth}"
621 | ) # Max 10 additional threads
622 | assert peak_memory < initial_memory + 1000, (
623 | f"Peak memory too high: {peak_memory:.1f}MB"
624 | ) # Peak within 1GB of initial
625 |
626 | logger.info(
627 | f"Resource Management Test Results:\n"
628 | f" • Initial Memory: {initial_memory:.1f}MB\n"
629 | f" • Final Memory: {final_memory:.1f}MB\n"
630 | f" • Memory Growth: {memory_growth:.1f}MB\n"
631 | f" • Peak Memory: {peak_memory:.1f}MB\n"
632 | f" • Initial Threads: {initial_threads}\n"
633 | f" • Final Threads: {final_threads}\n"
634 | f" • Thread Growth: {thread_growth}\n"
635 | f" • Avg Threads: {avg_threads:.1f}"
636 | )
637 |
638 | return {
639 | "memory_growth": memory_growth,
640 | "thread_growth": thread_growth,
641 | "peak_memory": peak_memory,
642 | "resource_snapshots": resource_snapshots,
643 | }
644 |
645 |
646 | if __name__ == "__main__":
647 | # Run advanced integration tests
648 | pytest.main(
649 | [
650 | __file__,
651 | "-v",
652 | "--tb=short",
653 | "--asyncio-mode=auto",
654 | "--timeout=600", # 10 minute timeout for comprehensive tests
655 | "-x", # Stop on first failure
656 | "--durations=10", # Show 10 slowest tests
657 | ]
658 | )
659 |
```
--------------------------------------------------------------------------------
/tests/test_tool_estimation_config.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive tests for ToolEstimationConfig.
3 |
4 | This module tests the centralized tool cost estimation configuration that replaces
5 | magic numbers scattered throughout the codebase. Tests cover:
6 | - All tool-specific estimates
7 | - Confidence levels and estimation basis
8 | - Monitoring thresholds and alert conditions
9 | - Edge cases and error handling
10 | - Integration with server.py patterns
11 | """
12 |
13 | from unittest.mock import patch
14 |
15 | import pytest
16 |
17 | from maverick_mcp.config.tool_estimation import (
18 | EstimationBasis,
19 | MonitoringThresholds,
20 | ToolComplexity,
21 | ToolEstimate,
22 | ToolEstimationConfig,
23 | get_tool_estimate,
24 | get_tool_estimation_config,
25 | should_alert_for_usage,
26 | )
27 |
28 |
29 | class TestToolEstimate:
30 | """Test ToolEstimate model validation and behavior."""
31 |
32 | def test_valid_tool_estimate(self):
33 | """Test creating a valid ToolEstimate."""
34 | estimate = ToolEstimate(
35 | llm_calls=5,
36 | total_tokens=8000,
37 | confidence=0.8,
38 | based_on=EstimationBasis.EMPIRICAL,
39 | complexity=ToolComplexity.COMPLEX,
40 | notes="Test estimate",
41 | )
42 |
43 | assert estimate.llm_calls == 5
44 | assert estimate.total_tokens == 8000
45 | assert estimate.confidence == 0.8
46 | assert estimate.based_on == EstimationBasis.EMPIRICAL
47 | assert estimate.complexity == ToolComplexity.COMPLEX
48 | assert estimate.notes == "Test estimate"
49 |
50 | def test_confidence_validation(self):
51 | """Test confidence level validation."""
52 | # Valid confidence levels
53 | for confidence in [0.0, 0.5, 1.0]:
54 | estimate = ToolEstimate(
55 | llm_calls=1,
56 | total_tokens=100,
57 | confidence=confidence,
58 | based_on=EstimationBasis.EMPIRICAL,
59 | complexity=ToolComplexity.SIMPLE,
60 | )
61 | assert estimate.confidence == confidence
62 |
63 | # Invalid confidence levels - Pydantic ValidationError
64 | from pydantic import ValidationError
65 |
66 | with pytest.raises(ValidationError):
67 | ToolEstimate(
68 | llm_calls=1,
69 | total_tokens=100,
70 | confidence=-0.1,
71 | based_on=EstimationBasis.EMPIRICAL,
72 | complexity=ToolComplexity.SIMPLE,
73 | )
74 |
75 | with pytest.raises(ValidationError):
76 | ToolEstimate(
77 | llm_calls=1,
78 | total_tokens=100,
79 | confidence=1.1,
80 | based_on=EstimationBasis.EMPIRICAL,
81 | complexity=ToolComplexity.SIMPLE,
82 | )
83 |
84 | def test_negative_values_validation(self):
85 | """Test that negative values are not allowed."""
86 | from pydantic import ValidationError
87 |
88 | with pytest.raises(ValidationError):
89 | ToolEstimate(
90 | llm_calls=-1,
91 | total_tokens=100,
92 | confidence=0.8,
93 | based_on=EstimationBasis.EMPIRICAL,
94 | complexity=ToolComplexity.SIMPLE,
95 | )
96 |
97 | with pytest.raises(ValidationError):
98 | ToolEstimate(
99 | llm_calls=1,
100 | total_tokens=-100,
101 | confidence=0.8,
102 | based_on=EstimationBasis.EMPIRICAL,
103 | complexity=ToolComplexity.SIMPLE,
104 | )
105 |
106 |
107 | class TestMonitoringThresholds:
108 | """Test MonitoringThresholds model and validation."""
109 |
110 | def test_default_thresholds(self):
111 | """Test default monitoring thresholds."""
112 | thresholds = MonitoringThresholds()
113 |
114 | assert thresholds.llm_calls_warning == 15
115 | assert thresholds.llm_calls_critical == 25
116 | assert thresholds.tokens_warning == 20000
117 | assert thresholds.tokens_critical == 35000
118 | assert thresholds.variance_warning == 0.5
119 | assert thresholds.variance_critical == 1.0
120 |
121 | def test_custom_thresholds(self):
122 | """Test custom monitoring thresholds."""
123 | thresholds = MonitoringThresholds(
124 | llm_calls_warning=10,
125 | llm_calls_critical=20,
126 | tokens_warning=15000,
127 | tokens_critical=30000,
128 | variance_warning=0.3,
129 | variance_critical=0.8,
130 | )
131 |
132 | assert thresholds.llm_calls_warning == 10
133 | assert thresholds.llm_calls_critical == 20
134 | assert thresholds.tokens_warning == 15000
135 | assert thresholds.tokens_critical == 30000
136 | assert thresholds.variance_warning == 0.3
137 | assert thresholds.variance_critical == 0.8
138 |
139 |
140 | class TestToolEstimationConfig:
141 | """Test the main ToolEstimationConfig class."""
142 |
143 | def test_default_configuration(self):
144 | """Test default configuration initialization."""
145 | config = ToolEstimationConfig()
146 |
147 | # Test default estimates by complexity
148 | assert config.simple_default.complexity == ToolComplexity.SIMPLE
149 | assert config.standard_default.complexity == ToolComplexity.STANDARD
150 | assert config.complex_default.complexity == ToolComplexity.COMPLEX
151 | assert config.premium_default.complexity == ToolComplexity.PREMIUM
152 |
153 | # Test unknown tool fallback
154 | assert config.unknown_tool_estimate.complexity == ToolComplexity.STANDARD
155 | assert config.unknown_tool_estimate.confidence == 0.3
156 | assert config.unknown_tool_estimate.based_on == EstimationBasis.CONSERVATIVE
157 |
158 | def test_get_estimate_known_tools(self):
159 | """Test getting estimates for known tools."""
160 | config = ToolEstimationConfig()
161 |
162 | # Test simple tools
163 | simple_tools = [
164 | "get_stock_price",
165 | "get_company_info",
166 | "get_stock_info",
167 | "calculate_sma",
168 | "get_market_hours",
169 | "get_chart_links",
170 | "list_available_agents",
171 | "clear_cache",
172 | "get_cached_price_data",
173 | "get_watchlist",
174 | "generate_dev_token",
175 | ]
176 |
177 | for tool in simple_tools:
178 | estimate = config.get_estimate(tool)
179 | assert estimate.complexity == ToolComplexity.SIMPLE
180 | assert estimate.llm_calls <= 1 # Simple tools should have minimal LLM usage
181 | assert estimate.confidence >= 0.8 # Should have high confidence
182 |
183 | # Test standard tools
184 | standard_tools = [
185 | "get_rsi_analysis",
186 | "get_macd_analysis",
187 | "get_support_resistance",
188 | "fetch_stock_data",
189 | "get_maverick_stocks",
190 | "get_news_sentiment",
191 | "get_economic_calendar",
192 | ]
193 |
194 | for tool in standard_tools:
195 | estimate = config.get_estimate(tool)
196 | assert estimate.complexity == ToolComplexity.STANDARD
197 | assert 1 <= estimate.llm_calls <= 5
198 | assert estimate.confidence >= 0.7
199 |
200 | # Test complex tools
201 | complex_tools = [
202 | "get_full_technical_analysis",
203 | "risk_adjusted_analysis",
204 | "compare_tickers",
205 | "portfolio_correlation_analysis",
206 | "get_market_overview",
207 | "get_all_screening_recommendations",
208 | ]
209 |
210 | for tool in complex_tools:
211 | estimate = config.get_estimate(tool)
212 | assert estimate.complexity == ToolComplexity.COMPLEX
213 | assert 4 <= estimate.llm_calls <= 8
214 | assert estimate.confidence >= 0.7
215 |
216 | # Test premium tools
217 | premium_tools = [
218 | "analyze_market_with_agent",
219 | "get_agent_streaming_analysis",
220 | "compare_personas_analysis",
221 | ]
222 |
223 | for tool in premium_tools:
224 | estimate = config.get_estimate(tool)
225 | assert estimate.complexity == ToolComplexity.PREMIUM
226 | assert estimate.llm_calls >= 8
227 | assert estimate.total_tokens >= 10000
228 |
229 | def test_get_estimate_unknown_tool(self):
230 | """Test getting estimate for unknown tools."""
231 | config = ToolEstimationConfig()
232 | estimate = config.get_estimate("unknown_tool_name")
233 |
234 | assert estimate == config.unknown_tool_estimate
235 | assert estimate.complexity == ToolComplexity.STANDARD
236 | assert estimate.confidence == 0.3
237 | assert estimate.based_on == EstimationBasis.CONSERVATIVE
238 |
239 | def test_get_default_for_complexity(self):
240 | """Test getting default estimates by complexity."""
241 | config = ToolEstimationConfig()
242 |
243 | simple = config.get_default_for_complexity(ToolComplexity.SIMPLE)
244 | assert simple == config.simple_default
245 |
246 | standard = config.get_default_for_complexity(ToolComplexity.STANDARD)
247 | assert standard == config.standard_default
248 |
249 | complex_est = config.get_default_for_complexity(ToolComplexity.COMPLEX)
250 | assert complex_est == config.complex_default
251 |
252 | premium = config.get_default_for_complexity(ToolComplexity.PREMIUM)
253 | assert premium == config.premium_default
254 |
255 | def test_should_alert_critical_thresholds(self):
256 | """Test alert conditions for critical thresholds."""
257 | config = ToolEstimationConfig()
258 |
259 | # Test critical LLM calls threshold
260 | should_alert, message = config.should_alert("test_tool", 30, 5000)
261 | assert should_alert
262 | assert "Critical: LLM calls (30) exceeded threshold (25)" in message
263 |
264 | # Test critical token threshold
265 | should_alert, message = config.should_alert("test_tool", 5, 40000)
266 | assert should_alert
267 | assert "Critical: Token usage (40000) exceeded threshold (35000)" in message
268 |
269 | def test_should_alert_variance_thresholds(self):
270 | """Test alert conditions for variance thresholds."""
271 | config = ToolEstimationConfig()
272 |
273 | # Test tool with known estimate for variance calculation
274 | # get_stock_price: llm_calls=0, total_tokens=200
275 |
276 | # Test critical LLM variance (infinite variance since estimate is 0)
277 | should_alert, message = config.should_alert("get_stock_price", 5, 200)
278 | assert should_alert
279 | assert "Critical: LLM call variance" in message
280 |
281 | # Test critical token variance (5x the estimate = 400% variance)
282 | should_alert, message = config.should_alert("get_stock_price", 0, 1000)
283 | assert should_alert
284 | assert "Critical: Token variance" in message
285 |
286 | def test_should_alert_warning_thresholds(self):
287 | """Test alert conditions for warning thresholds."""
288 | config = ToolEstimationConfig()
289 |
290 | # Test warning LLM calls threshold (15-24 should trigger warning)
291 | # Use unknown tool which has reasonable base estimates to avoid variance issues
292 | should_alert, message = config.should_alert("unknown_tool", 18, 5000)
293 | assert should_alert
294 | assert (
295 | "Warning" in message or "Critical" in message
296 | ) # May trigger critical due to variance
297 |
298 | # Test warning token threshold with a tool that has known estimates
299 | # get_rsi_analysis: llm_calls=2, total_tokens=3000
300 | should_alert, message = config.should_alert("get_rsi_analysis", 2, 25000)
301 | assert should_alert
302 | assert (
303 | "Warning" in message or "Critical" in message
304 | ) # High token variance may trigger critical
305 |
306 | def test_should_alert_no_alert(self):
307 | """Test cases where no alert should be triggered."""
308 | config = ToolEstimationConfig()
309 |
310 | # Normal usage within expected ranges
311 | should_alert, message = config.should_alert("get_stock_price", 0, 200)
312 | assert not should_alert
313 | assert message == ""
314 |
315 | # Slightly above estimate but within acceptable variance
316 | should_alert, message = config.should_alert("get_stock_price", 0, 250)
317 | assert not should_alert
318 | assert message == ""
319 |
320 | def test_get_tools_by_complexity(self):
321 | """Test filtering tools by complexity category."""
322 | config = ToolEstimationConfig()
323 |
324 | simple_tools = config.get_tools_by_complexity(ToolComplexity.SIMPLE)
325 | standard_tools = config.get_tools_by_complexity(ToolComplexity.STANDARD)
326 | complex_tools = config.get_tools_by_complexity(ToolComplexity.COMPLEX)
327 | premium_tools = config.get_tools_by_complexity(ToolComplexity.PREMIUM)
328 |
329 | # Verify all tools are categorized
330 | all_tools = simple_tools + standard_tools + complex_tools + premium_tools
331 | assert len(all_tools) == len(config.tool_estimates)
332 |
333 | # Verify no overlap between categories
334 | assert len(set(all_tools)) == len(all_tools)
335 |
336 | # Verify specific known tools are in correct categories
337 | assert "get_stock_price" in simple_tools
338 | assert "get_rsi_analysis" in standard_tools
339 | assert "get_full_technical_analysis" in complex_tools
340 | assert "analyze_market_with_agent" in premium_tools
341 |
342 | def test_get_summary_stats(self):
343 | """Test summary statistics generation."""
344 | config = ToolEstimationConfig()
345 | stats = config.get_summary_stats()
346 |
347 | # Verify structure
348 | assert "total_tools" in stats
349 | assert "by_complexity" in stats
350 | assert "avg_llm_calls" in stats
351 | assert "avg_tokens" in stats
352 | assert "avg_confidence" in stats
353 | assert "basis_distribution" in stats
354 |
355 | # Verify content
356 | assert stats["total_tools"] > 0
357 | assert len(stats["by_complexity"]) == 4 # All complexity levels
358 | assert stats["avg_llm_calls"] >= 0
359 | assert stats["avg_tokens"] > 0
360 | assert 0 <= stats["avg_confidence"] <= 1
361 |
362 | # Verify complexity distribution adds up
363 | complexity_sum = sum(stats["by_complexity"].values())
364 | assert complexity_sum == stats["total_tools"]
365 |
366 |
367 | class TestModuleFunctions:
368 | """Test module-level functions."""
369 |
370 | def test_get_tool_estimation_config_singleton(self):
371 | """Test that get_tool_estimation_config returns a singleton."""
372 | config1 = get_tool_estimation_config()
373 | config2 = get_tool_estimation_config()
374 |
375 | # Should return the same instance
376 | assert config1 is config2
377 |
378 | @patch("maverick_mcp.config.tool_estimation._config", None)
379 | def test_get_tool_estimation_config_initialization(self):
380 | """Test that configuration is initialized correctly."""
381 | config = get_tool_estimation_config()
382 |
383 | assert isinstance(config, ToolEstimationConfig)
384 | assert len(config.tool_estimates) > 0
385 |
386 | def test_get_tool_estimate_function(self):
387 | """Test the get_tool_estimate convenience function."""
388 | estimate = get_tool_estimate("get_stock_price")
389 |
390 | assert isinstance(estimate, ToolEstimate)
391 | assert estimate.complexity == ToolComplexity.SIMPLE
392 |
393 | # Test unknown tool
394 | unknown_estimate = get_tool_estimate("unknown_tool")
395 | assert unknown_estimate.based_on == EstimationBasis.CONSERVATIVE
396 |
397 | def test_should_alert_for_usage_function(self):
398 | """Test the should_alert_for_usage convenience function."""
399 | should_alert, message = should_alert_for_usage("test_tool", 30, 5000)
400 |
401 | assert isinstance(should_alert, bool)
402 | assert isinstance(message, str)
403 |
404 | # Should trigger alert for high LLM calls
405 | assert should_alert
406 | assert "Critical" in message
407 |
408 |
409 | class TestMagicNumberReplacement:
410 | """Test that configuration correctly replaces magic numbers from server.py."""
411 |
412 | def test_all_usage_tier_tools_have_estimates(self):
413 | """Test that all tools referenced in server.py have estimates."""
414 | config = ToolEstimationConfig()
415 |
416 | # These are tools that were using magic numbers in server.py
417 | # Based on the TOOL usage tier mapping pattern
418 | critical_tools = [
419 | # Simple tools (baseline tier)
420 | "get_stock_price",
421 | "get_company_info",
422 | "get_stock_info",
423 | "calculate_sma",
424 | "get_market_hours",
425 | "get_chart_links",
426 | # Standard tools (core analysis tier)
427 | "get_rsi_analysis",
428 | "get_macd_analysis",
429 | "get_support_resistance",
430 | "fetch_stock_data",
431 | "get_maverick_stocks",
432 | "get_news_sentiment",
433 | # Complex tools (advanced analysis tier)
434 | "get_full_technical_analysis",
435 | "risk_adjusted_analysis",
436 | "compare_tickers",
437 | "portfolio_correlation_analysis",
438 | "get_market_overview",
439 | # Premium tools (orchestration tier)
440 | "analyze_market_with_agent",
441 | "get_agent_streaming_analysis",
442 | "compare_personas_analysis",
443 | ]
444 |
445 | for tool in critical_tools:
446 | estimate = config.get_estimate(tool)
447 | # Should not get the fallback estimate
448 | assert estimate != config.unknown_tool_estimate, (
449 | f"Tool {tool} missing specific estimate"
450 | )
451 | # Should have reasonable confidence
452 | assert estimate.confidence > 0.5, f"Tool {tool} has low confidence estimate"
453 |
454 | def test_estimates_align_with_usage_tiers(self):
455 | """Test that tool estimates align with usage complexity tiers."""
456 | config = ToolEstimationConfig()
457 |
458 | # Simple tools should require minimal resources
459 | simple_tools = [
460 | "get_stock_price",
461 | "get_company_info",
462 | "get_stock_info",
463 | "calculate_sma",
464 | "get_market_hours",
465 | "get_chart_links",
466 | ]
467 |
468 | for tool in simple_tools:
469 | estimate = config.get_estimate(tool)
470 | assert estimate.complexity == ToolComplexity.SIMPLE
471 | assert estimate.llm_calls <= 1 # Should require minimal/no LLM calls
472 |
473 | # Standard tools perform moderate analysis
474 | standard_tools = [
475 | "get_rsi_analysis",
476 | "get_macd_analysis",
477 | "get_support_resistance",
478 | "fetch_stock_data",
479 | "get_maverick_stocks",
480 | ]
481 |
482 | for tool in standard_tools:
483 | estimate = config.get_estimate(tool)
484 | assert estimate.complexity == ToolComplexity.STANDARD
485 | assert 1 <= estimate.llm_calls <= 5 # Moderate LLM usage
486 |
487 | # Complex tools orchestrate heavier workloads
488 | complex_tools = [
489 | "get_full_technical_analysis",
490 | "risk_adjusted_analysis",
491 | "compare_tickers",
492 | "portfolio_correlation_analysis",
493 | ]
494 |
495 | for tool in complex_tools:
496 | estimate = config.get_estimate(tool)
497 | assert estimate.complexity == ToolComplexity.COMPLEX
498 | assert 4 <= estimate.llm_calls <= 8 # Multiple LLM interactions
499 |
500 | # Premium tools coordinate multi-stage workflows
501 | premium_tools = [
502 | "analyze_market_with_agent",
503 | "get_agent_streaming_analysis",
504 | "compare_personas_analysis",
505 | ]
506 |
507 | for tool in premium_tools:
508 | estimate = config.get_estimate(tool)
509 | assert estimate.complexity == ToolComplexity.PREMIUM
510 | assert estimate.llm_calls >= 8 # Extensive LLM coordination
511 |
512 | def test_no_hardcoded_estimates_remain(self):
513 | """Test that estimates are data-driven, not hardcoded."""
514 | config = ToolEstimationConfig()
515 |
516 | # All tool estimates should have basis information
517 | for tool_name, estimate in config.tool_estimates.items():
518 | assert estimate.based_on in EstimationBasis
519 | assert estimate.complexity in ToolComplexity
520 | assert estimate.notes is not None, f"Tool {tool_name} missing notes"
521 |
522 | # Empirical estimates should generally have reasonable confidence
523 | if estimate.based_on == EstimationBasis.EMPIRICAL:
524 | assert estimate.confidence >= 0.6, (
525 | f"Empirical estimate for {tool_name} has very low confidence"
526 | )
527 |
528 | # Conservative estimates should have lower confidence
529 | if estimate.based_on == EstimationBasis.CONSERVATIVE:
530 | assert estimate.confidence <= 0.6, (
531 | f"Conservative estimate for {tool_name} has unexpectedly high confidence"
532 | )
533 |
534 |
535 | class TestEdgeCases:
536 | """Test edge cases and error conditions."""
537 |
538 | def test_empty_configuration(self):
539 | """Test behavior with empty tool estimates."""
540 | config = ToolEstimationConfig(tool_estimates={})
541 |
542 | # Should fall back to unknown tool estimate
543 | estimate = config.get_estimate("any_tool")
544 | assert estimate == config.unknown_tool_estimate
545 |
546 | # Summary stats should handle empty case
547 | stats = config.get_summary_stats()
548 | assert stats == {}
549 |
550 | def test_alert_with_zero_estimates(self):
551 | """Test alert calculation when estimates are zero."""
552 | config = ToolEstimationConfig()
553 |
554 | # Tool with zero LLM calls in estimate
555 | should_alert, message = config.should_alert("get_stock_price", 1, 200)
556 | # Should alert because variance is infinite (1 vs 0 expected)
557 | assert should_alert
558 |
559 | def test_variance_calculation_edge_cases(self):
560 | """Test variance calculation with edge cases."""
561 | config = ToolEstimationConfig()
562 |
563 | # Perfect match should not alert
564 | should_alert, message = config.should_alert("get_rsi_analysis", 2, 3000)
565 | # get_rsi_analysis has: llm_calls=2, total_tokens=3000
566 | assert not should_alert
567 |
568 | def test_performance_with_large_usage(self):
569 | """Test performance and behavior with extremely large usage values."""
570 | config = ToolEstimationConfig()
571 |
572 | # Very large values should still work
573 | should_alert, message = config.should_alert("test_tool", 1000, 1000000)
574 | assert should_alert
575 | assert "Critical" in message
576 |
577 | def test_custom_monitoring_thresholds(self):
578 | """Test configuration with custom monitoring thresholds."""
579 | custom_monitoring = MonitoringThresholds(
580 | llm_calls_warning=5,
581 | llm_calls_critical=10,
582 | tokens_warning=1000,
583 | tokens_critical=5000,
584 | variance_warning=0.1,
585 | variance_critical=0.2,
586 | )
587 |
588 | config = ToolEstimationConfig(monitoring=custom_monitoring)
589 |
590 | # Should use custom thresholds
591 | # Test critical threshold first (easier to predict)
592 | should_alert, message = config.should_alert("test_tool", 12, 500)
593 | assert should_alert
594 | assert "Critical" in message
595 |
596 | # Test LLM calls warning threshold
597 | should_alert, message = config.should_alert(
598 | "test_tool", 6, 100
599 | ) # Lower tokens to avoid variance issues
600 | assert should_alert
601 | # May be warning or critical depending on variance calculation
602 |
603 |
604 | class TestIntegrationPatterns:
605 | """Test patterns that match server.py integration."""
606 |
607 | def test_low_confidence_logging_pattern(self):
608 | """Test identifying tools that need monitoring due to low confidence."""
609 | config = ToolEstimationConfig()
610 |
611 | low_confidence_tools = []
612 | for tool_name, estimate in config.tool_estimates.items():
613 | if estimate.confidence < 0.8:
614 | low_confidence_tools.append(tool_name)
615 |
616 | # These tools should be logged for monitoring in production
617 | assert len(low_confidence_tools) > 0
618 |
619 | # Verify these are typically more complex tools
620 | for tool_name in low_confidence_tools:
621 | estimate = config.get_estimate(tool_name)
622 | # Low confidence tools should typically be complex, premium, or analytical standard tools
623 | assert estimate.complexity in [
624 | ToolComplexity.STANDARD,
625 | ToolComplexity.COMPLEX,
626 | ToolComplexity.PREMIUM,
627 | ], (
628 | f"Tool {tool_name} with low confidence has unexpected complexity {estimate.complexity}"
629 | )
630 |
631 | def test_error_handling_fallback_pattern(self):
632 | """Test the error handling pattern used in server.py."""
633 | config = ToolEstimationConfig()
634 |
635 | # Simulate error case - should fall back to unknown tool estimate
636 | try:
637 | # This would be the pattern in server.py when get_tool_estimate fails
638 | estimate = config.get_estimate("nonexistent_tool")
639 | fallback_estimate = config.unknown_tool_estimate
640 |
641 | # Verify fallback has conservative characteristics
642 | assert fallback_estimate.based_on == EstimationBasis.CONSERVATIVE
643 | assert fallback_estimate.confidence == 0.3
644 | assert fallback_estimate.complexity == ToolComplexity.STANDARD
645 |
646 | # Should be the same as what get_estimate returns for unknown tools
647 | assert estimate == fallback_estimate
648 |
649 | except Exception:
650 | # If estimation fails entirely, should be able to use fallback
651 | fallback = config.unknown_tool_estimate
652 | assert fallback.llm_calls > 0
653 | assert fallback.total_tokens > 0
654 |
655 | def test_usage_logging_extra_fields(self):
656 | """Test that estimates provide all fields needed for logging."""
657 | config = ToolEstimationConfig()
658 |
659 | for _tool_name, estimate in config.tool_estimates.items():
660 | # Verify all fields needed for server.py logging are present
661 | assert hasattr(estimate, "confidence")
662 | assert hasattr(estimate, "based_on")
663 | assert hasattr(estimate, "complexity")
664 | assert hasattr(estimate, "llm_calls")
665 | assert hasattr(estimate, "total_tokens")
666 |
667 | # Verify fields have appropriate types for logging
668 | assert isinstance(estimate.confidence, float)
669 | assert isinstance(estimate.based_on, EstimationBasis)
670 | assert isinstance(estimate.complexity, ToolComplexity)
671 | assert isinstance(estimate.llm_calls, int)
672 | assert isinstance(estimate.total_tokens, int)
673 |
```
--------------------------------------------------------------------------------
/docs/PORTFOLIO_PERSONALIZATION_PLAN.md:
--------------------------------------------------------------------------------
```markdown
1 | # PORTFOLIO PERSONALIZATION - EXECUTION PLAN
2 |
3 | ## 1. Big Picture / Goal
4 |
5 | **Objective:** Transform MaverickMCP's portfolio analysis tools from stateless, repetitive-input operations into an intelligent, personalized AI financial assistant through persistent portfolio storage and context-aware tool integration.
6 |
7 | **Architectural Goal:** Implement a two-phase system that (1) adds persistent portfolio storage with cost basis tracking using established DDD patterns, and (2) intelligently enhances existing tools to auto-detect user holdings and provide personalized analysis without breaking the stateless MCP tool contract.
8 |
9 | **Success Criteria (Mandatory):**
10 | - **Phase 1 Complete:** 4 new MCP tools (`add_portfolio_position`, `get_my_portfolio`, `remove_portfolio_position`, `clear_my_portfolio`) and 1 MCP resource (`portfolio://my-holdings`) fully functional
11 | - **Database Integration:** SQLAlchemy models with proper cost basis averaging, Alembic migration creating tables without conflicts
12 | - **Phase 2 Integration:** 3 existing tools enhanced (`risk_adjusted_analysis`, `portfolio_correlation_analysis`, `compare_tickers`) with automatic portfolio detection
13 | - **AI Context Injection:** Portfolio resource provides live P&L, diversification metrics, and position details to AI agents automatically
14 | - **Test Coverage:** 85%+ test coverage with unit, integration, and domain tests passing
15 | - **Code Quality:** Zero linting errors (ruff), full type annotations (ty), all hooks passing
16 | - **Documentation:** PORTFOLIO.md guide, updated tool docstrings, usage examples in Claude Desktop
17 |
18 | **Financial Disclaimer:** All portfolio features include educational disclaimers. No investment recommendations. Local-first storage only. No tax advice provided.
19 |
20 | ## 2. To-Do List (High Level)
21 |
22 | ### Phase 1: Persistent Portfolio Storage Foundation (4-5 days)
23 | - [ ] **Spike 1:** Research cost basis averaging algorithms and edge cases (FIFO, average cost)
24 | - [ ] **Domain Entities:** Create `Portfolio` and `Position` domain entities with business logic
25 | - [ ] **Database Models:** Implement `UserPortfolio` and `PortfolioPosition` SQLAlchemy models
26 | - [ ] **Migration:** Create Alembic migration with proper indexes and constraints
27 | - [ ] **MCP Tools:** Implement 4 portfolio management tools with validation
28 | - [ ] **MCP Resource:** Implement `portfolio://my-holdings` with live P&L calculations
29 | - [ ] **Unit Tests:** Comprehensive domain entity and cost basis tests
30 | - [ ] **Integration Tests:** Database operation and transaction tests
31 |
32 | ### Phase 2: Intelligent Tool Integration (2-3 days)
33 | - [ ] **Risk Analysis Enhancement:** Add position awareness to `risk_adjusted_analysis`
34 | - [ ] **Correlation Enhancement:** Enable `portfolio_correlation_analysis` with no arguments
35 | - [ ] **Comparison Enhancement:** Enable `compare_tickers` with optional portfolio auto-fill
36 | - [ ] **Resource Enhancement:** Add live market data to portfolio resource
37 | - [ ] **Integration Tests:** Cross-tool functionality validation
38 | - [ ] **Documentation:** Update existing tool docstrings with new capabilities
39 |
40 | ### Phase 3: Polish & Documentation (1-2 days)
41 | - [ ] **Manual Testing:** Claude Desktop end-to-end workflow validation
42 | - [ ] **Error Handling:** Edge case coverage (partial sells, zero shares, invalid tickers)
43 | - [ ] **Performance:** Query optimization, batch operations, caching strategy
44 | - [ ] **Documentation:** Complete PORTFOLIO.md with examples and screenshots
45 | - [ ] **Migration Testing:** Test upgrade/downgrade paths
46 |
47 | ## 3. Plan Details (Spikes & Features)
48 |
49 | ### Spike 1: Cost Basis Averaging Research
50 |
51 | **Action:** Investigate cost basis calculation methods (FIFO, LIFO, average cost) and determine optimal approach for educational portfolio tracking.
52 |
53 | **Steps:**
54 | 1. Research IRS cost basis methods and educational best practices
55 | 2. Analyze existing `PortfolioManager` tool (JSON-based, average cost) for patterns
56 | 3. Design algorithm for averaging purchases and handling partial sells
57 | 4. Create specification document for edge cases:
58 | - Multiple purchases at different prices
59 | - Partial position sales
60 | - Zero/negative share handling
61 | - Rounding and precision (financial data uses Numeric(12,4))
62 | 5. Benchmark performance for 100+ positions with 1000+ transactions
63 |
64 | **Expected Outcome:** Clear specification for cost basis implementation using **average cost method** (simplest for educational use, matches existing PortfolioManager), with edge case handling documented.
65 |
66 | **Decision Rationale:** Average cost is simpler than FIFO/LIFO, appropriate for educational context, and avoids tax accounting complexity.
67 |
68 | ---
69 |
70 | ### Feature A: Domain Entities (DDD Pattern)
71 |
72 | **Goal:** Create pure business logic entities following MaverickMCP's established DDD patterns (similar to backtesting domain entities).
73 |
74 | **Files to Create:**
75 | - `maverick_mcp/domain/portfolio.py` - Core domain entities
76 | - `maverick_mcp/domain/position.py` - Position value objects
77 |
78 | **Domain Entity Design:**
79 |
80 | ```python
81 | # maverick_mcp/domain/portfolio.py
82 | from dataclasses import dataclass
83 | from datetime import datetime
84 | from decimal import Decimal
85 | from typing import List, Optional
86 |
87 | @dataclass
88 | class Position:
89 | """Value object representing a single portfolio position."""
90 | ticker: str
91 | shares: Decimal # Use Decimal for precision
92 | average_cost_basis: Decimal
93 | total_cost: Decimal
94 | purchase_date: datetime # Earliest purchase
95 | notes: Optional[str] = None
96 |
97 | def add_shares(self, shares: Decimal, price: Decimal, date: datetime) -> "Position":
98 | """Add shares with automatic cost basis averaging."""
99 | new_total_shares = self.shares + shares
100 | new_total_cost = self.total_cost + (shares * price)
101 | new_avg_cost = new_total_cost / new_total_shares
102 |
103 | return Position(
104 | ticker=self.ticker,
105 | shares=new_total_shares,
106 | average_cost_basis=new_avg_cost,
107 | total_cost=new_total_cost,
108 | purchase_date=min(self.purchase_date, date),
109 | notes=self.notes
110 | )
111 |
112 | def remove_shares(self, shares: Decimal) -> Optional["Position"]:
113 | """Remove shares, return None if position fully closed."""
114 | if shares >= self.shares:
115 | return None # Full position close
116 |
117 | new_shares = self.shares - shares
118 | new_total_cost = new_shares * self.average_cost_basis
119 |
120 | return Position(
121 | ticker=self.ticker,
122 | shares=new_shares,
123 | average_cost_basis=self.average_cost_basis,
124 | total_cost=new_total_cost,
125 | purchase_date=self.purchase_date,
126 | notes=self.notes
127 | )
128 |
129 | def calculate_current_value(self, current_price: Decimal) -> dict:
130 | """Calculate live P&L metrics."""
131 | current_value = self.shares * current_price
132 | unrealized_pnl = current_value - self.total_cost
133 | pnl_percentage = (unrealized_pnl / self.total_cost * 100) if self.total_cost else Decimal(0)
134 |
135 | return {
136 | "current_value": current_value,
137 | "unrealized_pnl": unrealized_pnl,
138 | "pnl_percentage": pnl_percentage
139 | }
140 |
141 | @dataclass
142 | class Portfolio:
143 | """Aggregate root for user portfolio."""
144 | portfolio_id: str # UUID
145 | user_id: str # "default" for single-user
146 | name: str
147 | positions: List[Position]
148 | created_at: datetime
149 | updated_at: datetime
150 |
151 | def add_position(self, ticker: str, shares: Decimal, price: Decimal,
152 | date: datetime, notes: Optional[str] = None) -> None:
153 | """Add or update position with automatic averaging."""
154 | # Find existing position
155 | for i, pos in enumerate(self.positions):
156 | if pos.ticker == ticker:
157 | self.positions[i] = pos.add_shares(shares, price, date)
158 | self.updated_at = datetime.now(UTC)
159 | return
160 |
161 | # Create new position
162 | new_position = Position(
163 | ticker=ticker,
164 | shares=shares,
165 | average_cost_basis=price,
166 | total_cost=shares * price,
167 | purchase_date=date,
168 | notes=notes
169 | )
170 | self.positions.append(new_position)
171 | self.updated_at = datetime.now(UTC)
172 |
173 | def remove_position(self, ticker: str, shares: Optional[Decimal] = None) -> bool:
174 | """Remove position or partial shares."""
175 | for i, pos in enumerate(self.positions):
176 | if pos.ticker == ticker:
177 | if shares is None or shares >= pos.shares:
178 | # Full position removal
179 | self.positions.pop(i)
180 | else:
181 | # Partial removal
182 | updated_pos = pos.remove_shares(shares)
183 | if updated_pos:
184 | self.positions[i] = updated_pos
185 | else:
186 | self.positions.pop(i)
187 |
188 | self.updated_at = datetime.now(UTC)
189 | return True
190 | return False
191 |
192 | def get_position(self, ticker: str) -> Optional[Position]:
193 | """Get position by ticker."""
194 | return next((pos for pos in self.positions if pos.ticker == ticker), None)
195 |
196 | def get_total_invested(self) -> Decimal:
197 | """Calculate total capital invested."""
198 | return sum(pos.total_cost for pos in self.positions)
199 |
200 | def calculate_portfolio_metrics(self, current_prices: dict[str, Decimal]) -> dict:
201 | """Calculate comprehensive portfolio metrics."""
202 | total_value = Decimal(0)
203 | total_cost = Decimal(0)
204 | position_details = []
205 |
206 | for pos in self.positions:
207 | current_price = current_prices.get(pos.ticker, pos.average_cost_basis)
208 | metrics = pos.calculate_current_value(current_price)
209 |
210 | total_value += metrics["current_value"]
211 | total_cost += pos.total_cost
212 |
213 | position_details.append({
214 | "ticker": pos.ticker,
215 | "shares": float(pos.shares),
216 | "cost_basis": float(pos.average_cost_basis),
217 | "current_price": float(current_price),
218 | **{k: float(v) for k, v in metrics.items()}
219 | })
220 |
221 | total_pnl = total_value - total_cost
222 | total_pnl_pct = (total_pnl / total_cost * 100) if total_cost else Decimal(0)
223 |
224 | return {
225 | "total_value": float(total_value),
226 | "total_invested": float(total_cost),
227 | "total_pnl": float(total_pnl),
228 | "total_pnl_percentage": float(total_pnl_pct),
229 | "position_count": len(self.positions),
230 | "positions": position_details
231 | }
232 | ```
233 |
234 | **Testing Strategy:**
235 | - Unit tests for cost basis averaging edge cases
236 | - Property-based tests for arithmetic precision
237 | - Edge case tests: zero shares, negative P&L, division by zero
238 |
239 | ---
240 |
241 | ### Feature B: Database Models (SQLAlchemy ORM)
242 |
243 | **Goal:** Create persistent storage models following established patterns in `maverick_mcp/data/models.py`.
244 |
245 | **Files to Modify:**
246 | - `maverick_mcp/data/models.py` - Add new models (lines ~1700+)
247 |
248 | **Model Design:**
249 |
250 | ```python
251 | # Add to maverick_mcp/data/models.py
252 |
253 | class UserPortfolio(TimestampMixin, Base):
254 | """
255 | User portfolio for tracking investment holdings.
256 |
257 | Follows personal-use design: single user_id="default"
258 | """
259 | __tablename__ = "mcp_portfolios"
260 |
261 | id = Column(Uuid, primary_key=True, default=uuid.uuid4)
262 | user_id = Column(String(50), nullable=False, default="default", index=True)
263 | name = Column(String(200), nullable=False, default="My Portfolio")
264 |
265 | # Relationships
266 | positions = relationship(
267 | "PortfolioPosition",
268 | back_populates="portfolio",
269 | cascade="all, delete-orphan",
270 | lazy="selectin" # Efficient loading
271 | )
272 |
273 | # Indexes for queries
274 | __table_args__ = (
275 | Index("idx_portfolio_user", "user_id"),
276 | UniqueConstraint("user_id", "name", name="uq_user_portfolio_name"),
277 | )
278 |
279 | def __repr__(self):
280 | return f"<UserPortfolio(id={self.id}, name='{self.name}', positions={len(self.positions)})>"
281 |
282 |
283 | class PortfolioPosition(TimestampMixin, Base):
284 | """
285 | Individual position within a portfolio with cost basis tracking.
286 | """
287 | __tablename__ = "mcp_portfolio_positions"
288 |
289 | id = Column(Uuid, primary_key=True, default=uuid.uuid4)
290 | portfolio_id = Column(Uuid, ForeignKey("mcp_portfolios.id", ondelete="CASCADE"), nullable=False)
291 |
292 | # Position details
293 | ticker = Column(String(20), nullable=False, index=True)
294 | shares = Column(Numeric(20, 8), nullable=False) # High precision for fractional shares
295 | average_cost_basis = Column(Numeric(12, 4), nullable=False) # Financial precision
296 | total_cost = Column(Numeric(20, 4), nullable=False) # Total capital invested
297 | purchase_date = Column(DateTime(timezone=True), nullable=False) # Earliest purchase
298 | notes = Column(Text, nullable=True) # Optional user notes
299 |
300 | # Relationships
301 | portfolio = relationship("UserPortfolio", back_populates="positions")
302 |
303 | # Indexes for efficient queries
304 | __table_args__ = (
305 | Index("idx_position_portfolio", "portfolio_id"),
306 | Index("idx_position_ticker", "ticker"),
307 | Index("idx_position_portfolio_ticker", "portfolio_id", "ticker"),
308 | UniqueConstraint("portfolio_id", "ticker", name="uq_portfolio_position_ticker"),
309 | )
310 |
311 | def __repr__(self):
312 | return f"<PortfolioPosition(ticker='{self.ticker}', shares={self.shares}, cost_basis={self.average_cost_basis})>"
313 | ```
314 |
315 | **Key Design Decisions:**
316 | 1. **Table Names:** `mcp_portfolios` and `mcp_portfolio_positions` (consistent with `mcp_*` pattern)
317 | 2. **user_id:** Default "default" for single-user personal use
318 | 3. **Numeric Precision:** Matches existing financial data patterns (12,4 for prices, 20,8 for shares)
319 | 4. **Cascade Delete:** Portfolio deletion removes all positions automatically
320 | 5. **Unique Constraint:** One position per ticker per portfolio
321 | 6. **Indexes:** Optimized for common queries (user lookup, ticker filtering)
322 |
323 | ---
324 |
325 | ### Feature C: Alembic Migration
326 |
327 | **Goal:** Create database migration following established patterns without conflicts.
328 |
329 | **File to Create:**
330 | - `alembic/versions/014_add_portfolio_models.py`
331 |
332 | **Migration Pattern:**
333 |
334 | ```python
335 | """Add portfolio and position models
336 |
337 | Revision ID: 014_add_portfolio_models
338 | Revises: 013_add_backtest_persistence_models
339 | Create Date: 2025-11-01 10:00:00.000000
340 | """
341 | from alembic import op
342 | import sqlalchemy as sa
343 | from sqlalchemy.dialects import postgresql
344 |
345 | # revision identifiers
346 | revision = '014_add_portfolio_models'
347 | down_revision = '013_add_backtest_persistence_models'
348 | branch_labels = None
349 | depends_on = None
350 |
351 |
352 | def upgrade():
353 | """Create portfolio management tables."""
354 |
355 | # Create portfolios table
356 | op.create_table(
357 | 'mcp_portfolios',
358 | sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True),
359 | sa.Column('user_id', sa.String(50), nullable=False, server_default='default'),
360 | sa.Column('name', sa.String(200), nullable=False, server_default='My Portfolio'),
361 | sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
362 | sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
363 | )
364 |
365 | # Create indexes on portfolios
366 | op.create_index('idx_portfolio_user', 'mcp_portfolios', ['user_id'])
367 | op.create_unique_constraint('uq_user_portfolio_name', 'mcp_portfolios', ['user_id', 'name'])
368 |
369 | # Create positions table
370 | op.create_table(
371 | 'mcp_portfolio_positions',
372 | sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True),
373 | sa.Column('portfolio_id', postgresql.UUID(as_uuid=True), nullable=False),
374 | sa.Column('ticker', sa.String(20), nullable=False),
375 | sa.Column('shares', sa.Numeric(20, 8), nullable=False),
376 | sa.Column('average_cost_basis', sa.Numeric(12, 4), nullable=False),
377 | sa.Column('total_cost', sa.Numeric(20, 4), nullable=False),
378 | sa.Column('purchase_date', sa.DateTime(timezone=True), nullable=False),
379 | sa.Column('notes', sa.Text, nullable=True),
380 | sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
381 | sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
382 | sa.ForeignKeyConstraint(['portfolio_id'], ['mcp_portfolios.id'], ondelete='CASCADE'),
383 | )
384 |
385 | # Create indexes on positions
386 | op.create_index('idx_position_portfolio', 'mcp_portfolio_positions', ['portfolio_id'])
387 | op.create_index('idx_position_ticker', 'mcp_portfolio_positions', ['ticker'])
388 | op.create_index('idx_position_portfolio_ticker', 'mcp_portfolio_positions', ['portfolio_id', 'ticker'])
389 | op.create_unique_constraint('uq_portfolio_position_ticker', 'mcp_portfolio_positions', ['portfolio_id', 'ticker'])
390 |
391 |
392 | def downgrade():
393 | """Drop portfolio management tables."""
394 | op.drop_table('mcp_portfolio_positions')
395 | op.drop_table('mcp_portfolios')
396 | ```
397 |
398 | **Testing:**
399 | - Test upgrade: `alembic upgrade head`
400 | - Test downgrade: `alembic downgrade -1`
401 | - Verify indexes created: SQL query inspection
402 | - Test with SQLite and PostgreSQL
403 |
404 | ---
405 |
406 | ### Feature D: MCP Tools Implementation
407 |
408 | **Goal:** Implement 4 portfolio management tools following tool_registry.py pattern.
409 |
410 | **Files to Create:**
411 | - `maverick_mcp/api/routers/portfolio_management.py` - New tool implementations
412 | - `maverick_mcp/api/services/portfolio_persistence_service.py` - Service layer
413 | - `maverick_mcp/validation/portfolio_management.py` - Pydantic validation
414 |
415 | **Service Layer Pattern:**
416 |
417 | ```python
418 | # maverick_mcp/api/services/portfolio_persistence_service.py
419 |
420 | class PortfolioPersistenceService(BaseService):
421 | """Service for portfolio CRUD operations."""
422 |
423 | async def get_or_create_default_portfolio(self) -> UserPortfolio:
424 | """Get the default portfolio, create if doesn't exist."""
425 | pass
426 |
427 | async def add_position(self, ticker: str, shares: Decimal,
428 | price: Decimal, date: datetime,
429 | notes: Optional[str]) -> PortfolioPosition:
430 | """Add or update position with cost averaging."""
431 | pass
432 |
433 | async def get_portfolio_with_live_data(self) -> dict:
434 | """Fetch portfolio with current market prices."""
435 | pass
436 |
437 | async def remove_position(self, ticker: str,
438 | shares: Optional[Decimal]) -> bool:
439 | """Remove position or partial shares."""
440 | pass
441 |
442 | async def clear_portfolio(self) -> bool:
443 | """Delete all positions."""
444 | pass
445 | ```
446 |
447 | **Tool Registration:**
448 |
449 | ```python
450 | # Add to maverick_mcp/api/routers/tool_registry.py
451 |
452 | def register_portfolio_management_tools(mcp: FastMCP) -> None:
453 | """Register portfolio management tools."""
454 | from maverick_mcp.api.routers.portfolio_management import (
455 | add_portfolio_position,
456 | get_my_portfolio,
457 | remove_portfolio_position,
458 | clear_my_portfolio
459 | )
460 |
461 | mcp.tool(name="portfolio_add_position")(add_portfolio_position)
462 | mcp.tool(name="portfolio_get_my_portfolio")(get_my_portfolio)
463 | mcp.tool(name="portfolio_remove_position")(remove_portfolio_position)
464 | mcp.tool(name="portfolio_clear")(clear_my_portfolio)
465 | ```
466 |
467 | ---
468 |
469 | ### Feature E: MCP Resource Implementation
470 |
471 | **Goal:** Create `portfolio://my-holdings` resource for automatic AI context injection.
472 |
473 | **File to Modify:**
474 | - `maverick_mcp/api/server.py` - Add resource alongside existing health:// and dashboard:// resources
475 |
476 | **Resource Implementation:**
477 |
478 | ```python
479 | # Add to maverick_mcp/api/server.py (around line 823, near other resources)
480 |
481 | @mcp.resource("portfolio://my-holdings")
482 | def portfolio_holdings_resource() -> dict[str, Any]:
483 | """
484 | Portfolio holdings resource for AI context injection.
485 |
486 | Provides comprehensive portfolio context to AI agents including:
487 | - Current positions with live P&L
488 | - Portfolio metrics and diversification
489 | - Sector exposure analysis
490 | - Top/bottom performers
491 |
492 | This resource is automatically available to AI agents during conversations,
493 | enabling personalized analysis without requiring manual ticker input.
494 | """
495 | # Implementation using service layer with async handling
496 | pass
497 | ```
498 |
499 | ---
500 |
501 | ### Feature F: Phase 2 Tool Enhancements
502 |
503 | **Goal:** Enhance existing tools to auto-detect portfolio holdings.
504 |
505 | **Files to Modify:**
506 | 1. `maverick_mcp/api/routers/portfolio.py` - Enhance 3 existing tools
507 | 2. `maverick_mcp/validation/portfolio.py` - Update validation to allow optional parameters
508 |
509 | **Enhancement Pattern:**
510 | - Add optional parameters (tickers can be None)
511 | - Check portfolio for holdings if no tickers provided
512 | - Add position awareness to analysis results
513 | - Maintain backward compatibility
514 |
515 | ---
516 |
517 | ## 4. Progress (Living Document Section)
518 |
519 | | Date | Time | Item Completed / Status Update | Resulting Changes (LOC/Files) |
520 | |:-----|:-----|:------------------------------|:------------------------------|
521 | | 2025-11-01 | Start | Plan approved and documented | PORTFOLIO_PERSONALIZATION_PLAN.md created |
522 | | TBD | TBD | Implementation begins | - |
523 |
524 | _(This section will be updated during implementation)_
525 |
526 | ---
527 |
528 | ## 5. Surprises and Discoveries
529 |
530 | _(Technical issues discovered during implementation will be documented here)_
531 |
532 | **Anticipated Challenges:**
533 | 1. **MCP Resource Async Context:** Resources are sync functions but need async database calls - solved with event loop management (see existing health_resource pattern)
534 | 2. **Cost Basis Precision:** Financial calculations require Decimal precision, not floats - use Numeric(12,4) for prices, Numeric(20,8) for shares
535 | 3. **Portfolio Resource Performance:** Live price fetching could be slow - implement caching strategy, consider async batching
536 | 4. **Single User Assumption:** No user authentication means all operations use user_id="default" - acceptable for personal use
537 |
538 | ---
539 |
540 | ## 6. Decision Log
541 |
542 | | Date | Decision | Rationale |
543 | |:-----|:---------|:----------|
544 | | 2025-11-01 | **Cost Basis Method: Average Cost** | Simplest for educational use, matches existing PortfolioManager, avoids tax accounting complexity |
545 | | 2025-11-01 | **Table Names: mcp_portfolios, mcp_portfolio_positions** | Consistent with existing mcp_* naming convention for MCP-specific tables |
546 | | 2025-11-01 | **User ID: "default" for all users** | Single-user personal-use design, consistent with auth disabled architecture |
547 | | 2025-11-01 | **Numeric Precision: Numeric(12,4) for prices, Numeric(20,8) for shares** | Matches existing financial data patterns, supports fractional shares |
548 | | 2025-11-01 | **Optional tickers parameter for Phase 2** | Enables "just works" UX while maintaining backward compatibility |
549 | | 2025-11-01 | **MCP Resource for AI context** | Most elegant solution for automatic context injection without breaking tool contracts |
550 | | 2025-11-01 | **Domain-Driven Design pattern** | Follows established MaverickMCP architecture, clean separation of concerns |
551 |
552 | ---
553 |
554 | ## 7. Implementation Phases
555 |
556 | ### Phase 1: Foundation (4-5 days)
557 | **Files Created:** 8 new files
558 | **Files Modified:** 3 existing files
559 | **Estimated LOC:** ~2,500 lines
560 | **Tests:** ~1,200 lines
561 |
562 | ### Phase 2: Integration (2-3 days)
563 | **Files Modified:** 4 existing files
564 | **Estimated LOC:** ~800 lines additional
565 | **Tests:** ~600 lines additional
566 |
567 | ### Phase 3: Polish (1-2 days)
568 | **Documentation:** PORTFOLIO.md (~300 lines)
569 | **Performance:** Query optimization
570 | **Testing:** Manual Claude Desktop validation
571 |
572 | **Total Effort:** 7-10 days
573 | **Total New Code:** ~3,500 lines (including tests)
574 | **Total Tests:** ~1,800 lines
575 |
576 | ---
577 |
578 | ## 8. Risk Assessment
579 |
580 | **Low Risk:**
581 | - ✅ Follows established patterns
582 | - ✅ No breaking changes to existing tools
583 | - ✅ Optional Phase 2 enhancements
584 | - ✅ Well-scoped feature
585 |
586 | **Medium Risk:**
587 | - ⚠️ MCP resource performance with live prices
588 | - ⚠️ Migration compatibility (SQLite vs PostgreSQL)
589 | - ⚠️ Edge cases in cost basis averaging
590 |
591 | **Mitigation Strategies:**
592 | 1. **Performance:** Implement caching, batch price fetches, add timeout protection
593 | 2. **Migration:** Test with both SQLite and PostgreSQL, provide rollback path
594 | 3. **Edge Cases:** Comprehensive unit tests, property-based testing for arithmetic
595 |
596 | ---
597 |
598 | ## 9. Testing Strategy
599 |
600 | **Unit Tests (~60% of test code):**
601 | - Domain entity logic (Position, Portfolio)
602 | - Cost basis averaging edge cases
603 | - Numeric precision validation
604 | - Business logic validation
605 |
606 | **Integration Tests (~30% of test code):**
607 | - Database CRUD operations
608 | - Migration upgrade/downgrade
609 | - Service layer with real database
610 | - Cross-tool functionality
611 |
612 | **Manual Tests (~10% of effort):**
613 | - Claude Desktop end-to-end workflows
614 | - Natural language interactions
615 | - MCP resource visibility
616 | - Tool integration scenarios
617 |
618 | **Test Coverage Target:** 85%+
619 |
620 | ---
621 |
622 | ## 10. Success Metrics
623 |
624 | **Functional Success:**
625 | - [ ] All 4 new tools work in Claude Desktop
626 | - [ ] Portfolio resource visible to AI agents
627 | - [ ] Cost basis averaging accurate to 4 decimal places
628 | - [ ] Migration works on SQLite and PostgreSQL
629 | - [ ] 3 enhanced tools auto-detect portfolio
630 |
631 | **Quality Success:**
632 | - [ ] 85%+ test coverage
633 | - [ ] Zero linting errors (ruff)
634 | - [ ] Full type annotations (ty check passes)
635 | - [ ] All pre-commit hooks pass
636 |
637 | **UX Success:**
638 | - [ ] "Analyze my portfolio" works without ticker input
639 | - [ ] AI agents reference actual holdings in responses
640 | - [ ] Natural language interactions feel seamless
641 | - [ ] Error messages are clear and actionable
642 |
643 | ---
644 |
645 | ## 11. Related Documentation
646 |
647 | - **Original Issue:** [#40 - Portfolio Personalization](https://github.com/wshobson/maverick-mcp/issues/40)
648 | - **User Documentation:** `docs/PORTFOLIO.md` (to be created)
649 | - **API Documentation:** Tool docstrings and MCP introspection
650 | - **Testing Guide:** `tests/README.md` (to be updated)
651 |
652 | ---
653 |
654 | This execution plan provides a comprehensive roadmap following the PLANS.md rubric structure. The implementation is well-scoped, follows established patterns, and delivers significant UX improvement while maintaining code quality and architectural integrity.
655 |
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/retraining_pipeline.py:
--------------------------------------------------------------------------------
```python
1 | """Automated retraining pipeline for ML models with data drift detection."""
2 |
3 | import logging
4 | from collections.abc import Callable
5 | from datetime import datetime
6 | from typing import Any
7 |
8 | import numpy as np
9 | import pandas as pd
10 | from scipy import stats
11 | from sklearn.base import BaseEstimator
12 | from sklearn.metrics import accuracy_score, classification_report
13 | from sklearn.model_selection import train_test_split
14 | from sklearn.preprocessing import StandardScaler
15 |
16 | from .model_manager import ModelManager
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class DataDriftDetector:
22 | """Detects data drift in features and targets."""
23 |
24 | def __init__(self, significance_level: float = 0.05):
25 | """Initialize drift detector.
26 |
27 | Args:
28 | significance_level: Statistical significance level for drift detection
29 | """
30 | self.significance_level = significance_level
31 | self.reference_data: pd.DataFrame | None = None
32 | self.reference_target: pd.Series | None = None
33 | self.feature_stats: dict[str, dict[str, float]] = {}
34 |
35 | def set_reference_data(
36 | self, features: pd.DataFrame, target: pd.Series | None = None
37 | ):
38 | """Set reference data for drift detection.
39 |
40 | Args:
41 | features: Reference feature data
42 | target: Reference target data (optional)
43 | """
44 | self.reference_data = features.copy()
45 | self.reference_target = target.copy() if target is not None else None
46 |
47 | # Calculate reference statistics
48 | self.feature_stats = {}
49 | for col in features.columns:
50 | if features[col].dtype in ["float64", "float32", "int64", "int32"]:
51 | self.feature_stats[col] = {
52 | "mean": features[col].mean(),
53 | "std": features[col].std(),
54 | "min": features[col].min(),
55 | "max": features[col].max(),
56 | "median": features[col].median(),
57 | }
58 |
59 | logger.info(
60 | f"Set reference data with {len(features)} samples and {len(features.columns)} features"
61 | )
62 |
63 | def detect_feature_drift(
64 | self, new_features: pd.DataFrame
65 | ) -> dict[str, dict[str, Any]]:
66 | """Detect drift in features using statistical tests.
67 |
68 | Args:
69 | new_features: New feature data to compare
70 |
71 | Returns:
72 | Dictionary with drift detection results per feature
73 | """
74 | if self.reference_data is None:
75 | raise ValueError("Reference data not set")
76 |
77 | drift_results = {}
78 |
79 | for col in new_features.columns:
80 | if col not in self.reference_data.columns:
81 | continue
82 |
83 | if new_features[col].dtype not in ["float64", "float32", "int64", "int32"]:
84 | continue
85 |
86 | ref_data = self.reference_data[col].dropna()
87 | new_data = new_features[col].dropna()
88 |
89 | if len(ref_data) == 0 or len(new_data) == 0:
90 | continue
91 |
92 | # Perform statistical tests
93 | drift_detected = False
94 | test_results = {}
95 |
96 | try:
97 | # Kolmogorov-Smirnov test for distribution change
98 | ks_statistic, ks_p_value = stats.ks_2samp(ref_data, new_data)
99 | test_results["ks_statistic"] = ks_statistic
100 | test_results["ks_p_value"] = ks_p_value
101 | ks_drift = ks_p_value < self.significance_level
102 |
103 | # Mann-Whitney U test for location shift
104 | mw_statistic, mw_p_value = stats.mannwhitneyu(
105 | ref_data, new_data, alternative="two-sided"
106 | )
107 | test_results["mw_statistic"] = mw_statistic
108 | test_results["mw_p_value"] = mw_p_value
109 | mw_drift = mw_p_value < self.significance_level
110 |
111 | # Levene test for variance change
112 | levene_statistic, levene_p_value = stats.levene(ref_data, new_data)
113 | test_results["levene_statistic"] = levene_statistic
114 | test_results["levene_p_value"] = levene_p_value
115 | levene_drift = levene_p_value < self.significance_level
116 |
117 | # Overall drift detection
118 | drift_detected = ks_drift or mw_drift or levene_drift
119 |
120 | # Calculate effect sizes
121 | test_results["mean_diff"] = new_data.mean() - ref_data.mean()
122 | test_results["std_ratio"] = new_data.std() / (ref_data.std() + 1e-8)
123 |
124 | except Exception as e:
125 | logger.warning(f"Error in drift detection for {col}: {e}")
126 | test_results["error"] = str(e)
127 |
128 | drift_results[col] = {
129 | "drift_detected": drift_detected,
130 | "test_results": test_results,
131 | "reference_stats": self.feature_stats.get(col, {}),
132 | "new_stats": {
133 | "mean": new_data.mean(),
134 | "std": new_data.std(),
135 | "min": new_data.min(),
136 | "max": new_data.max(),
137 | "median": new_data.median(),
138 | },
139 | }
140 |
141 | return drift_results
142 |
143 | def detect_target_drift(self, new_target: pd.Series) -> dict[str, Any]:
144 | """Detect drift in target variable.
145 |
146 | Args:
147 | new_target: New target data to compare
148 |
149 | Returns:
150 | Dictionary with target drift results
151 | """
152 | if self.reference_target is None:
153 | logger.warning("No reference target data set")
154 | return {"drift_detected": False, "reason": "no_reference_target"}
155 |
156 | ref_target = self.reference_target.dropna()
157 | new_target = new_target.dropna()
158 |
159 | if len(ref_target) == 0 or len(new_target) == 0:
160 | return {"drift_detected": False, "reason": "insufficient_data"}
161 |
162 | drift_results = {"drift_detected": False}
163 |
164 | try:
165 | # For categorical targets, use chi-square test
166 | if ref_target.dtype == "object" or ref_target.nunique() < 10:
167 | ref_counts = ref_target.value_counts()
168 | new_counts = new_target.value_counts()
169 |
170 | # Align the categories
171 | all_categories = set(ref_counts.index) | set(new_counts.index)
172 | ref_aligned = [ref_counts.get(cat, 0) for cat in all_categories]
173 | new_aligned = [new_counts.get(cat, 0) for cat in all_categories]
174 |
175 | if sum(ref_aligned) > 0 and sum(new_aligned) > 0:
176 | chi2_stat, chi2_p_value = stats.chisquare(new_aligned, ref_aligned)
177 | drift_results.update(
178 | {
179 | "test_type": "chi_square",
180 | "chi2_statistic": chi2_stat,
181 | "chi2_p_value": chi2_p_value,
182 | "drift_detected": chi2_p_value < self.significance_level,
183 | }
184 | )
185 |
186 | # For continuous targets
187 | else:
188 | ks_statistic, ks_p_value = stats.ks_2samp(ref_target, new_target)
189 | drift_results.update(
190 | {
191 | "test_type": "kolmogorov_smirnov",
192 | "ks_statistic": ks_statistic,
193 | "ks_p_value": ks_p_value,
194 | "drift_detected": ks_p_value < self.significance_level,
195 | }
196 | )
197 |
198 | except Exception as e:
199 | logger.warning(f"Error in target drift detection: {e}")
200 | drift_results["error"] = str(e)
201 |
202 | return drift_results
203 |
204 | def get_drift_summary(
205 | self, feature_drift: dict[str, dict], target_drift: dict[str, Any]
206 | ) -> dict[str, Any]:
207 | """Get summary of drift detection results.
208 |
209 | Args:
210 | feature_drift: Feature drift results
211 | target_drift: Target drift results
212 |
213 | Returns:
214 | Summary dictionary
215 | """
216 | total_features = len(feature_drift)
217 | drifted_features = sum(
218 | 1 for result in feature_drift.values() if result["drift_detected"]
219 | )
220 | target_drift_detected = target_drift.get("drift_detected", False)
221 |
222 | drift_severity = "none"
223 | if target_drift_detected or drifted_features > total_features * 0.5:
224 | drift_severity = "high"
225 | elif drifted_features > total_features * 0.2:
226 | drift_severity = "medium"
227 | elif drifted_features > 0:
228 | drift_severity = "low"
229 |
230 | return {
231 | "total_features": total_features,
232 | "drifted_features": drifted_features,
233 | "drift_percentage": drifted_features / max(total_features, 1) * 100,
234 | "target_drift_detected": target_drift_detected,
235 | "drift_severity": drift_severity,
236 | "recommendation": self._get_retraining_recommendation(
237 | drift_severity, target_drift_detected
238 | ),
239 | }
240 |
241 | def _get_retraining_recommendation(
242 | self, drift_severity: str, target_drift: bool
243 | ) -> str:
244 | """Get retraining recommendation based on drift severity."""
245 | if target_drift:
246 | return "immediate_retraining"
247 | elif drift_severity == "high":
248 | return "urgent_retraining"
249 | elif drift_severity == "medium":
250 | return "scheduled_retraining"
251 | elif drift_severity == "low":
252 | return "monitor_closely"
253 | else:
254 | return "no_action_needed"
255 |
256 |
257 | class ModelPerformanceMonitor:
258 | """Monitors model performance and detects degradation."""
259 |
260 | def __init__(self, performance_threshold: float = 0.05):
261 | """Initialize performance monitor.
262 |
263 | Args:
264 | performance_threshold: Threshold for performance degradation detection
265 | """
266 | self.performance_threshold = performance_threshold
267 | self.baseline_metrics: dict[str, float] = {}
268 | self.performance_history: list[dict[str, Any]] = []
269 |
270 | def set_baseline_performance(self, metrics: dict[str, float]):
271 | """Set baseline performance metrics.
272 |
273 | Args:
274 | metrics: Baseline performance metrics
275 | """
276 | self.baseline_metrics = metrics.copy()
277 | logger.info(f"Set baseline performance: {metrics}")
278 |
279 | def evaluate_performance(
280 | self,
281 | model: BaseEstimator,
282 | X_test: pd.DataFrame,
283 | y_test: pd.Series,
284 | additional_metrics: dict[str, float] | None = None,
285 | ) -> dict[str, Any]:
286 | """Evaluate current model performance.
287 |
288 | Args:
289 | model: Trained model
290 | X_test: Test features
291 | y_test: Test targets
292 | additional_metrics: Additional metrics to include
293 |
294 | Returns:
295 | Performance evaluation results
296 | """
297 | try:
298 | # Make predictions
299 | y_pred = model.predict(X_test)
300 |
301 | # Calculate metrics
302 | metrics = {
303 | "accuracy": accuracy_score(y_test, y_pred),
304 | "timestamp": datetime.now().isoformat(),
305 | }
306 |
307 | # Add additional metrics if provided
308 | if additional_metrics:
309 | metrics.update(additional_metrics)
310 |
311 | # Detect performance degradation
312 | degradation_detected = False
313 | degradation_details = {}
314 |
315 | for metric_name, current_value in metrics.items():
316 | if metric_name in self.baseline_metrics and metric_name != "timestamp":
317 | baseline_value = self.baseline_metrics[metric_name]
318 | degradation = (baseline_value - current_value) / abs(baseline_value)
319 |
320 | if degradation > self.performance_threshold:
321 | degradation_detected = True
322 | degradation_details[metric_name] = {
323 | "baseline": baseline_value,
324 | "current": current_value,
325 | "degradation": degradation,
326 | }
327 |
328 | evaluation_result = {
329 | "metrics": metrics,
330 | "degradation_detected": degradation_detected,
331 | "degradation_details": degradation_details,
332 | "classification_report": classification_report(
333 | y_test, y_pred, output_dict=True
334 | ),
335 | }
336 |
337 | # Store in history
338 | self.performance_history.append(evaluation_result)
339 |
340 | # Keep only recent history
341 | if len(self.performance_history) > 100:
342 | self.performance_history = self.performance_history[-100:]
343 |
344 | return evaluation_result
345 |
346 | except Exception as e:
347 | logger.error(f"Error evaluating model performance: {e}")
348 | return {"error": str(e)}
349 |
350 | def get_performance_trend(self, metric_name: str = "accuracy") -> dict[str, Any]:
351 | """Analyze performance trend over time.
352 |
353 | Args:
354 | metric_name: Metric to analyze
355 |
356 | Returns:
357 | Trend analysis results
358 | """
359 | if not self.performance_history:
360 | return {"trend": "no_data"}
361 |
362 | values = []
363 | timestamps = []
364 |
365 | for record in self.performance_history:
366 | if metric_name in record["metrics"]:
367 | values.append(record["metrics"][metric_name])
368 | timestamps.append(record["metrics"]["timestamp"])
369 |
370 | if len(values) < 3:
371 | return {"trend": "insufficient_data"}
372 |
373 | # Calculate trend
374 | x = np.arange(len(values))
375 | slope, _, r_value, p_value, _ = stats.linregress(x, values)
376 |
377 | trend_direction = "stable"
378 | if p_value < 0.05: # Statistically significant trend
379 | if slope > 0:
380 | trend_direction = "improving"
381 | else:
382 | trend_direction = "degrading"
383 |
384 | return {
385 | "trend": trend_direction,
386 | "slope": slope,
387 | "r_squared": r_value**2,
388 | "p_value": p_value,
389 | "recent_values": values[-5:],
390 | "timestamps": timestamps[-5:],
391 | }
392 |
393 |
394 | class AutoRetrainingPipeline:
395 | """Automated pipeline for model retraining with drift detection and performance monitoring."""
396 |
397 | def __init__(
398 | self,
399 | model_manager: ModelManager,
400 | model_factory: Callable[[], BaseEstimator],
401 | feature_extractor: Callable[[pd.DataFrame], pd.DataFrame],
402 | target_extractor: Callable[[pd.DataFrame], pd.Series],
403 | retraining_schedule_hours: int = 24,
404 | min_samples_for_retraining: int = 100,
405 | ):
406 | """Initialize auto-retraining pipeline.
407 |
408 | Args:
409 | model_manager: Model manager instance
410 | model_factory: Function that creates new model instances
411 | feature_extractor: Function to extract features from data
412 | target_extractor: Function to extract targets from data
413 | retraining_schedule_hours: Hours between scheduled retraining checks
414 | min_samples_for_retraining: Minimum samples required for retraining
415 | """
416 | self.model_manager = model_manager
417 | self.model_factory = model_factory
418 | self.feature_extractor = feature_extractor
419 | self.target_extractor = target_extractor
420 | self.retraining_schedule_hours = retraining_schedule_hours
421 | self.min_samples_for_retraining = min_samples_for_retraining
422 |
423 | self.drift_detector = DataDriftDetector()
424 | self.performance_monitor = ModelPerformanceMonitor()
425 |
426 | self.last_retraining: dict[str, datetime] = {}
427 | self.retraining_history: list[dict[str, Any]] = []
428 |
429 | def should_retrain(
430 | self,
431 | model_id: str,
432 | new_data: pd.DataFrame,
433 | force_check: bool = False,
434 | ) -> tuple[bool, str]:
435 | """Determine if a model should be retrained.
436 |
437 | Args:
438 | model_id: Model identifier
439 | new_data: New data for evaluation
440 | force_check: Force retraining check regardless of schedule
441 |
442 | Returns:
443 | Tuple of (should_retrain, reason)
444 | """
445 | # Check schedule
446 | last_retrain = self.last_retraining.get(model_id)
447 | if not force_check and last_retrain is not None:
448 | time_since_retrain = datetime.now() - last_retrain
449 | if (
450 | time_since_retrain.total_seconds()
451 | < self.retraining_schedule_hours * 3600
452 | ):
453 | return False, "schedule_not_due"
454 |
455 | # Check minimum samples
456 | if len(new_data) < self.min_samples_for_retraining:
457 | return False, "insufficient_samples"
458 |
459 | # Extract features and targets
460 | try:
461 | features = self.feature_extractor(new_data)
462 | targets = self.target_extractor(new_data)
463 | except Exception as e:
464 | logger.error(f"Error extracting features/targets: {e}")
465 | return False, f"extraction_error: {e}"
466 |
467 | # Check for data drift
468 | if self.drift_detector.reference_data is not None:
469 | feature_drift = self.drift_detector.detect_feature_drift(features)
470 | target_drift = self.drift_detector.detect_target_drift(targets)
471 | drift_summary = self.drift_detector.get_drift_summary(
472 | feature_drift, target_drift
473 | )
474 |
475 | if drift_summary["recommendation"] in [
476 | "immediate_retraining",
477 | "urgent_retraining",
478 | ]:
479 | return True, f"data_drift_{drift_summary['drift_severity']}"
480 |
481 | # Check performance degradation
482 | current_model = self.model_manager.load_model(model_id)
483 | if current_model is not None and current_model.model is not None:
484 | try:
485 | # Split data for evaluation
486 | X_train, X_test, y_train, y_test = train_test_split(
487 | features, targets, test_size=0.3, random_state=42, stratify=targets
488 | )
489 |
490 | # Scale features if scaler is available
491 | if current_model.scaler is not None:
492 | X_test_scaled = current_model.scaler.transform(X_test)
493 | else:
494 | X_test_scaled = X_test
495 |
496 | # Evaluate performance
497 | performance_result = self.performance_monitor.evaluate_performance(
498 | current_model.model, X_test_scaled, y_test
499 | )
500 |
501 | if performance_result.get("degradation_detected", False):
502 | return True, "performance_degradation"
503 |
504 | except Exception as e:
505 | logger.warning(f"Error evaluating model performance: {e}")
506 |
507 | return False, "no_triggers"
508 |
509 | def retrain_model(
510 | self,
511 | model_id: str,
512 | training_data: pd.DataFrame,
513 | validation_split: float = 0.2,
514 | **model_params,
515 | ) -> str | None:
516 | """Retrain a model with new data.
517 |
518 | Args:
519 | model_id: Model identifier
520 | training_data: Training data
521 | validation_split: Fraction of data to use for validation
522 | **model_params: Additional parameters for model training
523 |
524 | Returns:
525 | New model version string if successful, None otherwise
526 | """
527 | try:
528 | # Extract features and targets
529 | features = self.feature_extractor(training_data)
530 | targets = self.target_extractor(training_data)
531 |
532 | # Split data
533 | X_train, X_val, y_train, y_val = train_test_split(
534 | features,
535 | targets,
536 | test_size=validation_split,
537 | random_state=42,
538 | stratify=targets,
539 | )
540 |
541 | # Scale features
542 | scaler = StandardScaler()
543 | X_train_scaled = scaler.fit_transform(X_train)
544 | X_val_scaled = scaler.transform(X_val)
545 |
546 | # Create and train new model
547 | model = self.model_factory()
548 | model.set_params(**model_params)
549 | model.fit(X_train_scaled, y_train)
550 |
551 | # Evaluate model
552 | train_score = model.score(X_train_scaled, y_train)
553 | val_score = model.score(X_val_scaled, y_val)
554 |
555 | # Create version string
556 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
557 | new_version = f"v_{timestamp}"
558 |
559 | # Prepare metadata
560 | metadata = {
561 | "training_samples": len(X_train),
562 | "validation_samples": len(X_val),
563 | "features_count": X_train.shape[1],
564 | "model_params": model_params,
565 | "retraining_trigger": "automated",
566 | }
567 |
568 | # Prepare performance metrics
569 | performance_metrics = {
570 | "train_accuracy": train_score,
571 | "validation_accuracy": val_score,
572 | "overfitting_gap": train_score - val_score,
573 | }
574 |
575 | # Save model
576 | success = self.model_manager.save_model(
577 | model_id=model_id,
578 | version=new_version,
579 | model=model,
580 | scaler=scaler,
581 | metadata=metadata,
582 | performance_metrics=performance_metrics,
583 | set_as_active=True, # Set as active if validation performance is good
584 | )
585 |
586 | if success:
587 | # Update retraining history
588 | self.last_retraining[model_id] = datetime.now()
589 | self.retraining_history.append(
590 | {
591 | "model_id": model_id,
592 | "version": new_version,
593 | "timestamp": datetime.now().isoformat(),
594 | "training_samples": len(X_train),
595 | "validation_accuracy": val_score,
596 | }
597 | )
598 |
599 | # Update drift detector reference data
600 | self.drift_detector.set_reference_data(features, targets)
601 |
602 | # Update performance monitor baseline
603 | self.performance_monitor.set_baseline_performance(performance_metrics)
604 |
605 | logger.info(
606 | f"Successfully retrained model {model_id} -> {new_version} (val_acc: {val_score:.4f})"
607 | )
608 | return new_version
609 | else:
610 | logger.error(f"Failed to save retrained model {model_id}")
611 | return None
612 |
613 | except Exception as e:
614 | logger.error(f"Error retraining model {model_id}: {e}")
615 | return None
616 |
617 | def run_retraining_check(
618 | self, model_id: str, new_data: pd.DataFrame
619 | ) -> dict[str, Any]:
620 | """Run complete retraining check and execute if needed.
621 |
622 | Args:
623 | model_id: Model identifier
624 | new_data: New data for evaluation
625 |
626 | Returns:
627 | Dictionary with check results and actions taken
628 | """
629 | start_time = datetime.now()
630 |
631 | try:
632 | # Check if retraining is needed
633 | should_retrain, reason = self.should_retrain(model_id, new_data)
634 |
635 | result = {
636 | "model_id": model_id,
637 | "timestamp": start_time.isoformat(),
638 | "should_retrain": should_retrain,
639 | "reason": reason,
640 | "data_samples": len(new_data),
641 | "new_version": None,
642 | "success": False,
643 | }
644 |
645 | if should_retrain:
646 | logger.info(f"Retraining {model_id} due to: {reason}")
647 | new_version = self.retrain_model(model_id, new_data)
648 |
649 | if new_version:
650 | result.update(
651 | {
652 | "new_version": new_version,
653 | "success": True,
654 | "action": "retrained",
655 | }
656 | )
657 | else:
658 | result.update(
659 | {
660 | "action": "retrain_failed",
661 | "error": "Model retraining failed",
662 | }
663 | )
664 | else:
665 | result.update(
666 | {
667 | "action": "no_retrain",
668 | "success": True,
669 | }
670 | )
671 |
672 | # Calculate execution time
673 | execution_time = (datetime.now() - start_time).total_seconds()
674 | result["execution_time_seconds"] = execution_time
675 |
676 | return result
677 |
678 | except Exception as e:
679 | logger.error(f"Error in retraining check for {model_id}: {e}")
680 | return {
681 | "model_id": model_id,
682 | "timestamp": start_time.isoformat(),
683 | "should_retrain": False,
684 | "reason": "check_error",
685 | "error": str(e),
686 | "success": False,
687 | "execution_time_seconds": (datetime.now() - start_time).total_seconds(),
688 | }
689 |
690 | def get_retraining_summary(self) -> dict[str, Any]:
691 | """Get summary of retraining pipeline status.
692 |
693 | Returns:
694 | Summary dictionary
695 | """
696 | return {
697 | "total_models_managed": len(self.last_retraining),
698 | "total_retrainings": len(self.retraining_history),
699 | "recent_retrainings": self.retraining_history[-10:],
700 | "last_retraining_times": {
701 | model_id: timestamp.isoformat()
702 | for model_id, timestamp in self.last_retraining.items()
703 | },
704 | "retraining_schedule_hours": self.retraining_schedule_hours,
705 | "min_samples_for_retraining": self.min_samples_for_retraining,
706 | }
707 |
708 |
709 | # Alias for backward compatibility
710 | RetrainingPipeline = AutoRetrainingPipeline
711 |
712 | # Ensure all expected names are available
713 | __all__ = [
714 | "DataDriftDetector",
715 | "ModelPerformanceMonitor",
716 | "AutoRetrainingPipeline",
717 | "RetrainingPipeline", # Alias for backward compatibility
718 | ]
719 |
```
--------------------------------------------------------------------------------
/maverick_mcp/data/performance.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Performance optimization utilities for Maverick-MCP.
3 |
4 | This module provides Redis connection pooling, request caching,
5 | and query optimization features to improve application performance.
6 | """
7 |
8 | import hashlib
9 | import json
10 | import logging
11 | import time
12 | from collections.abc import Callable
13 | from contextlib import asynccontextmanager
14 | from functools import wraps
15 | from typing import Any, TypeVar, cast
16 |
17 | import redis.asyncio as redis
18 | from redis.asyncio.client import Pipeline
19 | from sqlalchemy import text
20 | from sqlalchemy.ext.asyncio import AsyncSession
21 |
22 | from maverick_mcp.config.settings import get_settings
23 | from maverick_mcp.data.session_management import get_async_db_session
24 |
25 | settings = get_settings()
26 | logger = logging.getLogger(__name__)
27 |
28 | # Type variables for generic typing
29 | F = TypeVar("F", bound=Callable[..., Any])
30 |
31 |
32 | class RedisConnectionManager:
33 | """
34 | Centralized Redis connection manager with connection pooling.
35 |
36 | This manager provides:
37 | - Connection pooling with configurable limits
38 | - Automatic failover and retry logic
39 | - Health monitoring and metrics
40 | - Graceful degradation when Redis is unavailable
41 | """
42 |
43 | def __init__(self):
44 | self._pool: redis.ConnectionPool | None = None
45 | self._client: redis.Redis | None = None
46 | self._initialized = False
47 | self._healthy = False
48 | self._last_health_check = 0
49 | self._health_check_interval = 30 # seconds
50 |
51 | # Connection pool configuration
52 | self._max_connections = settings.db.redis_max_connections
53 | self._retry_on_timeout = settings.db.redis_retry_on_timeout
54 | self._socket_timeout = settings.db.redis_socket_timeout
55 | self._socket_connect_timeout = settings.db.redis_socket_connect_timeout
56 | self._health_check_interval_sec = 30
57 |
58 | # Metrics
59 | self._metrics = {
60 | "connections_created": 0,
61 | "connections_closed": 0,
62 | "commands_executed": 0,
63 | "errors": 0,
64 | "health_checks": 0,
65 | "last_error": None,
66 | }
67 |
68 | async def initialize(self) -> bool:
69 | """
70 | Initialize Redis connection pool.
71 |
72 | Returns:
73 | bool: True if initialization successful, False otherwise
74 | """
75 | if self._initialized:
76 | return self._healthy
77 |
78 | try:
79 | # Create connection pool
80 | self._pool = redis.ConnectionPool.from_url(
81 | settings.redis.url,
82 | max_connections=self._max_connections,
83 | retry_on_timeout=self._retry_on_timeout,
84 | socket_timeout=self._socket_timeout,
85 | socket_connect_timeout=self._socket_connect_timeout,
86 | decode_responses=True,
87 | health_check_interval=self._health_check_interval_sec,
88 | )
89 |
90 | # Create Redis client
91 | self._client = redis.Redis(connection_pool=self._pool)
92 |
93 | client = self._client
94 | if client is None: # Defensive guard for static type checking
95 | msg = "Redis client initialization failed"
96 | raise RuntimeError(msg)
97 |
98 | # Test connection
99 | await client.ping()
100 |
101 | self._healthy = True
102 | self._initialized = True
103 | self._metrics["connections_created"] += 1
104 |
105 | logger.info(
106 | f"Redis connection pool initialized: "
107 | f"max_connections={self._max_connections}, "
108 | f"url={settings.redis.url}"
109 | )
110 |
111 | return True
112 |
113 | except Exception as e:
114 | logger.error(f"Failed to initialize Redis connection pool: {e}")
115 | self._metrics["errors"] += 1
116 | self._metrics["last_error"] = str(e)
117 | self._healthy = False
118 | return False
119 |
120 | async def get_client(self) -> redis.Redis | None:
121 | """
122 | Get Redis client from the connection pool.
123 |
124 | Returns:
125 | Redis client or None if unavailable
126 | """
127 | if not self._initialized:
128 | await self.initialize()
129 |
130 | if not self._healthy:
131 | await self._health_check()
132 |
133 | return self._client if self._healthy else None
134 |
135 | async def _health_check(self) -> bool:
136 | """
137 | Perform health check on Redis connection.
138 |
139 | Returns:
140 | bool: True if healthy, False otherwise
141 | """
142 | current_time = time.time()
143 |
144 | # Skip health check if recently performed
145 | if (current_time - self._last_health_check) < self._health_check_interval:
146 | return self._healthy
147 |
148 | self._last_health_check = current_time
149 | self._metrics["health_checks"] += 1
150 |
151 | try:
152 | if self._client:
153 | await self._client.ping()
154 | self._healthy = True
155 | logger.debug("Redis health check passed")
156 | else:
157 | self._healthy = False
158 |
159 | except Exception as e:
160 | logger.warning(f"Redis health check failed: {e}")
161 | self._healthy = False
162 | self._metrics["errors"] += 1
163 | self._metrics["last_error"] = str(e)
164 |
165 | # Try to reinitialize
166 | await self.initialize()
167 |
168 | return self._healthy
169 |
170 | async def execute_command(self, command: str, *args, **kwargs) -> Any:
171 | """
172 | Execute Redis command with error handling and metrics.
173 |
174 | Args:
175 | command: Redis command name
176 | *args: Command arguments
177 | **kwargs: Command keyword arguments
178 |
179 | Returns:
180 | Command result or None if failed
181 | """
182 | client = await self.get_client()
183 | if not client:
184 | return None
185 |
186 | try:
187 | self._metrics["commands_executed"] += 1
188 | result = await getattr(client, command)(*args, **kwargs)
189 | return result
190 |
191 | except Exception as e:
192 | logger.error(f"Redis command '{command}' failed: {e}")
193 | self._metrics["errors"] += 1
194 | self._metrics["last_error"] = str(e)
195 | return None
196 |
197 | async def pipeline(self) -> Pipeline | None:
198 | """
199 | Create Redis pipeline for batch operations.
200 |
201 | Returns:
202 | Redis pipeline or None if unavailable
203 | """
204 | client = await self.get_client()
205 | if not client:
206 | return None
207 |
208 | return client.pipeline()
209 |
210 | def get_metrics(self) -> dict[str, Any]:
211 | """Get connection pool metrics."""
212 | metrics = self._metrics.copy()
213 | metrics.update(
214 | {
215 | "healthy": self._healthy,
216 | "initialized": self._initialized,
217 | "pool_size": self._max_connections,
218 | "pool_created": bool(self._pool),
219 | }
220 | )
221 |
222 | if self._pool:
223 | # Safely get pool metrics with fallbacks for missing attributes
224 | try:
225 | metrics["pool_created_connections"] = getattr(
226 | self._pool, "created_connections", 0
227 | )
228 | except AttributeError:
229 | metrics["pool_created_connections"] = 0
230 |
231 | try:
232 | metrics["pool_available_connections"] = len(
233 | getattr(self._pool, "_available_connections", [])
234 | )
235 | except (AttributeError, TypeError):
236 | metrics["pool_available_connections"] = 0
237 |
238 | try:
239 | metrics["pool_in_use_connections"] = len(
240 | getattr(self._pool, "_in_use_connections", [])
241 | )
242 | except (AttributeError, TypeError):
243 | metrics["pool_in_use_connections"] = 0
244 |
245 | return metrics
246 |
247 | async def close(self):
248 | """Close connection pool gracefully."""
249 | if self._client:
250 | # Use aclose() instead of close() to avoid deprecation warning
251 | # aclose() is the new async close method in redis-py 5.0+
252 | if hasattr(self._client, "aclose"):
253 | await self._client.aclose()
254 | else:
255 | # Fallback for older versions
256 | await self._client.close()
257 | self._metrics["connections_closed"] += 1
258 |
259 | if self._pool:
260 | await self._pool.disconnect()
261 |
262 | self._initialized = False
263 | self._healthy = False
264 | logger.info("Redis connection pool closed")
265 |
266 |
267 | # Global Redis connection manager instance
268 | redis_manager = RedisConnectionManager()
269 |
270 |
271 | class RequestCache:
272 | """
273 | Smart request-level caching system.
274 |
275 | This system provides:
276 | - Automatic cache key generation based on function signature
277 | - TTL strategies for different data types
278 | - Cache invalidation mechanisms
279 | - Hit/miss metrics and monitoring
280 | """
281 |
282 | def __init__(self):
283 | self._hit_count = 0
284 | self._miss_count = 0
285 | self._error_count = 0
286 |
287 | # Default TTL values for different data types (in seconds)
288 | self._default_ttls = {
289 | "stock_data": 3600, # 1 hour for stock data
290 | "technical_analysis": 1800, # 30 minutes for technical indicators
291 | "market_data": 300, # 5 minutes for market data
292 | "screening": 7200, # 2 hours for screening results
293 | "portfolio": 1800, # 30 minutes for portfolio analysis
294 | "macro_data": 3600, # 1 hour for macro data
295 | "default": 900, # 15 minutes default
296 | }
297 |
298 | def _generate_cache_key(self, prefix: str, *args, **kwargs) -> str:
299 | """
300 | Generate cache key from function arguments.
301 |
302 | Args:
303 | prefix: Cache key prefix
304 | *args: Function arguments
305 | **kwargs: Function keyword arguments
306 |
307 | Returns:
308 | Generated cache key
309 | """
310 | # Create a hash of the arguments
311 | key_data = {
312 | "args": args,
313 | "kwargs": sorted(kwargs.items()),
314 | }
315 |
316 | key_hash = hashlib.sha256(
317 | json.dumps(key_data, sort_keys=True, default=str).encode()
318 | ).hexdigest()[:16] # Use first 16 chars for brevity
319 |
320 | return f"cache:{prefix}:{key_hash}"
321 |
322 | def _get_ttl(self, data_type: str) -> int:
323 | """Get TTL for data type."""
324 | return self._default_ttls.get(data_type, self._default_ttls["default"])
325 |
326 | async def get(self, key: str) -> Any | None:
327 | """
328 | Get value from cache.
329 |
330 | Args:
331 | key: Cache key
332 |
333 | Returns:
334 | Cached value or None if not found
335 | """
336 | try:
337 | client = await redis_manager.get_client()
338 | if not client:
339 | return None
340 |
341 | data = await client.get(key)
342 | if data:
343 | self._hit_count += 1
344 | logger.debug(f"Cache hit for key: {key}")
345 | return json.loads(data)
346 | else:
347 | self._miss_count += 1
348 | logger.debug(f"Cache miss for key: {key}")
349 | return None
350 |
351 | except Exception as e:
352 | logger.error(f"Error getting from cache: {e}")
353 | self._error_count += 1
354 | return None
355 |
356 | async def set(
357 | self, key: str, value: Any, ttl: int | None = None, data_type: str = "default"
358 | ) -> bool:
359 | """
360 | Set value in cache.
361 |
362 | Args:
363 | key: Cache key
364 | value: Value to cache
365 | ttl: Time to live in seconds
366 | data_type: Data type for TTL determination
367 |
368 | Returns:
369 | True if successful, False otherwise
370 | """
371 | try:
372 | client = await redis_manager.get_client()
373 | if not client:
374 | return False
375 |
376 | if ttl is None:
377 | ttl = self._get_ttl(data_type)
378 |
379 | serialized_value = json.dumps(value, default=str)
380 | success = await client.setex(key, ttl, serialized_value)
381 |
382 | if success:
383 | logger.debug(f"Cached value for key: {key} (TTL: {ttl}s)")
384 |
385 | return bool(success)
386 |
387 | except Exception as e:
388 | logger.error(f"Error setting cache: {e}")
389 | self._error_count += 1
390 | return False
391 |
392 | async def delete(self, key: str) -> bool:
393 | """Delete key from cache."""
394 | try:
395 | client = await redis_manager.get_client()
396 | if not client:
397 | return False
398 |
399 | result = await client.delete(key)
400 | return bool(result)
401 |
402 | except Exception as e:
403 | logger.error(f"Error deleting from cache: {e}")
404 | self._error_count += 1
405 | return False
406 |
407 | async def delete_pattern(self, pattern: str) -> int:
408 | """Delete all keys matching pattern."""
409 | try:
410 | client = await redis_manager.get_client()
411 | if not client:
412 | return 0
413 |
414 | keys = await client.keys(pattern)
415 | if keys:
416 | result = await client.delete(*keys)
417 | logger.info(f"Deleted {result} keys matching pattern: {pattern}")
418 | return result
419 |
420 | return 0
421 |
422 | except Exception as e:
423 | logger.error(f"Error deleting pattern: {e}")
424 | self._error_count += 1
425 | return 0
426 |
427 | def get_metrics(self) -> dict[str, Any]:
428 | """Get cache metrics."""
429 | total_requests = self._hit_count + self._miss_count
430 | hit_rate = (self._hit_count / total_requests) if total_requests > 0 else 0
431 |
432 | return {
433 | "hit_count": self._hit_count,
434 | "miss_count": self._miss_count,
435 | "error_count": self._error_count,
436 | "total_requests": total_requests,
437 | "hit_rate": hit_rate,
438 | "ttl_config": self._default_ttls,
439 | }
440 |
441 |
442 | # Global request cache instance
443 | request_cache = RequestCache()
444 |
445 |
446 | def cached(
447 | data_type: str = "default",
448 | ttl: int | None = None,
449 | key_prefix: str | None = None,
450 | invalidate_patterns: list[str] | None = None,
451 | ):
452 | """
453 | Decorator for automatic function result caching.
454 |
455 | Args:
456 | data_type: Data type for TTL determination
457 | ttl: Custom TTL in seconds
458 | key_prefix: Custom cache key prefix
459 | invalidate_patterns: Patterns to invalidate on update
460 |
461 | Example:
462 | @cached(data_type="stock_data", ttl=3600)
463 | async def get_stock_price(symbol: str) -> float:
464 | # Expensive operation
465 | return price
466 | """
467 |
468 | def decorator(func: F) -> F:
469 | @wraps(func)
470 | async def wrapper(*args, **kwargs):
471 | # Generate cache key
472 | prefix = key_prefix or f"{func.__module__}.{func.__name__}"
473 | cache_key = request_cache._generate_cache_key(prefix, *args, **kwargs)
474 |
475 | # Try to get from cache
476 | cached_result = await request_cache.get(cache_key)
477 | if cached_result is not None:
478 | return cached_result
479 |
480 | # Execute function
481 | result = await func(*args, **kwargs)
482 |
483 | # Cache result
484 | if result is not None:
485 | await request_cache.set(cache_key, result, ttl, data_type)
486 |
487 | return result
488 |
489 | # Add cache invalidation method
490 | async def invalidate_cache(*args, **kwargs):
491 | """Invalidate cache for this function."""
492 | prefix = key_prefix or f"{func.__module__}.{func.__name__}"
493 | cache_key = request_cache._generate_cache_key(prefix, *args, **kwargs)
494 | await request_cache.delete(cache_key)
495 |
496 | # Invalidate patterns if specified
497 | if invalidate_patterns:
498 | for pattern in invalidate_patterns:
499 | await request_cache.delete_pattern(pattern)
500 |
501 | typed_wrapper = cast(F, wrapper)
502 | cast(Any, typed_wrapper).invalidate_cache = invalidate_cache
503 | return typed_wrapper
504 |
505 | return decorator
506 |
507 |
508 | class QueryOptimizer:
509 | """
510 | Database query optimization utilities.
511 |
512 | This class provides:
513 | - Query performance monitoring
514 | - Index recommendations
515 | - N+1 query detection
516 | - Connection pool monitoring
517 | """
518 |
519 | def __init__(self):
520 | self._query_stats = {}
521 | self._slow_query_threshold = 1.0 # seconds
522 | self._slow_queries = []
523 |
524 | def monitor_query(self, query_name: str):
525 | """
526 | Decorator for monitoring query performance.
527 |
528 | Args:
529 | query_name: Name for the query (for metrics)
530 | """
531 |
532 | def decorator(func: F) -> F:
533 | @wraps(func)
534 | async def wrapper(*args, **kwargs):
535 | start_time = time.time()
536 |
537 | try:
538 | result = await func(*args, **kwargs)
539 | execution_time = time.time() - start_time
540 |
541 | # Update statistics
542 | if query_name not in self._query_stats:
543 | self._query_stats[query_name] = {
544 | "count": 0,
545 | "total_time": 0,
546 | "avg_time": 0,
547 | "max_time": 0,
548 | "min_time": float("inf"),
549 | }
550 |
551 | stats = self._query_stats[query_name]
552 | stats["count"] += 1
553 | stats["total_time"] += execution_time
554 | stats["avg_time"] = stats["total_time"] / stats["count"]
555 | stats["max_time"] = max(stats["max_time"], execution_time)
556 | stats["min_time"] = min(stats["min_time"], execution_time)
557 |
558 | # Track slow queries
559 | if execution_time > self._slow_query_threshold:
560 | self._slow_queries.append(
561 | {
562 | "query_name": query_name,
563 | "execution_time": execution_time,
564 | "timestamp": time.time(),
565 | "args": str(args)[:200], # Truncate long args
566 | }
567 | )
568 |
569 | # Keep only last 100 slow queries
570 | if len(self._slow_queries) > 100:
571 | self._slow_queries = self._slow_queries[-100:]
572 |
573 | logger.warning(
574 | f"Slow query detected: {query_name} took {execution_time:.2f}s"
575 | )
576 |
577 | return result
578 |
579 | except Exception as e:
580 | execution_time = time.time() - start_time
581 | logger.error(
582 | f"Query {query_name} failed after {execution_time:.2f}s: {e}"
583 | )
584 | raise
585 |
586 | return cast(F, wrapper)
587 |
588 | return decorator
589 |
590 | def get_query_stats(self) -> dict[str, Any]:
591 | """Get query performance statistics."""
592 | return {
593 | "query_stats": self._query_stats,
594 | "slow_queries": self._slow_queries[-10:], # Last 10 slow queries
595 | "slow_query_threshold": self._slow_query_threshold,
596 | }
597 |
598 | async def analyze_missing_indexes(
599 | self, session: AsyncSession
600 | ) -> list[dict[str, Any]]:
601 | """
602 | Analyze database for missing indexes.
603 |
604 | Args:
605 | session: Database session
606 |
607 | Returns:
608 | List of recommended indexes
609 | """
610 | recommendations = []
611 |
612 | try:
613 | # Check for common missing indexes
614 | queries = [
615 | # PriceCache table analysis
616 | {
617 | "name": "PriceCache date range queries",
618 | "query": """
619 | SELECT schemaname, tablename, attname, n_distinct, correlation
620 | FROM pg_stats
621 | WHERE tablename = 'stocks_pricecache'
622 | AND attname IN ('date', 'stock_id', 'volume')
623 | """,
624 | "recommendation": "Consider composite index on (stock_id, date) if not exists",
625 | },
626 | # Stock lookup performance
627 | {
628 | "name": "Stock ticker lookups",
629 | "query": """
630 | SELECT schemaname, tablename, attname, n_distinct, correlation
631 | FROM pg_stats
632 | WHERE tablename = 'stocks_stock'
633 | AND attname = 'ticker_symbol'
634 | """,
635 | "recommendation": "Ensure unique index on ticker_symbol exists",
636 | },
637 | # Screening tables
638 | {
639 | "name": "Maverick screening queries",
640 | "query": """
641 | SELECT schemaname, tablename, attname, n_distinct
642 | FROM pg_stats
643 | WHERE tablename IN ('stocks_maverickstocks', 'stocks_maverickbearstocks', 'stocks_supply_demand_breakouts')
644 | AND attname IN ('score', 'rank', 'date_analyzed')
645 | """,
646 | "recommendation": "Consider indexes on score, rank, and date_analyzed columns",
647 | },
648 | ]
649 |
650 | for query_info in queries:
651 | try:
652 | result = await session.execute(text(query_info["query"]))
653 | rows = result.fetchall()
654 |
655 | if rows:
656 | recommendations.append(
657 | {
658 | "analysis": query_info["name"],
659 | "recommendation": query_info["recommendation"],
660 | "stats": [dict(row._mapping) for row in rows],
661 | }
662 | )
663 |
664 | except Exception as e:
665 | logger.error(f"Failed to analyze {query_info['name']}: {e}")
666 |
667 | # Check for tables without proper indexes
668 | missing_indexes_query = """
669 | SELECT
670 | schemaname,
671 | tablename,
672 | seq_scan,
673 | seq_tup_read,
674 | idx_scan,
675 | idx_tup_fetch,
676 | CASE
677 | WHEN seq_scan = 0 THEN 0
678 | ELSE seq_tup_read / seq_scan
679 | END as avg_seq_read
680 | FROM pg_stat_user_tables
681 | WHERE schemaname = 'public'
682 | AND tablename LIKE 'stocks_%'
683 | ORDER BY seq_tup_read DESC
684 | """
685 |
686 | result = await session.execute(text(missing_indexes_query))
687 | scan_stats = result.fetchall()
688 |
689 | for row in scan_stats:
690 | if row.seq_scan > 100 and row.avg_seq_read > 1000:
691 | recommendations.append(
692 | {
693 | "analysis": f"High sequential scans on {row.tablename}",
694 | "recommendation": f"Consider adding indexes to reduce {row.seq_tup_read} sequential reads",
695 | "stats": dict(row._mapping),
696 | }
697 | )
698 |
699 | except Exception as e:
700 | logger.error(f"Error analyzing missing indexes: {e}")
701 |
702 | return recommendations
703 |
704 |
705 | # Global query optimizer instance
706 | query_optimizer = QueryOptimizer()
707 |
708 |
709 | async def initialize_performance_systems():
710 | """Initialize all performance optimization systems."""
711 | logger.info("Initializing performance optimization systems...")
712 |
713 | # Initialize Redis connection manager
714 | redis_success = await redis_manager.initialize()
715 |
716 | logger.info(
717 | f"Performance systems initialized: Redis={'✓' if redis_success else '✗'}"
718 | )
719 |
720 | return {
721 | "redis_manager": redis_success,
722 | "request_cache": True,
723 | "query_optimizer": True,
724 | }
725 |
726 |
727 | async def get_performance_metrics() -> dict[str, Any]:
728 | """Get comprehensive performance metrics."""
729 | return {
730 | "redis_manager": redis_manager.get_metrics(),
731 | "request_cache": request_cache.get_metrics(),
732 | "query_optimizer": query_optimizer.get_query_stats(),
733 | "timestamp": time.time(),
734 | }
735 |
736 |
737 | async def cleanup_performance_systems():
738 | """Cleanup performance systems gracefully."""
739 | logger.info("Cleaning up performance optimization systems...")
740 |
741 | await redis_manager.close()
742 |
743 | logger.info("Performance systems cleanup completed")
744 |
745 |
746 | # Context manager for database session with query monitoring
747 | @asynccontextmanager
748 | async def monitored_db_session(query_name: str = "unknown"):
749 | """
750 | Context manager for database sessions with automatic query monitoring.
751 |
752 | Args:
753 | query_name: Name for the query (for metrics)
754 |
755 | Example:
756 | async with monitored_db_session("get_stock_data") as session:
757 | result = await session.execute(
758 | text("SELECT * FROM stocks_stock WHERE ticker_symbol = :symbol"),
759 | {"symbol": "AAPL"},
760 | )
761 | stock = result.first()
762 | """
763 | async with get_async_db_session() as session:
764 | start_time = time.time()
765 |
766 | try:
767 | yield session
768 |
769 | # Record successful query
770 | execution_time = time.time() - start_time
771 | if query_name not in query_optimizer._query_stats:
772 | query_optimizer._query_stats[query_name] = {
773 | "count": 0,
774 | "total_time": 0,
775 | "avg_time": 0,
776 | "max_time": 0,
777 | "min_time": float("inf"),
778 | }
779 |
780 | stats = query_optimizer._query_stats[query_name]
781 | stats["count"] += 1
782 | stats["total_time"] += execution_time
783 | stats["avg_time"] = stats["total_time"] / stats["count"]
784 | stats["max_time"] = max(stats["max_time"], execution_time)
785 | stats["min_time"] = min(stats["min_time"], execution_time)
786 |
787 | except Exception as e:
788 | execution_time = time.time() - start_time
789 | logger.error(
790 | f"Database query '{query_name}' failed after {execution_time:.2f}s: {e}"
791 | )
792 | raise
793 |
```