#
tokens: 46552/50000 6/435 files (page 23/39)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 23 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .dockerignore
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── FUNDING.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.md
│   │   ├── config.yml
│   │   ├── feature_request.md
│   │   ├── question.md
│   │   └── security_report.md
│   ├── pull_request_template.md
│   └── workflows
│       ├── claude-code-review.yml
│       └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│   ├── launch.json
│   └── settings.json
├── alembic
│   ├── env.py
│   ├── script.py.mako
│   └── versions
│       ├── 001_initial_schema.py
│       ├── 003_add_performance_indexes.py
│       ├── 006_rename_metadata_columns.py
│       ├── 008_performance_optimization_indexes.py
│       ├── 009_rename_to_supply_demand.py
│       ├── 010_self_contained_schema.py
│       ├── 011_remove_proprietary_terms.py
│       ├── 013_add_backtest_persistence_models.py
│       ├── 014_add_portfolio_models.py
│       ├── 08e3945a0c93_merge_heads.py
│       ├── 9374a5c9b679_merge_heads_for_testing.py
│       ├── abf9b9afb134_merge_multiple_heads.py
│       ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│       ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│       ├── f0696e2cac15_add_essential_performance_indexes.py
│       └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│   ├── api
│   │   └── backtesting.md
│   ├── BACKTESTING.md
│   ├── COST_BASIS_SPECIFICATION.md
│   ├── deep_research_agent.md
│   ├── exa_research_testing_strategy.md
│   ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│   ├── PORTFOLIO.md
│   ├── SETUP_SELF_CONTAINED.md
│   └── speed_testing_framework.md
├── examples
│   ├── complete_speed_validation.py
│   ├── deep_research_integration.py
│   ├── llm_optimization_example.py
│   ├── llm_speed_demo.py
│   ├── monitoring_example.py
│   ├── parallel_research_example.py
│   ├── speed_optimization_demo.py
│   └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│   ├── __init__.py
│   ├── agents
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── circuit_breaker.py
│   │   ├── deep_research.py
│   │   ├── market_analysis.py
│   │   ├── optimized_research.py
│   │   ├── supervisor.py
│   │   └── technical_analysis.py
│   ├── api
│   │   ├── __init__.py
│   │   ├── api_server.py
│   │   ├── connection_manager.py
│   │   ├── dependencies
│   │   │   ├── __init__.py
│   │   │   ├── stock_analysis.py
│   │   │   └── technical_analysis.py
│   │   ├── error_handling.py
│   │   ├── inspector_compatible_sse.py
│   │   ├── inspector_sse.py
│   │   ├── middleware
│   │   │   ├── error_handling.py
│   │   │   ├── mcp_logging.py
│   │   │   ├── rate_limiting_enhanced.py
│   │   │   └── security.py
│   │   ├── openapi_config.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── agents.py
│   │   │   ├── backtesting.py
│   │   │   ├── data_enhanced.py
│   │   │   ├── data.py
│   │   │   ├── health_enhanced.py
│   │   │   ├── health_tools.py
│   │   │   ├── health.py
│   │   │   ├── intelligent_backtesting.py
│   │   │   ├── introspection.py
│   │   │   ├── mcp_prompts.py
│   │   │   ├── monitoring.py
│   │   │   ├── news_sentiment_enhanced.py
│   │   │   ├── performance.py
│   │   │   ├── portfolio.py
│   │   │   ├── research.py
│   │   │   ├── screening_ddd.py
│   │   │   ├── screening_parallel.py
│   │   │   ├── screening.py
│   │   │   ├── technical_ddd.py
│   │   │   ├── technical_enhanced.py
│   │   │   ├── technical.py
│   │   │   └── tool_registry.py
│   │   ├── server.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── base_service.py
│   │   │   ├── market_service.py
│   │   │   ├── portfolio_service.py
│   │   │   ├── prompt_service.py
│   │   │   └── resource_service.py
│   │   ├── simple_sse.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── insomnia_export.py
│   │       └── postman_export.py
│   ├── application
│   │   ├── __init__.py
│   │   ├── commands
│   │   │   └── __init__.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   └── technical_analysis_dto.py
│   │   ├── queries
│   │   │   ├── __init__.py
│   │   │   └── get_technical_analysis.py
│   │   └── screening
│   │       ├── __init__.py
│   │       ├── dtos.py
│   │       └── queries.py
│   ├── backtesting
│   │   ├── __init__.py
│   │   ├── ab_testing.py
│   │   ├── analysis.py
│   │   ├── batch_processing_stub.py
│   │   ├── batch_processing.py
│   │   ├── model_manager.py
│   │   ├── optimization.py
│   │   ├── persistence.py
│   │   ├── retraining_pipeline.py
│   │   ├── strategies
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── ml
│   │   │   │   ├── __init__.py
│   │   │   │   ├── adaptive.py
│   │   │   │   ├── ensemble.py
│   │   │   │   ├── feature_engineering.py
│   │   │   │   └── regime_aware.py
│   │   │   ├── ml_strategies.py
│   │   │   ├── parser.py
│   │   │   └── templates.py
│   │   ├── strategy_executor.py
│   │   ├── vectorbt_engine.py
│   │   └── visualization.py
│   ├── config
│   │   ├── __init__.py
│   │   ├── constants.py
│   │   ├── database_self_contained.py
│   │   ├── database.py
│   │   ├── llm_optimization_config.py
│   │   ├── logging_settings.py
│   │   ├── plotly_config.py
│   │   ├── security_utils.py
│   │   ├── security.py
│   │   ├── settings.py
│   │   ├── technical_constants.py
│   │   ├── tool_estimation.py
│   │   └── validation.py
│   ├── core
│   │   ├── __init__.py
│   │   ├── technical_analysis.py
│   │   └── visualization.py
│   ├── data
│   │   ├── __init__.py
│   │   ├── cache_manager.py
│   │   ├── cache.py
│   │   ├── django_adapter.py
│   │   ├── health.py
│   │   ├── models.py
│   │   ├── performance.py
│   │   ├── session_management.py
│   │   └── validation.py
│   ├── database
│   │   ├── __init__.py
│   │   ├── base.py
│   │   └── optimization.py
│   ├── dependencies.py
│   ├── domain
│   │   ├── __init__.py
│   │   ├── entities
│   │   │   ├── __init__.py
│   │   │   └── stock_analysis.py
│   │   ├── events
│   │   │   └── __init__.py
│   │   ├── portfolio.py
│   │   ├── screening
│   │   │   ├── __init__.py
│   │   │   ├── entities.py
│   │   │   ├── services.py
│   │   │   └── value_objects.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   └── technical_analysis_service.py
│   │   ├── stock_analysis
│   │   │   ├── __init__.py
│   │   │   └── stock_analysis_service.py
│   │   └── value_objects
│   │       ├── __init__.py
│   │       └── technical_indicators.py
│   ├── exceptions.py
│   ├── infrastructure
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   └── __init__.py
│   │   ├── caching
│   │   │   ├── __init__.py
│   │   │   └── cache_management_service.py
│   │   ├── connection_manager.py
│   │   ├── data_fetching
│   │   │   ├── __init__.py
│   │   │   └── stock_data_service.py
│   │   ├── health
│   │   │   ├── __init__.py
│   │   │   └── health_checker.py
│   │   ├── persistence
│   │   │   ├── __init__.py
│   │   │   └── stock_repository.py
│   │   ├── providers
│   │   │   └── __init__.py
│   │   ├── screening
│   │   │   ├── __init__.py
│   │   │   └── repositories.py
│   │   └── sse_optimizer.py
│   ├── langchain_tools
│   │   ├── __init__.py
│   │   ├── adapters.py
│   │   └── registry.py
│   ├── logging_config.py
│   ├── memory
│   │   ├── __init__.py
│   │   └── stores.py
│   ├── monitoring
│   │   ├── __init__.py
│   │   ├── health_check.py
│   │   ├── health_monitor.py
│   │   ├── integration_example.py
│   │   ├── metrics.py
│   │   ├── middleware.py
│   │   └── status_dashboard.py
│   ├── providers
│   │   ├── __init__.py
│   │   ├── dependencies.py
│   │   ├── factories
│   │   │   ├── __init__.py
│   │   │   ├── config_factory.py
│   │   │   └── provider_factory.py
│   │   ├── implementations
│   │   │   ├── __init__.py
│   │   │   ├── cache_adapter.py
│   │   │   ├── macro_data_adapter.py
│   │   │   ├── market_data_adapter.py
│   │   │   ├── persistence_adapter.py
│   │   │   └── stock_data_adapter.py
│   │   ├── interfaces
│   │   │   ├── __init__.py
│   │   │   ├── cache.py
│   │   │   ├── config.py
│   │   │   ├── macro_data.py
│   │   │   ├── market_data.py
│   │   │   ├── persistence.py
│   │   │   └── stock_data.py
│   │   ├── llm_factory.py
│   │   ├── macro_data.py
│   │   ├── market_data.py
│   │   ├── mocks
│   │   │   ├── __init__.py
│   │   │   ├── mock_cache.py
│   │   │   ├── mock_config.py
│   │   │   ├── mock_macro_data.py
│   │   │   ├── mock_market_data.py
│   │   │   ├── mock_persistence.py
│   │   │   └── mock_stock_data.py
│   │   ├── openrouter_provider.py
│   │   ├── optimized_screening.py
│   │   ├── optimized_stock_data.py
│   │   └── stock_data.py
│   ├── README.md
│   ├── tests
│   │   ├── __init__.py
│   │   ├── README_INMEMORY_TESTS.md
│   │   ├── test_cache_debug.py
│   │   ├── test_fixes_validation.py
│   │   ├── test_in_memory_routers.py
│   │   ├── test_in_memory_server.py
│   │   ├── test_macro_data_provider.py
│   │   ├── test_mailgun_email.py
│   │   ├── test_market_calendar_caching.py
│   │   ├── test_mcp_tool_fixes_pytest.py
│   │   ├── test_mcp_tool_fixes.py
│   │   ├── test_mcp_tools.py
│   │   ├── test_models_functional.py
│   │   ├── test_server.py
│   │   ├── test_stock_data_enhanced.py
│   │   ├── test_stock_data_provider.py
│   │   └── test_technical_analysis.py
│   ├── tools
│   │   ├── __init__.py
│   │   ├── performance_monitoring.py
│   │   ├── portfolio_manager.py
│   │   ├── risk_management.py
│   │   └── sentiment_analysis.py
│   ├── utils
│   │   ├── __init__.py
│   │   ├── agent_errors.py
│   │   ├── batch_processing.py
│   │   ├── cache_warmer.py
│   │   ├── circuit_breaker_decorators.py
│   │   ├── circuit_breaker_services.py
│   │   ├── circuit_breaker.py
│   │   ├── data_chunking.py
│   │   ├── database_monitoring.py
│   │   ├── debug_utils.py
│   │   ├── fallback_strategies.py
│   │   ├── llm_optimization.py
│   │   ├── logging_example.py
│   │   ├── logging_init.py
│   │   ├── logging.py
│   │   ├── mcp_logging.py
│   │   ├── memory_profiler.py
│   │   ├── monitoring_middleware.py
│   │   ├── monitoring.py
│   │   ├── orchestration_logging.py
│   │   ├── parallel_research.py
│   │   ├── parallel_screening.py
│   │   ├── quick_cache.py
│   │   ├── resource_manager.py
│   │   ├── shutdown.py
│   │   ├── stock_helpers.py
│   │   ├── structured_logger.py
│   │   ├── tool_monitoring.py
│   │   ├── tracing.py
│   │   └── yfinance_pool.py
│   ├── validation
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── data.py
│   │   ├── middleware.py
│   │   ├── portfolio.py
│   │   ├── responses.py
│   │   ├── screening.py
│   │   └── technical.py
│   └── workflows
│       ├── __init__.py
│       ├── agents
│       │   ├── __init__.py
│       │   ├── market_analyzer.py
│       │   ├── optimizer_agent.py
│       │   ├── strategy_selector.py
│       │   └── validator_agent.py
│       ├── backtesting_workflow.py
│       └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│   ├── dev.sh
│   ├── INSTALLATION_GUIDE.md
│   ├── load_example.py
│   ├── load_market_data.py
│   ├── load_tiingo_data.py
│   ├── migrate_db.py
│   ├── README_TIINGO_LOADER.md
│   ├── requirements_tiingo.txt
│   ├── run_stock_screening.py
│   ├── run-migrations.sh
│   ├── seed_db.py
│   ├── seed_sp500.py
│   ├── setup_database.sh
│   ├── setup_self_contained.py
│   ├── setup_sp500_database.sh
│   ├── test_seeded_data.py
│   ├── test_tiingo_loader.py
│   ├── tiingo_config.py
│   └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│   ├── __init__.py
│   ├── conftest.py
│   ├── core
│   │   └── test_technical_analysis.py
│   ├── data
│   │   └── test_portfolio_models.py
│   ├── domain
│   │   ├── conftest.py
│   │   ├── test_portfolio_entities.py
│   │   └── test_technical_analysis_service.py
│   ├── fixtures
│   │   └── orchestration_fixtures.py
│   ├── integration
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── README.md
│   │   ├── run_integration_tests.sh
│   │   ├── test_api_technical.py
│   │   ├── test_chaos_engineering.py
│   │   ├── test_config_management.py
│   │   ├── test_full_backtest_workflow_advanced.py
│   │   ├── test_full_backtest_workflow.py
│   │   ├── test_high_volume.py
│   │   ├── test_mcp_tools.py
│   │   ├── test_orchestration_complete.py
│   │   ├── test_portfolio_persistence.py
│   │   ├── test_redis_cache.py
│   │   ├── test_security_integration.py.disabled
│   │   └── vcr_setup.py
│   ├── performance
│   │   ├── __init__.py
│   │   ├── test_benchmarks.py
│   │   ├── test_load.py
│   │   ├── test_profiling.py
│   │   └── test_stress.py
│   ├── providers
│   │   └── test_stock_data_simple.py
│   ├── README.md
│   ├── test_agents_router_mcp.py
│   ├── test_backtest_persistence.py
│   ├── test_cache_management_service.py
│   ├── test_cache_serialization.py
│   ├── test_circuit_breaker.py
│   ├── test_database_pool_config_simple.py
│   ├── test_database_pool_config.py
│   ├── test_deep_research_functional.py
│   ├── test_deep_research_integration.py
│   ├── test_deep_research_parallel_execution.py
│   ├── test_error_handling.py
│   ├── test_event_loop_integrity.py
│   ├── test_exa_research_integration.py
│   ├── test_exception_hierarchy.py
│   ├── test_financial_search.py
│   ├── test_graceful_shutdown.py
│   ├── test_integration_simple.py
│   ├── test_langgraph_workflow.py
│   ├── test_market_data_async.py
│   ├── test_market_data_simple.py
│   ├── test_mcp_orchestration_functional.py
│   ├── test_ml_strategies.py
│   ├── test_optimized_research_agent.py
│   ├── test_orchestration_integration.py
│   ├── test_orchestration_logging.py
│   ├── test_orchestration_tools_simple.py
│   ├── test_parallel_research_integration.py
│   ├── test_parallel_research_orchestrator.py
│   ├── test_parallel_research_performance.py
│   ├── test_performance_optimizations.py
│   ├── test_production_validation.py
│   ├── test_provider_architecture.py
│   ├── test_rate_limiting_enhanced.py
│   ├── test_runner_validation.py
│   ├── test_security_comprehensive.py.disabled
│   ├── test_security_cors.py
│   ├── test_security_enhancements.py.disabled
│   ├── test_security_headers.py
│   ├── test_security_penetration.py
│   ├── test_session_management.py
│   ├── test_speed_optimization_validation.py
│   ├── test_stock_analysis_dependencies.py
│   ├── test_stock_analysis_service.py
│   ├── test_stock_data_fetching_service.py
│   ├── test_supervisor_agent.py
│   ├── test_supervisor_functional.py
│   ├── test_tool_estimation_config.py
│   ├── test_visualization.py
│   └── utils
│       ├── test_agent_errors.py
│       ├── test_logging.py
│       ├── test_parallel_screening.py
│       └── test_quick_cache.py
├── tools
│   ├── check_orchestration_config.py
│   ├── experiments
│   │   ├── validation_examples.py
│   │   └── validation_fixed.py
│   ├── fast_dev.sh
│   ├── hot_reload.py
│   ├── quick_test.py
│   └── templates
│       ├── new_router_template.py
│       ├── new_tool_template.py
│       ├── screening_strategy_template.py
│       └── test_template.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/tests/test_visualization.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Comprehensive tests for backtesting visualization module.
  3 | 
  4 | Tests cover:
  5 | - Chart generation and base64 encoding with matplotlib
  6 | - Equity curve plotting with drawdown subplots
  7 | - Trade scatter plots on price charts
  8 | - Parameter optimization heatmaps
  9 | - Portfolio allocation pie charts
 10 | - Strategy comparison line charts
 11 | - Performance dashboard table generation
 12 | - Theme support (light/dark modes)
 13 | - Image resolution and size optimization
 14 | - Error handling for malformed data
 15 | """
 16 | 
 17 | import base64
 18 | from unittest.mock import patch
 19 | 
 20 | import matplotlib.pyplot as plt
 21 | import numpy as np
 22 | import pandas as pd
 23 | import pytest
 24 | 
 25 | from maverick_mcp.backtesting.visualization import (
 26 |     generate_equity_curve,
 27 |     generate_optimization_heatmap,
 28 |     generate_performance_dashboard,
 29 |     generate_portfolio_allocation,
 30 |     generate_strategy_comparison,
 31 |     generate_trade_scatter,
 32 |     image_to_base64,
 33 |     set_chart_style,
 34 | )
 35 | 
 36 | 
 37 | class TestVisualizationUtilities:
 38 |     """Test suite for visualization utility functions."""
 39 | 
 40 |     def test_set_chart_style_light_theme(self):
 41 |         """Test light theme styling configuration."""
 42 |         set_chart_style("light")
 43 | 
 44 |         # Test that matplotlib parameters are set correctly
 45 |         assert plt.rcParams["axes.facecolor"] == "white"
 46 |         assert plt.rcParams["figure.facecolor"] == "white"
 47 |         assert plt.rcParams["font.size"] == 10
 48 |         assert plt.rcParams["text.color"] == "black"
 49 |         assert plt.rcParams["axes.labelcolor"] == "black"
 50 |         assert plt.rcParams["xtick.color"] == "black"
 51 |         assert plt.rcParams["ytick.color"] == "black"
 52 | 
 53 |     def test_set_chart_style_dark_theme(self):
 54 |         """Test dark theme styling configuration."""
 55 |         set_chart_style("dark")
 56 | 
 57 |         # Test that matplotlib parameters are set correctly
 58 |         assert plt.rcParams["axes.facecolor"] == "#1E1E1E"
 59 |         assert plt.rcParams["figure.facecolor"] == "#121212"
 60 |         assert plt.rcParams["font.size"] == 10
 61 |         assert plt.rcParams["text.color"] == "white"
 62 |         assert plt.rcParams["axes.labelcolor"] == "white"
 63 |         assert plt.rcParams["xtick.color"] == "white"
 64 |         assert plt.rcParams["ytick.color"] == "white"
 65 | 
 66 |     def test_image_to_base64_conversion(self):
 67 |         """Test image to base64 conversion with proper formatting."""
 68 |         # Create a simple test figure
 69 |         fig, ax = plt.subplots(figsize=(6, 4))
 70 |         ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
 71 |         ax.set_title("Test Chart")
 72 | 
 73 |         # Convert to base64
 74 |         base64_str = image_to_base64(fig, dpi=100)
 75 | 
 76 |         # Test base64 string properties
 77 |         assert isinstance(base64_str, str)
 78 |         assert len(base64_str) > 100  # Should contain substantial data
 79 | 
 80 |         # Test that it's valid base64
 81 |         try:
 82 |             decoded_bytes = base64.b64decode(base64_str)
 83 |             assert len(decoded_bytes) > 0
 84 |         except Exception as e:
 85 |             pytest.fail(f"Invalid base64 encoding: {e}")
 86 | 
 87 |     def test_image_to_base64_size_optimization(self):
 88 |         """Test image size optimization and aspect ratio preservation."""
 89 |         # Create large figure
 90 |         fig, ax = plt.subplots(figsize=(20, 15))  # Large size
 91 |         ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
 92 | 
 93 |         original_width, original_height = fig.get_size_inches()
 94 |         original_aspect = original_height / original_width
 95 | 
 96 |         # Convert with size constraint
 97 |         base64_str = image_to_base64(fig, dpi=100, max_width=800)
 98 | 
 99 |         # Test that resizing occurred
100 |         final_width, final_height = fig.get_size_inches()
101 |         final_aspect = final_height / final_width
102 | 
103 |         assert final_width <= 8.0  # 800px / 100dpi = 8 inches
104 |         assert abs(final_aspect - original_aspect) < 0.01  # Aspect ratio preserved
105 |         assert len(base64_str) > 0
106 | 
107 |     def test_image_to_base64_error_handling(self):
108 |         """Test error handling in base64 conversion."""
109 |         with patch(
110 |             "matplotlib.figure.Figure.savefig", side_effect=Exception("Save error")
111 |         ):
112 |             fig, ax = plt.subplots()
113 |             ax.plot([1, 2, 3])
114 | 
115 |             result = image_to_base64(fig)
116 |             assert result == ""  # Should return empty string on error
117 | 
118 | 
119 | class TestEquityCurveGeneration:
120 |     """Test suite for equity curve chart generation."""
121 | 
122 |     @pytest.fixture
123 |     def sample_returns_data(self):
124 |         """Create sample returns data for testing."""
125 |         dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
126 |         returns = np.random.normal(0.001, 0.02, len(dates))
127 |         cumulative_returns = np.cumprod(1 + returns)
128 | 
129 |         # Create drawdown series
130 |         running_max = np.maximum.accumulate(cumulative_returns)
131 |         drawdown = (cumulative_returns - running_max) / running_max * 100
132 | 
133 |         return pd.Series(cumulative_returns, index=dates), pd.Series(
134 |             drawdown, index=dates
135 |         )
136 | 
137 |     def test_generate_equity_curve_basic(self, sample_returns_data):
138 |         """Test basic equity curve generation."""
139 |         returns, drawdown = sample_returns_data
140 | 
141 |         base64_str = generate_equity_curve(returns, title="Test Equity Curve")
142 | 
143 |         assert isinstance(base64_str, str)
144 |         assert len(base64_str) > 100
145 | 
146 |         # Test that it's valid base64 image
147 |         try:
148 |             decoded_bytes = base64.b64decode(base64_str)
149 |             assert decoded_bytes.startswith(b"\x89PNG")  # PNG header
150 |         except Exception as e:
151 |             pytest.fail(f"Invalid PNG image: {e}")
152 | 
153 |     def test_generate_equity_curve_with_drawdown(self, sample_returns_data):
154 |         """Test equity curve generation with drawdown subplot."""
155 |         returns, drawdown = sample_returns_data
156 | 
157 |         base64_str = generate_equity_curve(
158 |             returns, drawdown=drawdown, title="Equity Curve with Drawdown", theme="dark"
159 |         )
160 | 
161 |         assert isinstance(base64_str, str)
162 |         assert len(base64_str) > 100
163 | 
164 |         # Should be larger image due to subplot
165 |         base64_no_dd = generate_equity_curve(returns, title="No Drawdown")
166 |         assert len(base64_str) >= len(base64_no_dd)
167 | 
168 |     def test_generate_equity_curve_themes(self, sample_returns_data):
169 |         """Test equity curve generation with different themes."""
170 |         returns, _ = sample_returns_data
171 | 
172 |         light_chart = generate_equity_curve(returns, theme="light")
173 |         dark_chart = generate_equity_curve(returns, theme="dark")
174 | 
175 |         assert len(light_chart) > 100
176 |         assert len(dark_chart) > 100
177 |         # Different themes should produce different images
178 |         assert light_chart != dark_chart
179 | 
180 |     def test_generate_equity_curve_error_handling(self):
181 |         """Test error handling in equity curve generation."""
182 |         # Test with invalid data
183 |         invalid_returns = pd.Series([])  # Empty series
184 | 
185 |         result = generate_equity_curve(invalid_returns)
186 |         assert result == ""
187 | 
188 |         # Test with NaN data
189 |         nan_returns = pd.Series([np.nan, np.nan, np.nan])
190 |         result = generate_equity_curve(nan_returns)
191 |         assert result == ""
192 | 
193 | 
194 | class TestTradeScatterGeneration:
195 |     """Test suite for trade scatter plot generation."""
196 | 
197 |     @pytest.fixture
198 |     def sample_trade_data(self):
199 |         """Create sample trade data for testing."""
200 |         dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
201 |         prices = pd.Series(100 + np.random.walk(len(dates)), index=dates)
202 | 
203 |         # Create sample trades
204 |         trade_dates = dates[::30]  # Every 30 days
205 |         trades = []
206 | 
207 |         for i, trade_date in enumerate(trade_dates):
208 |             if i % 2 == 0:  # Entry
209 |                 trades.append(
210 |                     {
211 |                         "date": trade_date,
212 |                         "price": prices.loc[trade_date],
213 |                         "type": "entry",
214 |                     }
215 |                 )
216 |             else:  # Exit
217 |                 trades.append(
218 |                     {
219 |                         "date": trade_date,
220 |                         "price": prices.loc[trade_date],
221 |                         "type": "exit",
222 |                     }
223 |                 )
224 | 
225 |         trades_df = pd.DataFrame(trades).set_index("date")
226 |         return prices, trades_df
227 | 
228 |     def test_generate_trade_scatter_basic(self, sample_trade_data):
229 |         """Test basic trade scatter plot generation."""
230 |         prices, trades = sample_trade_data
231 | 
232 |         base64_str = generate_trade_scatter(prices, trades, title="Trade Scatter Plot")
233 | 
234 |         assert isinstance(base64_str, str)
235 |         assert len(base64_str) > 100
236 | 
237 |         # Verify valid PNG
238 |         try:
239 |             decoded_bytes = base64.b64decode(base64_str)
240 |             assert decoded_bytes.startswith(b"\x89PNG")
241 |         except Exception as e:
242 |             pytest.fail(f"Invalid PNG image: {e}")
243 | 
244 |     def test_generate_trade_scatter_themes(self, sample_trade_data):
245 |         """Test trade scatter plots with different themes."""
246 |         prices, trades = sample_trade_data
247 | 
248 |         light_chart = generate_trade_scatter(prices, trades, theme="light")
249 |         dark_chart = generate_trade_scatter(prices, trades, theme="dark")
250 | 
251 |         assert len(light_chart) > 100
252 |         assert len(dark_chart) > 100
253 |         assert light_chart != dark_chart
254 | 
255 |     def test_generate_trade_scatter_empty_trades(self, sample_trade_data):
256 |         """Test trade scatter plot with empty trade data."""
257 |         prices, _ = sample_trade_data
258 |         empty_trades = pd.DataFrame(columns=["price", "type"])
259 | 
260 |         result = generate_trade_scatter(prices, empty_trades)
261 |         assert result == ""
262 | 
263 |     def test_generate_trade_scatter_error_handling(self):
264 |         """Test error handling in trade scatter generation."""
265 |         # Test with mismatched data
266 |         prices = pd.Series([1, 2, 3])
267 |         trades = pd.DataFrame({"price": [10, 20], "type": ["entry", "exit"]})
268 | 
269 |         # Should handle gracefully
270 |         result = generate_trade_scatter(prices, trades)
271 |         # Might return empty string or valid chart depending on implementation
272 |         assert isinstance(result, str)
273 | 
274 | 
275 | class TestOptimizationHeatmapGeneration:
276 |     """Test suite for parameter optimization heatmap generation."""
277 | 
278 |     @pytest.fixture
279 |     def sample_optimization_data(self):
280 |         """Create sample optimization results for testing."""
281 |         parameters = ["param1", "param2", "param3"]
282 |         results = {}
283 | 
284 |         for p1 in parameters:
285 |             results[p1] = {}
286 |             for p2 in parameters:
287 |                 # Create some performance metric
288 |                 results[p1][p2] = np.random.uniform(0.5, 2.0)
289 | 
290 |         return results
291 | 
292 |     def test_generate_optimization_heatmap_basic(self, sample_optimization_data):
293 |         """Test basic optimization heatmap generation."""
294 |         base64_str = generate_optimization_heatmap(
295 |             sample_optimization_data, title="Parameter Optimization Heatmap"
296 |         )
297 | 
298 |         assert isinstance(base64_str, str)
299 |         assert len(base64_str) > 100
300 | 
301 |         # Verify valid PNG
302 |         try:
303 |             decoded_bytes = base64.b64decode(base64_str)
304 |             assert decoded_bytes.startswith(b"\x89PNG")
305 |         except Exception as e:
306 |             pytest.fail(f"Invalid PNG image: {e}")
307 | 
308 |     def test_generate_optimization_heatmap_themes(self, sample_optimization_data):
309 |         """Test optimization heatmap with different themes."""
310 |         light_chart = generate_optimization_heatmap(
311 |             sample_optimization_data, theme="light"
312 |         )
313 |         dark_chart = generate_optimization_heatmap(
314 |             sample_optimization_data, theme="dark"
315 |         )
316 | 
317 |         assert len(light_chart) > 100
318 |         assert len(dark_chart) > 100
319 |         assert light_chart != dark_chart
320 | 
321 |     def test_generate_optimization_heatmap_empty_data(self):
322 |         """Test heatmap generation with empty data."""
323 |         empty_data = {}
324 | 
325 |         result = generate_optimization_heatmap(empty_data)
326 |         assert result == ""
327 | 
328 |     def test_generate_optimization_heatmap_error_handling(self):
329 |         """Test error handling in heatmap generation."""
330 |         # Test with malformed data
331 |         malformed_data = {"param1": "not_a_dict"}
332 | 
333 |         result = generate_optimization_heatmap(malformed_data)
334 |         assert result == ""
335 | 
336 | 
337 | class TestPortfolioAllocationGeneration:
338 |     """Test suite for portfolio allocation chart generation."""
339 | 
340 |     @pytest.fixture
341 |     def sample_allocation_data(self):
342 |         """Create sample allocation data for testing."""
343 |         return {
344 |             "AAPL": 0.25,
345 |             "GOOGL": 0.20,
346 |             "MSFT": 0.15,
347 |             "TSLA": 0.15,
348 |             "AMZN": 0.10,
349 |             "Cash": 0.15,
350 |         }
351 | 
352 |     def test_generate_portfolio_allocation_basic(self, sample_allocation_data):
353 |         """Test basic portfolio allocation chart generation."""
354 |         base64_str = generate_portfolio_allocation(
355 |             sample_allocation_data, title="Portfolio Allocation"
356 |         )
357 | 
358 |         assert isinstance(base64_str, str)
359 |         assert len(base64_str) > 100
360 | 
361 |         # Verify valid PNG
362 |         try:
363 |             decoded_bytes = base64.b64decode(base64_str)
364 |             assert decoded_bytes.startswith(b"\x89PNG")
365 |         except Exception as e:
366 |             pytest.fail(f"Invalid PNG image: {e}")
367 | 
368 |     def test_generate_portfolio_allocation_themes(self, sample_allocation_data):
369 |         """Test portfolio allocation with different themes."""
370 |         light_chart = generate_portfolio_allocation(
371 |             sample_allocation_data, theme="light"
372 |         )
373 |         dark_chart = generate_portfolio_allocation(sample_allocation_data, theme="dark")
374 | 
375 |         assert len(light_chart) > 100
376 |         assert len(dark_chart) > 100
377 |         assert light_chart != dark_chart
378 | 
379 |     def test_generate_portfolio_allocation_empty_data(self):
380 |         """Test allocation chart with empty data."""
381 |         empty_data = {}
382 | 
383 |         result = generate_portfolio_allocation(empty_data)
384 |         assert result == ""
385 | 
386 |     def test_generate_portfolio_allocation_single_asset(self):
387 |         """Test allocation chart with single asset."""
388 |         single_asset = {"AAPL": 1.0}
389 | 
390 |         result = generate_portfolio_allocation(single_asset)
391 |         assert isinstance(result, str)
392 |         assert len(result) > 100  # Should still generate valid chart
393 | 
394 |     def test_generate_portfolio_allocation_error_handling(self):
395 |         """Test error handling in allocation chart generation."""
396 |         # Test with invalid allocation values
397 |         invalid_data = {"AAPL": "invalid_value"}
398 | 
399 |         result = generate_portfolio_allocation(invalid_data)
400 |         assert result == ""
401 | 
402 | 
403 | class TestStrategyComparisonGeneration:
404 |     """Test suite for strategy comparison chart generation."""
405 | 
406 |     @pytest.fixture
407 |     def sample_strategy_data(self):
408 |         """Create sample strategy comparison data."""
409 |         dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
410 | 
411 |         strategies = {
412 |             "Momentum": pd.Series(
413 |                 np.cumprod(1 + np.random.normal(0.0008, 0.015, len(dates))), index=dates
414 |             ),
415 |             "Mean Reversion": pd.Series(
416 |                 np.cumprod(1 + np.random.normal(0.0005, 0.012, len(dates))), index=dates
417 |             ),
418 |             "Breakout": pd.Series(
419 |                 np.cumprod(1 + np.random.normal(0.0012, 0.020, len(dates))), index=dates
420 |             ),
421 |         }
422 | 
423 |         return strategies
424 | 
425 |     def test_generate_strategy_comparison_basic(self, sample_strategy_data):
426 |         """Test basic strategy comparison chart generation."""
427 |         base64_str = generate_strategy_comparison(
428 |             sample_strategy_data, title="Strategy Performance Comparison"
429 |         )
430 | 
431 |         assert isinstance(base64_str, str)
432 |         assert len(base64_str) > 100
433 | 
434 |         # Verify valid PNG
435 |         try:
436 |             decoded_bytes = base64.b64decode(base64_str)
437 |             assert decoded_bytes.startswith(b"\x89PNG")
438 |         except Exception as e:
439 |             pytest.fail(f"Invalid PNG image: {e}")
440 | 
441 |     def test_generate_strategy_comparison_themes(self, sample_strategy_data):
442 |         """Test strategy comparison with different themes."""
443 |         light_chart = generate_strategy_comparison(sample_strategy_data, theme="light")
444 |         dark_chart = generate_strategy_comparison(sample_strategy_data, theme="dark")
445 | 
446 |         assert len(light_chart) > 100
447 |         assert len(dark_chart) > 100
448 |         assert light_chart != dark_chart
449 | 
450 |     def test_generate_strategy_comparison_single_strategy(self):
451 |         """Test comparison chart with single strategy."""
452 |         dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
453 |         single_strategy = {
454 |             "Only Strategy": pd.Series(
455 |                 np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))), index=dates
456 |             )
457 |         }
458 | 
459 |         result = generate_strategy_comparison(single_strategy)
460 |         assert isinstance(result, str)
461 |         assert len(result) > 100
462 | 
463 |     def test_generate_strategy_comparison_empty_data(self):
464 |         """Test comparison chart with empty data."""
465 |         empty_data = {}
466 | 
467 |         result = generate_strategy_comparison(empty_data)
468 |         assert result == ""
469 | 
470 |     def test_generate_strategy_comparison_error_handling(self):
471 |         """Test error handling in strategy comparison generation."""
472 |         # Test with mismatched data lengths
473 |         dates1 = pd.date_range(start="2023-01-01", end="2023-06-30", freq="D")
474 |         dates2 = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
475 | 
476 |         mismatched_data = {
477 |             "Strategy1": pd.Series(np.random.normal(0, 1, len(dates1)), index=dates1),
478 |             "Strategy2": pd.Series(np.random.normal(0, 1, len(dates2)), index=dates2),
479 |         }
480 | 
481 |         # Should handle gracefully
482 |         result = generate_strategy_comparison(mismatched_data)
483 |         assert isinstance(result, str)  # Might be empty or valid
484 | 
485 | 
486 | class TestPerformanceDashboardGeneration:
487 |     """Test suite for performance dashboard generation."""
488 | 
489 |     @pytest.fixture
490 |     def sample_metrics_data(self):
491 |         """Create sample performance metrics for testing."""
492 |         return {
493 |             "Total Return": 0.156,
494 |             "Sharpe Ratio": 1.25,
495 |             "Max Drawdown": -0.082,
496 |             "Win Rate": 0.583,
497 |             "Profit Factor": 1.35,
498 |             "Total Trades": 24,
499 |             "Annualized Return": 0.18,
500 |             "Volatility": 0.16,
501 |             "Calmar Ratio": 1.10,
502 |             "Best Trade": 0.12,
503 |         }
504 | 
505 |     def test_generate_performance_dashboard_basic(self, sample_metrics_data):
506 |         """Test basic performance dashboard generation."""
507 |         base64_str = generate_performance_dashboard(
508 |             sample_metrics_data, title="Performance Dashboard"
509 |         )
510 | 
511 |         assert isinstance(base64_str, str)
512 |         assert len(base64_str) > 100
513 | 
514 |         # Verify valid PNG
515 |         try:
516 |             decoded_bytes = base64.b64decode(base64_str)
517 |             assert decoded_bytes.startswith(b"\x89PNG")
518 |         except Exception as e:
519 |             pytest.fail(f"Invalid PNG image: {e}")
520 | 
521 |     def test_generate_performance_dashboard_themes(self, sample_metrics_data):
522 |         """Test performance dashboard with different themes."""
523 |         light_chart = generate_performance_dashboard(sample_metrics_data, theme="light")
524 |         dark_chart = generate_performance_dashboard(sample_metrics_data, theme="dark")
525 | 
526 |         assert len(light_chart) > 100
527 |         assert len(dark_chart) > 100
528 |         assert light_chart != dark_chart
529 | 
530 |     def test_generate_performance_dashboard_mixed_data_types(self):
531 |         """Test dashboard with mixed data types."""
532 |         mixed_metrics = {
533 |             "Total Return": 0.156,
534 |             "Strategy": "Momentum",
535 |             "Symbol": "AAPL",
536 |             "Duration": "365 days",
537 |             "Sharpe Ratio": 1.25,
538 |             "Status": "Completed",
539 |         }
540 | 
541 |         result = generate_performance_dashboard(mixed_metrics)
542 |         assert isinstance(result, str)
543 |         assert len(result) > 100
544 | 
545 |     def test_generate_performance_dashboard_empty_data(self):
546 |         """Test dashboard with empty metrics."""
547 |         empty_metrics = {}
548 | 
549 |         result = generate_performance_dashboard(empty_metrics)
550 |         assert result == ""
551 | 
552 |     def test_generate_performance_dashboard_large_dataset(self):
553 |         """Test dashboard with large number of metrics."""
554 |         large_metrics = {f"Metric_{i}": np.random.uniform(-1, 2) for i in range(50)}
555 | 
556 |         result = generate_performance_dashboard(large_metrics)
557 |         assert isinstance(result, str)
558 |         # Might be empty if table becomes too large, or valid if handled properly
559 | 
560 |     def test_generate_performance_dashboard_error_handling(self):
561 |         """Test error handling in dashboard generation."""
562 |         # Test with invalid data that might cause table generation to fail
563 |         problematic_metrics = {
564 |             "Valid Metric": 1.25,
565 |             "Problematic": [1, 2, 3],  # List instead of scalar
566 |             "Another Valid": 0.85,
567 |         }
568 | 
569 |         result = generate_performance_dashboard(problematic_metrics)
570 |         assert isinstance(result, str)
571 | 
572 | 
573 | class TestVisualizationIntegration:
574 |     """Integration tests for visualization functions working together."""
575 | 
576 |     def test_consistent_theming_across_charts(self):
577 |         """Test that theming is consistent across different chart types."""
578 |         # Create sample data for different chart types
579 |         dates = pd.date_range(start="2023-01-01", end="2023-06-30", freq="D")
580 |         returns = pd.Series(
581 |             np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))), index=dates
582 |         )
583 | 
584 |         allocation = {"AAPL": 0.4, "GOOGL": 0.3, "MSFT": 0.3}
585 |         metrics = {"Return": 0.15, "Sharpe": 1.2, "Drawdown": -0.08}
586 | 
587 |         # Generate charts with same theme
588 |         equity_chart = generate_equity_curve(returns, theme="dark")
589 |         allocation_chart = generate_portfolio_allocation(allocation, theme="dark")
590 |         dashboard_chart = generate_performance_dashboard(metrics, theme="dark")
591 | 
592 |         # All should generate valid base64 strings
593 |         charts = [equity_chart, allocation_chart, dashboard_chart]
594 |         for chart in charts:
595 |             assert isinstance(chart, str)
596 |             assert len(chart) > 100
597 | 
598 |             # Verify valid PNG
599 |             try:
600 |                 decoded_bytes = base64.b64decode(chart)
601 |                 assert decoded_bytes.startswith(b"\x89PNG")
602 |             except Exception as e:
603 |                 pytest.fail(f"Invalid PNG in themed charts: {e}")
604 | 
605 |     def test_memory_cleanup_after_chart_generation(self):
606 |         """Test that matplotlib figures are properly cleaned up."""
607 |         import matplotlib.pyplot as plt
608 | 
609 |         initial_figure_count = len(plt.get_fignums())
610 | 
611 |         # Generate multiple charts
612 |         dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
613 |         returns = pd.Series(
614 |             np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))), index=dates
615 |         )
616 | 
617 |         for i in range(10):
618 |             chart = generate_equity_curve(returns, title=f"Test Chart {i}")
619 |             assert len(chart) > 0
620 | 
621 |         final_figure_count = len(plt.get_fignums())
622 | 
623 |         # Figure count should not have increased (figures should be closed)
624 |         assert final_figure_count <= initial_figure_count + 1  # Allow for 1 open figure
625 | 
626 |     def test_chart_generation_performance_benchmark(self, benchmark_timer):
627 |         """Test chart generation performance benchmarks."""
628 |         # Create substantial dataset
629 |         dates = pd.date_range(
630 |             start="2023-01-01", end="2023-12-31", freq="H"
631 |         )  # Hourly data
632 |         returns = pd.Series(
633 |             np.cumprod(1 + np.random.normal(0.0001, 0.005, len(dates))), index=dates
634 |         )
635 | 
636 |         with benchmark_timer() as timer:
637 |             chart = generate_equity_curve(returns, title="Performance Test")
638 | 
639 |         # Should complete within reasonable time even with large dataset
640 |         assert timer.elapsed < 5.0  # < 5 seconds
641 |         assert len(chart) > 100  # Valid chart generated
642 | 
643 |     def test_concurrent_chart_generation(self):
644 |         """Test concurrent chart generation doesn't cause conflicts."""
645 |         import queue
646 |         import threading
647 | 
648 |         results_queue = queue.Queue()
649 |         error_queue = queue.Queue()
650 | 
651 |         def generate_chart(thread_id):
652 |             try:
653 |                 dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
654 |                 returns = pd.Series(
655 |                     np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))),
656 |                     index=dates,
657 |                 )
658 | 
659 |                 chart = generate_equity_curve(returns, title=f"Thread {thread_id}")
660 |                 results_queue.put((thread_id, len(chart)))
661 |             except Exception as e:
662 |                 error_queue.put(f"Thread {thread_id}: {e}")
663 | 
664 |         # Create multiple threads
665 |         threads = []
666 |         for i in range(5):
667 |             thread = threading.Thread(target=generate_chart, args=(i,))
668 |             threads.append(thread)
669 |             thread.start()
670 | 
671 |         # Wait for completion
672 |         for thread in threads:
673 |             thread.join(timeout=10)
674 | 
675 |         # Check results
676 |         assert error_queue.empty(), f"Errors: {list(error_queue.queue)}"
677 |         assert results_queue.qsize() == 5
678 | 
679 |         # All should have generated valid charts
680 |         while not results_queue.empty():
681 |             thread_id, chart_length = results_queue.get()
682 |             assert chart_length > 100
683 | 
684 | 
685 | if __name__ == "__main__":
686 |     # Run tests with detailed output
687 |     pytest.main([__file__, "-v", "--tb=short", "--asyncio-mode=auto"])
688 | 
```

--------------------------------------------------------------------------------
/tests/integration/test_full_backtest_workflow_advanced.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Advanced End-to-End Integration Tests for VectorBT Backtesting Workflow.
  3 | 
  4 | This comprehensive test suite covers:
  5 | - Complete workflow integration from data fetch to results
  6 | - All 15 strategies (9 traditional + 6 ML) testing
  7 | - Parallel execution capabilities
  8 | - Cache behavior and optimization
  9 | - Real production-like scenarios
 10 | - Error recovery and resilience
 11 | - Resource management and cleanup
 12 | """
 13 | 
 14 | import asyncio
 15 | import logging
 16 | import time
 17 | from unittest.mock import Mock
 18 | from uuid import UUID
 19 | 
 20 | import numpy as np
 21 | import pandas as pd
 22 | import pytest
 23 | 
 24 | from maverick_mcp.backtesting import (
 25 |     VectorBTEngine,
 26 | )
 27 | from maverick_mcp.backtesting.persistence import BacktestPersistenceManager
 28 | from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
 29 | from maverick_mcp.backtesting.visualization import (
 30 |     generate_equity_curve,
 31 |     generate_performance_dashboard,
 32 | )
 33 | 
 34 | logger = logging.getLogger(__name__)
 35 | 
 36 | # Strategy definitions for comprehensive testing
 37 | TRADITIONAL_STRATEGIES = [
 38 |     "sma_cross",
 39 |     "ema_cross",
 40 |     "rsi",
 41 |     "macd",
 42 |     "bollinger",
 43 |     "momentum",
 44 |     "breakout",
 45 |     "mean_reversion",
 46 |     "volume_momentum",
 47 | ]
 48 | 
 49 | ML_STRATEGIES = [
 50 |     "ml_predictor",
 51 |     "adaptive",
 52 |     "ensemble",
 53 |     "regime_aware",
 54 |     "online_learning",
 55 |     "reinforcement_learning",
 56 | ]
 57 | 
 58 | ALL_STRATEGIES = TRADITIONAL_STRATEGIES + ML_STRATEGIES
 59 | 
 60 | 
 61 | class TestAdvancedBacktestWorkflowIntegration:
 62 |     """Advanced integration tests for complete backtesting workflow."""
 63 | 
 64 |     @pytest.fixture
 65 |     async def enhanced_stock_data_provider(self):
 66 |         """Create enhanced mock stock data provider with realistic multi-year data."""
 67 |         provider = Mock()
 68 | 
 69 |         # Generate 3 years of realistic stock data with different market conditions
 70 |         dates = pd.date_range(start="2021-01-01", end="2023-12-31", freq="D")
 71 | 
 72 |         # Simulate different market regimes
 73 |         bull_period = len(dates) // 3  # First third: bull market
 74 |         sideways_period = len(dates) // 3  # Second third: sideways
 75 |         bear_period = len(dates) - bull_period - sideways_period  # Final: bear market
 76 | 
 77 |         # Generate returns for different regimes
 78 |         bull_returns = np.random.normal(0.0015, 0.015, bull_period)  # Positive drift
 79 |         sideways_returns = np.random.normal(0.0002, 0.02, sideways_period)  # Low drift
 80 |         bear_returns = np.random.normal(-0.001, 0.025, bear_period)  # Negative drift
 81 | 
 82 |         all_returns = np.concatenate([bull_returns, sideways_returns, bear_returns])
 83 |         prices = 100 * np.cumprod(1 + all_returns)  # Start at $100
 84 | 
 85 |         # Add realistic volume patterns
 86 |         volumes = np.random.randint(500000, 5000000, len(dates)).astype(float)
 87 |         volumes += np.random.normal(0, volumes * 0.1)  # Add volume volatility
 88 |         volumes = np.maximum(volumes, 100000)  # Minimum volume
 89 |         volumes = volumes.astype(int)  # Convert back to integers
 90 | 
 91 |         stock_data = pd.DataFrame(
 92 |             {
 93 |                 "Open": prices * np.random.uniform(0.995, 1.005, len(dates)),
 94 |                 "High": prices * np.random.uniform(1.002, 1.025, len(dates)),
 95 |                 "Low": prices * np.random.uniform(0.975, 0.998, len(dates)),
 96 |                 "Close": prices,
 97 |                 "Volume": volumes.astype(int),
 98 |                 "Adj Close": prices,
 99 |             },
100 |             index=dates,
101 |         )
102 | 
103 |         # Ensure OHLC constraints
104 |         stock_data["High"] = np.maximum(
105 |             stock_data["High"], np.maximum(stock_data["Open"], stock_data["Close"])
106 |         )
107 |         stock_data["Low"] = np.minimum(
108 |             stock_data["Low"], np.minimum(stock_data["Open"], stock_data["Close"])
109 |         )
110 | 
111 |         provider.get_stock_data.return_value = stock_data
112 |         return provider
113 | 
114 |     @pytest.fixture
115 |     async def complete_vectorbt_engine(self, enhanced_stock_data_provider):
116 |         """Create complete VectorBT engine with all strategies enabled."""
117 |         engine = VectorBTEngine(data_provider=enhanced_stock_data_provider)
118 |         return engine
119 | 
120 |     async def test_all_15_strategies_integration(
121 |         self, complete_vectorbt_engine, benchmark_timer
122 |     ):
123 |         """Test all 15 strategies (9 traditional + 6 ML) in complete workflow."""
124 |         results = {}
125 |         failed_strategies = []
126 | 
127 |         with benchmark_timer() as timer:
128 |             # Test traditional strategies
129 |             for strategy in TRADITIONAL_STRATEGIES:
130 |                 try:
131 |                     if strategy in STRATEGY_TEMPLATES:
132 |                         parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
133 |                         result = await complete_vectorbt_engine.run_backtest(
134 |                             symbol="COMPREHENSIVE_TEST",
135 |                             strategy_type=strategy,
136 |                             parameters=parameters,
137 |                             start_date="2022-01-01",
138 |                             end_date="2023-12-31",
139 |                         )
140 |                         results[strategy] = result
141 | 
142 |                         # Validate basic result structure
143 |                         assert "metrics" in result
144 |                         assert "trades" in result
145 |                         assert "equity_curve" in result
146 |                         assert result["symbol"] == "COMPREHENSIVE_TEST"
147 | 
148 |                         logger.info(f"✓ {strategy} strategy executed successfully")
149 |                     else:
150 |                         logger.warning(f"Strategy {strategy} not found in templates")
151 | 
152 |                 except Exception as e:
153 |                     failed_strategies.append(strategy)
154 |                     logger.error(f"✗ {strategy} strategy failed: {str(e)}")
155 | 
156 |             # Test ML strategies (mock implementation for integration test)
157 |             for strategy in ML_STRATEGIES:
158 |                 try:
159 |                     # Mock ML strategy execution
160 |                     mock_ml_result = {
161 |                         "symbol": "COMPREHENSIVE_TEST",
162 |                         "strategy_type": strategy,
163 |                         "metrics": {
164 |                             "total_return": np.random.uniform(-0.2, 0.3),
165 |                             "sharpe_ratio": np.random.uniform(0.5, 2.0),
166 |                             "max_drawdown": np.random.uniform(-0.3, -0.05),
167 |                             "total_trades": np.random.randint(10, 100),
168 |                         },
169 |                         "trades": [],
170 |                         "equity_curve": np.random.cumsum(
171 |                             np.random.normal(0.001, 0.02, 252)
172 |                         ).tolist(),
173 |                         "ml_specific": {
174 |                             "model_accuracy": np.random.uniform(0.55, 0.85),
175 |                             "feature_importance": {
176 |                                 "momentum": 0.3,
177 |                                 "volatility": 0.25,
178 |                                 "volume": 0.45,
179 |                             },
180 |                         },
181 |                     }
182 |                     results[strategy] = mock_ml_result
183 |                     logger.info(f"✓ {strategy} ML strategy simulated successfully")
184 | 
185 |                 except Exception as e:
186 |                     failed_strategies.append(strategy)
187 |                     logger.error(f"✗ {strategy} ML strategy failed: {str(e)}")
188 | 
189 |         execution_time = timer.elapsed
190 | 
191 |         # Validate overall results
192 |         successful_strategies = len(results)
193 |         total_strategies = len(ALL_STRATEGIES)
194 |         success_rate = successful_strategies / total_strategies
195 | 
196 |         # Performance requirements
197 |         assert execution_time < 180.0  # Should complete within 3 minutes
198 |         assert success_rate >= 0.8  # At least 80% success rate
199 |         assert successful_strategies >= 12  # At least 12 strategies should work
200 | 
201 |         # Log comprehensive results
202 |         logger.info(
203 |             f"Strategy Integration Test Results:\n"
204 |             f"  • Total Strategies: {total_strategies}\n"
205 |             f"  • Successful: {successful_strategies}\n"
206 |             f"  • Failed: {len(failed_strategies)}\n"
207 |             f"  • Success Rate: {success_rate:.1%}\n"
208 |             f"  • Execution Time: {execution_time:.2f}s\n"
209 |             f"  • Failed Strategies: {failed_strategies}"
210 |         )
211 | 
212 |         return {
213 |             "total_strategies": total_strategies,
214 |             "successful_strategies": successful_strategies,
215 |             "failed_strategies": failed_strategies,
216 |             "success_rate": success_rate,
217 |             "execution_time": execution_time,
218 |             "results": results,
219 |         }
220 | 
221 |     async def test_parallel_execution_capabilities(
222 |         self, complete_vectorbt_engine, benchmark_timer
223 |     ):
224 |         """Test parallel execution of multiple backtests."""
225 |         symbols = ["AAPL", "GOOGL", "MSFT", "TSLA", "AMZN", "META", "NFLX", "NVDA"]
226 |         strategies = ["sma_cross", "rsi", "macd", "bollinger"]
227 | 
228 |         async def run_single_backtest(symbol, strategy):
229 |             """Run a single backtest."""
230 |             try:
231 |                 parameters = STRATEGY_TEMPLATES.get(strategy, {}).get("parameters", {})
232 |                 result = await complete_vectorbt_engine.run_backtest(
233 |                     symbol=symbol,
234 |                     strategy_type=strategy,
235 |                     parameters=parameters,
236 |                     start_date="2023-01-01",
237 |                     end_date="2023-12-31",
238 |                 )
239 |                 return {
240 |                     "symbol": symbol,
241 |                     "strategy": strategy,
242 |                     "result": result,
243 |                     "success": True,
244 |                 }
245 |             except Exception as e:
246 |                 return {
247 |                     "symbol": symbol,
248 |                     "strategy": strategy,
249 |                     "error": str(e),
250 |                     "success": False,
251 |                 }
252 | 
253 |         with benchmark_timer() as timer:
254 |             # Create all combinations
255 |             tasks = []
256 |             for symbol in symbols:
257 |                 for strategy in strategies:
258 |                     tasks.append(run_single_backtest(symbol, strategy))
259 | 
260 |             # Execute in parallel with semaphore to control concurrency
261 |             semaphore = asyncio.Semaphore(8)  # Max 8 concurrent executions
262 | 
263 |             async def run_with_semaphore(task):
264 |                 async with semaphore:
265 |                     return await task
266 | 
267 |             results = await asyncio.gather(
268 |                 *[run_with_semaphore(task) for task in tasks], return_exceptions=True
269 |             )
270 | 
271 |         execution_time = timer.elapsed
272 | 
273 |         # Analyze results
274 |         total_executions = len(tasks)
275 |         successful_executions = sum(
276 |             1 for r in results if isinstance(r, dict) and r.get("success", False)
277 |         )
278 |         failed_executions = total_executions - successful_executions
279 | 
280 |         # Performance assertions
281 |         assert execution_time < 300.0  # Should complete within 5 minutes
282 |         assert successful_executions >= total_executions * 0.7  # At least 70% success
283 | 
284 |         # Calculate average execution time per backtest
285 |         avg_time_per_backtest = execution_time / total_executions
286 | 
287 |         logger.info(
288 |             f"Parallel Execution Results:\n"
289 |             f"  • Total Executions: {total_executions}\n"
290 |             f"  • Successful: {successful_executions}\n"
291 |             f"  • Failed: {failed_executions}\n"
292 |             f"  • Success Rate: {successful_executions / total_executions:.1%}\n"
293 |             f"  • Total Time: {execution_time:.2f}s\n"
294 |             f"  • Avg Time/Backtest: {avg_time_per_backtest:.2f}s\n"
295 |             f"  • Parallel Speedup: ~{total_executions * avg_time_per_backtest / execution_time:.1f}x"
296 |         )
297 | 
298 |         return {
299 |             "total_executions": total_executions,
300 |             "successful_executions": successful_executions,
301 |             "execution_time": execution_time,
302 |             "avg_time_per_backtest": avg_time_per_backtest,
303 |         }
304 | 
305 |     async def test_cache_behavior_and_optimization(self, complete_vectorbt_engine):
306 |         """Test cache behavior and optimization in integrated workflow."""
307 |         symbol = "CACHE_TEST_SYMBOL"
308 |         strategy = "sma_cross"
309 |         parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
310 | 
311 |         # First run - should populate cache
312 |         start_time = time.time()
313 |         result1 = await complete_vectorbt_engine.run_backtest(
314 |             symbol=symbol,
315 |             strategy_type=strategy,
316 |             parameters=parameters,
317 |             start_date="2023-01-01",
318 |             end_date="2023-12-31",
319 |         )
320 |         first_run_time = time.time() - start_time
321 | 
322 |         # Second run - should use cache
323 |         start_time = time.time()
324 |         result2 = await complete_vectorbt_engine.run_backtest(
325 |             symbol=symbol,
326 |             strategy_type=strategy,
327 |             parameters=parameters,
328 |             start_date="2023-01-01",
329 |             end_date="2023-12-31",
330 |         )
331 |         second_run_time = time.time() - start_time
332 | 
333 |         # Third run with different parameters - should not use cache
334 |         modified_parameters = {
335 |             **parameters,
336 |             "fast_period": parameters.get("fast_period", 10) + 5,
337 |         }
338 |         start_time = time.time()
339 |         await complete_vectorbt_engine.run_backtest(
340 |             symbol=symbol,
341 |             strategy_type=strategy,
342 |             parameters=modified_parameters,
343 |             start_date="2023-01-01",
344 |             end_date="2023-12-31",
345 |         )
346 |         third_run_time = time.time() - start_time
347 | 
348 |         # Validate results consistency (for cached runs)
349 |         assert result1["symbol"] == result2["symbol"]
350 |         assert result1["strategy_type"] == result2["strategy_type"]
351 | 
352 |         # Cache effectiveness check (second run might be faster, but not guaranteed)
353 |         cache_speedup = first_run_time / second_run_time if second_run_time > 0 else 1.0
354 | 
355 |         logger.info(
356 |             f"Cache Behavior Test Results:\n"
357 |             f"  • First Run: {first_run_time:.3f}s\n"
358 |             f"  • Second Run (cached): {second_run_time:.3f}s\n"
359 |             f"  • Third Run (different params): {third_run_time:.3f}s\n"
360 |             f"  • Cache Speedup: {cache_speedup:.2f}x\n"
361 |         )
362 | 
363 |         return {
364 |             "first_run_time": first_run_time,
365 |             "second_run_time": second_run_time,
366 |             "third_run_time": third_run_time,
367 |             "cache_speedup": cache_speedup,
368 |         }
369 | 
370 |     async def test_database_persistence_integration(
371 |         self, complete_vectorbt_engine, db_session
372 |     ):
373 |         """Test complete database persistence integration."""
374 |         # Generate test results
375 |         result = await complete_vectorbt_engine.run_backtest(
376 |             symbol="PERSISTENCE_TEST",
377 |             strategy_type="sma_cross",
378 |             parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
379 |             start_date="2023-01-01",
380 |             end_date="2023-12-31",
381 |         )
382 | 
383 |         # Test persistence workflow
384 |         with BacktestPersistenceManager(session=db_session) as persistence:
385 |             # Save backtest result
386 |             backtest_id = persistence.save_backtest_result(
387 |                 vectorbt_results=result,
388 |                 execution_time=2.5,
389 |                 notes="Integration test - complete persistence workflow",
390 |             )
391 | 
392 |             # Validate saved data
393 |             assert backtest_id is not None
394 |             assert UUID(backtest_id)  # Valid UUID
395 | 
396 |             # Retrieve and validate
397 |             saved_result = persistence.get_backtest_by_id(backtest_id)
398 |             assert saved_result is not None
399 |             assert saved_result.symbol == "PERSISTENCE_TEST"
400 |             assert saved_result.strategy_type == "sma_cross"
401 |             assert saved_result.execution_time == 2.5
402 | 
403 |             # Test batch operations
404 |             batch_results = []
405 |             for i in range(5):
406 |                 batch_result = await complete_vectorbt_engine.run_backtest(
407 |                     symbol=f"BATCH_TEST_{i}",
408 |                     strategy_type="rsi",
409 |                     parameters=STRATEGY_TEMPLATES["rsi"]["parameters"],
410 |                     start_date="2023-06-01",
411 |                     end_date="2023-12-31",
412 |                 )
413 |                 batch_results.append(batch_result)
414 | 
415 |             # Save batch results
416 |             batch_ids = []
417 |             for i, batch_result in enumerate(batch_results):
418 |                 batch_id = persistence.save_backtest_result(
419 |                     vectorbt_results=batch_result,
420 |                     execution_time=1.8 + i * 0.1,
421 |                     notes=f"Batch test #{i + 1}",
422 |                 )
423 |                 batch_ids.append(batch_id)
424 | 
425 |             # Query saved batch results
426 |             saved_batch = [persistence.get_backtest_by_id(bid) for bid in batch_ids]
427 |             assert all(saved is not None for saved in saved_batch)
428 |             assert len(saved_batch) == 5
429 | 
430 |             # Test filtering and querying
431 |             rsi_results = persistence.get_backtests_by_strategy("rsi")
432 |             assert len(rsi_results) >= 5  # At least our batch results
433 | 
434 |         logger.info("Database persistence test completed successfully")
435 |         return {"batch_ids": batch_ids, "single_id": backtest_id}
436 | 
437 |     async def test_visualization_integration_complete(self, complete_vectorbt_engine):
438 |         """Test complete visualization integration workflow."""
439 |         # Run backtest to get data for visualization
440 |         result = await complete_vectorbt_engine.run_backtest(
441 |             symbol="VIZ_TEST",
442 |             strategy_type="macd",
443 |             parameters=STRATEGY_TEMPLATES["macd"]["parameters"],
444 |             start_date="2023-01-01",
445 |             end_date="2023-12-31",
446 |         )
447 | 
448 |         # Test all visualization components
449 |         visualizations = {}
450 | 
451 |         # 1. Equity curve visualization
452 |         equity_data = pd.Series(result["equity_curve"])
453 |         drawdown_data = pd.Series(result["drawdown_series"])
454 | 
455 |         equity_chart = generate_equity_curve(
456 |             equity_data,
457 |             drawdown=drawdown_data,
458 |             title="Complete Integration Test - Equity Curve",
459 |         )
460 |         visualizations["equity_curve"] = equity_chart
461 | 
462 |         # 2. Performance dashboard
463 |         dashboard_chart = generate_performance_dashboard(
464 |             result["metrics"], title="Complete Integration Test - Performance Dashboard"
465 |         )
466 |         visualizations["dashboard"] = dashboard_chart
467 | 
468 |         # 3. Validate all visualizations
469 |         for viz_name, viz_data in visualizations.items():
470 |             assert isinstance(viz_data, str), f"{viz_name} should return string"
471 |             assert len(viz_data) > 100, f"{viz_name} should have substantial content"
472 | 
473 |             # Try to decode as base64 (should be valid image)
474 |             try:
475 |                 import base64
476 | 
477 |                 decoded = base64.b64decode(viz_data)
478 |                 assert len(decoded) > 0, f"{viz_name} should have valid image data"
479 |                 logger.info(f"✓ {viz_name} visualization generated successfully")
480 |             except Exception as e:
481 |                 logger.error(f"✗ {viz_name} visualization failed: {e}")
482 |                 raise
483 | 
484 |         return visualizations
485 | 
486 |     async def test_error_recovery_comprehensive(self, complete_vectorbt_engine):
487 |         """Test comprehensive error recovery across the workflow."""
488 |         recovery_results = {}
489 | 
490 |         # 1. Invalid symbol handling
491 |         try:
492 |             result = await complete_vectorbt_engine.run_backtest(
493 |                 symbol="",  # Empty symbol
494 |                 strategy_type="sma_cross",
495 |                 parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
496 |                 start_date="2023-01-01",
497 |                 end_date="2023-12-31",
498 |             )
499 |             recovery_results["empty_symbol"] = {"recovered": True, "result": result}
500 |         except Exception as e:
501 |             recovery_results["empty_symbol"] = {"recovered": False, "error": str(e)}
502 | 
503 |         # 2. Invalid date range handling
504 |         try:
505 |             result = await complete_vectorbt_engine.run_backtest(
506 |                 symbol="ERROR_TEST",
507 |                 strategy_type="sma_cross",
508 |                 parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
509 |                 start_date="2025-01-01",  # Future date
510 |                 end_date="2025-12-31",
511 |             )
512 |             recovery_results["future_dates"] = {"recovered": True, "result": result}
513 |         except Exception as e:
514 |             recovery_results["future_dates"] = {"recovered": False, "error": str(e)}
515 | 
516 |         # 3. Invalid strategy parameters
517 |         try:
518 |             invalid_params = {
519 |                 "fast_period": -10,
520 |                 "slow_period": -20,
521 |             }  # Invalid negative values
522 |             result = await complete_vectorbt_engine.run_backtest(
523 |                 symbol="ERROR_TEST",
524 |                 strategy_type="sma_cross",
525 |                 parameters=invalid_params,
526 |                 start_date="2023-01-01",
527 |                 end_date="2023-12-31",
528 |             )
529 |             recovery_results["invalid_params"] = {"recovered": True, "result": result}
530 |         except Exception as e:
531 |             recovery_results["invalid_params"] = {"recovered": False, "error": str(e)}
532 | 
533 |         # 4. Unknown strategy handling
534 |         try:
535 |             result = await complete_vectorbt_engine.run_backtest(
536 |                 symbol="ERROR_TEST",
537 |                 strategy_type="nonexistent_strategy",
538 |                 parameters={},
539 |                 start_date="2023-01-01",
540 |                 end_date="2023-12-31",
541 |             )
542 |             recovery_results["unknown_strategy"] = {"recovered": True, "result": result}
543 |         except Exception as e:
544 |             recovery_results["unknown_strategy"] = {"recovered": False, "error": str(e)}
545 | 
546 |         # Analyze recovery effectiveness
547 |         total_tests = len(recovery_results)
548 |         recovered_tests = sum(
549 |             1 for r in recovery_results.values() if r.get("recovered", False)
550 |         )
551 |         recovery_rate = recovered_tests / total_tests if total_tests > 0 else 0
552 | 
553 |         logger.info(
554 |             f"Error Recovery Test Results:\n"
555 |             f"  • Total Error Scenarios: {total_tests}\n"
556 |             f"  • Successfully Recovered: {recovered_tests}\n"
557 |             f"  • Recovery Rate: {recovery_rate:.1%}\n"
558 |         )
559 | 
560 |         for scenario, result in recovery_results.items():
561 |             status = "✓ RECOVERED" if result.get("recovered") else "✗ FAILED"
562 |             logger.info(f"  • {scenario}: {status}")
563 | 
564 |         return recovery_results
565 | 
566 |     async def test_resource_management_comprehensive(self, complete_vectorbt_engine):
567 |         """Test comprehensive resource management across workflow."""
568 |         import os
569 | 
570 |         import psutil
571 | 
572 |         process = psutil.Process(os.getpid())
573 | 
574 |         # Baseline measurements
575 |         initial_memory = process.memory_info().rss / 1024 / 1024  # MB
576 |         initial_threads = process.num_threads()
577 |         resource_snapshots = []
578 | 
579 |         # Run multiple backtests while monitoring resources
580 |         for i in range(10):
581 |             await complete_vectorbt_engine.run_backtest(
582 |                 symbol=f"RESOURCE_TEST_{i}",
583 |                 strategy_type="sma_cross",
584 |                 parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
585 |                 start_date="2023-01-01",
586 |                 end_date="2023-12-31",
587 |             )
588 | 
589 |             # Take resource snapshot
590 |             current_memory = process.memory_info().rss / 1024 / 1024  # MB
591 |             current_threads = process.num_threads()
592 |             current_cpu = process.cpu_percent()
593 | 
594 |             resource_snapshots.append(
595 |                 {
596 |                     "iteration": i + 1,
597 |                     "memory_mb": current_memory,
598 |                     "threads": current_threads,
599 |                     "cpu_percent": current_cpu,
600 |                 }
601 |             )
602 | 
603 |         # Final measurements
604 |         final_memory = process.memory_info().rss / 1024 / 1024  # MB
605 |         final_threads = process.num_threads()
606 | 
607 |         # Calculate resource growth
608 |         memory_growth = final_memory - initial_memory
609 |         thread_growth = final_threads - initial_threads
610 |         peak_memory = max(snapshot["memory_mb"] for snapshot in resource_snapshots)
611 |         avg_threads = sum(snapshot["threads"] for snapshot in resource_snapshots) / len(
612 |             resource_snapshots
613 |         )
614 | 
615 |         # Resource management assertions
616 |         assert memory_growth < 500, (
617 |             f"Memory growth too high: {memory_growth:.1f}MB"
618 |         )  # Max 500MB growth
619 |         assert thread_growth <= 10, (
620 |             f"Thread growth too high: {thread_growth}"
621 |         )  # Max 10 additional threads
622 |         assert peak_memory < initial_memory + 1000, (
623 |             f"Peak memory too high: {peak_memory:.1f}MB"
624 |         )  # Peak within 1GB of initial
625 | 
626 |         logger.info(
627 |             f"Resource Management Test Results:\n"
628 |             f"  • Initial Memory: {initial_memory:.1f}MB\n"
629 |             f"  • Final Memory: {final_memory:.1f}MB\n"
630 |             f"  • Memory Growth: {memory_growth:.1f}MB\n"
631 |             f"  • Peak Memory: {peak_memory:.1f}MB\n"
632 |             f"  • Initial Threads: {initial_threads}\n"
633 |             f"  • Final Threads: {final_threads}\n"
634 |             f"  • Thread Growth: {thread_growth}\n"
635 |             f"  • Avg Threads: {avg_threads:.1f}"
636 |         )
637 | 
638 |         return {
639 |             "memory_growth": memory_growth,
640 |             "thread_growth": thread_growth,
641 |             "peak_memory": peak_memory,
642 |             "resource_snapshots": resource_snapshots,
643 |         }
644 | 
645 | 
646 | if __name__ == "__main__":
647 |     # Run advanced integration tests
648 |     pytest.main(
649 |         [
650 |             __file__,
651 |             "-v",
652 |             "--tb=short",
653 |             "--asyncio-mode=auto",
654 |             "--timeout=600",  # 10 minute timeout for comprehensive tests
655 |             "-x",  # Stop on first failure
656 |             "--durations=10",  # Show 10 slowest tests
657 |         ]
658 |     )
659 | 
```

--------------------------------------------------------------------------------
/tests/test_tool_estimation_config.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Comprehensive tests for ToolEstimationConfig.
  3 | 
  4 | This module tests the centralized tool cost estimation configuration that replaces
  5 | magic numbers scattered throughout the codebase. Tests cover:
  6 | - All tool-specific estimates
  7 | - Confidence levels and estimation basis
  8 | - Monitoring thresholds and alert conditions
  9 | - Edge cases and error handling
 10 | - Integration with server.py patterns
 11 | """
 12 | 
 13 | from unittest.mock import patch
 14 | 
 15 | import pytest
 16 | 
 17 | from maverick_mcp.config.tool_estimation import (
 18 |     EstimationBasis,
 19 |     MonitoringThresholds,
 20 |     ToolComplexity,
 21 |     ToolEstimate,
 22 |     ToolEstimationConfig,
 23 |     get_tool_estimate,
 24 |     get_tool_estimation_config,
 25 |     should_alert_for_usage,
 26 | )
 27 | 
 28 | 
 29 | class TestToolEstimate:
 30 |     """Test ToolEstimate model validation and behavior."""
 31 | 
 32 |     def test_valid_tool_estimate(self):
 33 |         """Test creating a valid ToolEstimate."""
 34 |         estimate = ToolEstimate(
 35 |             llm_calls=5,
 36 |             total_tokens=8000,
 37 |             confidence=0.8,
 38 |             based_on=EstimationBasis.EMPIRICAL,
 39 |             complexity=ToolComplexity.COMPLEX,
 40 |             notes="Test estimate",
 41 |         )
 42 | 
 43 |         assert estimate.llm_calls == 5
 44 |         assert estimate.total_tokens == 8000
 45 |         assert estimate.confidence == 0.8
 46 |         assert estimate.based_on == EstimationBasis.EMPIRICAL
 47 |         assert estimate.complexity == ToolComplexity.COMPLEX
 48 |         assert estimate.notes == "Test estimate"
 49 | 
 50 |     def test_confidence_validation(self):
 51 |         """Test confidence level validation."""
 52 |         # Valid confidence levels
 53 |         for confidence in [0.0, 0.5, 1.0]:
 54 |             estimate = ToolEstimate(
 55 |                 llm_calls=1,
 56 |                 total_tokens=100,
 57 |                 confidence=confidence,
 58 |                 based_on=EstimationBasis.EMPIRICAL,
 59 |                 complexity=ToolComplexity.SIMPLE,
 60 |             )
 61 |             assert estimate.confidence == confidence
 62 | 
 63 |         # Invalid confidence levels - Pydantic ValidationError
 64 |         from pydantic import ValidationError
 65 | 
 66 |         with pytest.raises(ValidationError):
 67 |             ToolEstimate(
 68 |                 llm_calls=1,
 69 |                 total_tokens=100,
 70 |                 confidence=-0.1,
 71 |                 based_on=EstimationBasis.EMPIRICAL,
 72 |                 complexity=ToolComplexity.SIMPLE,
 73 |             )
 74 | 
 75 |         with pytest.raises(ValidationError):
 76 |             ToolEstimate(
 77 |                 llm_calls=1,
 78 |                 total_tokens=100,
 79 |                 confidence=1.1,
 80 |                 based_on=EstimationBasis.EMPIRICAL,
 81 |                 complexity=ToolComplexity.SIMPLE,
 82 |             )
 83 | 
 84 |     def test_negative_values_validation(self):
 85 |         """Test that negative values are not allowed."""
 86 |         from pydantic import ValidationError
 87 | 
 88 |         with pytest.raises(ValidationError):
 89 |             ToolEstimate(
 90 |                 llm_calls=-1,
 91 |                 total_tokens=100,
 92 |                 confidence=0.8,
 93 |                 based_on=EstimationBasis.EMPIRICAL,
 94 |                 complexity=ToolComplexity.SIMPLE,
 95 |             )
 96 | 
 97 |         with pytest.raises(ValidationError):
 98 |             ToolEstimate(
 99 |                 llm_calls=1,
100 |                 total_tokens=-100,
101 |                 confidence=0.8,
102 |                 based_on=EstimationBasis.EMPIRICAL,
103 |                 complexity=ToolComplexity.SIMPLE,
104 |             )
105 | 
106 | 
107 | class TestMonitoringThresholds:
108 |     """Test MonitoringThresholds model and validation."""
109 | 
110 |     def test_default_thresholds(self):
111 |         """Test default monitoring thresholds."""
112 |         thresholds = MonitoringThresholds()
113 | 
114 |         assert thresholds.llm_calls_warning == 15
115 |         assert thresholds.llm_calls_critical == 25
116 |         assert thresholds.tokens_warning == 20000
117 |         assert thresholds.tokens_critical == 35000
118 |         assert thresholds.variance_warning == 0.5
119 |         assert thresholds.variance_critical == 1.0
120 | 
121 |     def test_custom_thresholds(self):
122 |         """Test custom monitoring thresholds."""
123 |         thresholds = MonitoringThresholds(
124 |             llm_calls_warning=10,
125 |             llm_calls_critical=20,
126 |             tokens_warning=15000,
127 |             tokens_critical=30000,
128 |             variance_warning=0.3,
129 |             variance_critical=0.8,
130 |         )
131 | 
132 |         assert thresholds.llm_calls_warning == 10
133 |         assert thresholds.llm_calls_critical == 20
134 |         assert thresholds.tokens_warning == 15000
135 |         assert thresholds.tokens_critical == 30000
136 |         assert thresholds.variance_warning == 0.3
137 |         assert thresholds.variance_critical == 0.8
138 | 
139 | 
140 | class TestToolEstimationConfig:
141 |     """Test the main ToolEstimationConfig class."""
142 | 
143 |     def test_default_configuration(self):
144 |         """Test default configuration initialization."""
145 |         config = ToolEstimationConfig()
146 | 
147 |         # Test default estimates by complexity
148 |         assert config.simple_default.complexity == ToolComplexity.SIMPLE
149 |         assert config.standard_default.complexity == ToolComplexity.STANDARD
150 |         assert config.complex_default.complexity == ToolComplexity.COMPLEX
151 |         assert config.premium_default.complexity == ToolComplexity.PREMIUM
152 | 
153 |         # Test unknown tool fallback
154 |         assert config.unknown_tool_estimate.complexity == ToolComplexity.STANDARD
155 |         assert config.unknown_tool_estimate.confidence == 0.3
156 |         assert config.unknown_tool_estimate.based_on == EstimationBasis.CONSERVATIVE
157 | 
158 |     def test_get_estimate_known_tools(self):
159 |         """Test getting estimates for known tools."""
160 |         config = ToolEstimationConfig()
161 | 
162 |         # Test simple tools
163 |         simple_tools = [
164 |             "get_stock_price",
165 |             "get_company_info",
166 |             "get_stock_info",
167 |             "calculate_sma",
168 |             "get_market_hours",
169 |             "get_chart_links",
170 |             "list_available_agents",
171 |             "clear_cache",
172 |             "get_cached_price_data",
173 |             "get_watchlist",
174 |             "generate_dev_token",
175 |         ]
176 | 
177 |         for tool in simple_tools:
178 |             estimate = config.get_estimate(tool)
179 |             assert estimate.complexity == ToolComplexity.SIMPLE
180 |             assert estimate.llm_calls <= 1  # Simple tools should have minimal LLM usage
181 |             assert estimate.confidence >= 0.8  # Should have high confidence
182 | 
183 |         # Test standard tools
184 |         standard_tools = [
185 |             "get_rsi_analysis",
186 |             "get_macd_analysis",
187 |             "get_support_resistance",
188 |             "fetch_stock_data",
189 |             "get_maverick_stocks",
190 |             "get_news_sentiment",
191 |             "get_economic_calendar",
192 |         ]
193 | 
194 |         for tool in standard_tools:
195 |             estimate = config.get_estimate(tool)
196 |             assert estimate.complexity == ToolComplexity.STANDARD
197 |             assert 1 <= estimate.llm_calls <= 5
198 |             assert estimate.confidence >= 0.7
199 | 
200 |         # Test complex tools
201 |         complex_tools = [
202 |             "get_full_technical_analysis",
203 |             "risk_adjusted_analysis",
204 |             "compare_tickers",
205 |             "portfolio_correlation_analysis",
206 |             "get_market_overview",
207 |             "get_all_screening_recommendations",
208 |         ]
209 | 
210 |         for tool in complex_tools:
211 |             estimate = config.get_estimate(tool)
212 |             assert estimate.complexity == ToolComplexity.COMPLEX
213 |             assert 4 <= estimate.llm_calls <= 8
214 |             assert estimate.confidence >= 0.7
215 | 
216 |         # Test premium tools
217 |         premium_tools = [
218 |             "analyze_market_with_agent",
219 |             "get_agent_streaming_analysis",
220 |             "compare_personas_analysis",
221 |         ]
222 | 
223 |         for tool in premium_tools:
224 |             estimate = config.get_estimate(tool)
225 |             assert estimate.complexity == ToolComplexity.PREMIUM
226 |             assert estimate.llm_calls >= 8
227 |             assert estimate.total_tokens >= 10000
228 | 
229 |     def test_get_estimate_unknown_tool(self):
230 |         """Test getting estimate for unknown tools."""
231 |         config = ToolEstimationConfig()
232 |         estimate = config.get_estimate("unknown_tool_name")
233 | 
234 |         assert estimate == config.unknown_tool_estimate
235 |         assert estimate.complexity == ToolComplexity.STANDARD
236 |         assert estimate.confidence == 0.3
237 |         assert estimate.based_on == EstimationBasis.CONSERVATIVE
238 | 
239 |     def test_get_default_for_complexity(self):
240 |         """Test getting default estimates by complexity."""
241 |         config = ToolEstimationConfig()
242 | 
243 |         simple = config.get_default_for_complexity(ToolComplexity.SIMPLE)
244 |         assert simple == config.simple_default
245 | 
246 |         standard = config.get_default_for_complexity(ToolComplexity.STANDARD)
247 |         assert standard == config.standard_default
248 | 
249 |         complex_est = config.get_default_for_complexity(ToolComplexity.COMPLEX)
250 |         assert complex_est == config.complex_default
251 | 
252 |         premium = config.get_default_for_complexity(ToolComplexity.PREMIUM)
253 |         assert premium == config.premium_default
254 | 
255 |     def test_should_alert_critical_thresholds(self):
256 |         """Test alert conditions for critical thresholds."""
257 |         config = ToolEstimationConfig()
258 | 
259 |         # Test critical LLM calls threshold
260 |         should_alert, message = config.should_alert("test_tool", 30, 5000)
261 |         assert should_alert
262 |         assert "Critical: LLM calls (30) exceeded threshold (25)" in message
263 | 
264 |         # Test critical token threshold
265 |         should_alert, message = config.should_alert("test_tool", 5, 40000)
266 |         assert should_alert
267 |         assert "Critical: Token usage (40000) exceeded threshold (35000)" in message
268 | 
269 |     def test_should_alert_variance_thresholds(self):
270 |         """Test alert conditions for variance thresholds."""
271 |         config = ToolEstimationConfig()
272 | 
273 |         # Test tool with known estimate for variance calculation
274 |         # get_stock_price: llm_calls=0, total_tokens=200
275 | 
276 |         # Test critical LLM variance (infinite variance since estimate is 0)
277 |         should_alert, message = config.should_alert("get_stock_price", 5, 200)
278 |         assert should_alert
279 |         assert "Critical: LLM call variance" in message
280 | 
281 |         # Test critical token variance (5x the estimate = 400% variance)
282 |         should_alert, message = config.should_alert("get_stock_price", 0, 1000)
283 |         assert should_alert
284 |         assert "Critical: Token variance" in message
285 | 
286 |     def test_should_alert_warning_thresholds(self):
287 |         """Test alert conditions for warning thresholds."""
288 |         config = ToolEstimationConfig()
289 | 
290 |         # Test warning LLM calls threshold (15-24 should trigger warning)
291 |         # Use unknown tool which has reasonable base estimates to avoid variance issues
292 |         should_alert, message = config.should_alert("unknown_tool", 18, 5000)
293 |         assert should_alert
294 |         assert (
295 |             "Warning" in message or "Critical" in message
296 |         )  # May trigger critical due to variance
297 | 
298 |         # Test warning token threshold with a tool that has known estimates
299 |         # get_rsi_analysis: llm_calls=2, total_tokens=3000
300 |         should_alert, message = config.should_alert("get_rsi_analysis", 2, 25000)
301 |         assert should_alert
302 |         assert (
303 |             "Warning" in message or "Critical" in message
304 |         )  # High token variance may trigger critical
305 | 
306 |     def test_should_alert_no_alert(self):
307 |         """Test cases where no alert should be triggered."""
308 |         config = ToolEstimationConfig()
309 | 
310 |         # Normal usage within expected ranges
311 |         should_alert, message = config.should_alert("get_stock_price", 0, 200)
312 |         assert not should_alert
313 |         assert message == ""
314 | 
315 |         # Slightly above estimate but within acceptable variance
316 |         should_alert, message = config.should_alert("get_stock_price", 0, 250)
317 |         assert not should_alert
318 |         assert message == ""
319 | 
320 |     def test_get_tools_by_complexity(self):
321 |         """Test filtering tools by complexity category."""
322 |         config = ToolEstimationConfig()
323 | 
324 |         simple_tools = config.get_tools_by_complexity(ToolComplexity.SIMPLE)
325 |         standard_tools = config.get_tools_by_complexity(ToolComplexity.STANDARD)
326 |         complex_tools = config.get_tools_by_complexity(ToolComplexity.COMPLEX)
327 |         premium_tools = config.get_tools_by_complexity(ToolComplexity.PREMIUM)
328 | 
329 |         # Verify all tools are categorized
330 |         all_tools = simple_tools + standard_tools + complex_tools + premium_tools
331 |         assert len(all_tools) == len(config.tool_estimates)
332 | 
333 |         # Verify no overlap between categories
334 |         assert len(set(all_tools)) == len(all_tools)
335 | 
336 |         # Verify specific known tools are in correct categories
337 |         assert "get_stock_price" in simple_tools
338 |         assert "get_rsi_analysis" in standard_tools
339 |         assert "get_full_technical_analysis" in complex_tools
340 |         assert "analyze_market_with_agent" in premium_tools
341 | 
342 |     def test_get_summary_stats(self):
343 |         """Test summary statistics generation."""
344 |         config = ToolEstimationConfig()
345 |         stats = config.get_summary_stats()
346 | 
347 |         # Verify structure
348 |         assert "total_tools" in stats
349 |         assert "by_complexity" in stats
350 |         assert "avg_llm_calls" in stats
351 |         assert "avg_tokens" in stats
352 |         assert "avg_confidence" in stats
353 |         assert "basis_distribution" in stats
354 | 
355 |         # Verify content
356 |         assert stats["total_tools"] > 0
357 |         assert len(stats["by_complexity"]) == 4  # All complexity levels
358 |         assert stats["avg_llm_calls"] >= 0
359 |         assert stats["avg_tokens"] > 0
360 |         assert 0 <= stats["avg_confidence"] <= 1
361 | 
362 |         # Verify complexity distribution adds up
363 |         complexity_sum = sum(stats["by_complexity"].values())
364 |         assert complexity_sum == stats["total_tools"]
365 | 
366 | 
367 | class TestModuleFunctions:
368 |     """Test module-level functions."""
369 | 
370 |     def test_get_tool_estimation_config_singleton(self):
371 |         """Test that get_tool_estimation_config returns a singleton."""
372 |         config1 = get_tool_estimation_config()
373 |         config2 = get_tool_estimation_config()
374 | 
375 |         # Should return the same instance
376 |         assert config1 is config2
377 | 
378 |     @patch("maverick_mcp.config.tool_estimation._config", None)
379 |     def test_get_tool_estimation_config_initialization(self):
380 |         """Test that configuration is initialized correctly."""
381 |         config = get_tool_estimation_config()
382 | 
383 |         assert isinstance(config, ToolEstimationConfig)
384 |         assert len(config.tool_estimates) > 0
385 | 
386 |     def test_get_tool_estimate_function(self):
387 |         """Test the get_tool_estimate convenience function."""
388 |         estimate = get_tool_estimate("get_stock_price")
389 | 
390 |         assert isinstance(estimate, ToolEstimate)
391 |         assert estimate.complexity == ToolComplexity.SIMPLE
392 | 
393 |         # Test unknown tool
394 |         unknown_estimate = get_tool_estimate("unknown_tool")
395 |         assert unknown_estimate.based_on == EstimationBasis.CONSERVATIVE
396 | 
397 |     def test_should_alert_for_usage_function(self):
398 |         """Test the should_alert_for_usage convenience function."""
399 |         should_alert, message = should_alert_for_usage("test_tool", 30, 5000)
400 | 
401 |         assert isinstance(should_alert, bool)
402 |         assert isinstance(message, str)
403 | 
404 |         # Should trigger alert for high LLM calls
405 |         assert should_alert
406 |         assert "Critical" in message
407 | 
408 | 
409 | class TestMagicNumberReplacement:
410 |     """Test that configuration correctly replaces magic numbers from server.py."""
411 | 
412 |     def test_all_usage_tier_tools_have_estimates(self):
413 |         """Test that all tools referenced in server.py have estimates."""
414 |         config = ToolEstimationConfig()
415 | 
416 |         # These are tools that were using magic numbers in server.py
417 |         # Based on the TOOL usage tier mapping pattern
418 |         critical_tools = [
419 |             # Simple tools (baseline tier)
420 |             "get_stock_price",
421 |             "get_company_info",
422 |             "get_stock_info",
423 |             "calculate_sma",
424 |             "get_market_hours",
425 |             "get_chart_links",
426 |             # Standard tools (core analysis tier)
427 |             "get_rsi_analysis",
428 |             "get_macd_analysis",
429 |             "get_support_resistance",
430 |             "fetch_stock_data",
431 |             "get_maverick_stocks",
432 |             "get_news_sentiment",
433 |             # Complex tools (advanced analysis tier)
434 |             "get_full_technical_analysis",
435 |             "risk_adjusted_analysis",
436 |             "compare_tickers",
437 |             "portfolio_correlation_analysis",
438 |             "get_market_overview",
439 |             # Premium tools (orchestration tier)
440 |             "analyze_market_with_agent",
441 |             "get_agent_streaming_analysis",
442 |             "compare_personas_analysis",
443 |         ]
444 | 
445 |         for tool in critical_tools:
446 |             estimate = config.get_estimate(tool)
447 |             # Should not get the fallback estimate
448 |             assert estimate != config.unknown_tool_estimate, (
449 |                 f"Tool {tool} missing specific estimate"
450 |             )
451 |             # Should have reasonable confidence
452 |             assert estimate.confidence > 0.5, f"Tool {tool} has low confidence estimate"
453 | 
454 |     def test_estimates_align_with_usage_tiers(self):
455 |         """Test that tool estimates align with usage complexity tiers."""
456 |         config = ToolEstimationConfig()
457 | 
458 |         # Simple tools should require minimal resources
459 |         simple_tools = [
460 |             "get_stock_price",
461 |             "get_company_info",
462 |             "get_stock_info",
463 |             "calculate_sma",
464 |             "get_market_hours",
465 |             "get_chart_links",
466 |         ]
467 | 
468 |         for tool in simple_tools:
469 |             estimate = config.get_estimate(tool)
470 |             assert estimate.complexity == ToolComplexity.SIMPLE
471 |             assert estimate.llm_calls <= 1  # Should require minimal/no LLM calls
472 | 
473 |         # Standard tools perform moderate analysis
474 |         standard_tools = [
475 |             "get_rsi_analysis",
476 |             "get_macd_analysis",
477 |             "get_support_resistance",
478 |             "fetch_stock_data",
479 |             "get_maverick_stocks",
480 |         ]
481 | 
482 |         for tool in standard_tools:
483 |             estimate = config.get_estimate(tool)
484 |             assert estimate.complexity == ToolComplexity.STANDARD
485 |             assert 1 <= estimate.llm_calls <= 5  # Moderate LLM usage
486 | 
487 |         # Complex tools orchestrate heavier workloads
488 |         complex_tools = [
489 |             "get_full_technical_analysis",
490 |             "risk_adjusted_analysis",
491 |             "compare_tickers",
492 |             "portfolio_correlation_analysis",
493 |         ]
494 | 
495 |         for tool in complex_tools:
496 |             estimate = config.get_estimate(tool)
497 |             assert estimate.complexity == ToolComplexity.COMPLEX
498 |             assert 4 <= estimate.llm_calls <= 8  # Multiple LLM interactions
499 | 
500 |         # Premium tools coordinate multi-stage workflows
501 |         premium_tools = [
502 |             "analyze_market_with_agent",
503 |             "get_agent_streaming_analysis",
504 |             "compare_personas_analysis",
505 |         ]
506 | 
507 |         for tool in premium_tools:
508 |             estimate = config.get_estimate(tool)
509 |             assert estimate.complexity == ToolComplexity.PREMIUM
510 |             assert estimate.llm_calls >= 8  # Extensive LLM coordination
511 | 
512 |     def test_no_hardcoded_estimates_remain(self):
513 |         """Test that estimates are data-driven, not hardcoded."""
514 |         config = ToolEstimationConfig()
515 | 
516 |         # All tool estimates should have basis information
517 |         for tool_name, estimate in config.tool_estimates.items():
518 |             assert estimate.based_on in EstimationBasis
519 |             assert estimate.complexity in ToolComplexity
520 |             assert estimate.notes is not None, f"Tool {tool_name} missing notes"
521 | 
522 |             # Empirical estimates should generally have reasonable confidence
523 |             if estimate.based_on == EstimationBasis.EMPIRICAL:
524 |                 assert estimate.confidence >= 0.6, (
525 |                     f"Empirical estimate for {tool_name} has very low confidence"
526 |                 )
527 | 
528 |             # Conservative estimates should have lower confidence
529 |             if estimate.based_on == EstimationBasis.CONSERVATIVE:
530 |                 assert estimate.confidence <= 0.6, (
531 |                     f"Conservative estimate for {tool_name} has unexpectedly high confidence"
532 |                 )
533 | 
534 | 
535 | class TestEdgeCases:
536 |     """Test edge cases and error conditions."""
537 | 
538 |     def test_empty_configuration(self):
539 |         """Test behavior with empty tool estimates."""
540 |         config = ToolEstimationConfig(tool_estimates={})
541 | 
542 |         # Should fall back to unknown tool estimate
543 |         estimate = config.get_estimate("any_tool")
544 |         assert estimate == config.unknown_tool_estimate
545 | 
546 |         # Summary stats should handle empty case
547 |         stats = config.get_summary_stats()
548 |         assert stats == {}
549 | 
550 |     def test_alert_with_zero_estimates(self):
551 |         """Test alert calculation when estimates are zero."""
552 |         config = ToolEstimationConfig()
553 | 
554 |         # Tool with zero LLM calls in estimate
555 |         should_alert, message = config.should_alert("get_stock_price", 1, 200)
556 |         # Should alert because variance is infinite (1 vs 0 expected)
557 |         assert should_alert
558 | 
559 |     def test_variance_calculation_edge_cases(self):
560 |         """Test variance calculation with edge cases."""
561 |         config = ToolEstimationConfig()
562 | 
563 |         # Perfect match should not alert
564 |         should_alert, message = config.should_alert("get_rsi_analysis", 2, 3000)
565 |         # get_rsi_analysis has: llm_calls=2, total_tokens=3000
566 |         assert not should_alert
567 | 
568 |     def test_performance_with_large_usage(self):
569 |         """Test performance and behavior with extremely large usage values."""
570 |         config = ToolEstimationConfig()
571 | 
572 |         # Very large values should still work
573 |         should_alert, message = config.should_alert("test_tool", 1000, 1000000)
574 |         assert should_alert
575 |         assert "Critical" in message
576 | 
577 |     def test_custom_monitoring_thresholds(self):
578 |         """Test configuration with custom monitoring thresholds."""
579 |         custom_monitoring = MonitoringThresholds(
580 |             llm_calls_warning=5,
581 |             llm_calls_critical=10,
582 |             tokens_warning=1000,
583 |             tokens_critical=5000,
584 |             variance_warning=0.1,
585 |             variance_critical=0.2,
586 |         )
587 | 
588 |         config = ToolEstimationConfig(monitoring=custom_monitoring)
589 | 
590 |         # Should use custom thresholds
591 |         # Test critical threshold first (easier to predict)
592 |         should_alert, message = config.should_alert("test_tool", 12, 500)
593 |         assert should_alert
594 |         assert "Critical" in message
595 | 
596 |         # Test LLM calls warning threshold
597 |         should_alert, message = config.should_alert(
598 |             "test_tool", 6, 100
599 |         )  # Lower tokens to avoid variance issues
600 |         assert should_alert
601 |         # May be warning or critical depending on variance calculation
602 | 
603 | 
604 | class TestIntegrationPatterns:
605 |     """Test patterns that match server.py integration."""
606 | 
607 |     def test_low_confidence_logging_pattern(self):
608 |         """Test identifying tools that need monitoring due to low confidence."""
609 |         config = ToolEstimationConfig()
610 | 
611 |         low_confidence_tools = []
612 |         for tool_name, estimate in config.tool_estimates.items():
613 |             if estimate.confidence < 0.8:
614 |                 low_confidence_tools.append(tool_name)
615 | 
616 |         # These tools should be logged for monitoring in production
617 |         assert len(low_confidence_tools) > 0
618 | 
619 |         # Verify these are typically more complex tools
620 |         for tool_name in low_confidence_tools:
621 |             estimate = config.get_estimate(tool_name)
622 |             # Low confidence tools should typically be complex, premium, or analytical standard tools
623 |             assert estimate.complexity in [
624 |                 ToolComplexity.STANDARD,
625 |                 ToolComplexity.COMPLEX,
626 |                 ToolComplexity.PREMIUM,
627 |             ], (
628 |                 f"Tool {tool_name} with low confidence has unexpected complexity {estimate.complexity}"
629 |             )
630 | 
631 |     def test_error_handling_fallback_pattern(self):
632 |         """Test the error handling pattern used in server.py."""
633 |         config = ToolEstimationConfig()
634 | 
635 |         # Simulate error case - should fall back to unknown tool estimate
636 |         try:
637 |             # This would be the pattern in server.py when get_tool_estimate fails
638 |             estimate = config.get_estimate("nonexistent_tool")
639 |             fallback_estimate = config.unknown_tool_estimate
640 | 
641 |             # Verify fallback has conservative characteristics
642 |             assert fallback_estimate.based_on == EstimationBasis.CONSERVATIVE
643 |             assert fallback_estimate.confidence == 0.3
644 |             assert fallback_estimate.complexity == ToolComplexity.STANDARD
645 | 
646 |             # Should be the same as what get_estimate returns for unknown tools
647 |             assert estimate == fallback_estimate
648 | 
649 |         except Exception:
650 |             # If estimation fails entirely, should be able to use fallback
651 |             fallback = config.unknown_tool_estimate
652 |             assert fallback.llm_calls > 0
653 |             assert fallback.total_tokens > 0
654 | 
655 |     def test_usage_logging_extra_fields(self):
656 |         """Test that estimates provide all fields needed for logging."""
657 |         config = ToolEstimationConfig()
658 | 
659 |         for _tool_name, estimate in config.tool_estimates.items():
660 |             # Verify all fields needed for server.py logging are present
661 |             assert hasattr(estimate, "confidence")
662 |             assert hasattr(estimate, "based_on")
663 |             assert hasattr(estimate, "complexity")
664 |             assert hasattr(estimate, "llm_calls")
665 |             assert hasattr(estimate, "total_tokens")
666 | 
667 |             # Verify fields have appropriate types for logging
668 |             assert isinstance(estimate.confidence, float)
669 |             assert isinstance(estimate.based_on, EstimationBasis)
670 |             assert isinstance(estimate.complexity, ToolComplexity)
671 |             assert isinstance(estimate.llm_calls, int)
672 |             assert isinstance(estimate.total_tokens, int)
673 | 
```

--------------------------------------------------------------------------------
/docs/PORTFOLIO_PERSONALIZATION_PLAN.md:
--------------------------------------------------------------------------------

```markdown
  1 | # PORTFOLIO PERSONALIZATION - EXECUTION PLAN
  2 | 
  3 | ## 1. Big Picture / Goal
  4 | 
  5 | **Objective:** Transform MaverickMCP's portfolio analysis tools from stateless, repetitive-input operations into an intelligent, personalized AI financial assistant through persistent portfolio storage and context-aware tool integration.
  6 | 
  7 | **Architectural Goal:** Implement a two-phase system that (1) adds persistent portfolio storage with cost basis tracking using established DDD patterns, and (2) intelligently enhances existing tools to auto-detect user holdings and provide personalized analysis without breaking the stateless MCP tool contract.
  8 | 
  9 | **Success Criteria (Mandatory):**
 10 | - **Phase 1 Complete:** 4 new MCP tools (`add_portfolio_position`, `get_my_portfolio`, `remove_portfolio_position`, `clear_my_portfolio`) and 1 MCP resource (`portfolio://my-holdings`) fully functional
 11 | - **Database Integration:** SQLAlchemy models with proper cost basis averaging, Alembic migration creating tables without conflicts
 12 | - **Phase 2 Integration:** 3 existing tools enhanced (`risk_adjusted_analysis`, `portfolio_correlation_analysis`, `compare_tickers`) with automatic portfolio detection
 13 | - **AI Context Injection:** Portfolio resource provides live P&L, diversification metrics, and position details to AI agents automatically
 14 | - **Test Coverage:** 85%+ test coverage with unit, integration, and domain tests passing
 15 | - **Code Quality:** Zero linting errors (ruff), full type annotations (ty), all hooks passing
 16 | - **Documentation:** PORTFOLIO.md guide, updated tool docstrings, usage examples in Claude Desktop
 17 | 
 18 | **Financial Disclaimer:** All portfolio features include educational disclaimers. No investment recommendations. Local-first storage only. No tax advice provided.
 19 | 
 20 | ## 2. To-Do List (High Level)
 21 | 
 22 | ### Phase 1: Persistent Portfolio Storage Foundation (4-5 days)
 23 | - [ ] **Spike 1:** Research cost basis averaging algorithms and edge cases (FIFO, average cost)
 24 | - [ ] **Domain Entities:** Create `Portfolio` and `Position` domain entities with business logic
 25 | - [ ] **Database Models:** Implement `UserPortfolio` and `PortfolioPosition` SQLAlchemy models
 26 | - [ ] **Migration:** Create Alembic migration with proper indexes and constraints
 27 | - [ ] **MCP Tools:** Implement 4 portfolio management tools with validation
 28 | - [ ] **MCP Resource:** Implement `portfolio://my-holdings` with live P&L calculations
 29 | - [ ] **Unit Tests:** Comprehensive domain entity and cost basis tests
 30 | - [ ] **Integration Tests:** Database operation and transaction tests
 31 | 
 32 | ### Phase 2: Intelligent Tool Integration (2-3 days)
 33 | - [ ] **Risk Analysis Enhancement:** Add position awareness to `risk_adjusted_analysis`
 34 | - [ ] **Correlation Enhancement:** Enable `portfolio_correlation_analysis` with no arguments
 35 | - [ ] **Comparison Enhancement:** Enable `compare_tickers` with optional portfolio auto-fill
 36 | - [ ] **Resource Enhancement:** Add live market data to portfolio resource
 37 | - [ ] **Integration Tests:** Cross-tool functionality validation
 38 | - [ ] **Documentation:** Update existing tool docstrings with new capabilities
 39 | 
 40 | ### Phase 3: Polish & Documentation (1-2 days)
 41 | - [ ] **Manual Testing:** Claude Desktop end-to-end workflow validation
 42 | - [ ] **Error Handling:** Edge case coverage (partial sells, zero shares, invalid tickers)
 43 | - [ ] **Performance:** Query optimization, batch operations, caching strategy
 44 | - [ ] **Documentation:** Complete PORTFOLIO.md with examples and screenshots
 45 | - [ ] **Migration Testing:** Test upgrade/downgrade paths
 46 | 
 47 | ## 3. Plan Details (Spikes & Features)
 48 | 
 49 | ### Spike 1: Cost Basis Averaging Research
 50 | 
 51 | **Action:** Investigate cost basis calculation methods (FIFO, LIFO, average cost) and determine optimal approach for educational portfolio tracking.
 52 | 
 53 | **Steps:**
 54 | 1. Research IRS cost basis methods and educational best practices
 55 | 2. Analyze existing `PortfolioManager` tool (JSON-based, average cost) for patterns
 56 | 3. Design algorithm for averaging purchases and handling partial sells
 57 | 4. Create specification document for edge cases:
 58 |    - Multiple purchases at different prices
 59 |    - Partial position sales
 60 |    - Zero/negative share handling
 61 |    - Rounding and precision (financial data uses Numeric(12,4))
 62 | 5. Benchmark performance for 100+ positions with 1000+ transactions
 63 | 
 64 | **Expected Outcome:** Clear specification for cost basis implementation using **average cost method** (simplest for educational use, matches existing PortfolioManager), with edge case handling documented.
 65 | 
 66 | **Decision Rationale:** Average cost is simpler than FIFO/LIFO, appropriate for educational context, and avoids tax accounting complexity.
 67 | 
 68 | ---
 69 | 
 70 | ### Feature A: Domain Entities (DDD Pattern)
 71 | 
 72 | **Goal:** Create pure business logic entities following MaverickMCP's established DDD patterns (similar to backtesting domain entities).
 73 | 
 74 | **Files to Create:**
 75 | - `maverick_mcp/domain/portfolio.py` - Core domain entities
 76 | - `maverick_mcp/domain/position.py` - Position value objects
 77 | 
 78 | **Domain Entity Design:**
 79 | 
 80 | ```python
 81 | # maverick_mcp/domain/portfolio.py
 82 | from dataclasses import dataclass
 83 | from datetime import datetime
 84 | from decimal import Decimal
 85 | from typing import List, Optional
 86 | 
 87 | @dataclass
 88 | class Position:
 89 |     """Value object representing a single portfolio position."""
 90 |     ticker: str
 91 |     shares: Decimal  # Use Decimal for precision
 92 |     average_cost_basis: Decimal
 93 |     total_cost: Decimal
 94 |     purchase_date: datetime  # Earliest purchase
 95 |     notes: Optional[str] = None
 96 | 
 97 |     def add_shares(self, shares: Decimal, price: Decimal, date: datetime) -> "Position":
 98 |         """Add shares with automatic cost basis averaging."""
 99 |         new_total_shares = self.shares + shares
100 |         new_total_cost = self.total_cost + (shares * price)
101 |         new_avg_cost = new_total_cost / new_total_shares
102 | 
103 |         return Position(
104 |             ticker=self.ticker,
105 |             shares=new_total_shares,
106 |             average_cost_basis=new_avg_cost,
107 |             total_cost=new_total_cost,
108 |             purchase_date=min(self.purchase_date, date),
109 |             notes=self.notes
110 |         )
111 | 
112 |     def remove_shares(self, shares: Decimal) -> Optional["Position"]:
113 |         """Remove shares, return None if position fully closed."""
114 |         if shares >= self.shares:
115 |             return None  # Full position close
116 | 
117 |         new_shares = self.shares - shares
118 |         new_total_cost = new_shares * self.average_cost_basis
119 | 
120 |         return Position(
121 |             ticker=self.ticker,
122 |             shares=new_shares,
123 |             average_cost_basis=self.average_cost_basis,
124 |             total_cost=new_total_cost,
125 |             purchase_date=self.purchase_date,
126 |             notes=self.notes
127 |         )
128 | 
129 |     def calculate_current_value(self, current_price: Decimal) -> dict:
130 |         """Calculate live P&L metrics."""
131 |         current_value = self.shares * current_price
132 |         unrealized_pnl = current_value - self.total_cost
133 |         pnl_percentage = (unrealized_pnl / self.total_cost * 100) if self.total_cost else Decimal(0)
134 | 
135 |         return {
136 |             "current_value": current_value,
137 |             "unrealized_pnl": unrealized_pnl,
138 |             "pnl_percentage": pnl_percentage
139 |         }
140 | 
141 | @dataclass
142 | class Portfolio:
143 |     """Aggregate root for user portfolio."""
144 |     portfolio_id: str  # UUID
145 |     user_id: str  # "default" for single-user
146 |     name: str
147 |     positions: List[Position]
148 |     created_at: datetime
149 |     updated_at: datetime
150 | 
151 |     def add_position(self, ticker: str, shares: Decimal, price: Decimal,
152 |                     date: datetime, notes: Optional[str] = None) -> None:
153 |         """Add or update position with automatic averaging."""
154 |         # Find existing position
155 |         for i, pos in enumerate(self.positions):
156 |             if pos.ticker == ticker:
157 |                 self.positions[i] = pos.add_shares(shares, price, date)
158 |                 self.updated_at = datetime.now(UTC)
159 |                 return
160 | 
161 |         # Create new position
162 |         new_position = Position(
163 |             ticker=ticker,
164 |             shares=shares,
165 |             average_cost_basis=price,
166 |             total_cost=shares * price,
167 |             purchase_date=date,
168 |             notes=notes
169 |         )
170 |         self.positions.append(new_position)
171 |         self.updated_at = datetime.now(UTC)
172 | 
173 |     def remove_position(self, ticker: str, shares: Optional[Decimal] = None) -> bool:
174 |         """Remove position or partial shares."""
175 |         for i, pos in enumerate(self.positions):
176 |             if pos.ticker == ticker:
177 |                 if shares is None or shares >= pos.shares:
178 |                     # Full position removal
179 |                     self.positions.pop(i)
180 |                 else:
181 |                     # Partial removal
182 |                     updated_pos = pos.remove_shares(shares)
183 |                     if updated_pos:
184 |                         self.positions[i] = updated_pos
185 |                     else:
186 |                         self.positions.pop(i)
187 | 
188 |                 self.updated_at = datetime.now(UTC)
189 |                 return True
190 |         return False
191 | 
192 |     def get_position(self, ticker: str) -> Optional[Position]:
193 |         """Get position by ticker."""
194 |         return next((pos for pos in self.positions if pos.ticker == ticker), None)
195 | 
196 |     def get_total_invested(self) -> Decimal:
197 |         """Calculate total capital invested."""
198 |         return sum(pos.total_cost for pos in self.positions)
199 | 
200 |     def calculate_portfolio_metrics(self, current_prices: dict[str, Decimal]) -> dict:
201 |         """Calculate comprehensive portfolio metrics."""
202 |         total_value = Decimal(0)
203 |         total_cost = Decimal(0)
204 |         position_details = []
205 | 
206 |         for pos in self.positions:
207 |             current_price = current_prices.get(pos.ticker, pos.average_cost_basis)
208 |             metrics = pos.calculate_current_value(current_price)
209 | 
210 |             total_value += metrics["current_value"]
211 |             total_cost += pos.total_cost
212 | 
213 |             position_details.append({
214 |                 "ticker": pos.ticker,
215 |                 "shares": float(pos.shares),
216 |                 "cost_basis": float(pos.average_cost_basis),
217 |                 "current_price": float(current_price),
218 |                 **{k: float(v) for k, v in metrics.items()}
219 |             })
220 | 
221 |         total_pnl = total_value - total_cost
222 |         total_pnl_pct = (total_pnl / total_cost * 100) if total_cost else Decimal(0)
223 | 
224 |         return {
225 |             "total_value": float(total_value),
226 |             "total_invested": float(total_cost),
227 |             "total_pnl": float(total_pnl),
228 |             "total_pnl_percentage": float(total_pnl_pct),
229 |             "position_count": len(self.positions),
230 |             "positions": position_details
231 |         }
232 | ```
233 | 
234 | **Testing Strategy:**
235 | - Unit tests for cost basis averaging edge cases
236 | - Property-based tests for arithmetic precision
237 | - Edge case tests: zero shares, negative P&L, division by zero
238 | 
239 | ---
240 | 
241 | ### Feature B: Database Models (SQLAlchemy ORM)
242 | 
243 | **Goal:** Create persistent storage models following established patterns in `maverick_mcp/data/models.py`.
244 | 
245 | **Files to Modify:**
246 | - `maverick_mcp/data/models.py` - Add new models (lines ~1700+)
247 | 
248 | **Model Design:**
249 | 
250 | ```python
251 | # Add to maverick_mcp/data/models.py
252 | 
253 | class UserPortfolio(TimestampMixin, Base):
254 |     """
255 |     User portfolio for tracking investment holdings.
256 | 
257 |     Follows personal-use design: single user_id="default"
258 |     """
259 |     __tablename__ = "mcp_portfolios"
260 | 
261 |     id = Column(Uuid, primary_key=True, default=uuid.uuid4)
262 |     user_id = Column(String(50), nullable=False, default="default", index=True)
263 |     name = Column(String(200), nullable=False, default="My Portfolio")
264 | 
265 |     # Relationships
266 |     positions = relationship(
267 |         "PortfolioPosition",
268 |         back_populates="portfolio",
269 |         cascade="all, delete-orphan",
270 |         lazy="selectin"  # Efficient loading
271 |     )
272 | 
273 |     # Indexes for queries
274 |     __table_args__ = (
275 |         Index("idx_portfolio_user", "user_id"),
276 |         UniqueConstraint("user_id", "name", name="uq_user_portfolio_name"),
277 |     )
278 | 
279 |     def __repr__(self):
280 |         return f"<UserPortfolio(id={self.id}, name='{self.name}', positions={len(self.positions)})>"
281 | 
282 | 
283 | class PortfolioPosition(TimestampMixin, Base):
284 |     """
285 |     Individual position within a portfolio with cost basis tracking.
286 |     """
287 |     __tablename__ = "mcp_portfolio_positions"
288 | 
289 |     id = Column(Uuid, primary_key=True, default=uuid.uuid4)
290 |     portfolio_id = Column(Uuid, ForeignKey("mcp_portfolios.id", ondelete="CASCADE"), nullable=False)
291 | 
292 |     # Position details
293 |     ticker = Column(String(20), nullable=False, index=True)
294 |     shares = Column(Numeric(20, 8), nullable=False)  # High precision for fractional shares
295 |     average_cost_basis = Column(Numeric(12, 4), nullable=False)  # Financial precision
296 |     total_cost = Column(Numeric(20, 4), nullable=False)  # Total capital invested
297 |     purchase_date = Column(DateTime(timezone=True), nullable=False)  # Earliest purchase
298 |     notes = Column(Text, nullable=True)  # Optional user notes
299 | 
300 |     # Relationships
301 |     portfolio = relationship("UserPortfolio", back_populates="positions")
302 | 
303 |     # Indexes for efficient queries
304 |     __table_args__ = (
305 |         Index("idx_position_portfolio", "portfolio_id"),
306 |         Index("idx_position_ticker", "ticker"),
307 |         Index("idx_position_portfolio_ticker", "portfolio_id", "ticker"),
308 |         UniqueConstraint("portfolio_id", "ticker", name="uq_portfolio_position_ticker"),
309 |     )
310 | 
311 |     def __repr__(self):
312 |         return f"<PortfolioPosition(ticker='{self.ticker}', shares={self.shares}, cost_basis={self.average_cost_basis})>"
313 | ```
314 | 
315 | **Key Design Decisions:**
316 | 1. **Table Names:** `mcp_portfolios` and `mcp_portfolio_positions` (consistent with `mcp_*` pattern)
317 | 2. **user_id:** Default "default" for single-user personal use
318 | 3. **Numeric Precision:** Matches existing financial data patterns (12,4 for prices, 20,8 for shares)
319 | 4. **Cascade Delete:** Portfolio deletion removes all positions automatically
320 | 5. **Unique Constraint:** One position per ticker per portfolio
321 | 6. **Indexes:** Optimized for common queries (user lookup, ticker filtering)
322 | 
323 | ---
324 | 
325 | ### Feature C: Alembic Migration
326 | 
327 | **Goal:** Create database migration following established patterns without conflicts.
328 | 
329 | **File to Create:**
330 | - `alembic/versions/014_add_portfolio_models.py`
331 | 
332 | **Migration Pattern:**
333 | 
334 | ```python
335 | """Add portfolio and position models
336 | 
337 | Revision ID: 014_add_portfolio_models
338 | Revises: 013_add_backtest_persistence_models
339 | Create Date: 2025-11-01 10:00:00.000000
340 | """
341 | from alembic import op
342 | import sqlalchemy as sa
343 | from sqlalchemy.dialects import postgresql
344 | 
345 | # revision identifiers
346 | revision = '014_add_portfolio_models'
347 | down_revision = '013_add_backtest_persistence_models'
348 | branch_labels = None
349 | depends_on = None
350 | 
351 | 
352 | def upgrade():
353 |     """Create portfolio management tables."""
354 | 
355 |     # Create portfolios table
356 |     op.create_table(
357 |         'mcp_portfolios',
358 |         sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True),
359 |         sa.Column('user_id', sa.String(50), nullable=False, server_default='default'),
360 |         sa.Column('name', sa.String(200), nullable=False, server_default='My Portfolio'),
361 |         sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
362 |         sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
363 |     )
364 | 
365 |     # Create indexes on portfolios
366 |     op.create_index('idx_portfolio_user', 'mcp_portfolios', ['user_id'])
367 |     op.create_unique_constraint('uq_user_portfolio_name', 'mcp_portfolios', ['user_id', 'name'])
368 | 
369 |     # Create positions table
370 |     op.create_table(
371 |         'mcp_portfolio_positions',
372 |         sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True),
373 |         sa.Column('portfolio_id', postgresql.UUID(as_uuid=True), nullable=False),
374 |         sa.Column('ticker', sa.String(20), nullable=False),
375 |         sa.Column('shares', sa.Numeric(20, 8), nullable=False),
376 |         sa.Column('average_cost_basis', sa.Numeric(12, 4), nullable=False),
377 |         sa.Column('total_cost', sa.Numeric(20, 4), nullable=False),
378 |         sa.Column('purchase_date', sa.DateTime(timezone=True), nullable=False),
379 |         sa.Column('notes', sa.Text, nullable=True),
380 |         sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
381 |         sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
382 |         sa.ForeignKeyConstraint(['portfolio_id'], ['mcp_portfolios.id'], ondelete='CASCADE'),
383 |     )
384 | 
385 |     # Create indexes on positions
386 |     op.create_index('idx_position_portfolio', 'mcp_portfolio_positions', ['portfolio_id'])
387 |     op.create_index('idx_position_ticker', 'mcp_portfolio_positions', ['ticker'])
388 |     op.create_index('idx_position_portfolio_ticker', 'mcp_portfolio_positions', ['portfolio_id', 'ticker'])
389 |     op.create_unique_constraint('uq_portfolio_position_ticker', 'mcp_portfolio_positions', ['portfolio_id', 'ticker'])
390 | 
391 | 
392 | def downgrade():
393 |     """Drop portfolio management tables."""
394 |     op.drop_table('mcp_portfolio_positions')
395 |     op.drop_table('mcp_portfolios')
396 | ```
397 | 
398 | **Testing:**
399 | - Test upgrade: `alembic upgrade head`
400 | - Test downgrade: `alembic downgrade -1`
401 | - Verify indexes created: SQL query inspection
402 | - Test with SQLite and PostgreSQL
403 | 
404 | ---
405 | 
406 | ### Feature D: MCP Tools Implementation
407 | 
408 | **Goal:** Implement 4 portfolio management tools following tool_registry.py pattern.
409 | 
410 | **Files to Create:**
411 | - `maverick_mcp/api/routers/portfolio_management.py` - New tool implementations
412 | - `maverick_mcp/api/services/portfolio_persistence_service.py` - Service layer
413 | - `maverick_mcp/validation/portfolio_management.py` - Pydantic validation
414 | 
415 | **Service Layer Pattern:**
416 | 
417 | ```python
418 | # maverick_mcp/api/services/portfolio_persistence_service.py
419 | 
420 | class PortfolioPersistenceService(BaseService):
421 |     """Service for portfolio CRUD operations."""
422 | 
423 |     async def get_or_create_default_portfolio(self) -> UserPortfolio:
424 |         """Get the default portfolio, create if doesn't exist."""
425 |         pass
426 | 
427 |     async def add_position(self, ticker: str, shares: Decimal,
428 |                           price: Decimal, date: datetime,
429 |                           notes: Optional[str]) -> PortfolioPosition:
430 |         """Add or update position with cost averaging."""
431 |         pass
432 | 
433 |     async def get_portfolio_with_live_data(self) -> dict:
434 |         """Fetch portfolio with current market prices."""
435 |         pass
436 | 
437 |     async def remove_position(self, ticker: str,
438 |                             shares: Optional[Decimal]) -> bool:
439 |         """Remove position or partial shares."""
440 |         pass
441 | 
442 |     async def clear_portfolio(self) -> bool:
443 |         """Delete all positions."""
444 |         pass
445 | ```
446 | 
447 | **Tool Registration:**
448 | 
449 | ```python
450 | # Add to maverick_mcp/api/routers/tool_registry.py
451 | 
452 | def register_portfolio_management_tools(mcp: FastMCP) -> None:
453 |     """Register portfolio management tools."""
454 |     from maverick_mcp.api.routers.portfolio_management import (
455 |         add_portfolio_position,
456 |         get_my_portfolio,
457 |         remove_portfolio_position,
458 |         clear_my_portfolio
459 |     )
460 | 
461 |     mcp.tool(name="portfolio_add_position")(add_portfolio_position)
462 |     mcp.tool(name="portfolio_get_my_portfolio")(get_my_portfolio)
463 |     mcp.tool(name="portfolio_remove_position")(remove_portfolio_position)
464 |     mcp.tool(name="portfolio_clear")(clear_my_portfolio)
465 | ```
466 | 
467 | ---
468 | 
469 | ### Feature E: MCP Resource Implementation
470 | 
471 | **Goal:** Create `portfolio://my-holdings` resource for automatic AI context injection.
472 | 
473 | **File to Modify:**
474 | - `maverick_mcp/api/server.py` - Add resource alongside existing health:// and dashboard:// resources
475 | 
476 | **Resource Implementation:**
477 | 
478 | ```python
479 | # Add to maverick_mcp/api/server.py (around line 823, near other resources)
480 | 
481 | @mcp.resource("portfolio://my-holdings")
482 | def portfolio_holdings_resource() -> dict[str, Any]:
483 |     """
484 |     Portfolio holdings resource for AI context injection.
485 | 
486 |     Provides comprehensive portfolio context to AI agents including:
487 |     - Current positions with live P&L
488 |     - Portfolio metrics and diversification
489 |     - Sector exposure analysis
490 |     - Top/bottom performers
491 | 
492 |     This resource is automatically available to AI agents during conversations,
493 |     enabling personalized analysis without requiring manual ticker input.
494 |     """
495 |     # Implementation using service layer with async handling
496 |     pass
497 | ```
498 | 
499 | ---
500 | 
501 | ### Feature F: Phase 2 Tool Enhancements
502 | 
503 | **Goal:** Enhance existing tools to auto-detect portfolio holdings.
504 | 
505 | **Files to Modify:**
506 | 1. `maverick_mcp/api/routers/portfolio.py` - Enhance 3 existing tools
507 | 2. `maverick_mcp/validation/portfolio.py` - Update validation to allow optional parameters
508 | 
509 | **Enhancement Pattern:**
510 | - Add optional parameters (tickers can be None)
511 | - Check portfolio for holdings if no tickers provided
512 | - Add position awareness to analysis results
513 | - Maintain backward compatibility
514 | 
515 | ---
516 | 
517 | ## 4. Progress (Living Document Section)
518 | 
519 | | Date | Time | Item Completed / Status Update | Resulting Changes (LOC/Files) |
520 | |:-----|:-----|:------------------------------|:------------------------------|
521 | | 2025-11-01 | Start | Plan approved and documented | PORTFOLIO_PERSONALIZATION_PLAN.md created |
522 | | TBD | TBD | Implementation begins | - |
523 | 
524 | _(This section will be updated during implementation)_
525 | 
526 | ---
527 | 
528 | ## 5. Surprises and Discoveries
529 | 
530 | _(Technical issues discovered during implementation will be documented here)_
531 | 
532 | **Anticipated Challenges:**
533 | 1. **MCP Resource Async Context:** Resources are sync functions but need async database calls - solved with event loop management (see existing health_resource pattern)
534 | 2. **Cost Basis Precision:** Financial calculations require Decimal precision, not floats - use Numeric(12,4) for prices, Numeric(20,8) for shares
535 | 3. **Portfolio Resource Performance:** Live price fetching could be slow - implement caching strategy, consider async batching
536 | 4. **Single User Assumption:** No user authentication means all operations use user_id="default" - acceptable for personal use
537 | 
538 | ---
539 | 
540 | ## 6. Decision Log
541 | 
542 | | Date | Decision | Rationale |
543 | |:-----|:---------|:----------|
544 | | 2025-11-01 | **Cost Basis Method: Average Cost** | Simplest for educational use, matches existing PortfolioManager, avoids tax accounting complexity |
545 | | 2025-11-01 | **Table Names: mcp_portfolios, mcp_portfolio_positions** | Consistent with existing mcp_* naming convention for MCP-specific tables |
546 | | 2025-11-01 | **User ID: "default" for all users** | Single-user personal-use design, consistent with auth disabled architecture |
547 | | 2025-11-01 | **Numeric Precision: Numeric(12,4) for prices, Numeric(20,8) for shares** | Matches existing financial data patterns, supports fractional shares |
548 | | 2025-11-01 | **Optional tickers parameter for Phase 2** | Enables "just works" UX while maintaining backward compatibility |
549 | | 2025-11-01 | **MCP Resource for AI context** | Most elegant solution for automatic context injection without breaking tool contracts |
550 | | 2025-11-01 | **Domain-Driven Design pattern** | Follows established MaverickMCP architecture, clean separation of concerns |
551 | 
552 | ---
553 | 
554 | ## 7. Implementation Phases
555 | 
556 | ### Phase 1: Foundation (4-5 days)
557 | **Files Created:** 8 new files
558 | **Files Modified:** 3 existing files
559 | **Estimated LOC:** ~2,500 lines
560 | **Tests:** ~1,200 lines
561 | 
562 | ### Phase 2: Integration (2-3 days)
563 | **Files Modified:** 4 existing files
564 | **Estimated LOC:** ~800 lines additional
565 | **Tests:** ~600 lines additional
566 | 
567 | ### Phase 3: Polish (1-2 days)
568 | **Documentation:** PORTFOLIO.md (~300 lines)
569 | **Performance:** Query optimization
570 | **Testing:** Manual Claude Desktop validation
571 | 
572 | **Total Effort:** 7-10 days
573 | **Total New Code:** ~3,500 lines (including tests)
574 | **Total Tests:** ~1,800 lines
575 | 
576 | ---
577 | 
578 | ## 8. Risk Assessment
579 | 
580 | **Low Risk:**
581 | - ✅ Follows established patterns
582 | - ✅ No breaking changes to existing tools
583 | - ✅ Optional Phase 2 enhancements
584 | - ✅ Well-scoped feature
585 | 
586 | **Medium Risk:**
587 | - ⚠️ MCP resource performance with live prices
588 | - ⚠️ Migration compatibility (SQLite vs PostgreSQL)
589 | - ⚠️ Edge cases in cost basis averaging
590 | 
591 | **Mitigation Strategies:**
592 | 1. **Performance:** Implement caching, batch price fetches, add timeout protection
593 | 2. **Migration:** Test with both SQLite and PostgreSQL, provide rollback path
594 | 3. **Edge Cases:** Comprehensive unit tests, property-based testing for arithmetic
595 | 
596 | ---
597 | 
598 | ## 9. Testing Strategy
599 | 
600 | **Unit Tests (~60% of test code):**
601 | - Domain entity logic (Position, Portfolio)
602 | - Cost basis averaging edge cases
603 | - Numeric precision validation
604 | - Business logic validation
605 | 
606 | **Integration Tests (~30% of test code):**
607 | - Database CRUD operations
608 | - Migration upgrade/downgrade
609 | - Service layer with real database
610 | - Cross-tool functionality
611 | 
612 | **Manual Tests (~10% of effort):**
613 | - Claude Desktop end-to-end workflows
614 | - Natural language interactions
615 | - MCP resource visibility
616 | - Tool integration scenarios
617 | 
618 | **Test Coverage Target:** 85%+
619 | 
620 | ---
621 | 
622 | ## 10. Success Metrics
623 | 
624 | **Functional Success:**
625 | - [ ] All 4 new tools work in Claude Desktop
626 | - [ ] Portfolio resource visible to AI agents
627 | - [ ] Cost basis averaging accurate to 4 decimal places
628 | - [ ] Migration works on SQLite and PostgreSQL
629 | - [ ] 3 enhanced tools auto-detect portfolio
630 | 
631 | **Quality Success:**
632 | - [ ] 85%+ test coverage
633 | - [ ] Zero linting errors (ruff)
634 | - [ ] Full type annotations (ty check passes)
635 | - [ ] All pre-commit hooks pass
636 | 
637 | **UX Success:**
638 | - [ ] "Analyze my portfolio" works without ticker input
639 | - [ ] AI agents reference actual holdings in responses
640 | - [ ] Natural language interactions feel seamless
641 | - [ ] Error messages are clear and actionable
642 | 
643 | ---
644 | 
645 | ## 11. Related Documentation
646 | 
647 | - **Original Issue:** [#40 - Portfolio Personalization](https://github.com/wshobson/maverick-mcp/issues/40)
648 | - **User Documentation:** `docs/PORTFOLIO.md` (to be created)
649 | - **API Documentation:** Tool docstrings and MCP introspection
650 | - **Testing Guide:** `tests/README.md` (to be updated)
651 | 
652 | ---
653 | 
654 | This execution plan provides a comprehensive roadmap following the PLANS.md rubric structure. The implementation is well-scoped, follows established patterns, and delivers significant UX improvement while maintaining code quality and architectural integrity.
655 | 
```

--------------------------------------------------------------------------------
/maverick_mcp/backtesting/retraining_pipeline.py:
--------------------------------------------------------------------------------

```python
  1 | """Automated retraining pipeline for ML models with data drift detection."""
  2 | 
  3 | import logging
  4 | from collections.abc import Callable
  5 | from datetime import datetime
  6 | from typing import Any
  7 | 
  8 | import numpy as np
  9 | import pandas as pd
 10 | from scipy import stats
 11 | from sklearn.base import BaseEstimator
 12 | from sklearn.metrics import accuracy_score, classification_report
 13 | from sklearn.model_selection import train_test_split
 14 | from sklearn.preprocessing import StandardScaler
 15 | 
 16 | from .model_manager import ModelManager
 17 | 
 18 | logger = logging.getLogger(__name__)
 19 | 
 20 | 
 21 | class DataDriftDetector:
 22 |     """Detects data drift in features and targets."""
 23 | 
 24 |     def __init__(self, significance_level: float = 0.05):
 25 |         """Initialize drift detector.
 26 | 
 27 |         Args:
 28 |             significance_level: Statistical significance level for drift detection
 29 |         """
 30 |         self.significance_level = significance_level
 31 |         self.reference_data: pd.DataFrame | None = None
 32 |         self.reference_target: pd.Series | None = None
 33 |         self.feature_stats: dict[str, dict[str, float]] = {}
 34 | 
 35 |     def set_reference_data(
 36 |         self, features: pd.DataFrame, target: pd.Series | None = None
 37 |     ):
 38 |         """Set reference data for drift detection.
 39 | 
 40 |         Args:
 41 |             features: Reference feature data
 42 |             target: Reference target data (optional)
 43 |         """
 44 |         self.reference_data = features.copy()
 45 |         self.reference_target = target.copy() if target is not None else None
 46 | 
 47 |         # Calculate reference statistics
 48 |         self.feature_stats = {}
 49 |         for col in features.columns:
 50 |             if features[col].dtype in ["float64", "float32", "int64", "int32"]:
 51 |                 self.feature_stats[col] = {
 52 |                     "mean": features[col].mean(),
 53 |                     "std": features[col].std(),
 54 |                     "min": features[col].min(),
 55 |                     "max": features[col].max(),
 56 |                     "median": features[col].median(),
 57 |                 }
 58 | 
 59 |         logger.info(
 60 |             f"Set reference data with {len(features)} samples and {len(features.columns)} features"
 61 |         )
 62 | 
 63 |     def detect_feature_drift(
 64 |         self, new_features: pd.DataFrame
 65 |     ) -> dict[str, dict[str, Any]]:
 66 |         """Detect drift in features using statistical tests.
 67 | 
 68 |         Args:
 69 |             new_features: New feature data to compare
 70 | 
 71 |         Returns:
 72 |             Dictionary with drift detection results per feature
 73 |         """
 74 |         if self.reference_data is None:
 75 |             raise ValueError("Reference data not set")
 76 | 
 77 |         drift_results = {}
 78 | 
 79 |         for col in new_features.columns:
 80 |             if col not in self.reference_data.columns:
 81 |                 continue
 82 | 
 83 |             if new_features[col].dtype not in ["float64", "float32", "int64", "int32"]:
 84 |                 continue
 85 | 
 86 |             ref_data = self.reference_data[col].dropna()
 87 |             new_data = new_features[col].dropna()
 88 | 
 89 |             if len(ref_data) == 0 or len(new_data) == 0:
 90 |                 continue
 91 | 
 92 |             # Perform statistical tests
 93 |             drift_detected = False
 94 |             test_results = {}
 95 | 
 96 |             try:
 97 |                 # Kolmogorov-Smirnov test for distribution change
 98 |                 ks_statistic, ks_p_value = stats.ks_2samp(ref_data, new_data)
 99 |                 test_results["ks_statistic"] = ks_statistic
100 |                 test_results["ks_p_value"] = ks_p_value
101 |                 ks_drift = ks_p_value < self.significance_level
102 | 
103 |                 # Mann-Whitney U test for location shift
104 |                 mw_statistic, mw_p_value = stats.mannwhitneyu(
105 |                     ref_data, new_data, alternative="two-sided"
106 |                 )
107 |                 test_results["mw_statistic"] = mw_statistic
108 |                 test_results["mw_p_value"] = mw_p_value
109 |                 mw_drift = mw_p_value < self.significance_level
110 | 
111 |                 # Levene test for variance change
112 |                 levene_statistic, levene_p_value = stats.levene(ref_data, new_data)
113 |                 test_results["levene_statistic"] = levene_statistic
114 |                 test_results["levene_p_value"] = levene_p_value
115 |                 levene_drift = levene_p_value < self.significance_level
116 | 
117 |                 # Overall drift detection
118 |                 drift_detected = ks_drift or mw_drift or levene_drift
119 | 
120 |                 # Calculate effect sizes
121 |                 test_results["mean_diff"] = new_data.mean() - ref_data.mean()
122 |                 test_results["std_ratio"] = new_data.std() / (ref_data.std() + 1e-8)
123 | 
124 |             except Exception as e:
125 |                 logger.warning(f"Error in drift detection for {col}: {e}")
126 |                 test_results["error"] = str(e)
127 | 
128 |             drift_results[col] = {
129 |                 "drift_detected": drift_detected,
130 |                 "test_results": test_results,
131 |                 "reference_stats": self.feature_stats.get(col, {}),
132 |                 "new_stats": {
133 |                     "mean": new_data.mean(),
134 |                     "std": new_data.std(),
135 |                     "min": new_data.min(),
136 |                     "max": new_data.max(),
137 |                     "median": new_data.median(),
138 |                 },
139 |             }
140 | 
141 |         return drift_results
142 | 
143 |     def detect_target_drift(self, new_target: pd.Series) -> dict[str, Any]:
144 |         """Detect drift in target variable.
145 | 
146 |         Args:
147 |             new_target: New target data to compare
148 | 
149 |         Returns:
150 |             Dictionary with target drift results
151 |         """
152 |         if self.reference_target is None:
153 |             logger.warning("No reference target data set")
154 |             return {"drift_detected": False, "reason": "no_reference_target"}
155 | 
156 |         ref_target = self.reference_target.dropna()
157 |         new_target = new_target.dropna()
158 | 
159 |         if len(ref_target) == 0 or len(new_target) == 0:
160 |             return {"drift_detected": False, "reason": "insufficient_data"}
161 | 
162 |         drift_results = {"drift_detected": False}
163 | 
164 |         try:
165 |             # For categorical targets, use chi-square test
166 |             if ref_target.dtype == "object" or ref_target.nunique() < 10:
167 |                 ref_counts = ref_target.value_counts()
168 |                 new_counts = new_target.value_counts()
169 | 
170 |                 # Align the categories
171 |                 all_categories = set(ref_counts.index) | set(new_counts.index)
172 |                 ref_aligned = [ref_counts.get(cat, 0) for cat in all_categories]
173 |                 new_aligned = [new_counts.get(cat, 0) for cat in all_categories]
174 | 
175 |                 if sum(ref_aligned) > 0 and sum(new_aligned) > 0:
176 |                     chi2_stat, chi2_p_value = stats.chisquare(new_aligned, ref_aligned)
177 |                     drift_results.update(
178 |                         {
179 |                             "test_type": "chi_square",
180 |                             "chi2_statistic": chi2_stat,
181 |                             "chi2_p_value": chi2_p_value,
182 |                             "drift_detected": chi2_p_value < self.significance_level,
183 |                         }
184 |                     )
185 | 
186 |             # For continuous targets
187 |             else:
188 |                 ks_statistic, ks_p_value = stats.ks_2samp(ref_target, new_target)
189 |                 drift_results.update(
190 |                     {
191 |                         "test_type": "kolmogorov_smirnov",
192 |                         "ks_statistic": ks_statistic,
193 |                         "ks_p_value": ks_p_value,
194 |                         "drift_detected": ks_p_value < self.significance_level,
195 |                     }
196 |                 )
197 | 
198 |         except Exception as e:
199 |             logger.warning(f"Error in target drift detection: {e}")
200 |             drift_results["error"] = str(e)
201 | 
202 |         return drift_results
203 | 
204 |     def get_drift_summary(
205 |         self, feature_drift: dict[str, dict], target_drift: dict[str, Any]
206 |     ) -> dict[str, Any]:
207 |         """Get summary of drift detection results.
208 | 
209 |         Args:
210 |             feature_drift: Feature drift results
211 |             target_drift: Target drift results
212 | 
213 |         Returns:
214 |             Summary dictionary
215 |         """
216 |         total_features = len(feature_drift)
217 |         drifted_features = sum(
218 |             1 for result in feature_drift.values() if result["drift_detected"]
219 |         )
220 |         target_drift_detected = target_drift.get("drift_detected", False)
221 | 
222 |         drift_severity = "none"
223 |         if target_drift_detected or drifted_features > total_features * 0.5:
224 |             drift_severity = "high"
225 |         elif drifted_features > total_features * 0.2:
226 |             drift_severity = "medium"
227 |         elif drifted_features > 0:
228 |             drift_severity = "low"
229 | 
230 |         return {
231 |             "total_features": total_features,
232 |             "drifted_features": drifted_features,
233 |             "drift_percentage": drifted_features / max(total_features, 1) * 100,
234 |             "target_drift_detected": target_drift_detected,
235 |             "drift_severity": drift_severity,
236 |             "recommendation": self._get_retraining_recommendation(
237 |                 drift_severity, target_drift_detected
238 |             ),
239 |         }
240 | 
241 |     def _get_retraining_recommendation(
242 |         self, drift_severity: str, target_drift: bool
243 |     ) -> str:
244 |         """Get retraining recommendation based on drift severity."""
245 |         if target_drift:
246 |             return "immediate_retraining"
247 |         elif drift_severity == "high":
248 |             return "urgent_retraining"
249 |         elif drift_severity == "medium":
250 |             return "scheduled_retraining"
251 |         elif drift_severity == "low":
252 |             return "monitor_closely"
253 |         else:
254 |             return "no_action_needed"
255 | 
256 | 
257 | class ModelPerformanceMonitor:
258 |     """Monitors model performance and detects degradation."""
259 | 
260 |     def __init__(self, performance_threshold: float = 0.05):
261 |         """Initialize performance monitor.
262 | 
263 |         Args:
264 |             performance_threshold: Threshold for performance degradation detection
265 |         """
266 |         self.performance_threshold = performance_threshold
267 |         self.baseline_metrics: dict[str, float] = {}
268 |         self.performance_history: list[dict[str, Any]] = []
269 | 
270 |     def set_baseline_performance(self, metrics: dict[str, float]):
271 |         """Set baseline performance metrics.
272 | 
273 |         Args:
274 |             metrics: Baseline performance metrics
275 |         """
276 |         self.baseline_metrics = metrics.copy()
277 |         logger.info(f"Set baseline performance: {metrics}")
278 | 
279 |     def evaluate_performance(
280 |         self,
281 |         model: BaseEstimator,
282 |         X_test: pd.DataFrame,
283 |         y_test: pd.Series,
284 |         additional_metrics: dict[str, float] | None = None,
285 |     ) -> dict[str, Any]:
286 |         """Evaluate current model performance.
287 | 
288 |         Args:
289 |             model: Trained model
290 |             X_test: Test features
291 |             y_test: Test targets
292 |             additional_metrics: Additional metrics to include
293 | 
294 |         Returns:
295 |             Performance evaluation results
296 |         """
297 |         try:
298 |             # Make predictions
299 |             y_pred = model.predict(X_test)
300 | 
301 |             # Calculate metrics
302 |             metrics = {
303 |                 "accuracy": accuracy_score(y_test, y_pred),
304 |                 "timestamp": datetime.now().isoformat(),
305 |             }
306 | 
307 |             # Add additional metrics if provided
308 |             if additional_metrics:
309 |                 metrics.update(additional_metrics)
310 | 
311 |             # Detect performance degradation
312 |             degradation_detected = False
313 |             degradation_details = {}
314 | 
315 |             for metric_name, current_value in metrics.items():
316 |                 if metric_name in self.baseline_metrics and metric_name != "timestamp":
317 |                     baseline_value = self.baseline_metrics[metric_name]
318 |                     degradation = (baseline_value - current_value) / abs(baseline_value)
319 | 
320 |                     if degradation > self.performance_threshold:
321 |                         degradation_detected = True
322 |                         degradation_details[metric_name] = {
323 |                             "baseline": baseline_value,
324 |                             "current": current_value,
325 |                             "degradation": degradation,
326 |                         }
327 | 
328 |             evaluation_result = {
329 |                 "metrics": metrics,
330 |                 "degradation_detected": degradation_detected,
331 |                 "degradation_details": degradation_details,
332 |                 "classification_report": classification_report(
333 |                     y_test, y_pred, output_dict=True
334 |                 ),
335 |             }
336 | 
337 |             # Store in history
338 |             self.performance_history.append(evaluation_result)
339 | 
340 |             # Keep only recent history
341 |             if len(self.performance_history) > 100:
342 |                 self.performance_history = self.performance_history[-100:]
343 | 
344 |             return evaluation_result
345 | 
346 |         except Exception as e:
347 |             logger.error(f"Error evaluating model performance: {e}")
348 |             return {"error": str(e)}
349 | 
350 |     def get_performance_trend(self, metric_name: str = "accuracy") -> dict[str, Any]:
351 |         """Analyze performance trend over time.
352 | 
353 |         Args:
354 |             metric_name: Metric to analyze
355 | 
356 |         Returns:
357 |             Trend analysis results
358 |         """
359 |         if not self.performance_history:
360 |             return {"trend": "no_data"}
361 | 
362 |         values = []
363 |         timestamps = []
364 | 
365 |         for record in self.performance_history:
366 |             if metric_name in record["metrics"]:
367 |                 values.append(record["metrics"][metric_name])
368 |                 timestamps.append(record["metrics"]["timestamp"])
369 | 
370 |         if len(values) < 3:
371 |             return {"trend": "insufficient_data"}
372 | 
373 |         # Calculate trend
374 |         x = np.arange(len(values))
375 |         slope, _, r_value, p_value, _ = stats.linregress(x, values)
376 | 
377 |         trend_direction = "stable"
378 |         if p_value < 0.05:  # Statistically significant trend
379 |             if slope > 0:
380 |                 trend_direction = "improving"
381 |             else:
382 |                 trend_direction = "degrading"
383 | 
384 |         return {
385 |             "trend": trend_direction,
386 |             "slope": slope,
387 |             "r_squared": r_value**2,
388 |             "p_value": p_value,
389 |             "recent_values": values[-5:],
390 |             "timestamps": timestamps[-5:],
391 |         }
392 | 
393 | 
394 | class AutoRetrainingPipeline:
395 |     """Automated pipeline for model retraining with drift detection and performance monitoring."""
396 | 
397 |     def __init__(
398 |         self,
399 |         model_manager: ModelManager,
400 |         model_factory: Callable[[], BaseEstimator],
401 |         feature_extractor: Callable[[pd.DataFrame], pd.DataFrame],
402 |         target_extractor: Callable[[pd.DataFrame], pd.Series],
403 |         retraining_schedule_hours: int = 24,
404 |         min_samples_for_retraining: int = 100,
405 |     ):
406 |         """Initialize auto-retraining pipeline.
407 | 
408 |         Args:
409 |             model_manager: Model manager instance
410 |             model_factory: Function that creates new model instances
411 |             feature_extractor: Function to extract features from data
412 |             target_extractor: Function to extract targets from data
413 |             retraining_schedule_hours: Hours between scheduled retraining checks
414 |             min_samples_for_retraining: Minimum samples required for retraining
415 |         """
416 |         self.model_manager = model_manager
417 |         self.model_factory = model_factory
418 |         self.feature_extractor = feature_extractor
419 |         self.target_extractor = target_extractor
420 |         self.retraining_schedule_hours = retraining_schedule_hours
421 |         self.min_samples_for_retraining = min_samples_for_retraining
422 | 
423 |         self.drift_detector = DataDriftDetector()
424 |         self.performance_monitor = ModelPerformanceMonitor()
425 | 
426 |         self.last_retraining: dict[str, datetime] = {}
427 |         self.retraining_history: list[dict[str, Any]] = []
428 | 
429 |     def should_retrain(
430 |         self,
431 |         model_id: str,
432 |         new_data: pd.DataFrame,
433 |         force_check: bool = False,
434 |     ) -> tuple[bool, str]:
435 |         """Determine if a model should be retrained.
436 | 
437 |         Args:
438 |             model_id: Model identifier
439 |             new_data: New data for evaluation
440 |             force_check: Force retraining check regardless of schedule
441 | 
442 |         Returns:
443 |             Tuple of (should_retrain, reason)
444 |         """
445 |         # Check schedule
446 |         last_retrain = self.last_retraining.get(model_id)
447 |         if not force_check and last_retrain is not None:
448 |             time_since_retrain = datetime.now() - last_retrain
449 |             if (
450 |                 time_since_retrain.total_seconds()
451 |                 < self.retraining_schedule_hours * 3600
452 |             ):
453 |                 return False, "schedule_not_due"
454 | 
455 |         # Check minimum samples
456 |         if len(new_data) < self.min_samples_for_retraining:
457 |             return False, "insufficient_samples"
458 | 
459 |         # Extract features and targets
460 |         try:
461 |             features = self.feature_extractor(new_data)
462 |             targets = self.target_extractor(new_data)
463 |         except Exception as e:
464 |             logger.error(f"Error extracting features/targets: {e}")
465 |             return False, f"extraction_error: {e}"
466 | 
467 |         # Check for data drift
468 |         if self.drift_detector.reference_data is not None:
469 |             feature_drift = self.drift_detector.detect_feature_drift(features)
470 |             target_drift = self.drift_detector.detect_target_drift(targets)
471 |             drift_summary = self.drift_detector.get_drift_summary(
472 |                 feature_drift, target_drift
473 |             )
474 | 
475 |             if drift_summary["recommendation"] in [
476 |                 "immediate_retraining",
477 |                 "urgent_retraining",
478 |             ]:
479 |                 return True, f"data_drift_{drift_summary['drift_severity']}"
480 | 
481 |         # Check performance degradation
482 |         current_model = self.model_manager.load_model(model_id)
483 |         if current_model is not None and current_model.model is not None:
484 |             try:
485 |                 # Split data for evaluation
486 |                 X_train, X_test, y_train, y_test = train_test_split(
487 |                     features, targets, test_size=0.3, random_state=42, stratify=targets
488 |                 )
489 | 
490 |                 # Scale features if scaler is available
491 |                 if current_model.scaler is not None:
492 |                     X_test_scaled = current_model.scaler.transform(X_test)
493 |                 else:
494 |                     X_test_scaled = X_test
495 | 
496 |                 # Evaluate performance
497 |                 performance_result = self.performance_monitor.evaluate_performance(
498 |                     current_model.model, X_test_scaled, y_test
499 |                 )
500 | 
501 |                 if performance_result.get("degradation_detected", False):
502 |                     return True, "performance_degradation"
503 | 
504 |             except Exception as e:
505 |                 logger.warning(f"Error evaluating model performance: {e}")
506 | 
507 |         return False, "no_triggers"
508 | 
509 |     def retrain_model(
510 |         self,
511 |         model_id: str,
512 |         training_data: pd.DataFrame,
513 |         validation_split: float = 0.2,
514 |         **model_params,
515 |     ) -> str | None:
516 |         """Retrain a model with new data.
517 | 
518 |         Args:
519 |             model_id: Model identifier
520 |             training_data: Training data
521 |             validation_split: Fraction of data to use for validation
522 |             **model_params: Additional parameters for model training
523 | 
524 |         Returns:
525 |             New model version string if successful, None otherwise
526 |         """
527 |         try:
528 |             # Extract features and targets
529 |             features = self.feature_extractor(training_data)
530 |             targets = self.target_extractor(training_data)
531 | 
532 |             # Split data
533 |             X_train, X_val, y_train, y_val = train_test_split(
534 |                 features,
535 |                 targets,
536 |                 test_size=validation_split,
537 |                 random_state=42,
538 |                 stratify=targets,
539 |             )
540 | 
541 |             # Scale features
542 |             scaler = StandardScaler()
543 |             X_train_scaled = scaler.fit_transform(X_train)
544 |             X_val_scaled = scaler.transform(X_val)
545 | 
546 |             # Create and train new model
547 |             model = self.model_factory()
548 |             model.set_params(**model_params)
549 |             model.fit(X_train_scaled, y_train)
550 | 
551 |             # Evaluate model
552 |             train_score = model.score(X_train_scaled, y_train)
553 |             val_score = model.score(X_val_scaled, y_val)
554 | 
555 |             # Create version string
556 |             timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
557 |             new_version = f"v_{timestamp}"
558 | 
559 |             # Prepare metadata
560 |             metadata = {
561 |                 "training_samples": len(X_train),
562 |                 "validation_samples": len(X_val),
563 |                 "features_count": X_train.shape[1],
564 |                 "model_params": model_params,
565 |                 "retraining_trigger": "automated",
566 |             }
567 | 
568 |             # Prepare performance metrics
569 |             performance_metrics = {
570 |                 "train_accuracy": train_score,
571 |                 "validation_accuracy": val_score,
572 |                 "overfitting_gap": train_score - val_score,
573 |             }
574 | 
575 |             # Save model
576 |             success = self.model_manager.save_model(
577 |                 model_id=model_id,
578 |                 version=new_version,
579 |                 model=model,
580 |                 scaler=scaler,
581 |                 metadata=metadata,
582 |                 performance_metrics=performance_metrics,
583 |                 set_as_active=True,  # Set as active if validation performance is good
584 |             )
585 | 
586 |             if success:
587 |                 # Update retraining history
588 |                 self.last_retraining[model_id] = datetime.now()
589 |                 self.retraining_history.append(
590 |                     {
591 |                         "model_id": model_id,
592 |                         "version": new_version,
593 |                         "timestamp": datetime.now().isoformat(),
594 |                         "training_samples": len(X_train),
595 |                         "validation_accuracy": val_score,
596 |                     }
597 |                 )
598 | 
599 |                 # Update drift detector reference data
600 |                 self.drift_detector.set_reference_data(features, targets)
601 | 
602 |                 # Update performance monitor baseline
603 |                 self.performance_monitor.set_baseline_performance(performance_metrics)
604 | 
605 |                 logger.info(
606 |                     f"Successfully retrained model {model_id} -> {new_version} (val_acc: {val_score:.4f})"
607 |                 )
608 |                 return new_version
609 |             else:
610 |                 logger.error(f"Failed to save retrained model {model_id}")
611 |                 return None
612 | 
613 |         except Exception as e:
614 |             logger.error(f"Error retraining model {model_id}: {e}")
615 |             return None
616 | 
617 |     def run_retraining_check(
618 |         self, model_id: str, new_data: pd.DataFrame
619 |     ) -> dict[str, Any]:
620 |         """Run complete retraining check and execute if needed.
621 | 
622 |         Args:
623 |             model_id: Model identifier
624 |             new_data: New data for evaluation
625 | 
626 |         Returns:
627 |             Dictionary with check results and actions taken
628 |         """
629 |         start_time = datetime.now()
630 | 
631 |         try:
632 |             # Check if retraining is needed
633 |             should_retrain, reason = self.should_retrain(model_id, new_data)
634 | 
635 |             result = {
636 |                 "model_id": model_id,
637 |                 "timestamp": start_time.isoformat(),
638 |                 "should_retrain": should_retrain,
639 |                 "reason": reason,
640 |                 "data_samples": len(new_data),
641 |                 "new_version": None,
642 |                 "success": False,
643 |             }
644 | 
645 |             if should_retrain:
646 |                 logger.info(f"Retraining {model_id} due to: {reason}")
647 |                 new_version = self.retrain_model(model_id, new_data)
648 | 
649 |                 if new_version:
650 |                     result.update(
651 |                         {
652 |                             "new_version": new_version,
653 |                             "success": True,
654 |                             "action": "retrained",
655 |                         }
656 |                     )
657 |                 else:
658 |                     result.update(
659 |                         {
660 |                             "action": "retrain_failed",
661 |                             "error": "Model retraining failed",
662 |                         }
663 |                     )
664 |             else:
665 |                 result.update(
666 |                     {
667 |                         "action": "no_retrain",
668 |                         "success": True,
669 |                     }
670 |                 )
671 | 
672 |             # Calculate execution time
673 |             execution_time = (datetime.now() - start_time).total_seconds()
674 |             result["execution_time_seconds"] = execution_time
675 | 
676 |             return result
677 | 
678 |         except Exception as e:
679 |             logger.error(f"Error in retraining check for {model_id}: {e}")
680 |             return {
681 |                 "model_id": model_id,
682 |                 "timestamp": start_time.isoformat(),
683 |                 "should_retrain": False,
684 |                 "reason": "check_error",
685 |                 "error": str(e),
686 |                 "success": False,
687 |                 "execution_time_seconds": (datetime.now() - start_time).total_seconds(),
688 |             }
689 | 
690 |     def get_retraining_summary(self) -> dict[str, Any]:
691 |         """Get summary of retraining pipeline status.
692 | 
693 |         Returns:
694 |             Summary dictionary
695 |         """
696 |         return {
697 |             "total_models_managed": len(self.last_retraining),
698 |             "total_retrainings": len(self.retraining_history),
699 |             "recent_retrainings": self.retraining_history[-10:],
700 |             "last_retraining_times": {
701 |                 model_id: timestamp.isoformat()
702 |                 for model_id, timestamp in self.last_retraining.items()
703 |             },
704 |             "retraining_schedule_hours": self.retraining_schedule_hours,
705 |             "min_samples_for_retraining": self.min_samples_for_retraining,
706 |         }
707 | 
708 | 
709 | # Alias for backward compatibility
710 | RetrainingPipeline = AutoRetrainingPipeline
711 | 
712 | # Ensure all expected names are available
713 | __all__ = [
714 |     "DataDriftDetector",
715 |     "ModelPerformanceMonitor",
716 |     "AutoRetrainingPipeline",
717 |     "RetrainingPipeline",  # Alias for backward compatibility
718 | ]
719 | 
```

--------------------------------------------------------------------------------
/maverick_mcp/data/performance.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Performance optimization utilities for Maverick-MCP.
  3 | 
  4 | This module provides Redis connection pooling, request caching,
  5 | and query optimization features to improve application performance.
  6 | """
  7 | 
  8 | import hashlib
  9 | import json
 10 | import logging
 11 | import time
 12 | from collections.abc import Callable
 13 | from contextlib import asynccontextmanager
 14 | from functools import wraps
 15 | from typing import Any, TypeVar, cast
 16 | 
 17 | import redis.asyncio as redis
 18 | from redis.asyncio.client import Pipeline
 19 | from sqlalchemy import text
 20 | from sqlalchemy.ext.asyncio import AsyncSession
 21 | 
 22 | from maverick_mcp.config.settings import get_settings
 23 | from maverick_mcp.data.session_management import get_async_db_session
 24 | 
 25 | settings = get_settings()
 26 | logger = logging.getLogger(__name__)
 27 | 
 28 | # Type variables for generic typing
 29 | F = TypeVar("F", bound=Callable[..., Any])
 30 | 
 31 | 
 32 | class RedisConnectionManager:
 33 |     """
 34 |     Centralized Redis connection manager with connection pooling.
 35 | 
 36 |     This manager provides:
 37 |     - Connection pooling with configurable limits
 38 |     - Automatic failover and retry logic
 39 |     - Health monitoring and metrics
 40 |     - Graceful degradation when Redis is unavailable
 41 |     """
 42 | 
 43 |     def __init__(self):
 44 |         self._pool: redis.ConnectionPool | None = None
 45 |         self._client: redis.Redis | None = None
 46 |         self._initialized = False
 47 |         self._healthy = False
 48 |         self._last_health_check = 0
 49 |         self._health_check_interval = 30  # seconds
 50 | 
 51 |         # Connection pool configuration
 52 |         self._max_connections = settings.db.redis_max_connections
 53 |         self._retry_on_timeout = settings.db.redis_retry_on_timeout
 54 |         self._socket_timeout = settings.db.redis_socket_timeout
 55 |         self._socket_connect_timeout = settings.db.redis_socket_connect_timeout
 56 |         self._health_check_interval_sec = 30
 57 | 
 58 |         # Metrics
 59 |         self._metrics = {
 60 |             "connections_created": 0,
 61 |             "connections_closed": 0,
 62 |             "commands_executed": 0,
 63 |             "errors": 0,
 64 |             "health_checks": 0,
 65 |             "last_error": None,
 66 |         }
 67 | 
 68 |     async def initialize(self) -> bool:
 69 |         """
 70 |         Initialize Redis connection pool.
 71 | 
 72 |         Returns:
 73 |             bool: True if initialization successful, False otherwise
 74 |         """
 75 |         if self._initialized:
 76 |             return self._healthy
 77 | 
 78 |         try:
 79 |             # Create connection pool
 80 |             self._pool = redis.ConnectionPool.from_url(
 81 |                 settings.redis.url,
 82 |                 max_connections=self._max_connections,
 83 |                 retry_on_timeout=self._retry_on_timeout,
 84 |                 socket_timeout=self._socket_timeout,
 85 |                 socket_connect_timeout=self._socket_connect_timeout,
 86 |                 decode_responses=True,
 87 |                 health_check_interval=self._health_check_interval_sec,
 88 |             )
 89 | 
 90 |             # Create Redis client
 91 |             self._client = redis.Redis(connection_pool=self._pool)
 92 | 
 93 |             client = self._client
 94 |             if client is None:  # Defensive guard for static type checking
 95 |                 msg = "Redis client initialization failed"
 96 |                 raise RuntimeError(msg)
 97 | 
 98 |             # Test connection
 99 |             await client.ping()
100 | 
101 |             self._healthy = True
102 |             self._initialized = True
103 |             self._metrics["connections_created"] += 1
104 | 
105 |             logger.info(
106 |                 f"Redis connection pool initialized: "
107 |                 f"max_connections={self._max_connections}, "
108 |                 f"url={settings.redis.url}"
109 |             )
110 | 
111 |             return True
112 | 
113 |         except Exception as e:
114 |             logger.error(f"Failed to initialize Redis connection pool: {e}")
115 |             self._metrics["errors"] += 1
116 |             self._metrics["last_error"] = str(e)
117 |             self._healthy = False
118 |             return False
119 | 
120 |     async def get_client(self) -> redis.Redis | None:
121 |         """
122 |         Get Redis client from the connection pool.
123 | 
124 |         Returns:
125 |             Redis client or None if unavailable
126 |         """
127 |         if not self._initialized:
128 |             await self.initialize()
129 | 
130 |         if not self._healthy:
131 |             await self._health_check()
132 | 
133 |         return self._client if self._healthy else None
134 | 
135 |     async def _health_check(self) -> bool:
136 |         """
137 |         Perform health check on Redis connection.
138 | 
139 |         Returns:
140 |             bool: True if healthy, False otherwise
141 |         """
142 |         current_time = time.time()
143 | 
144 |         # Skip health check if recently performed
145 |         if (current_time - self._last_health_check) < self._health_check_interval:
146 |             return self._healthy
147 | 
148 |         self._last_health_check = current_time
149 |         self._metrics["health_checks"] += 1
150 | 
151 |         try:
152 |             if self._client:
153 |                 await self._client.ping()
154 |                 self._healthy = True
155 |                 logger.debug("Redis health check passed")
156 |             else:
157 |                 self._healthy = False
158 | 
159 |         except Exception as e:
160 |             logger.warning(f"Redis health check failed: {e}")
161 |             self._healthy = False
162 |             self._metrics["errors"] += 1
163 |             self._metrics["last_error"] = str(e)
164 | 
165 |             # Try to reinitialize
166 |             await self.initialize()
167 | 
168 |         return self._healthy
169 | 
170 |     async def execute_command(self, command: str, *args, **kwargs) -> Any:
171 |         """
172 |         Execute Redis command with error handling and metrics.
173 | 
174 |         Args:
175 |             command: Redis command name
176 |             *args: Command arguments
177 |             **kwargs: Command keyword arguments
178 | 
179 |         Returns:
180 |             Command result or None if failed
181 |         """
182 |         client = await self.get_client()
183 |         if not client:
184 |             return None
185 | 
186 |         try:
187 |             self._metrics["commands_executed"] += 1
188 |             result = await getattr(client, command)(*args, **kwargs)
189 |             return result
190 | 
191 |         except Exception as e:
192 |             logger.error(f"Redis command '{command}' failed: {e}")
193 |             self._metrics["errors"] += 1
194 |             self._metrics["last_error"] = str(e)
195 |             return None
196 | 
197 |     async def pipeline(self) -> Pipeline | None:
198 |         """
199 |         Create Redis pipeline for batch operations.
200 | 
201 |         Returns:
202 |             Redis pipeline or None if unavailable
203 |         """
204 |         client = await self.get_client()
205 |         if not client:
206 |             return None
207 | 
208 |         return client.pipeline()
209 | 
210 |     def get_metrics(self) -> dict[str, Any]:
211 |         """Get connection pool metrics."""
212 |         metrics = self._metrics.copy()
213 |         metrics.update(
214 |             {
215 |                 "healthy": self._healthy,
216 |                 "initialized": self._initialized,
217 |                 "pool_size": self._max_connections,
218 |                 "pool_created": bool(self._pool),
219 |             }
220 |         )
221 | 
222 |         if self._pool:
223 |             # Safely get pool metrics with fallbacks for missing attributes
224 |             try:
225 |                 metrics["pool_created_connections"] = getattr(
226 |                     self._pool, "created_connections", 0
227 |                 )
228 |             except AttributeError:
229 |                 metrics["pool_created_connections"] = 0
230 | 
231 |             try:
232 |                 metrics["pool_available_connections"] = len(
233 |                     getattr(self._pool, "_available_connections", [])
234 |                 )
235 |             except (AttributeError, TypeError):
236 |                 metrics["pool_available_connections"] = 0
237 | 
238 |             try:
239 |                 metrics["pool_in_use_connections"] = len(
240 |                     getattr(self._pool, "_in_use_connections", [])
241 |                 )
242 |             except (AttributeError, TypeError):
243 |                 metrics["pool_in_use_connections"] = 0
244 | 
245 |         return metrics
246 | 
247 |     async def close(self):
248 |         """Close connection pool gracefully."""
249 |         if self._client:
250 |             # Use aclose() instead of close() to avoid deprecation warning
251 |             # aclose() is the new async close method in redis-py 5.0+
252 |             if hasattr(self._client, "aclose"):
253 |                 await self._client.aclose()
254 |             else:
255 |                 # Fallback for older versions
256 |                 await self._client.close()
257 |             self._metrics["connections_closed"] += 1
258 | 
259 |         if self._pool:
260 |             await self._pool.disconnect()
261 | 
262 |         self._initialized = False
263 |         self._healthy = False
264 |         logger.info("Redis connection pool closed")
265 | 
266 | 
267 | # Global Redis connection manager instance
268 | redis_manager = RedisConnectionManager()
269 | 
270 | 
271 | class RequestCache:
272 |     """
273 |     Smart request-level caching system.
274 | 
275 |     This system provides:
276 |     - Automatic cache key generation based on function signature
277 |     - TTL strategies for different data types
278 |     - Cache invalidation mechanisms
279 |     - Hit/miss metrics and monitoring
280 |     """
281 | 
282 |     def __init__(self):
283 |         self._hit_count = 0
284 |         self._miss_count = 0
285 |         self._error_count = 0
286 | 
287 |         # Default TTL values for different data types (in seconds)
288 |         self._default_ttls = {
289 |             "stock_data": 3600,  # 1 hour for stock data
290 |             "technical_analysis": 1800,  # 30 minutes for technical indicators
291 |             "market_data": 300,  # 5 minutes for market data
292 |             "screening": 7200,  # 2 hours for screening results
293 |             "portfolio": 1800,  # 30 minutes for portfolio analysis
294 |             "macro_data": 3600,  # 1 hour for macro data
295 |             "default": 900,  # 15 minutes default
296 |         }
297 | 
298 |     def _generate_cache_key(self, prefix: str, *args, **kwargs) -> str:
299 |         """
300 |         Generate cache key from function arguments.
301 | 
302 |         Args:
303 |             prefix: Cache key prefix
304 |             *args: Function arguments
305 |             **kwargs: Function keyword arguments
306 | 
307 |         Returns:
308 |             Generated cache key
309 |         """
310 |         # Create a hash of the arguments
311 |         key_data = {
312 |             "args": args,
313 |             "kwargs": sorted(kwargs.items()),
314 |         }
315 | 
316 |         key_hash = hashlib.sha256(
317 |             json.dumps(key_data, sort_keys=True, default=str).encode()
318 |         ).hexdigest()[:16]  # Use first 16 chars for brevity
319 | 
320 |         return f"cache:{prefix}:{key_hash}"
321 | 
322 |     def _get_ttl(self, data_type: str) -> int:
323 |         """Get TTL for data type."""
324 |         return self._default_ttls.get(data_type, self._default_ttls["default"])
325 | 
326 |     async def get(self, key: str) -> Any | None:
327 |         """
328 |         Get value from cache.
329 | 
330 |         Args:
331 |             key: Cache key
332 | 
333 |         Returns:
334 |             Cached value or None if not found
335 |         """
336 |         try:
337 |             client = await redis_manager.get_client()
338 |             if not client:
339 |                 return None
340 | 
341 |             data = await client.get(key)
342 |             if data:
343 |                 self._hit_count += 1
344 |                 logger.debug(f"Cache hit for key: {key}")
345 |                 return json.loads(data)
346 |             else:
347 |                 self._miss_count += 1
348 |                 logger.debug(f"Cache miss for key: {key}")
349 |                 return None
350 | 
351 |         except Exception as e:
352 |             logger.error(f"Error getting from cache: {e}")
353 |             self._error_count += 1
354 |             return None
355 | 
356 |     async def set(
357 |         self, key: str, value: Any, ttl: int | None = None, data_type: str = "default"
358 |     ) -> bool:
359 |         """
360 |         Set value in cache.
361 | 
362 |         Args:
363 |             key: Cache key
364 |             value: Value to cache
365 |             ttl: Time to live in seconds
366 |             data_type: Data type for TTL determination
367 | 
368 |         Returns:
369 |             True if successful, False otherwise
370 |         """
371 |         try:
372 |             client = await redis_manager.get_client()
373 |             if not client:
374 |                 return False
375 | 
376 |             if ttl is None:
377 |                 ttl = self._get_ttl(data_type)
378 | 
379 |             serialized_value = json.dumps(value, default=str)
380 |             success = await client.setex(key, ttl, serialized_value)
381 | 
382 |             if success:
383 |                 logger.debug(f"Cached value for key: {key} (TTL: {ttl}s)")
384 | 
385 |             return bool(success)
386 | 
387 |         except Exception as e:
388 |             logger.error(f"Error setting cache: {e}")
389 |             self._error_count += 1
390 |             return False
391 | 
392 |     async def delete(self, key: str) -> bool:
393 |         """Delete key from cache."""
394 |         try:
395 |             client = await redis_manager.get_client()
396 |             if not client:
397 |                 return False
398 | 
399 |             result = await client.delete(key)
400 |             return bool(result)
401 | 
402 |         except Exception as e:
403 |             logger.error(f"Error deleting from cache: {e}")
404 |             self._error_count += 1
405 |             return False
406 | 
407 |     async def delete_pattern(self, pattern: str) -> int:
408 |         """Delete all keys matching pattern."""
409 |         try:
410 |             client = await redis_manager.get_client()
411 |             if not client:
412 |                 return 0
413 | 
414 |             keys = await client.keys(pattern)
415 |             if keys:
416 |                 result = await client.delete(*keys)
417 |                 logger.info(f"Deleted {result} keys matching pattern: {pattern}")
418 |                 return result
419 | 
420 |             return 0
421 | 
422 |         except Exception as e:
423 |             logger.error(f"Error deleting pattern: {e}")
424 |             self._error_count += 1
425 |             return 0
426 | 
427 |     def get_metrics(self) -> dict[str, Any]:
428 |         """Get cache metrics."""
429 |         total_requests = self._hit_count + self._miss_count
430 |         hit_rate = (self._hit_count / total_requests) if total_requests > 0 else 0
431 | 
432 |         return {
433 |             "hit_count": self._hit_count,
434 |             "miss_count": self._miss_count,
435 |             "error_count": self._error_count,
436 |             "total_requests": total_requests,
437 |             "hit_rate": hit_rate,
438 |             "ttl_config": self._default_ttls,
439 |         }
440 | 
441 | 
442 | # Global request cache instance
443 | request_cache = RequestCache()
444 | 
445 | 
446 | def cached(
447 |     data_type: str = "default",
448 |     ttl: int | None = None,
449 |     key_prefix: str | None = None,
450 |     invalidate_patterns: list[str] | None = None,
451 | ):
452 |     """
453 |     Decorator for automatic function result caching.
454 | 
455 |     Args:
456 |         data_type: Data type for TTL determination
457 |         ttl: Custom TTL in seconds
458 |         key_prefix: Custom cache key prefix
459 |         invalidate_patterns: Patterns to invalidate on update
460 | 
461 |     Example:
462 |         @cached(data_type="stock_data", ttl=3600)
463 |         async def get_stock_price(symbol: str) -> float:
464 |             # Expensive operation
465 |             return price
466 |     """
467 | 
468 |     def decorator(func: F) -> F:
469 |         @wraps(func)
470 |         async def wrapper(*args, **kwargs):
471 |             # Generate cache key
472 |             prefix = key_prefix or f"{func.__module__}.{func.__name__}"
473 |             cache_key = request_cache._generate_cache_key(prefix, *args, **kwargs)
474 | 
475 |             # Try to get from cache
476 |             cached_result = await request_cache.get(cache_key)
477 |             if cached_result is not None:
478 |                 return cached_result
479 | 
480 |             # Execute function
481 |             result = await func(*args, **kwargs)
482 | 
483 |             # Cache result
484 |             if result is not None:
485 |                 await request_cache.set(cache_key, result, ttl, data_type)
486 | 
487 |             return result
488 | 
489 |         # Add cache invalidation method
490 |         async def invalidate_cache(*args, **kwargs):
491 |             """Invalidate cache for this function."""
492 |             prefix = key_prefix or f"{func.__module__}.{func.__name__}"
493 |             cache_key = request_cache._generate_cache_key(prefix, *args, **kwargs)
494 |             await request_cache.delete(cache_key)
495 | 
496 |             # Invalidate patterns if specified
497 |             if invalidate_patterns:
498 |                 for pattern in invalidate_patterns:
499 |                     await request_cache.delete_pattern(pattern)
500 | 
501 |         typed_wrapper = cast(F, wrapper)
502 |         cast(Any, typed_wrapper).invalidate_cache = invalidate_cache
503 |         return typed_wrapper
504 | 
505 |     return decorator
506 | 
507 | 
508 | class QueryOptimizer:
509 |     """
510 |     Database query optimization utilities.
511 | 
512 |     This class provides:
513 |     - Query performance monitoring
514 |     - Index recommendations
515 |     - N+1 query detection
516 |     - Connection pool monitoring
517 |     """
518 | 
519 |     def __init__(self):
520 |         self._query_stats = {}
521 |         self._slow_query_threshold = 1.0  # seconds
522 |         self._slow_queries = []
523 | 
524 |     def monitor_query(self, query_name: str):
525 |         """
526 |         Decorator for monitoring query performance.
527 | 
528 |         Args:
529 |             query_name: Name for the query (for metrics)
530 |         """
531 | 
532 |         def decorator(func: F) -> F:
533 |             @wraps(func)
534 |             async def wrapper(*args, **kwargs):
535 |                 start_time = time.time()
536 | 
537 |                 try:
538 |                     result = await func(*args, **kwargs)
539 |                     execution_time = time.time() - start_time
540 | 
541 |                     # Update statistics
542 |                     if query_name not in self._query_stats:
543 |                         self._query_stats[query_name] = {
544 |                             "count": 0,
545 |                             "total_time": 0,
546 |                             "avg_time": 0,
547 |                             "max_time": 0,
548 |                             "min_time": float("inf"),
549 |                         }
550 | 
551 |                     stats = self._query_stats[query_name]
552 |                     stats["count"] += 1
553 |                     stats["total_time"] += execution_time
554 |                     stats["avg_time"] = stats["total_time"] / stats["count"]
555 |                     stats["max_time"] = max(stats["max_time"], execution_time)
556 |                     stats["min_time"] = min(stats["min_time"], execution_time)
557 | 
558 |                     # Track slow queries
559 |                     if execution_time > self._slow_query_threshold:
560 |                         self._slow_queries.append(
561 |                             {
562 |                                 "query_name": query_name,
563 |                                 "execution_time": execution_time,
564 |                                 "timestamp": time.time(),
565 |                                 "args": str(args)[:200],  # Truncate long args
566 |                             }
567 |                         )
568 | 
569 |                         # Keep only last 100 slow queries
570 |                         if len(self._slow_queries) > 100:
571 |                             self._slow_queries = self._slow_queries[-100:]
572 | 
573 |                         logger.warning(
574 |                             f"Slow query detected: {query_name} took {execution_time:.2f}s"
575 |                         )
576 | 
577 |                     return result
578 | 
579 |                 except Exception as e:
580 |                     execution_time = time.time() - start_time
581 |                     logger.error(
582 |                         f"Query {query_name} failed after {execution_time:.2f}s: {e}"
583 |                     )
584 |                     raise
585 | 
586 |             return cast(F, wrapper)
587 | 
588 |         return decorator
589 | 
590 |     def get_query_stats(self) -> dict[str, Any]:
591 |         """Get query performance statistics."""
592 |         return {
593 |             "query_stats": self._query_stats,
594 |             "slow_queries": self._slow_queries[-10:],  # Last 10 slow queries
595 |             "slow_query_threshold": self._slow_query_threshold,
596 |         }
597 | 
598 |     async def analyze_missing_indexes(
599 |         self, session: AsyncSession
600 |     ) -> list[dict[str, Any]]:
601 |         """
602 |         Analyze database for missing indexes.
603 | 
604 |         Args:
605 |             session: Database session
606 | 
607 |         Returns:
608 |             List of recommended indexes
609 |         """
610 |         recommendations = []
611 | 
612 |         try:
613 |             # Check for common missing indexes
614 |             queries = [
615 |                 # PriceCache table analysis
616 |                 {
617 |                     "name": "PriceCache date range queries",
618 |                     "query": """
619 |                         SELECT schemaname, tablename, attname, n_distinct, correlation
620 |                         FROM pg_stats
621 |                         WHERE tablename = 'stocks_pricecache'
622 |                         AND attname IN ('date', 'stock_id', 'volume')
623 |                     """,
624 |                     "recommendation": "Consider composite index on (stock_id, date) if not exists",
625 |                 },
626 |                 # Stock lookup performance
627 |                 {
628 |                     "name": "Stock ticker lookups",
629 |                     "query": """
630 |                         SELECT schemaname, tablename, attname, n_distinct, correlation
631 |                         FROM pg_stats
632 |                         WHERE tablename = 'stocks_stock'
633 |                         AND attname = 'ticker_symbol'
634 |                     """,
635 |                     "recommendation": "Ensure unique index on ticker_symbol exists",
636 |                 },
637 |                 # Screening tables
638 |                 {
639 |                     "name": "Maverick screening queries",
640 |                     "query": """
641 |                         SELECT schemaname, tablename, attname, n_distinct
642 |                         FROM pg_stats
643 |                         WHERE tablename IN ('stocks_maverickstocks', 'stocks_maverickbearstocks', 'stocks_supply_demand_breakouts')
644 |                         AND attname IN ('score', 'rank', 'date_analyzed')
645 |                     """,
646 |                     "recommendation": "Consider indexes on score, rank, and date_analyzed columns",
647 |                 },
648 |             ]
649 | 
650 |             for query_info in queries:
651 |                 try:
652 |                     result = await session.execute(text(query_info["query"]))
653 |                     rows = result.fetchall()
654 | 
655 |                     if rows:
656 |                         recommendations.append(
657 |                             {
658 |                                 "analysis": query_info["name"],
659 |                                 "recommendation": query_info["recommendation"],
660 |                                 "stats": [dict(row._mapping) for row in rows],
661 |                             }
662 |                         )
663 | 
664 |                 except Exception as e:
665 |                     logger.error(f"Failed to analyze {query_info['name']}: {e}")
666 | 
667 |             # Check for tables without proper indexes
668 |             missing_indexes_query = """
669 |                 SELECT
670 |                     schemaname,
671 |                     tablename,
672 |                     seq_scan,
673 |                     seq_tup_read,
674 |                     idx_scan,
675 |                     idx_tup_fetch,
676 |                     CASE
677 |                         WHEN seq_scan = 0 THEN 0
678 |                         ELSE seq_tup_read / seq_scan
679 |                     END as avg_seq_read
680 |                 FROM pg_stat_user_tables
681 |                 WHERE schemaname = 'public'
682 |                 AND tablename LIKE 'stocks_%'
683 |                 ORDER BY seq_tup_read DESC
684 |             """
685 | 
686 |             result = await session.execute(text(missing_indexes_query))
687 |             scan_stats = result.fetchall()
688 | 
689 |             for row in scan_stats:
690 |                 if row.seq_scan > 100 and row.avg_seq_read > 1000:
691 |                     recommendations.append(
692 |                         {
693 |                             "analysis": f"High sequential scans on {row.tablename}",
694 |                             "recommendation": f"Consider adding indexes to reduce {row.seq_tup_read} sequential reads",
695 |                             "stats": dict(row._mapping),
696 |                         }
697 |                     )
698 | 
699 |         except Exception as e:
700 |             logger.error(f"Error analyzing missing indexes: {e}")
701 | 
702 |         return recommendations
703 | 
704 | 
705 | # Global query optimizer instance
706 | query_optimizer = QueryOptimizer()
707 | 
708 | 
709 | async def initialize_performance_systems():
710 |     """Initialize all performance optimization systems."""
711 |     logger.info("Initializing performance optimization systems...")
712 | 
713 |     # Initialize Redis connection manager
714 |     redis_success = await redis_manager.initialize()
715 | 
716 |     logger.info(
717 |         f"Performance systems initialized: Redis={'✓' if redis_success else '✗'}"
718 |     )
719 | 
720 |     return {
721 |         "redis_manager": redis_success,
722 |         "request_cache": True,
723 |         "query_optimizer": True,
724 |     }
725 | 
726 | 
727 | async def get_performance_metrics() -> dict[str, Any]:
728 |     """Get comprehensive performance metrics."""
729 |     return {
730 |         "redis_manager": redis_manager.get_metrics(),
731 |         "request_cache": request_cache.get_metrics(),
732 |         "query_optimizer": query_optimizer.get_query_stats(),
733 |         "timestamp": time.time(),
734 |     }
735 | 
736 | 
737 | async def cleanup_performance_systems():
738 |     """Cleanup performance systems gracefully."""
739 |     logger.info("Cleaning up performance optimization systems...")
740 | 
741 |     await redis_manager.close()
742 | 
743 |     logger.info("Performance systems cleanup completed")
744 | 
745 | 
746 | # Context manager for database session with query monitoring
747 | @asynccontextmanager
748 | async def monitored_db_session(query_name: str = "unknown"):
749 |     """
750 |     Context manager for database sessions with automatic query monitoring.
751 | 
752 |     Args:
753 |         query_name: Name for the query (for metrics)
754 | 
755 |     Example:
756 |         async with monitored_db_session("get_stock_data") as session:
757 |             result = await session.execute(
758 |                 text("SELECT * FROM stocks_stock WHERE ticker_symbol = :symbol"),
759 |                 {"symbol": "AAPL"},
760 |             )
761 |             stock = result.first()
762 |     """
763 |     async with get_async_db_session() as session:
764 |         start_time = time.time()
765 | 
766 |         try:
767 |             yield session
768 | 
769 |             # Record successful query
770 |             execution_time = time.time() - start_time
771 |             if query_name not in query_optimizer._query_stats:
772 |                 query_optimizer._query_stats[query_name] = {
773 |                     "count": 0,
774 |                     "total_time": 0,
775 |                     "avg_time": 0,
776 |                     "max_time": 0,
777 |                     "min_time": float("inf"),
778 |                 }
779 | 
780 |             stats = query_optimizer._query_stats[query_name]
781 |             stats["count"] += 1
782 |             stats["total_time"] += execution_time
783 |             stats["avg_time"] = stats["total_time"] / stats["count"]
784 |             stats["max_time"] = max(stats["max_time"], execution_time)
785 |             stats["min_time"] = min(stats["min_time"], execution_time)
786 | 
787 |         except Exception as e:
788 |             execution_time = time.time() - start_time
789 |             logger.error(
790 |                 f"Database query '{query_name}' failed after {execution_time:.2f}s: {e}"
791 |             )
792 |             raise
793 | 
```
Page 23/39FirstPrevNextLast