This is page 17 of 29. Use http://codebase.md/wshobson/maverick-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/performance/test_load.py:
--------------------------------------------------------------------------------
```python
"""
Load Testing for Concurrent Users and Backtest Operations.
This test suite covers:
- Concurrent user load testing (10, 50, 100 users)
- Response time and throughput measurement
- Memory usage under concurrent load
- Database performance with multiple connections
- API rate limiting behavior
- Queue management and task distribution
- System stability under sustained load
"""
import asyncio
import logging
import os
import random
import statistics
import time
from dataclasses import dataclass
from typing import Any
from unittest.mock import Mock
import numpy as np
import pandas as pd
import psutil
import pytest
from maverick_mcp.backtesting import VectorBTEngine
from maverick_mcp.backtesting.persistence import BacktestPersistenceManager
from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
logger = logging.getLogger(__name__)
@dataclass
class LoadTestResult:
"""Data class for load test results."""
concurrent_users: int
total_requests: int
successful_requests: int
failed_requests: int
total_duration: float
avg_response_time: float
min_response_time: float
max_response_time: float
p50_response_time: float
p95_response_time: float
p99_response_time: float
requests_per_second: float
errors_per_second: float
memory_usage_mb: float
cpu_usage_percent: float
class LoadTestRunner:
"""Load test runner with realistic user simulation."""
def __init__(self, data_provider):
self.data_provider = data_provider
self.results = []
self.active_requests = 0
async def simulate_user_session(
self, user_id: int, session_config: dict[str, Any]
) -> dict[str, Any]:
"""Simulate a realistic user session with multiple backtests."""
session_start = time.time()
user_results = []
response_times = []
symbols = session_config.get("symbols", ["AAPL"])
strategies = session_config.get("strategies", ["sma_cross"])
think_time_range = session_config.get("think_time", (0.5, 2.0))
engine = VectorBTEngine(data_provider=self.data_provider)
for symbol in symbols:
for strategy in strategies:
self.active_requests += 1
request_start = time.time()
try:
parameters = STRATEGY_TEMPLATES.get(strategy, {}).get(
"parameters", {}
)
result = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
request_time = time.time() - request_start
response_times.append(request_time)
user_results.append(
{
"symbol": symbol,
"strategy": strategy,
"success": True,
"response_time": request_time,
"result_size": len(str(result)),
}
)
except Exception as e:
request_time = time.time() - request_start
response_times.append(request_time)
user_results.append(
{
"symbol": symbol,
"strategy": strategy,
"success": False,
"response_time": request_time,
"error": str(e),
}
)
finally:
self.active_requests -= 1
# Simulate think time between requests
think_time = random.uniform(*think_time_range)
await asyncio.sleep(think_time)
session_time = time.time() - session_start
return {
"user_id": user_id,
"session_time": session_time,
"results": user_results,
"response_times": response_times,
"success_count": sum(1 for r in user_results if r["success"]),
"failure_count": sum(1 for r in user_results if not r["success"]),
}
def calculate_percentiles(self, response_times: list[float]) -> dict[str, float]:
"""Calculate response time percentiles."""
if not response_times:
return {"p50": 0, "p95": 0, "p99": 0}
sorted_times = sorted(response_times)
return {
"p50": np.percentile(sorted_times, 50),
"p95": np.percentile(sorted_times, 95),
"p99": np.percentile(sorted_times, 99),
}
async def run_load_test(
self,
concurrent_users: int,
session_config: dict[str, Any],
duration_seconds: int = 60,
) -> LoadTestResult:
"""Run load test with specified concurrent users."""
logger.info(
f"Starting load test: {concurrent_users} concurrent users for {duration_seconds}s"
)
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
start_time = time.time()
all_response_times = []
all_user_results = []
# Create semaphore to control concurrency
semaphore = asyncio.Semaphore(concurrent_users)
async def run_user_with_semaphore(user_id: int):
async with semaphore:
return await self.simulate_user_session(user_id, session_config)
# Generate user sessions
user_tasks = []
for user_id in range(concurrent_users):
task = run_user_with_semaphore(user_id)
user_tasks.append(task)
# Execute all user sessions concurrently
try:
user_results = await asyncio.wait_for(
asyncio.gather(*user_tasks, return_exceptions=True),
timeout=duration_seconds + 30, # Add buffer to test timeout
)
except TimeoutError:
logger.warning(f"Load test timed out after {duration_seconds + 30}s")
user_results = []
end_time = time.time()
actual_duration = end_time - start_time
# Process results
successful_sessions = []
failed_sessions = []
for result in user_results:
if isinstance(result, Exception):
failed_sessions.append(str(result))
elif isinstance(result, dict):
successful_sessions.append(result)
all_response_times.extend(result.get("response_times", []))
all_user_results.extend(result.get("results", []))
# Calculate metrics
total_requests = len(all_user_results)
successful_requests = sum(
1 for r in all_user_results if r.get("success", False)
)
failed_requests = total_requests - successful_requests
# Response time statistics
percentiles = self.calculate_percentiles(all_response_times)
avg_response_time = (
statistics.mean(all_response_times) if all_response_times else 0
)
min_response_time = min(all_response_times) if all_response_times else 0
max_response_time = max(all_response_times) if all_response_times else 0
# Throughput metrics
requests_per_second = (
total_requests / actual_duration if actual_duration > 0 else 0
)
errors_per_second = (
failed_requests / actual_duration if actual_duration > 0 else 0
)
# Resource usage
final_memory = process.memory_info().rss / 1024 / 1024
memory_usage = final_memory - initial_memory
cpu_usage = process.cpu_percent()
result = LoadTestResult(
concurrent_users=concurrent_users,
total_requests=total_requests,
successful_requests=successful_requests,
failed_requests=failed_requests,
total_duration=actual_duration,
avg_response_time=avg_response_time,
min_response_time=min_response_time,
max_response_time=max_response_time,
p50_response_time=percentiles["p50"],
p95_response_time=percentiles["p95"],
p99_response_time=percentiles["p99"],
requests_per_second=requests_per_second,
errors_per_second=errors_per_second,
memory_usage_mb=memory_usage,
cpu_usage_percent=cpu_usage,
)
logger.info(
f"Load Test Results ({concurrent_users} users):\n"
f" • Total Requests: {total_requests}\n"
f" • Success Rate: {successful_requests / total_requests * 100:.1f}%\n"
f" • Avg Response Time: {avg_response_time:.3f}s\n"
f" • 95th Percentile: {percentiles['p95']:.3f}s\n"
f" • Throughput: {requests_per_second:.1f} req/s\n"
f" • Memory Usage: {memory_usage:.1f}MB\n"
f" • Duration: {actual_duration:.1f}s"
)
return result
class TestLoadTesting:
"""Load testing suite for concurrent users."""
@pytest.fixture
async def optimized_data_provider(self):
"""Create optimized data provider for load testing."""
provider = Mock()
# Pre-generate data for common symbols to reduce computation
symbol_data_cache = {}
def get_cached_data(symbol: str) -> pd.DataFrame:
"""Get or generate cached data for symbol."""
if symbol not in symbol_data_cache:
# Generate deterministic data based on symbol hash
seed = hash(symbol) % 1000
np.random.seed(seed)
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
returns = np.random.normal(0.001, 0.02, len(dates))
prices = 100 * np.cumprod(1 + returns)
symbol_data_cache[symbol] = pd.DataFrame(
{
"Open": prices * np.random.uniform(0.99, 1.01, len(dates)),
"High": prices * np.random.uniform(1.01, 1.03, len(dates)),
"Low": prices * np.random.uniform(0.97, 0.99, len(dates)),
"Close": prices,
"Volume": np.random.randint(1000000, 10000000, len(dates)),
"Adj Close": prices,
},
index=dates,
)
# Ensure OHLC constraints
data = symbol_data_cache[symbol]
data["High"] = np.maximum(
data["High"], np.maximum(data["Open"], data["Close"])
)
data["Low"] = np.minimum(
data["Low"], np.minimum(data["Open"], data["Close"])
)
return symbol_data_cache[symbol].copy()
provider.get_stock_data.side_effect = get_cached_data
return provider
async def test_concurrent_users_10(self, optimized_data_provider, benchmark_timer):
"""Test load with 10 concurrent users."""
load_runner = LoadTestRunner(optimized_data_provider)
session_config = {
"symbols": ["AAPL", "GOOGL"],
"strategies": ["sma_cross", "rsi"],
"think_time": (0.1, 0.5), # Faster think time for testing
}
with benchmark_timer():
result = await load_runner.run_load_test(
concurrent_users=10, session_config=session_config, duration_seconds=30
)
# Performance assertions for 10 users
assert result.requests_per_second >= 2.0, (
f"Throughput too low: {result.requests_per_second:.1f} req/s"
)
assert result.avg_response_time <= 5.0, (
f"Response time too high: {result.avg_response_time:.2f}s"
)
assert result.p95_response_time <= 10.0, (
f"95th percentile too high: {result.p95_response_time:.2f}s"
)
assert result.successful_requests / result.total_requests >= 0.9, (
"Success rate too low"
)
assert result.memory_usage_mb <= 500, (
f"Memory usage too high: {result.memory_usage_mb:.1f}MB"
)
return result
async def test_concurrent_users_50(self, optimized_data_provider, benchmark_timer):
"""Test load with 50 concurrent users."""
load_runner = LoadTestRunner(optimized_data_provider)
session_config = {
"symbols": ["AAPL", "MSFT", "GOOGL"],
"strategies": ["sma_cross", "rsi", "macd"],
"think_time": (0.2, 1.0),
}
with benchmark_timer():
result = await load_runner.run_load_test(
concurrent_users=50, session_config=session_config, duration_seconds=60
)
# Performance assertions for 50 users
assert result.requests_per_second >= 5.0, (
f"Throughput too low: {result.requests_per_second:.1f} req/s"
)
assert result.avg_response_time <= 8.0, (
f"Response time too high: {result.avg_response_time:.2f}s"
)
assert result.p95_response_time <= 15.0, (
f"95th percentile too high: {result.p95_response_time:.2f}s"
)
assert result.successful_requests / result.total_requests >= 0.85, (
"Success rate too low"
)
assert result.memory_usage_mb <= 1000, (
f"Memory usage too high: {result.memory_usage_mb:.1f}MB"
)
return result
async def test_concurrent_users_100(self, optimized_data_provider, benchmark_timer):
"""Test load with 100 concurrent users."""
load_runner = LoadTestRunner(optimized_data_provider)
session_config = {
"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN"],
"strategies": ["sma_cross", "rsi"], # Reduced strategies for higher load
"think_time": (0.5, 1.5),
}
with benchmark_timer():
result = await load_runner.run_load_test(
concurrent_users=100, session_config=session_config, duration_seconds=90
)
# More relaxed performance assertions for 100 users
assert result.requests_per_second >= 3.0, (
f"Throughput too low: {result.requests_per_second:.1f} req/s"
)
assert result.avg_response_time <= 15.0, (
f"Response time too high: {result.avg_response_time:.2f}s"
)
assert result.p95_response_time <= 30.0, (
f"95th percentile too high: {result.p95_response_time:.2f}s"
)
assert result.successful_requests / result.total_requests >= 0.8, (
"Success rate too low"
)
assert result.memory_usage_mb <= 2000, (
f"Memory usage too high: {result.memory_usage_mb:.1f}MB"
)
return result
async def test_load_scalability_analysis(self, optimized_data_provider):
"""Analyze how performance scales with user load."""
load_runner = LoadTestRunner(optimized_data_provider)
session_config = {
"symbols": ["AAPL", "GOOGL"],
"strategies": ["sma_cross"],
"think_time": (0.3, 0.7),
}
user_loads = [5, 10, 20, 40]
scalability_results = []
for user_count in user_loads:
logger.info(f"Testing scalability with {user_count} users")
result = await load_runner.run_load_test(
concurrent_users=user_count,
session_config=session_config,
duration_seconds=30,
)
scalability_results.append(result)
# Analyze scalability metrics
throughput_efficiency = []
response_time_degradation = []
baseline_rps = scalability_results[0].requests_per_second
baseline_response_time = scalability_results[0].avg_response_time
for i, result in enumerate(scalability_results):
expected_rps = baseline_rps * user_loads[i] / user_loads[0]
actual_efficiency = (
result.requests_per_second / expected_rps if expected_rps > 0 else 0
)
throughput_efficiency.append(actual_efficiency)
response_degradation = (
result.avg_response_time / baseline_response_time
if baseline_response_time > 0
else 1
)
response_time_degradation.append(response_degradation)
logger.info(
f"Scalability Analysis ({user_loads[i]} users):\n"
f" • RPS: {result.requests_per_second:.2f}\n"
f" • RPS Efficiency: {actual_efficiency:.2%}\n"
f" • Response Time: {result.avg_response_time:.3f}s\n"
f" • Response Degradation: {response_degradation:.2f}x\n"
f" • Memory: {result.memory_usage_mb:.1f}MB"
)
# Scalability assertions
avg_efficiency = statistics.mean(throughput_efficiency)
max_response_degradation = max(response_time_degradation)
assert avg_efficiency >= 0.5, (
f"Average throughput efficiency too low: {avg_efficiency:.2%}"
)
assert max_response_degradation <= 5.0, (
f"Response time degradation too high: {max_response_degradation:.1f}x"
)
return {
"user_loads": user_loads,
"results": scalability_results,
"throughput_efficiency": throughput_efficiency,
"response_time_degradation": response_time_degradation,
"avg_efficiency": avg_efficiency,
}
async def test_sustained_load_stability(self, optimized_data_provider):
"""Test stability under sustained load."""
load_runner = LoadTestRunner(optimized_data_provider)
session_config = {
"symbols": ["AAPL", "MSFT"],
"strategies": ["sma_cross", "rsi"],
"think_time": (0.5, 1.0),
}
# Run sustained load for longer duration
result = await load_runner.run_load_test(
concurrent_users=25,
session_config=session_config,
duration_seconds=300, # 5 minutes
)
# Stability assertions
assert result.errors_per_second <= 0.1, (
f"Error rate too high: {result.errors_per_second:.3f} err/s"
)
assert result.successful_requests / result.total_requests >= 0.95, (
"Success rate degraded over time"
)
assert result.memory_usage_mb <= 800, (
f"Memory usage grew too much: {result.memory_usage_mb:.1f}MB"
)
# Check for performance consistency (no significant degradation)
assert result.p99_response_time / result.p50_response_time <= 5.0, (
"Response time variance too high"
)
logger.info(
f"Sustained Load Results (25 users, 5 minutes):\n"
f" • Total Requests: {result.total_requests}\n"
f" • Success Rate: {result.successful_requests / result.total_requests * 100:.2f}%\n"
f" • Avg Throughput: {result.requests_per_second:.2f} req/s\n"
f" • Response Time (50/95/99): {result.p50_response_time:.2f}s/"
f"{result.p95_response_time:.2f}s/{result.p99_response_time:.2f}s\n"
f" • Memory Growth: {result.memory_usage_mb:.1f}MB\n"
f" • Error Rate: {result.errors_per_second:.4f} err/s"
)
return result
async def test_database_connection_pooling_under_load(
self, optimized_data_provider, db_session
):
"""Test database connection pooling under concurrent load."""
# Generate backtest results to save to database
engine = VectorBTEngine(data_provider=optimized_data_provider)
test_symbols = ["DB_LOAD_1", "DB_LOAD_2", "DB_LOAD_3"]
# Pre-generate results for database testing
backtest_results = []
for symbol in test_symbols:
result = await engine.run_backtest(
symbol=symbol,
strategy_type="sma_cross",
parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
)
backtest_results.append(result)
# Test concurrent database operations
async def concurrent_database_operations(operation_id: int) -> dict[str, Any]:
"""Simulate concurrent database save/retrieve operations."""
start_time = time.time()
operations_completed = 0
errors = []
try:
with BacktestPersistenceManager(session=db_session) as persistence:
# Save operations
for result in backtest_results:
try:
backtest_id = persistence.save_backtest_result(
vectorbt_results=result,
execution_time=2.0,
notes=f"Load test operation {operation_id}",
)
operations_completed += 1
# Retrieve operation
retrieved = persistence.get_backtest_by_id(backtest_id)
if retrieved:
operations_completed += 1
except Exception as e:
errors.append(str(e))
except Exception as e:
errors.append(f"Session error: {str(e)}")
operation_time = time.time() - start_time
return {
"operation_id": operation_id,
"operations_completed": operations_completed,
"errors": errors,
"operation_time": operation_time,
}
# Run concurrent database operations
concurrent_operations = 20
db_tasks = [
concurrent_database_operations(i) for i in range(concurrent_operations)
]
start_time = time.time()
db_results = await asyncio.gather(*db_tasks, return_exceptions=True)
total_time = time.time() - start_time
# Analyze database performance under load
successful_operations = [r for r in db_results if isinstance(r, dict)]
failed_operations = len(db_results) - len(successful_operations)
total_operations = sum(r["operations_completed"] for r in successful_operations)
total_errors = sum(len(r["errors"]) for r in successful_operations)
avg_operation_time = statistics.mean(
[r["operation_time"] for r in successful_operations]
)
db_throughput = total_operations / total_time if total_time > 0 else 0
error_rate = total_errors / total_operations if total_operations > 0 else 0
logger.info(
f"Database Load Test Results:\n"
f" • Concurrent Operations: {concurrent_operations}\n"
f" • Successful Sessions: {len(successful_operations)}\n"
f" • Failed Sessions: {failed_operations}\n"
f" • Total DB Operations: {total_operations}\n"
f" • DB Throughput: {db_throughput:.2f} ops/s\n"
f" • Error Rate: {error_rate:.3%}\n"
f" • Avg Operation Time: {avg_operation_time:.3f}s"
)
# Database performance assertions
assert len(successful_operations) / len(db_results) >= 0.9, (
"DB session success rate too low"
)
assert error_rate <= 0.05, f"DB error rate too high: {error_rate:.3%}"
assert db_throughput >= 5.0, f"DB throughput too low: {db_throughput:.2f} ops/s"
return {
"concurrent_operations": concurrent_operations,
"db_throughput": db_throughput,
"error_rate": error_rate,
"avg_operation_time": avg_operation_time,
}
if __name__ == "__main__":
# Run load testing suite
pytest.main(
[
__file__,
"-v",
"--tb=short",
"--asyncio-mode=auto",
"--timeout=600", # 10 minute timeout for load tests
"--durations=10",
]
)
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/strategies/ml/ensemble.py:
--------------------------------------------------------------------------------
```python
"""Strategy ensemble methods for combining multiple trading strategies."""
import logging
from typing import Any
import numpy as np
import pandas as pd
from pandas import DataFrame, Series
from maverick_mcp.backtesting.strategies.base import Strategy
logger = logging.getLogger(__name__)
class StrategyEnsemble(Strategy):
"""Ensemble strategy that combines multiple strategies with dynamic weighting."""
def __init__(
self,
strategies: list[Strategy],
weighting_method: str = "performance",
lookback_period: int = 50,
rebalance_frequency: int = 20,
parameters: dict[str, Any] = None,
):
"""Initialize strategy ensemble.
Args:
strategies: List of base strategies to combine
weighting_method: Method for calculating weights ('performance', 'equal', 'volatility')
lookback_period: Period for calculating performance metrics
rebalance_frequency: How often to update weights
parameters: Additional parameters
"""
super().__init__(parameters)
self.strategies = strategies
self.weighting_method = weighting_method
self.lookback_period = lookback_period
self.rebalance_frequency = rebalance_frequency
# Initialize strategy weights
self.weights = np.ones(len(strategies)) / len(strategies)
self.strategy_returns = {}
self.strategy_signals = {}
self.last_rebalance = 0
@property
def name(self) -> str:
"""Get strategy name."""
strategy_names = [s.name for s in self.strategies]
return f"Ensemble({','.join(strategy_names)})"
@property
def description(self) -> str:
"""Get strategy description."""
return f"Dynamic ensemble combining {len(self.strategies)} strategies using {self.weighting_method} weighting"
def calculate_performance_weights(self, data: DataFrame) -> np.ndarray:
"""Calculate performance-based weights for strategies.
Args:
data: Price data for performance calculation
Returns:
Array of strategy weights
"""
if len(self.strategy_returns) < 2:
return self.weights
# Calculate Sharpe ratios for each strategy
sharpe_ratios = []
for i, _strategy in enumerate(self.strategies):
if (
i in self.strategy_returns
and len(self.strategy_returns[i]) >= self.lookback_period
):
returns = pd.Series(self.strategy_returns[i][-self.lookback_period :])
sharpe = returns.mean() / (returns.std() + 1e-8) * np.sqrt(252)
sharpe_ratios.append(max(0, sharpe)) # Ensure non-negative
else:
sharpe_ratios.append(0.1) # Small positive weight for new strategies
# Convert to weights (softmax-like normalization)
sharpe_array = np.array(sharpe_ratios)
# Fix: Properly check for empty array and zero sum conditions
if sharpe_array.size == 0 or np.sum(sharpe_array) == 0:
weights = np.ones(len(self.strategies)) / len(self.strategies)
else:
# Exponential weighting to emphasize better performers
exp_sharpe = np.exp(sharpe_array * 2)
weights = exp_sharpe / exp_sharpe.sum()
return weights
def calculate_volatility_weights(self, data: DataFrame) -> np.ndarray:
"""Calculate inverse volatility weights for strategies.
Args:
data: Price data for volatility calculation
Returns:
Array of strategy weights
"""
if len(self.strategy_returns) < 2:
return self.weights
# Calculate volatilities for each strategy
volatilities = []
for i, _strategy in enumerate(self.strategies):
if (
i in self.strategy_returns
and len(self.strategy_returns[i]) >= self.lookback_period
):
returns = pd.Series(self.strategy_returns[i][-self.lookback_period :])
vol = returns.std() * np.sqrt(252)
volatilities.append(max(0.01, vol)) # Minimum volatility
else:
volatilities.append(0.2) # Default volatility assumption
# Inverse volatility weighting
vol_array = np.array(volatilities)
inv_vol = 1.0 / vol_array
weights = inv_vol / inv_vol.sum()
return weights
def update_weights(self, data: DataFrame, current_index: int) -> None:
"""Update strategy weights based on recent performance.
Args:
data: Price data
current_index: Current position in data
"""
# Check if it's time to rebalance
if current_index - self.last_rebalance < self.rebalance_frequency:
return
try:
if self.weighting_method == "performance":
self.weights = self.calculate_performance_weights(data)
elif self.weighting_method == "volatility":
self.weights = self.calculate_volatility_weights(data)
elif self.weighting_method == "equal":
self.weights = np.ones(len(self.strategies)) / len(self.strategies)
else:
logger.warning(f"Unknown weighting method: {self.weighting_method}")
self.last_rebalance = current_index
logger.debug(
f"Updated ensemble weights: {dict(zip([s.name for s in self.strategies], self.weights, strict=False))}"
)
except Exception as e:
logger.error(f"Error updating weights: {e}")
def generate_individual_signals(
self, data: DataFrame
) -> dict[int, tuple[Series, Series]]:
"""Generate signals from all individual strategies with enhanced error handling.
Args:
data: Price data
Returns:
Dictionary mapping strategy index to (entry_signals, exit_signals)
"""
signals = {}
failed_strategies = []
for i, strategy in enumerate(self.strategies):
try:
# Generate signals with timeout protection
entry_signals, exit_signals = strategy.generate_signals(data)
# Validate signals
if not isinstance(entry_signals, pd.Series) or not isinstance(
exit_signals, pd.Series
):
raise ValueError(
f"Strategy {strategy.name} returned invalid signal types"
)
if len(entry_signals) != len(data) or len(exit_signals) != len(data):
raise ValueError(
f"Strategy {strategy.name} returned signals with wrong length"
)
if not entry_signals.dtype == bool or not exit_signals.dtype == bool:
# Convert to boolean if necessary
entry_signals = entry_signals.astype(bool)
exit_signals = exit_signals.astype(bool)
signals[i] = (entry_signals, exit_signals)
# Calculate strategy returns for weight updates (with error handling)
try:
positions = entry_signals.astype(int) - exit_signals.astype(int)
price_returns = data["close"].pct_change()
returns = positions.shift(1) * price_returns
# Remove invalid returns
valid_returns = returns.dropna()
valid_returns = valid_returns[np.isfinite(valid_returns)]
if i not in self.strategy_returns:
self.strategy_returns[i] = []
if len(valid_returns) > 0:
self.strategy_returns[i].extend(valid_returns.tolist())
# Keep only recent returns for performance calculation
if len(self.strategy_returns[i]) > self.lookback_period * 2:
self.strategy_returns[i] = self.strategy_returns[i][
-self.lookback_period * 2 :
]
except Exception as return_error:
logger.debug(
f"Error calculating returns for strategy {strategy.name}: {return_error}"
)
logger.debug(
f"Strategy {strategy.name}: {entry_signals.sum()} entries, {exit_signals.sum()} exits"
)
except Exception as e:
logger.error(
f"Error generating signals for strategy {strategy.name}: {e}"
)
failed_strategies.append(i)
# Create safe fallback signals
try:
signals[i] = (
pd.Series(False, index=data.index),
pd.Series(False, index=data.index),
)
except Exception:
# If even creating empty signals fails, skip this strategy
logger.error(f"Cannot create fallback signals for strategy {i}")
continue
# Log summary of strategy performance
if failed_strategies:
failed_names = [self.strategies[i].name for i in failed_strategies]
logger.warning(f"Failed strategies: {failed_names}")
successful_strategies = len(signals) - len(failed_strategies)
logger.info(
f"Successfully generated signals from {successful_strategies}/{len(self.strategies)} strategies"
)
return signals
def combine_signals(
self, individual_signals: dict[int, tuple[Series, Series]]
) -> tuple[Series, Series]:
"""Combine individual strategy signals using enhanced weighted voting.
Args:
individual_signals: Dictionary of individual strategy signals
Returns:
Tuple of combined (entry_signals, exit_signals)
"""
if not individual_signals:
# Return empty series with minimal index when no individual signals available
empty_index = pd.Index([])
return pd.Series(False, index=empty_index), pd.Series(
False, index=empty_index
)
# Get data index from first strategy
first_signals = next(iter(individual_signals.values()))
data_index = first_signals[0].index
# Initialize voting arrays
entry_votes = np.zeros(len(data_index))
exit_votes = np.zeros(len(data_index))
total_weights = 0
# Collect votes with weights and confidence scores
valid_strategies = 0
for i, (entry_signals, exit_signals) in individual_signals.items():
weight = self.weights[i] if i < len(self.weights) else 0
if weight > 0:
# Add weighted votes
entry_votes += weight * entry_signals.astype(float)
exit_votes += weight * exit_signals.astype(float)
total_weights += weight
valid_strategies += 1
if total_weights == 0 or valid_strategies == 0:
logger.warning("No valid strategies with positive weights")
return pd.Series(False, index=data_index), pd.Series(
False, index=data_index
)
# Normalize votes by total weights
entry_votes = entry_votes / total_weights
exit_votes = exit_votes / total_weights
# Enhanced voting mechanisms
voting_method = self.parameters.get("voting_method", "weighted")
if voting_method == "majority":
# Simple majority vote (more than half of strategies agree)
entry_threshold = 0.5
exit_threshold = 0.5
elif voting_method == "supermajority":
# Require 2/3 agreement
entry_threshold = 0.67
exit_threshold = 0.67
elif voting_method == "consensus":
# Require near-unanimous agreement
entry_threshold = 0.8
exit_threshold = 0.8
else: # weighted (default)
entry_threshold = self.parameters.get("entry_threshold", 0.5)
exit_threshold = self.parameters.get("exit_threshold", 0.5)
# Anti-conflict mechanism: don't signal entry and exit simultaneously
combined_entry = entry_votes > entry_threshold
combined_exit = exit_votes > exit_threshold
# Resolve conflicts (simultaneous entry and exit signals)
conflicts = combined_entry & combined_exit
# Fix: Check array size and ensure it's not empty before evaluating boolean truth
if conflicts.size > 0 and np.any(conflicts):
logger.debug(f"Resolving {conflicts.sum()} signal conflicts")
# In case of conflict, use the stronger signal
entry_strength = entry_votes[conflicts]
exit_strength = exit_votes[conflicts]
# Keep only the stronger signal
stronger_entry = entry_strength > exit_strength
combined_entry[conflicts] = stronger_entry
combined_exit[conflicts] = ~stronger_entry
# Quality filter: require minimum signal strength
min_signal_strength = self.parameters.get("min_signal_strength", 0.1)
weak_entry_signals = (combined_entry) & (entry_votes < min_signal_strength)
weak_exit_signals = (combined_exit) & (exit_votes < min_signal_strength)
# Fix: Ensure arrays are not empty before boolean indexing
if weak_entry_signals.size > 0:
combined_entry[weak_entry_signals] = False
if weak_exit_signals.size > 0:
combined_exit[weak_exit_signals] = False
# Convert to pandas Series
combined_entry = pd.Series(combined_entry, index=data_index)
combined_exit = pd.Series(combined_exit, index=data_index)
return combined_entry, combined_exit
def generate_signals(self, data: DataFrame) -> tuple[Series, Series]:
"""Generate ensemble trading signals.
Args:
data: Price data with OHLCV columns
Returns:
Tuple of (entry_signals, exit_signals) as boolean Series
"""
try:
# Generate signals from all individual strategies
individual_signals = self.generate_individual_signals(data)
if not individual_signals:
return pd.Series(False, index=data.index), pd.Series(
False, index=data.index
)
# Update weights periodically
for idx in range(
self.rebalance_frequency, len(data), self.rebalance_frequency
):
self.update_weights(data.iloc[:idx], idx)
# Combine signals
entry_signals, exit_signals = self.combine_signals(individual_signals)
logger.info(
f"Generated ensemble signals: {entry_signals.sum()} entries, {exit_signals.sum()} exits"
)
return entry_signals, exit_signals
except Exception as e:
logger.error(f"Error generating ensemble signals: {e}")
return pd.Series(False, index=data.index), pd.Series(
False, index=data.index
)
def get_strategy_weights(self) -> dict[str, float]:
"""Get current strategy weights.
Returns:
Dictionary mapping strategy names to weights
"""
return dict(zip([s.name for s in self.strategies], self.weights, strict=False))
def get_strategy_performance(self) -> dict[str, dict[str, float]]:
"""Get performance metrics for individual strategies.
Returns:
Dictionary mapping strategy names to performance metrics
"""
performance = {}
for i, strategy in enumerate(self.strategies):
if i in self.strategy_returns and len(self.strategy_returns[i]) > 0:
returns = pd.Series(self.strategy_returns[i])
performance[strategy.name] = {
"total_return": returns.sum(),
"annual_return": returns.mean() * 252,
"volatility": returns.std() * np.sqrt(252),
"sharpe_ratio": returns.mean()
/ (returns.std() + 1e-8)
* np.sqrt(252),
"max_drawdown": (
returns.cumsum() - returns.cumsum().expanding().max()
).min(),
"win_rate": (returns > 0).mean(),
"current_weight": self.weights[i],
}
else:
performance[strategy.name] = {
"total_return": 0.0,
"annual_return": 0.0,
"volatility": 0.0,
"sharpe_ratio": 0.0,
"max_drawdown": 0.0,
"win_rate": 0.0,
"current_weight": self.weights[i] if i < len(self.weights) else 0.0,
}
return performance
def validate_parameters(self) -> bool:
"""Validate ensemble parameters.
Returns:
True if parameters are valid
"""
if not self.strategies:
return False
if self.weighting_method not in ["performance", "equal", "volatility"]:
return False
if self.lookback_period <= 0 or self.rebalance_frequency <= 0:
return False
# Validate individual strategies
for strategy in self.strategies:
if not strategy.validate_parameters():
return False
return True
def get_default_parameters(self) -> dict[str, Any]:
"""Get default ensemble parameters.
Returns:
Dictionary of default parameters
"""
return {
"weighting_method": "performance",
"lookback_period": 50,
"rebalance_frequency": 20,
"entry_threshold": 0.5,
"exit_threshold": 0.5,
"voting_method": "weighted", # weighted, majority, supermajority, consensus
"min_signal_strength": 0.1, # Minimum signal strength to avoid weak signals
"conflict_resolution": "stronger", # How to resolve entry/exit conflicts
}
def to_dict(self) -> dict[str, Any]:
"""Convert ensemble to dictionary representation.
Returns:
Dictionary with ensemble details
"""
base_dict = super().to_dict()
base_dict.update(
{
"strategies": [s.to_dict() for s in self.strategies],
"current_weights": self.get_strategy_weights(),
"weighting_method": self.weighting_method,
"lookback_period": self.lookback_period,
"rebalance_frequency": self.rebalance_frequency,
}
)
return base_dict
class RiskAdjustedEnsemble(StrategyEnsemble):
"""Risk-adjusted ensemble with position sizing and risk management."""
def __init__(
self,
strategies: list[Strategy],
max_position_size: float = 0.1,
max_correlation: float = 0.7,
risk_target: float = 0.15,
**kwargs,
):
"""Initialize risk-adjusted ensemble.
Args:
strategies: List of base strategies
max_position_size: Maximum position size per strategy
max_correlation: Maximum correlation between strategies
risk_target: Target portfolio volatility
**kwargs: Additional parameters for base ensemble
"""
super().__init__(strategies, **kwargs)
self.max_position_size = max_position_size
self.max_correlation = max_correlation
self.risk_target = risk_target
def calculate_correlation_matrix(self) -> pd.DataFrame:
"""Calculate correlation matrix between strategy returns.
Returns:
Correlation matrix as DataFrame
"""
if len(self.strategy_returns) < 2:
return pd.DataFrame()
# Create returns DataFrame
min_length = min(
len(returns)
for returns in self.strategy_returns.values()
if len(returns) > 0
)
if min_length == 0:
return pd.DataFrame()
returns_data = {}
for i, strategy in enumerate(self.strategies):
if (
i in self.strategy_returns
and len(self.strategy_returns[i]) >= min_length
):
returns_data[strategy.name] = self.strategy_returns[i][-min_length:]
if not returns_data:
return pd.DataFrame()
returns_df = pd.DataFrame(returns_data)
return returns_df.corr()
def adjust_weights_for_correlation(self, weights: np.ndarray) -> np.ndarray:
"""Adjust weights to account for strategy correlation.
Args:
weights: Original weights
Returns:
Correlation-adjusted weights
"""
corr_matrix = self.calculate_correlation_matrix()
if corr_matrix.empty:
return weights
try:
# Penalize highly correlated strategies
adjusted_weights = weights.copy()
for i, strategy_i in enumerate(self.strategies):
for j, strategy_j in enumerate(self.strategies):
if (
i != j
and strategy_i.name in corr_matrix.index
and strategy_j.name in corr_matrix.columns
):
correlation = abs(
corr_matrix.loc[strategy_i.name, strategy_j.name]
)
if correlation > self.max_correlation:
# Reduce weight of both strategies
penalty = (correlation - self.max_correlation) * 0.5
adjusted_weights[i] *= 1 - penalty
adjusted_weights[j] *= 1 - penalty
# Renormalize weights
# Fix: Check array size and sum properly before normalization
if adjusted_weights.size > 0 and np.sum(adjusted_weights) > 0:
adjusted_weights /= adjusted_weights.sum()
else:
adjusted_weights = np.ones(len(self.strategies)) / len(self.strategies)
return adjusted_weights
except Exception as e:
logger.error(f"Error adjusting weights for correlation: {e}")
return weights
def calculate_risk_adjusted_weights(self, data: DataFrame) -> np.ndarray:
"""Calculate risk-adjusted weights based on target volatility.
Args:
data: Price data
Returns:
Risk-adjusted weights
"""
# Start with performance-based weights
base_weights = self.calculate_performance_weights(data)
# Adjust for correlation
corr_adjusted_weights = self.adjust_weights_for_correlation(base_weights)
# Apply position size limits
position_adjusted_weights = np.minimum(
corr_adjusted_weights, self.max_position_size
)
# Renormalize
# Fix: Check array size and sum properly before normalization
if position_adjusted_weights.size > 0 and np.sum(position_adjusted_weights) > 0:
position_adjusted_weights /= position_adjusted_weights.sum()
else:
position_adjusted_weights = np.ones(len(self.strategies)) / len(
self.strategies
)
return position_adjusted_weights
def update_weights(self, data: DataFrame, current_index: int) -> None:
"""Update risk-adjusted weights.
Args:
data: Price data
current_index: Current position in data
"""
if current_index - self.last_rebalance < self.rebalance_frequency:
return
try:
self.weights = self.calculate_risk_adjusted_weights(data)
self.last_rebalance = current_index
logger.debug(
f"Updated risk-adjusted weights: {dict(zip([s.name for s in self.strategies], self.weights, strict=False))}"
)
except Exception as e:
logger.error(f"Error updating risk-adjusted weights: {e}")
@property
def name(self) -> str:
"""Get strategy name."""
return f"RiskAdjusted{super().name}"
@property
def description(self) -> str:
"""Get strategy description."""
return "Risk-adjusted ensemble with correlation control and position sizing"
```
--------------------------------------------------------------------------------
/maverick_mcp/agents/technical_analysis.py:
--------------------------------------------------------------------------------
```python
"""
Technical Analysis Agent with pattern recognition and multi-timeframe analysis.
"""
import logging
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.tools import BaseTool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from maverick_mcp.agents.circuit_breaker import circuit_manager
from maverick_mcp.langchain_tools import get_tool_registry
from maverick_mcp.memory import ConversationStore
from maverick_mcp.tools.risk_management import TechnicalStopsTool
from maverick_mcp.workflows.state import TechnicalAnalysisState
from .base import PersonaAwareAgent
logger = logging.getLogger(__name__)
class TechnicalAnalysisAgent(PersonaAwareAgent):
"""
Professional technical analysis agent with pattern recognition.
Features:
- Chart pattern detection (head & shoulders, triangles, flags)
- Multi-timeframe analysis
- Indicator confluence scoring
- Support/resistance clustering
- Volume profile analysis
- LLM-powered technical narratives
"""
def __init__(
self,
llm,
persona: str = "moderate",
ttl_hours: int = 1,
):
"""
Initialize technical analysis agent.
Args:
llm: Language model
persona: Investor persona
ttl_hours: Cache TTL in hours
postgres_url: Optional PostgreSQL URL for checkpointing
"""
# Store persona temporarily for tool configuration
self._temp_persona = persona
# Get technical analysis tools
tools = self._get_technical_tools()
# Initialize with MemorySaver
super().__init__(
llm=llm,
tools=tools,
persona=persona,
checkpointer=MemorySaver(),
ttl_hours=ttl_hours,
)
# Initialize conversation store
self.conversation_store = ConversationStore(ttl_hours=ttl_hours)
def _get_technical_tools(self) -> list[BaseTool]:
"""Get comprehensive technical analysis tools."""
registry = get_tool_registry()
# Core technical tools
technical_tools = [
registry.get_tool("get_technical_indicators"),
registry.get_tool("calculate_support_resistance"),
registry.get_tool("detect_chart_patterns"),
registry.get_tool("calculate_moving_averages"),
registry.get_tool("calculate_oscillators"),
]
# Price action tools
price_tools = [
registry.get_tool("get_stock_price"),
registry.get_tool("get_stock_history"),
registry.get_tool("get_intraday_data"),
]
# Volume analysis tools
volume_tools = [
registry.get_tool("analyze_volume_profile"),
registry.get_tool("detect_volume_patterns"),
]
# Risk tools
risk_tools = [
TechnicalStopsTool(),
]
# Combine and filter
all_tools = technical_tools + price_tools + volume_tools + risk_tools
tools = [t for t in all_tools if t is not None]
# Configure persona for PersonaAwareTools
for tool in tools:
if hasattr(tool, "set_persona"):
tool.set_persona(self._temp_persona)
if not tools:
logger.warning("No technical tools available, using mock tools")
tools = self._create_mock_tools()
return tools
def get_state_schema(self) -> type:
"""Return enhanced state schema for technical analysis."""
return TechnicalAnalysisState
def _build_system_prompt(self) -> str:
"""Build comprehensive system prompt for technical analysis."""
base_prompt = super()._build_system_prompt()
technical_prompt = f"""
You are a professional technical analyst specializing in pattern recognition and multi-timeframe analysis.
Current date: {datetime.now().strftime("%Y-%m-%d")}
## Core Responsibilities:
1. **Pattern Recognition**:
- Chart patterns: Head & Shoulders, Triangles, Flags, Wedges
- Candlestick patterns: Doji, Hammer, Engulfing, etc.
- Support/Resistance: Dynamic and static levels
- Trend lines and channels
2. **Multi-Timeframe Analysis**:
- Align signals across daily, hourly, and 5-minute charts
- Identify confluences between timeframes
- Spot divergences early
- Time entries based on lower timeframe setups
3. **Indicator Analysis**:
- Trend: Moving averages, ADX, MACD
- Momentum: RSI, Stochastic, CCI
- Volume: OBV, Volume Profile, VWAP
- Volatility: Bollinger Bands, ATR, Keltner Channels
4. **Trade Setup Construction**:
- Entry points with specific triggers
- Stop loss placement using ATR or structure
- Profit targets based on measured moves
- Risk/Reward ratio calculation
## Analysis Framework by Persona:
**Conservative ({self.persona.name if self.persona.name == "Conservative" else "N/A"})**:
- Wait for confirmed patterns only
- Use wider stops above/below structure
- Target 1.5:1 risk/reward minimum
- Focus on daily/weekly timeframes
**Moderate ({self.persona.name if self.persona.name == "Moderate" else "N/A"})**:
- Balance pattern quality with opportunity
- Standard ATR-based stops
- Target 2:1 risk/reward
- Use daily/4H timeframes
**Aggressive ({self.persona.name if self.persona.name == "Aggressive" else "N/A"})**:
- Trade emerging patterns
- Tighter stops for larger positions
- Target 3:1+ risk/reward
- Include intraday timeframes
**Day Trader ({self.persona.name if self.persona.name == "Day Trader" else "N/A"})**:
- Focus on intraday patterns
- Use tick/volume charts
- Quick scalps with tight stops
- Multiple entries/exits
## Technical Analysis Process:
1. **Market Structure**: Identify trend direction and strength
2. **Key Levels**: Map support/resistance zones
3. **Pattern Search**: Scan for actionable patterns
4. **Indicator Confluence**: Check for agreement
5. **Volume Confirmation**: Validate with volume
6. **Risk Definition**: Calculate stops and targets
7. **Setup Quality**: Rate A+ to C based on confluence
Remember to:
- Be specific with price levels
- Explain pattern psychology
- Highlight invalidation levels
- Consider market context
- Provide clear action plans
"""
return base_prompt + technical_prompt
def _build_graph(self):
"""Build enhanced graph with technical analysis nodes."""
workflow = StateGraph(TechnicalAnalysisState)
# Add specialized nodes with unique names
workflow.add_node("analyze_structure", self._analyze_structure)
workflow.add_node("detect_patterns", self._detect_patterns)
workflow.add_node("analyze_indicators", self._analyze_indicators)
workflow.add_node("construct_trade_setup", self._construct_trade_setup)
workflow.add_node("agent", self._agent_node)
# Create tool node if tools available
if self.tools:
from langgraph.prebuilt import ToolNode
tool_node = ToolNode(self.tools)
workflow.add_node("tools", tool_node)
# Define flow
workflow.add_edge(START, "analyze_structure")
workflow.add_edge("analyze_structure", "detect_patterns")
workflow.add_edge("detect_patterns", "analyze_indicators")
workflow.add_edge("analyze_indicators", "construct_trade_setup")
workflow.add_edge("construct_trade_setup", "agent")
if self.tools:
workflow.add_conditional_edges(
"agent",
self._should_continue,
{
"continue": "tools",
"end": END,
},
)
workflow.add_edge("tools", "agent")
else:
workflow.add_edge("agent", END)
return workflow.compile(checkpointer=self.checkpointer)
async def _analyze_structure(self, state: TechnicalAnalysisState) -> dict[str, Any]:
"""Analyze market structure and identify key levels."""
try:
# Get support/resistance tool
sr_tool = next(
(t for t in self.tools if "support_resistance" in t.name), None
)
if sr_tool and state.get("symbol"):
circuit_breaker = await circuit_manager.get_or_create("technical")
async def get_levels():
return await sr_tool.ainvoke(
{
"symbol": state["symbol"],
"lookback_days": state.get("lookback_days", 20),
}
)
levels_data = await circuit_breaker.call(get_levels)
# Extract support/resistance levels
if isinstance(levels_data, dict):
state["support_levels"] = levels_data.get("support_levels", [])
state["resistance_levels"] = levels_data.get(
"resistance_levels", []
)
# Determine trend based on structure
if levels_data.get("trend"):
state["trend_direction"] = levels_data["trend"]
else:
# Simple trend determination
current = state.get("current_price", 0)
ma_50 = levels_data.get("ma_50", current)
state["trend_direction"] = (
"bullish" if current > ma_50 else "bearish"
)
except Exception as e:
logger.error(f"Error analyzing structure: {e}")
state["api_calls_made"] = state.get("api_calls_made", 0) + 1
return {
"support_levels": state.get("support_levels", []),
"resistance_levels": state.get("resistance_levels", []),
"trend_direction": state.get("trend_direction", "neutral"),
}
async def _detect_patterns(self, state: TechnicalAnalysisState) -> dict[str, Any]:
"""Detect chart patterns."""
try:
# Get pattern detection tool
pattern_tool = next((t for t in self.tools if "pattern" in t.name), None)
if pattern_tool and state.get("symbol"):
circuit_breaker = await circuit_manager.get_or_create("technical")
async def detect():
return await pattern_tool.ainvoke(
{
"symbol": state["symbol"],
"timeframe": state.get("timeframe", "1d"),
}
)
pattern_data = await circuit_breaker.call(detect)
if isinstance(pattern_data, dict) and "patterns" in pattern_data:
patterns = pattern_data["patterns"]
state["patterns"] = patterns
# Calculate pattern confidence scores
pattern_confidence = {}
for pattern in patterns:
name = pattern.get("name", "Unknown")
confidence = pattern.get("confidence", 50)
pattern_confidence[name] = confidence
state["pattern_confidence"] = pattern_confidence
except Exception as e:
logger.error(f"Error detecting patterns: {e}")
state["api_calls_made"] = state.get("api_calls_made", 0) + 1
return {
"patterns": state.get("patterns", []),
"pattern_confidence": state.get("pattern_confidence", {}),
}
async def _analyze_indicators(
self, state: TechnicalAnalysisState
) -> dict[str, Any]:
"""Analyze technical indicators."""
try:
# Get indicators tool
indicators_tool = next(
(t for t in self.tools if "technical_indicators" in t.name), None
)
if indicators_tool and state.get("symbol"):
circuit_breaker = await circuit_manager.get_or_create("technical")
indicators = state.get("indicators", ["RSI", "MACD", "BB"])
async def get_indicators():
return await indicators_tool.ainvoke(
{
"symbol": state["symbol"],
"indicators": indicators,
"period": state.get("lookback_days", 20),
}
)
indicator_data = await circuit_breaker.call(get_indicators)
if isinstance(indicator_data, dict):
# Store indicator values
state["indicator_values"] = indicator_data.get("values", {})
# Generate indicator signals
signals = self._generate_indicator_signals(indicator_data)
state["indicator_signals"] = signals
# Check for divergences
divergences = self._check_divergences(
state.get("price_history", {}), indicator_data
)
state["divergences"] = divergences
except Exception as e:
logger.error(f"Error analyzing indicators: {e}")
state["api_calls_made"] = state.get("api_calls_made", 0) + 1
return {
"indicator_values": state.get("indicator_values", {}),
"indicator_signals": state.get("indicator_signals", {}),
"divergences": state.get("divergences", []),
}
async def _construct_trade_setup(
self, state: TechnicalAnalysisState
) -> dict[str, Any]:
"""Construct complete trade setup."""
try:
current_price = state.get("current_price", 0)
if current_price > 0:
# Calculate entry points based on patterns and levels
entry_points = self._calculate_entry_points(state)
state["entry_points"] = entry_points
# Get stop loss recommendation
stops_tool = next(
(t for t in self.tools if isinstance(t, TechnicalStopsTool)), None
)
if stops_tool:
stops_data = await stops_tool.ainvoke(
{
"symbol": state["symbol"],
"lookback_days": 20,
}
)
if isinstance(stops_data, dict):
stop_loss = stops_data.get(
"recommended_stop", current_price * 0.95
)
else:
stop_loss = current_price * 0.95
else:
stop_loss = current_price * 0.95
state["stop_loss"] = stop_loss
# Calculate profit targets
risk = current_price - stop_loss
targets = [
current_price + (risk * 1.5), # 1.5R
current_price + (risk * 2.0), # 2R
current_price + (risk * 3.0), # 3R
]
state["profit_targets"] = targets
# Calculate risk/reward
state["risk_reward_ratio"] = 2.0 # Default target
# Rate setup quality
quality = self._rate_setup_quality(state)
state["setup_quality"] = quality
# Calculate confidence score
confidence = self._calculate_confidence_score(state)
state["confidence_score"] = confidence
except Exception as e:
logger.error(f"Error constructing trade setup: {e}")
return {
"entry_points": state.get("entry_points", []),
"stop_loss": state.get("stop_loss", 0),
"profit_targets": state.get("profit_targets", []),
"risk_reward_ratio": state.get("risk_reward_ratio", 0),
"setup_quality": state.get("setup_quality", "C"),
"confidence_score": state.get("confidence_score", 0),
}
def _generate_indicator_signals(self, indicator_data: dict) -> dict[str, str]:
"""Generate buy/sell/hold signals from indicators."""
signals = {}
# RSI signals
rsi = indicator_data.get("RSI", {}).get("value", 50)
if rsi < 30:
signals["RSI"] = "buy"
elif rsi > 70:
signals["RSI"] = "sell"
else:
signals["RSI"] = "hold"
# MACD signals
macd = indicator_data.get("MACD", {})
if macd.get("histogram", 0) > 0 and macd.get("signal_cross", "") == "bullish":
signals["MACD"] = "buy"
elif macd.get("histogram", 0) < 0 and macd.get("signal_cross", "") == "bearish":
signals["MACD"] = "sell"
else:
signals["MACD"] = "hold"
return signals
def _check_divergences(
self, price_history: dict, indicator_data: dict
) -> list[dict[str, Any]]:
"""Check for price/indicator divergences."""
divergences: list[dict[str, Any]] = []
# Simplified divergence detection
# In production, would use more sophisticated analysis
return divergences
def _calculate_entry_points(self, state: TechnicalAnalysisState) -> list[float]:
"""Calculate optimal entry points."""
current_price = state.get("current_price", 0)
support_levels = state.get("support_levels", [])
patterns = state.get("patterns", [])
entries = []
# Pattern-based entries
for pattern in patterns:
if pattern.get("entry_price"):
entries.append(pattern["entry_price"])
# Support-based entries
for support in support_levels:
if support < current_price:
# Entry just above support
entries.append(support * 1.01)
# Current price entry if momentum
if state.get("trend_direction") == "bullish":
entries.append(current_price)
return sorted(set(entries))[:3] # Top 3 unique entries
def _rate_setup_quality(self, state: TechnicalAnalysisState) -> str:
"""Rate the quality of the trade setup."""
score = 0
# Pattern quality
if state.get("patterns"):
max_confidence = max(p.get("confidence", 0) for p in state["patterns"])
if max_confidence > 80:
score += 30
elif max_confidence > 60:
score += 20
else:
score += 10
# Indicator confluence
signals = state.get("indicator_signals", {})
buy_signals = sum(1 for s in signals.values() if s == "buy")
if buy_signals >= 3:
score += 30
elif buy_signals >= 2:
score += 20
else:
score += 10
# Risk/Reward
rr = state.get("risk_reward_ratio", 0)
if rr >= 3:
score += 20
elif rr >= 2:
score += 15
else:
score += 5
# Volume confirmation (would check in real implementation)
score += 10
# Market alignment (would check in real implementation)
score += 10
# Convert score to grade
if score >= 85:
return "A+"
elif score >= 75:
return "A"
elif score >= 65:
return "B"
else:
return "C"
def _calculate_confidence_score(self, state: TechnicalAnalysisState) -> float:
"""Calculate overall confidence score for the setup."""
factors = []
# Pattern confidence
if state.get("pattern_confidence"):
factors.append(max(state["pattern_confidence"].values()) / 100)
# Indicator agreement
signals = state.get("indicator_signals", {})
if signals:
buy_count = sum(1 for s in signals.values() if s == "buy")
factors.append(buy_count / len(signals))
# Setup quality
quality_scores = {"A+": 1.0, "A": 0.85, "B": 0.70, "C": 0.50}
factors.append(quality_scores.get(state.get("setup_quality", "C"), 0.5))
# Average confidence
return round(sum(factors) / len(factors) * 100, 1) if factors else 50.0
async def analyze_stock(
self,
symbol: str,
timeframe: str = "1d",
indicators: list[str] | None = None,
**kwargs,
) -> dict[str, Any]:
"""
Perform comprehensive technical analysis on a stock.
Args:
symbol: Stock symbol
timeframe: Chart timeframe
indicators: List of indicators to analyze
**kwargs: Additional parameters
Returns:
Complete technical analysis with trade setup
"""
start_time = datetime.now()
# Default indicators
if indicators is None:
indicators = ["RSI", "MACD", "BB", "EMA", "VWAP"]
# Prepare query
query = f"Analyze {symbol} on {timeframe} timeframe with focus on patterns and trade setup"
# Initial state
initial_state = {
"messages": [HumanMessage(content=query)],
"symbol": symbol,
"timeframe": timeframe,
"indicators": indicators,
"lookback_days": kwargs.get("lookback_days", 20),
"pattern_detection": True,
"multi_timeframe": kwargs.get("multi_timeframe", False),
"persona": self.persona.name,
"session_id": kwargs.get(
"session_id", f"{symbol}_{datetime.now().timestamp()}"
),
"timestamp": datetime.now(),
"api_calls_made": 0,
}
# Run analysis
result = await self.ainvoke(
query, initial_state["session_id"], initial_state=initial_state
)
# Calculate execution time
execution_time = (datetime.now() - start_time).total_seconds() * 1000
# Extract results
return self._format_analysis_results(result, execution_time)
def _format_analysis_results(
self, result: dict[str, Any], execution_time: float
) -> dict[str, Any]:
"""Format technical analysis results."""
state = result.get("state", {})
messages = result.get("messages", [])
return {
"status": "success",
"timestamp": datetime.now().isoformat(),
"execution_time_ms": execution_time,
"symbol": state.get("symbol", ""),
"analysis": {
"market_structure": {
"trend": state.get("trend_direction", "neutral"),
"support_levels": state.get("support_levels", []),
"resistance_levels": state.get("resistance_levels", []),
},
"patterns": {
"detected": state.get("patterns", []),
"confidence": state.get("pattern_confidence", {}),
},
"indicators": {
"values": state.get("indicator_values", {}),
"signals": state.get("indicator_signals", {}),
"divergences": state.get("divergences", []),
},
"trade_setup": {
"entries": state.get("entry_points", []),
"stop_loss": state.get("stop_loss", 0),
"targets": state.get("profit_targets", []),
"risk_reward": state.get("risk_reward_ratio", 0),
"quality": state.get("setup_quality", "C"),
"confidence": state.get("confidence_score", 0),
},
},
"recommendation": messages[-1].content if messages else "",
"persona_adjusted": True,
"risk_profile": self.persona.name,
}
def _create_mock_tools(self) -> list:
"""Create mock tools for testing."""
from langchain_core.tools import tool
@tool
def mock_technical_indicators(symbol: str, indicators: list[str]) -> dict:
"""Mock technical indicators tool."""
return {
"RSI": {"value": 45, "trend": "neutral"},
"MACD": {"histogram": 0.5, "signal_cross": "bullish"},
"BB": {"upper": 150, "middle": 145, "lower": 140},
}
@tool
def mock_support_resistance(symbol: str) -> dict:
"""Mock support/resistance tool."""
return {
"support_levels": [140, 135, 130],
"resistance_levels": [150, 155, 160],
"trend": "bullish",
}
return [mock_technical_indicators, mock_support_resistance]
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/ab_testing.py:
--------------------------------------------------------------------------------
```python
"""A/B testing framework for comparing ML model performance."""
import logging
import random
from datetime import datetime
from typing import Any
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from .model_manager import ModelManager
logger = logging.getLogger(__name__)
class ABTestGroup:
"""Represents a group in an A/B test."""
def __init__(
self,
group_id: str,
model_id: str,
model_version: str,
traffic_allocation: float,
description: str = "",
):
"""Initialize A/B test group.
Args:
group_id: Unique identifier for the group
model_id: Model identifier
model_version: Model version
traffic_allocation: Fraction of traffic allocated to this group (0-1)
description: Description of the group
"""
self.group_id = group_id
self.model_id = model_id
self.model_version = model_version
self.traffic_allocation = traffic_allocation
self.description = description
self.created_at = datetime.now()
# Performance tracking
self.predictions: list[Any] = []
self.actual_values: list[Any] = []
self.prediction_timestamps: list[datetime] = []
self.prediction_confidence: list[float] = []
def add_prediction(
self,
prediction: Any,
actual: Any,
confidence: float = 1.0,
timestamp: datetime | None = None,
):
"""Add a prediction result to the group.
Args:
prediction: Model prediction
actual: Actual value
confidence: Prediction confidence score
timestamp: Prediction timestamp
"""
self.predictions.append(prediction)
self.actual_values.append(actual)
self.prediction_confidence.append(confidence)
self.prediction_timestamps.append(timestamp or datetime.now())
def get_metrics(self) -> dict[str, float]:
"""Calculate performance metrics for the group.
Returns:
Dictionary of performance metrics
"""
if not self.predictions or not self.actual_values:
return {}
try:
predictions = np.array(self.predictions)
actuals = np.array(self.actual_values)
metrics = {
"sample_count": len(predictions),
"accuracy": accuracy_score(actuals, predictions),
"precision": precision_score(
actuals, predictions, average="weighted", zero_division=0
),
"recall": recall_score(
actuals, predictions, average="weighted", zero_division=0
),
"f1_score": f1_score(
actuals, predictions, average="weighted", zero_division=0
),
"avg_confidence": np.mean(self.prediction_confidence),
}
# Add confusion matrix for binary/multiclass
unique_labels = np.unique(np.concatenate([predictions, actuals]))
if len(unique_labels) <= 10: # Reasonable number of classes
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(actuals, predictions, labels=unique_labels)
metrics["confusion_matrix"] = cm.tolist()
metrics["unique_labels"] = unique_labels.tolist()
return metrics
except Exception as e:
logger.error(f"Error calculating metrics for group {self.group_id}: {e}")
return {"error": str(e)}
def to_dict(self) -> dict[str, Any]:
"""Convert group to dictionary representation."""
return {
"group_id": self.group_id,
"model_id": self.model_id,
"model_version": self.model_version,
"traffic_allocation": self.traffic_allocation,
"description": self.description,
"created_at": self.created_at.isoformat(),
"metrics": self.get_metrics(),
}
class ABTest:
"""Manages an A/B test between different model versions."""
def __init__(
self,
test_id: str,
name: str,
description: str = "",
random_seed: int | None = None,
):
"""Initialize A/B test.
Args:
test_id: Unique identifier for the test
name: Human-readable name for the test
description: Description of the test
random_seed: Random seed for reproducible traffic splitting
"""
self.test_id = test_id
self.name = name
self.description = description
self.created_at = datetime.now()
self.started_at: datetime | None = None
self.ended_at: datetime | None = None
self.status = "created" # created, running, completed, cancelled
# Groups in the test
self.groups: dict[str, ABTestGroup] = {}
# Traffic allocation
self.traffic_splitter = TrafficSplitter(random_seed)
# Test configuration
self.min_samples_per_group = 100
self.confidence_level = 0.95
self.minimum_detectable_effect = 0.05
def add_group(
self,
group_id: str,
model_id: str,
model_version: str,
traffic_allocation: float,
description: str = "",
) -> bool:
"""Add a group to the A/B test.
Args:
group_id: Unique identifier for the group
model_id: Model identifier
model_version: Model version
traffic_allocation: Fraction of traffic (0-1)
description: Description of the group
Returns:
True if successful
"""
if self.status != "created":
logger.error(
f"Cannot add group to test {self.test_id} - test already started"
)
return False
if group_id in self.groups:
logger.error(f"Group {group_id} already exists in test {self.test_id}")
return False
# Validate traffic allocation
current_total = sum(g.traffic_allocation for g in self.groups.values())
if (
current_total + traffic_allocation > 1.0001
): # Small tolerance for floating point
logger.error(
f"Traffic allocation would exceed 100%: {current_total + traffic_allocation}"
)
return False
group = ABTestGroup(
group_id=group_id,
model_id=model_id,
model_version=model_version,
traffic_allocation=traffic_allocation,
description=description,
)
self.groups[group_id] = group
self.traffic_splitter.update_allocation(
{gid: g.traffic_allocation for gid, g in self.groups.items()}
)
logger.info(f"Added group {group_id} to test {self.test_id}")
return True
def start_test(self) -> bool:
"""Start the A/B test.
Returns:
True if successful
"""
if self.status != "created":
logger.error(
f"Cannot start test {self.test_id} - invalid status: {self.status}"
)
return False
if len(self.groups) < 2:
logger.error(f"Cannot start test {self.test_id} - need at least 2 groups")
return False
# Validate traffic allocation sums to approximately 1.0
total_allocation = sum(g.traffic_allocation for g in self.groups.values())
if abs(total_allocation - 1.0) > 0.01:
logger.error(f"Traffic allocation does not sum to 1.0: {total_allocation}")
return False
self.status = "running"
self.started_at = datetime.now()
logger.info(f"Started A/B test {self.test_id} with {len(self.groups)} groups")
return True
def assign_traffic(self, user_id: str | None = None) -> str | None:
"""Assign traffic to a group.
Args:
user_id: User identifier for consistent assignment
Returns:
Group ID or None if test not running
"""
if self.status != "running":
return None
return self.traffic_splitter.assign_group(user_id)
def record_prediction(
self,
group_id: str,
prediction: Any,
actual: Any,
confidence: float = 1.0,
timestamp: datetime | None = None,
) -> bool:
"""Record a prediction result for a group.
Args:
group_id: Group identifier
prediction: Model prediction
actual: Actual value
confidence: Prediction confidence
timestamp: Prediction timestamp
Returns:
True if successful
"""
if group_id not in self.groups:
logger.error(f"Group {group_id} not found in test {self.test_id}")
return False
self.groups[group_id].add_prediction(prediction, actual, confidence, timestamp)
return True
def get_results(self) -> dict[str, Any]:
"""Get current A/B test results.
Returns:
Dictionary with test results
"""
results = {
"test_id": self.test_id,
"name": self.name,
"description": self.description,
"status": self.status,
"created_at": self.created_at.isoformat(),
"started_at": self.started_at.isoformat() if self.started_at else None,
"ended_at": self.ended_at.isoformat() if self.ended_at else None,
"groups": {},
"statistical_analysis": {},
}
# Group results
for group_id, group in self.groups.items():
results["groups"][group_id] = group.to_dict()
# Statistical analysis
if len(self.groups) >= 2:
results["statistical_analysis"] = self._perform_statistical_analysis()
return results
def _perform_statistical_analysis(self) -> dict[str, Any]:
"""Perform statistical analysis of A/B test results.
Returns:
Statistical analysis results
"""
analysis = {
"ready_for_analysis": True,
"sample_size_adequate": True,
"statistical_significance": {},
"effect_sizes": {},
"recommendations": [],
}
# Check sample sizes
sample_sizes = {
group_id: len(group.predictions) for group_id, group in self.groups.items()
}
min_samples = min(sample_sizes.values()) if sample_sizes else 0
if min_samples < self.min_samples_per_group:
analysis["ready_for_analysis"] = False
analysis["sample_size_adequate"] = False
analysis["recommendations"].append(
f"Need at least {self.min_samples_per_group} samples per group (current min: {min_samples})"
)
if not analysis["ready_for_analysis"]:
return analysis
# Pairwise comparisons
group_ids = list(self.groups.keys())
for i, group_a_id in enumerate(group_ids):
for group_b_id in group_ids[i + 1 :]:
comparison_key = f"{group_a_id}_vs_{group_b_id}"
try:
group_a = self.groups[group_a_id]
group_b = self.groups[group_b_id]
# Compare accuracy scores
accuracy_a = accuracy_score(
group_a.actual_values, group_a.predictions
)
accuracy_b = accuracy_score(
group_b.actual_values, group_b.predictions
)
# Perform statistical test
# For classification accuracy, we can use a proportion test
n_correct_a = sum(
np.array(group_a.predictions) == np.array(group_a.actual_values)
)
n_correct_b = sum(
np.array(group_b.predictions) == np.array(group_b.actual_values)
)
n_total_a = len(group_a.predictions)
n_total_b = len(group_b.predictions)
# Two-proportion z-test
p_combined = (n_correct_a + n_correct_b) / (n_total_a + n_total_b)
se = np.sqrt(
p_combined * (1 - p_combined) * (1 / n_total_a + 1 / n_total_b)
)
if se > 0:
z_score = (accuracy_a - accuracy_b) / se
p_value = 2 * (1 - stats.norm.cdf(abs(z_score)))
# Effect size (Cohen's h for proportions)
h = 2 * (
np.arcsin(np.sqrt(accuracy_a))
- np.arcsin(np.sqrt(accuracy_b))
)
analysis["statistical_significance"][comparison_key] = {
"accuracy_a": accuracy_a,
"accuracy_b": accuracy_b,
"difference": accuracy_a - accuracy_b,
"z_score": z_score,
"p_value": p_value,
"significant": p_value < (1 - self.confidence_level),
"effect_size_h": h,
}
# Recommendations based on results
if p_value < (1 - self.confidence_level):
if accuracy_a > accuracy_b:
analysis["recommendations"].append(
f"Group {group_a_id} significantly outperforms {group_b_id} "
f"(p={p_value:.4f}, effect_size={h:.4f})"
)
else:
analysis["recommendations"].append(
f"Group {group_b_id} significantly outperforms {group_a_id} "
f"(p={p_value:.4f}, effect_size={h:.4f})"
)
else:
analysis["recommendations"].append(
f"No significant difference between {group_a_id} and {group_b_id} "
f"(p={p_value:.4f})"
)
except Exception as e:
logger.error(
f"Error in statistical analysis for {comparison_key}: {e}"
)
analysis["statistical_significance"][comparison_key] = {
"error": str(e)
}
return analysis
def stop_test(self, reason: str = "completed") -> bool:
"""Stop the A/B test.
Args:
reason: Reason for stopping
Returns:
True if successful
"""
if self.status != "running":
logger.error(f"Cannot stop test {self.test_id} - not running")
return False
self.status = "completed" if reason == "completed" else "cancelled"
self.ended_at = datetime.now()
logger.info(f"Stopped A/B test {self.test_id}: {reason}")
return True
def to_dict(self) -> dict[str, Any]:
"""Convert test to dictionary representation."""
return {
"test_id": self.test_id,
"name": self.name,
"description": self.description,
"status": self.status,
"created_at": self.created_at.isoformat(),
"started_at": self.started_at.isoformat() if self.started_at else None,
"ended_at": self.ended_at.isoformat() if self.ended_at else None,
"groups": {gid: g.to_dict() for gid, g in self.groups.items()},
"configuration": {
"min_samples_per_group": self.min_samples_per_group,
"confidence_level": self.confidence_level,
"minimum_detectable_effect": self.minimum_detectable_effect,
},
}
class TrafficSplitter:
"""Handles traffic splitting for A/B tests."""
def __init__(self, random_seed: int | None = None):
"""Initialize traffic splitter.
Args:
random_seed: Random seed for reproducible splitting
"""
self.random_seed = random_seed
self.group_allocation: dict[str, float] = {}
self.cumulative_allocation: list[tuple[str, float]] = []
def update_allocation(self, allocation: dict[str, float]):
"""Update group traffic allocation.
Args:
allocation: Dictionary mapping group_id to allocation fraction
"""
self.group_allocation = allocation.copy()
# Create cumulative distribution for sampling
self.cumulative_allocation = []
cumulative = 0.0
for group_id, fraction in allocation.items():
cumulative += fraction
self.cumulative_allocation.append((group_id, cumulative))
def assign_group(self, user_id: str | None = None) -> str | None:
"""Assign a user to a group.
Args:
user_id: User identifier for consistent assignment
Returns:
Group ID or None if no groups configured
"""
if not self.cumulative_allocation:
return None
# Generate random value
if user_id is not None:
# Hash user_id for consistent assignment
import hashlib
hash_object = hashlib.md5(user_id.encode())
hash_int = int(hash_object.hexdigest(), 16)
rand_value = (hash_int % 10000) / 10000.0 # Normalize to [0, 1)
else:
if self.random_seed is not None:
random.seed(self.random_seed)
rand_value = random.random()
# Find group based on cumulative allocation
for group_id, cumulative_threshold in self.cumulative_allocation:
if rand_value <= cumulative_threshold:
return group_id
# Fallback to last group
return self.cumulative_allocation[-1][0] if self.cumulative_allocation else None
class ABTestManager:
"""Manages multiple A/B tests."""
def __init__(self, model_manager: ModelManager):
"""Initialize A/B test manager.
Args:
model_manager: Model manager instance
"""
self.model_manager = model_manager
self.tests: dict[str, ABTest] = {}
def create_test(
self,
test_id: str,
name: str,
description: str = "",
random_seed: int | None = None,
) -> ABTest:
"""Create a new A/B test.
Args:
test_id: Unique identifier for the test
name: Human-readable name
description: Description
random_seed: Random seed for reproducible splitting
Returns:
ABTest instance
"""
if test_id in self.tests:
raise ValueError(f"Test {test_id} already exists")
test = ABTest(test_id, name, description, random_seed)
self.tests[test_id] = test
logger.info(f"Created A/B test {test_id}: {name}")
return test
def get_test(self, test_id: str) -> ABTest | None:
"""Get an A/B test by ID.
Args:
test_id: Test identifier
Returns:
ABTest instance or None
"""
return self.tests.get(test_id)
def list_tests(self, status_filter: str | None = None) -> list[dict[str, Any]]:
"""List all A/B tests.
Args:
status_filter: Filter by status (created, running, completed, cancelled)
Returns:
List of test summaries
"""
tests = []
for test in self.tests.values():
if status_filter is None or test.status == status_filter:
tests.append(
{
"test_id": test.test_id,
"name": test.name,
"status": test.status,
"groups_count": len(test.groups),
"created_at": test.created_at.isoformat(),
"started_at": test.started_at.isoformat()
if test.started_at
else None,
}
)
# Sort by creation time (newest first)
tests.sort(key=lambda x: x["created_at"], reverse=True)
return tests
def run_model_comparison(
self,
test_name: str,
model_versions: list[tuple[str, str]], # (model_id, version)
test_data: pd.DataFrame,
feature_extractor: Any,
target_extractor: Any,
traffic_allocation: list[float] | None = None,
test_duration_hours: int = 24,
) -> str:
"""Run a model comparison A/B test.
Args:
test_name: Name for the test
model_versions: List of (model_id, version) tuples to compare
test_data: Test data for evaluation
feature_extractor: Function to extract features
target_extractor: Function to extract targets
traffic_allocation: Custom traffic allocation (defaults to equal split)
test_duration_hours: Duration to run the test
Returns:
Test ID
"""
# Generate unique test ID
test_id = f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# Create test
test = self.create_test(
test_id=test_id,
name=test_name,
description=f"Comparing {len(model_versions)} model versions",
)
# Default equal traffic allocation
if traffic_allocation is None:
allocation_per_group = 1.0 / len(model_versions)
traffic_allocation = [allocation_per_group] * len(model_versions)
# Add groups
for i, (model_id, version) in enumerate(model_versions):
group_id = f"group_{i}_{model_id}_{version}"
test.add_group(
group_id=group_id,
model_id=model_id,
model_version=version,
traffic_allocation=traffic_allocation[i],
description=f"Model {model_id} version {version}",
)
# Start test
test.start_test()
# Extract features and targets
features = feature_extractor(test_data)
targets = target_extractor(test_data)
# Simulate predictions for each group
for _, row in features.iterrows():
# Assign traffic
group_id = test.assign_traffic(str(row.name)) # Use row index as user_id
if group_id is None:
continue
# Get corresponding group's model
group = test.groups[group_id]
model_version = self.model_manager.load_model(
group.model_id, group.model_version
)
if model_version is None or model_version.model is None:
logger.warning(f"Could not load model for group {group_id}")
continue
try:
# Make prediction
X = row.values.reshape(1, -1)
if model_version.scaler is not None:
X = model_version.scaler.transform(X)
prediction = model_version.model.predict(X)[0]
# Get confidence if available
confidence = 1.0
if hasattr(model_version.model, "predict_proba"):
proba = model_version.model.predict_proba(X)[0]
confidence = max(proba)
# Get actual value
actual = targets.loc[row.name]
# Record prediction
test.record_prediction(group_id, prediction, actual, confidence)
except Exception as e:
logger.warning(f"Error making prediction for group {group_id}: {e}")
logger.info(f"Completed model comparison test {test_id}")
return test_id
def get_test_summary(self) -> dict[str, Any]:
"""Get summary of all A/B tests.
Returns:
Summary dictionary
"""
total_tests = len(self.tests)
status_counts = {}
for test in self.tests.values():
status_counts[test.status] = status_counts.get(test.status, 0) + 1
recent_tests = sorted(
[
{
"test_id": test.test_id,
"name": test.name,
"status": test.status,
"created_at": test.created_at.isoformat(),
}
for test in self.tests.values()
],
key=lambda x: x["created_at"],
reverse=True,
)[:10]
return {
"total_tests": total_tests,
"status_distribution": status_counts,
"recent_tests": recent_tests,
}
```
--------------------------------------------------------------------------------
/tests/test_visualization.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive tests for backtesting visualization module.
Tests cover:
- Chart generation and base64 encoding with matplotlib
- Equity curve plotting with drawdown subplots
- Trade scatter plots on price charts
- Parameter optimization heatmaps
- Portfolio allocation pie charts
- Strategy comparison line charts
- Performance dashboard table generation
- Theme support (light/dark modes)
- Image resolution and size optimization
- Error handling for malformed data
"""
import base64
from unittest.mock import patch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
from maverick_mcp.backtesting.visualization import (
generate_equity_curve,
generate_optimization_heatmap,
generate_performance_dashboard,
generate_portfolio_allocation,
generate_strategy_comparison,
generate_trade_scatter,
image_to_base64,
set_chart_style,
)
class TestVisualizationUtilities:
"""Test suite for visualization utility functions."""
def test_set_chart_style_light_theme(self):
"""Test light theme styling configuration."""
set_chart_style("light")
# Test that matplotlib parameters are set correctly
assert plt.rcParams["axes.facecolor"] == "white"
assert plt.rcParams["figure.facecolor"] == "white"
assert plt.rcParams["font.size"] == 10
assert plt.rcParams["text.color"] == "black"
assert plt.rcParams["axes.labelcolor"] == "black"
assert plt.rcParams["xtick.color"] == "black"
assert plt.rcParams["ytick.color"] == "black"
def test_set_chart_style_dark_theme(self):
"""Test dark theme styling configuration."""
set_chart_style("dark")
# Test that matplotlib parameters are set correctly
assert plt.rcParams["axes.facecolor"] == "#1E1E1E"
assert plt.rcParams["figure.facecolor"] == "#121212"
assert plt.rcParams["font.size"] == 10
assert plt.rcParams["text.color"] == "white"
assert plt.rcParams["axes.labelcolor"] == "white"
assert plt.rcParams["xtick.color"] == "white"
assert plt.rcParams["ytick.color"] == "white"
def test_image_to_base64_conversion(self):
"""Test image to base64 conversion with proper formatting."""
# Create a simple test figure
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
ax.set_title("Test Chart")
# Convert to base64
base64_str = image_to_base64(fig, dpi=100)
# Test base64 string properties
assert isinstance(base64_str, str)
assert len(base64_str) > 100 # Should contain substantial data
# Test that it's valid base64
try:
decoded_bytes = base64.b64decode(base64_str)
assert len(decoded_bytes) > 0
except Exception as e:
pytest.fail(f"Invalid base64 encoding: {e}")
def test_image_to_base64_size_optimization(self):
"""Test image size optimization and aspect ratio preservation."""
# Create large figure
fig, ax = plt.subplots(figsize=(20, 15)) # Large size
ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
original_width, original_height = fig.get_size_inches()
original_aspect = original_height / original_width
# Convert with size constraint
base64_str = image_to_base64(fig, dpi=100, max_width=800)
# Test that resizing occurred
final_width, final_height = fig.get_size_inches()
final_aspect = final_height / final_width
assert final_width <= 8.0 # 800px / 100dpi = 8 inches
assert abs(final_aspect - original_aspect) < 0.01 # Aspect ratio preserved
assert len(base64_str) > 0
def test_image_to_base64_error_handling(self):
"""Test error handling in base64 conversion."""
with patch(
"matplotlib.figure.Figure.savefig", side_effect=Exception("Save error")
):
fig, ax = plt.subplots()
ax.plot([1, 2, 3])
result = image_to_base64(fig)
assert result == "" # Should return empty string on error
class TestEquityCurveGeneration:
"""Test suite for equity curve chart generation."""
@pytest.fixture
def sample_returns_data(self):
"""Create sample returns data for testing."""
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
returns = np.random.normal(0.001, 0.02, len(dates))
cumulative_returns = np.cumprod(1 + returns)
# Create drawdown series
running_max = np.maximum.accumulate(cumulative_returns)
drawdown = (cumulative_returns - running_max) / running_max * 100
return pd.Series(cumulative_returns, index=dates), pd.Series(
drawdown, index=dates
)
def test_generate_equity_curve_basic(self, sample_returns_data):
"""Test basic equity curve generation."""
returns, drawdown = sample_returns_data
base64_str = generate_equity_curve(returns, title="Test Equity Curve")
assert isinstance(base64_str, str)
assert len(base64_str) > 100
# Test that it's valid base64 image
try:
decoded_bytes = base64.b64decode(base64_str)
assert decoded_bytes.startswith(b"\x89PNG") # PNG header
except Exception as e:
pytest.fail(f"Invalid PNG image: {e}")
def test_generate_equity_curve_with_drawdown(self, sample_returns_data):
"""Test equity curve generation with drawdown subplot."""
returns, drawdown = sample_returns_data
base64_str = generate_equity_curve(
returns, drawdown=drawdown, title="Equity Curve with Drawdown", theme="dark"
)
assert isinstance(base64_str, str)
assert len(base64_str) > 100
# Should be larger image due to subplot
base64_no_dd = generate_equity_curve(returns, title="No Drawdown")
assert len(base64_str) >= len(base64_no_dd)
def test_generate_equity_curve_themes(self, sample_returns_data):
"""Test equity curve generation with different themes."""
returns, _ = sample_returns_data
light_chart = generate_equity_curve(returns, theme="light")
dark_chart = generate_equity_curve(returns, theme="dark")
assert len(light_chart) > 100
assert len(dark_chart) > 100
# Different themes should produce different images
assert light_chart != dark_chart
def test_generate_equity_curve_error_handling(self):
"""Test error handling in equity curve generation."""
# Test with invalid data
invalid_returns = pd.Series([]) # Empty series
result = generate_equity_curve(invalid_returns)
assert result == ""
# Test with NaN data
nan_returns = pd.Series([np.nan, np.nan, np.nan])
result = generate_equity_curve(nan_returns)
assert result == ""
class TestTradeScatterGeneration:
"""Test suite for trade scatter plot generation."""
@pytest.fixture
def sample_trade_data(self):
"""Create sample trade data for testing."""
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
prices = pd.Series(100 + np.random.walk(len(dates)), index=dates)
# Create sample trades
trade_dates = dates[::30] # Every 30 days
trades = []
for i, trade_date in enumerate(trade_dates):
if i % 2 == 0: # Entry
trades.append(
{
"date": trade_date,
"price": prices.loc[trade_date],
"type": "entry",
}
)
else: # Exit
trades.append(
{
"date": trade_date,
"price": prices.loc[trade_date],
"type": "exit",
}
)
trades_df = pd.DataFrame(trades).set_index("date")
return prices, trades_df
def test_generate_trade_scatter_basic(self, sample_trade_data):
"""Test basic trade scatter plot generation."""
prices, trades = sample_trade_data
base64_str = generate_trade_scatter(prices, trades, title="Trade Scatter Plot")
assert isinstance(base64_str, str)
assert len(base64_str) > 100
# Verify valid PNG
try:
decoded_bytes = base64.b64decode(base64_str)
assert decoded_bytes.startswith(b"\x89PNG")
except Exception as e:
pytest.fail(f"Invalid PNG image: {e}")
def test_generate_trade_scatter_themes(self, sample_trade_data):
"""Test trade scatter plots with different themes."""
prices, trades = sample_trade_data
light_chart = generate_trade_scatter(prices, trades, theme="light")
dark_chart = generate_trade_scatter(prices, trades, theme="dark")
assert len(light_chart) > 100
assert len(dark_chart) > 100
assert light_chart != dark_chart
def test_generate_trade_scatter_empty_trades(self, sample_trade_data):
"""Test trade scatter plot with empty trade data."""
prices, _ = sample_trade_data
empty_trades = pd.DataFrame(columns=["price", "type"])
result = generate_trade_scatter(prices, empty_trades)
assert result == ""
def test_generate_trade_scatter_error_handling(self):
"""Test error handling in trade scatter generation."""
# Test with mismatched data
prices = pd.Series([1, 2, 3])
trades = pd.DataFrame({"price": [10, 20], "type": ["entry", "exit"]})
# Should handle gracefully
result = generate_trade_scatter(prices, trades)
# Might return empty string or valid chart depending on implementation
assert isinstance(result, str)
class TestOptimizationHeatmapGeneration:
"""Test suite for parameter optimization heatmap generation."""
@pytest.fixture
def sample_optimization_data(self):
"""Create sample optimization results for testing."""
parameters = ["param1", "param2", "param3"]
results = {}
for p1 in parameters:
results[p1] = {}
for p2 in parameters:
# Create some performance metric
results[p1][p2] = np.random.uniform(0.5, 2.0)
return results
def test_generate_optimization_heatmap_basic(self, sample_optimization_data):
"""Test basic optimization heatmap generation."""
base64_str = generate_optimization_heatmap(
sample_optimization_data, title="Parameter Optimization Heatmap"
)
assert isinstance(base64_str, str)
assert len(base64_str) > 100
# Verify valid PNG
try:
decoded_bytes = base64.b64decode(base64_str)
assert decoded_bytes.startswith(b"\x89PNG")
except Exception as e:
pytest.fail(f"Invalid PNG image: {e}")
def test_generate_optimization_heatmap_themes(self, sample_optimization_data):
"""Test optimization heatmap with different themes."""
light_chart = generate_optimization_heatmap(
sample_optimization_data, theme="light"
)
dark_chart = generate_optimization_heatmap(
sample_optimization_data, theme="dark"
)
assert len(light_chart) > 100
assert len(dark_chart) > 100
assert light_chart != dark_chart
def test_generate_optimization_heatmap_empty_data(self):
"""Test heatmap generation with empty data."""
empty_data = {}
result = generate_optimization_heatmap(empty_data)
assert result == ""
def test_generate_optimization_heatmap_error_handling(self):
"""Test error handling in heatmap generation."""
# Test with malformed data
malformed_data = {"param1": "not_a_dict"}
result = generate_optimization_heatmap(malformed_data)
assert result == ""
class TestPortfolioAllocationGeneration:
"""Test suite for portfolio allocation chart generation."""
@pytest.fixture
def sample_allocation_data(self):
"""Create sample allocation data for testing."""
return {
"AAPL": 0.25,
"GOOGL": 0.20,
"MSFT": 0.15,
"TSLA": 0.15,
"AMZN": 0.10,
"Cash": 0.15,
}
def test_generate_portfolio_allocation_basic(self, sample_allocation_data):
"""Test basic portfolio allocation chart generation."""
base64_str = generate_portfolio_allocation(
sample_allocation_data, title="Portfolio Allocation"
)
assert isinstance(base64_str, str)
assert len(base64_str) > 100
# Verify valid PNG
try:
decoded_bytes = base64.b64decode(base64_str)
assert decoded_bytes.startswith(b"\x89PNG")
except Exception as e:
pytest.fail(f"Invalid PNG image: {e}")
def test_generate_portfolio_allocation_themes(self, sample_allocation_data):
"""Test portfolio allocation with different themes."""
light_chart = generate_portfolio_allocation(
sample_allocation_data, theme="light"
)
dark_chart = generate_portfolio_allocation(sample_allocation_data, theme="dark")
assert len(light_chart) > 100
assert len(dark_chart) > 100
assert light_chart != dark_chart
def test_generate_portfolio_allocation_empty_data(self):
"""Test allocation chart with empty data."""
empty_data = {}
result = generate_portfolio_allocation(empty_data)
assert result == ""
def test_generate_portfolio_allocation_single_asset(self):
"""Test allocation chart with single asset."""
single_asset = {"AAPL": 1.0}
result = generate_portfolio_allocation(single_asset)
assert isinstance(result, str)
assert len(result) > 100 # Should still generate valid chart
def test_generate_portfolio_allocation_error_handling(self):
"""Test error handling in allocation chart generation."""
# Test with invalid allocation values
invalid_data = {"AAPL": "invalid_value"}
result = generate_portfolio_allocation(invalid_data)
assert result == ""
class TestStrategyComparisonGeneration:
"""Test suite for strategy comparison chart generation."""
@pytest.fixture
def sample_strategy_data(self):
"""Create sample strategy comparison data."""
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
strategies = {
"Momentum": pd.Series(
np.cumprod(1 + np.random.normal(0.0008, 0.015, len(dates))), index=dates
),
"Mean Reversion": pd.Series(
np.cumprod(1 + np.random.normal(0.0005, 0.012, len(dates))), index=dates
),
"Breakout": pd.Series(
np.cumprod(1 + np.random.normal(0.0012, 0.020, len(dates))), index=dates
),
}
return strategies
def test_generate_strategy_comparison_basic(self, sample_strategy_data):
"""Test basic strategy comparison chart generation."""
base64_str = generate_strategy_comparison(
sample_strategy_data, title="Strategy Performance Comparison"
)
assert isinstance(base64_str, str)
assert len(base64_str) > 100
# Verify valid PNG
try:
decoded_bytes = base64.b64decode(base64_str)
assert decoded_bytes.startswith(b"\x89PNG")
except Exception as e:
pytest.fail(f"Invalid PNG image: {e}")
def test_generate_strategy_comparison_themes(self, sample_strategy_data):
"""Test strategy comparison with different themes."""
light_chart = generate_strategy_comparison(sample_strategy_data, theme="light")
dark_chart = generate_strategy_comparison(sample_strategy_data, theme="dark")
assert len(light_chart) > 100
assert len(dark_chart) > 100
assert light_chart != dark_chart
def test_generate_strategy_comparison_single_strategy(self):
"""Test comparison chart with single strategy."""
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
single_strategy = {
"Only Strategy": pd.Series(
np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))), index=dates
)
}
result = generate_strategy_comparison(single_strategy)
assert isinstance(result, str)
assert len(result) > 100
def test_generate_strategy_comparison_empty_data(self):
"""Test comparison chart with empty data."""
empty_data = {}
result = generate_strategy_comparison(empty_data)
assert result == ""
def test_generate_strategy_comparison_error_handling(self):
"""Test error handling in strategy comparison generation."""
# Test with mismatched data lengths
dates1 = pd.date_range(start="2023-01-01", end="2023-06-30", freq="D")
dates2 = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
mismatched_data = {
"Strategy1": pd.Series(np.random.normal(0, 1, len(dates1)), index=dates1),
"Strategy2": pd.Series(np.random.normal(0, 1, len(dates2)), index=dates2),
}
# Should handle gracefully
result = generate_strategy_comparison(mismatched_data)
assert isinstance(result, str) # Might be empty or valid
class TestPerformanceDashboardGeneration:
"""Test suite for performance dashboard generation."""
@pytest.fixture
def sample_metrics_data(self):
"""Create sample performance metrics for testing."""
return {
"Total Return": 0.156,
"Sharpe Ratio": 1.25,
"Max Drawdown": -0.082,
"Win Rate": 0.583,
"Profit Factor": 1.35,
"Total Trades": 24,
"Annualized Return": 0.18,
"Volatility": 0.16,
"Calmar Ratio": 1.10,
"Best Trade": 0.12,
}
def test_generate_performance_dashboard_basic(self, sample_metrics_data):
"""Test basic performance dashboard generation."""
base64_str = generate_performance_dashboard(
sample_metrics_data, title="Performance Dashboard"
)
assert isinstance(base64_str, str)
assert len(base64_str) > 100
# Verify valid PNG
try:
decoded_bytes = base64.b64decode(base64_str)
assert decoded_bytes.startswith(b"\x89PNG")
except Exception as e:
pytest.fail(f"Invalid PNG image: {e}")
def test_generate_performance_dashboard_themes(self, sample_metrics_data):
"""Test performance dashboard with different themes."""
light_chart = generate_performance_dashboard(sample_metrics_data, theme="light")
dark_chart = generate_performance_dashboard(sample_metrics_data, theme="dark")
assert len(light_chart) > 100
assert len(dark_chart) > 100
assert light_chart != dark_chart
def test_generate_performance_dashboard_mixed_data_types(self):
"""Test dashboard with mixed data types."""
mixed_metrics = {
"Total Return": 0.156,
"Strategy": "Momentum",
"Symbol": "AAPL",
"Duration": "365 days",
"Sharpe Ratio": 1.25,
"Status": "Completed",
}
result = generate_performance_dashboard(mixed_metrics)
assert isinstance(result, str)
assert len(result) > 100
def test_generate_performance_dashboard_empty_data(self):
"""Test dashboard with empty metrics."""
empty_metrics = {}
result = generate_performance_dashboard(empty_metrics)
assert result == ""
def test_generate_performance_dashboard_large_dataset(self):
"""Test dashboard with large number of metrics."""
large_metrics = {f"Metric_{i}": np.random.uniform(-1, 2) for i in range(50)}
result = generate_performance_dashboard(large_metrics)
assert isinstance(result, str)
# Might be empty if table becomes too large, or valid if handled properly
def test_generate_performance_dashboard_error_handling(self):
"""Test error handling in dashboard generation."""
# Test with invalid data that might cause table generation to fail
problematic_metrics = {
"Valid Metric": 1.25,
"Problematic": [1, 2, 3], # List instead of scalar
"Another Valid": 0.85,
}
result = generate_performance_dashboard(problematic_metrics)
assert isinstance(result, str)
class TestVisualizationIntegration:
"""Integration tests for visualization functions working together."""
def test_consistent_theming_across_charts(self):
"""Test that theming is consistent across different chart types."""
# Create sample data for different chart types
dates = pd.date_range(start="2023-01-01", end="2023-06-30", freq="D")
returns = pd.Series(
np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))), index=dates
)
allocation = {"AAPL": 0.4, "GOOGL": 0.3, "MSFT": 0.3}
metrics = {"Return": 0.15, "Sharpe": 1.2, "Drawdown": -0.08}
# Generate charts with same theme
equity_chart = generate_equity_curve(returns, theme="dark")
allocation_chart = generate_portfolio_allocation(allocation, theme="dark")
dashboard_chart = generate_performance_dashboard(metrics, theme="dark")
# All should generate valid base64 strings
charts = [equity_chart, allocation_chart, dashboard_chart]
for chart in charts:
assert isinstance(chart, str)
assert len(chart) > 100
# Verify valid PNG
try:
decoded_bytes = base64.b64decode(chart)
assert decoded_bytes.startswith(b"\x89PNG")
except Exception as e:
pytest.fail(f"Invalid PNG in themed charts: {e}")
def test_memory_cleanup_after_chart_generation(self):
"""Test that matplotlib figures are properly cleaned up."""
import matplotlib.pyplot as plt
initial_figure_count = len(plt.get_fignums())
# Generate multiple charts
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
returns = pd.Series(
np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))), index=dates
)
for i in range(10):
chart = generate_equity_curve(returns, title=f"Test Chart {i}")
assert len(chart) > 0
final_figure_count = len(plt.get_fignums())
# Figure count should not have increased (figures should be closed)
assert final_figure_count <= initial_figure_count + 1 # Allow for 1 open figure
def test_chart_generation_performance_benchmark(self, benchmark_timer):
"""Test chart generation performance benchmarks."""
# Create substantial dataset
dates = pd.date_range(
start="2023-01-01", end="2023-12-31", freq="H"
) # Hourly data
returns = pd.Series(
np.cumprod(1 + np.random.normal(0.0001, 0.005, len(dates))), index=dates
)
with benchmark_timer() as timer:
chart = generate_equity_curve(returns, title="Performance Test")
# Should complete within reasonable time even with large dataset
assert timer.elapsed < 5.0 # < 5 seconds
assert len(chart) > 100 # Valid chart generated
def test_concurrent_chart_generation(self):
"""Test concurrent chart generation doesn't cause conflicts."""
import queue
import threading
results_queue = queue.Queue()
error_queue = queue.Queue()
def generate_chart(thread_id):
try:
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
returns = pd.Series(
np.cumprod(1 + np.random.normal(0.001, 0.02, len(dates))),
index=dates,
)
chart = generate_equity_curve(returns, title=f"Thread {thread_id}")
results_queue.put((thread_id, len(chart)))
except Exception as e:
error_queue.put(f"Thread {thread_id}: {e}")
# Create multiple threads
threads = []
for i in range(5):
thread = threading.Thread(target=generate_chart, args=(i,))
threads.append(thread)
thread.start()
# Wait for completion
for thread in threads:
thread.join(timeout=10)
# Check results
assert error_queue.empty(), f"Errors: {list(error_queue.queue)}"
assert results_queue.qsize() == 5
# All should have generated valid charts
while not results_queue.empty():
thread_id, chart_length = results_queue.get()
assert chart_length > 100
if __name__ == "__main__":
# Run tests with detailed output
pytest.main([__file__, "-v", "--tb=short", "--asyncio-mode=auto"])
```
--------------------------------------------------------------------------------
/tests/integration/test_full_backtest_workflow_advanced.py:
--------------------------------------------------------------------------------
```python
"""
Advanced End-to-End Integration Tests for VectorBT Backtesting Workflow.
This comprehensive test suite covers:
- Complete workflow integration from data fetch to results
- All 15 strategies (9 traditional + 6 ML) testing
- Parallel execution capabilities
- Cache behavior and optimization
- Real production-like scenarios
- Error recovery and resilience
- Resource management and cleanup
"""
import asyncio
import logging
import time
from unittest.mock import Mock
from uuid import UUID
import numpy as np
import pandas as pd
import pytest
from maverick_mcp.backtesting import (
VectorBTEngine,
)
from maverick_mcp.backtesting.persistence import BacktestPersistenceManager
from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
from maverick_mcp.backtesting.visualization import (
generate_equity_curve,
generate_performance_dashboard,
)
logger = logging.getLogger(__name__)
# Strategy definitions for comprehensive testing
TRADITIONAL_STRATEGIES = [
"sma_cross",
"ema_cross",
"rsi",
"macd",
"bollinger",
"momentum",
"breakout",
"mean_reversion",
"volume_momentum",
]
ML_STRATEGIES = [
"ml_predictor",
"adaptive",
"ensemble",
"regime_aware",
"online_learning",
"reinforcement_learning",
]
ALL_STRATEGIES = TRADITIONAL_STRATEGIES + ML_STRATEGIES
class TestAdvancedBacktestWorkflowIntegration:
"""Advanced integration tests for complete backtesting workflow."""
@pytest.fixture
async def enhanced_stock_data_provider(self):
"""Create enhanced mock stock data provider with realistic multi-year data."""
provider = Mock()
# Generate 3 years of realistic stock data with different market conditions
dates = pd.date_range(start="2021-01-01", end="2023-12-31", freq="D")
# Simulate different market regimes
bull_period = len(dates) // 3 # First third: bull market
sideways_period = len(dates) // 3 # Second third: sideways
bear_period = len(dates) - bull_period - sideways_period # Final: bear market
# Generate returns for different regimes
bull_returns = np.random.normal(0.0015, 0.015, bull_period) # Positive drift
sideways_returns = np.random.normal(0.0002, 0.02, sideways_period) # Low drift
bear_returns = np.random.normal(-0.001, 0.025, bear_period) # Negative drift
all_returns = np.concatenate([bull_returns, sideways_returns, bear_returns])
prices = 100 * np.cumprod(1 + all_returns) # Start at $100
# Add realistic volume patterns
volumes = np.random.randint(500000, 5000000, len(dates)).astype(float)
volumes += np.random.normal(0, volumes * 0.1) # Add volume volatility
volumes = np.maximum(volumes, 100000) # Minimum volume
volumes = volumes.astype(int) # Convert back to integers
stock_data = pd.DataFrame(
{
"Open": prices * np.random.uniform(0.995, 1.005, len(dates)),
"High": prices * np.random.uniform(1.002, 1.025, len(dates)),
"Low": prices * np.random.uniform(0.975, 0.998, len(dates)),
"Close": prices,
"Volume": volumes.astype(int),
"Adj Close": prices,
},
index=dates,
)
# Ensure OHLC constraints
stock_data["High"] = np.maximum(
stock_data["High"], np.maximum(stock_data["Open"], stock_data["Close"])
)
stock_data["Low"] = np.minimum(
stock_data["Low"], np.minimum(stock_data["Open"], stock_data["Close"])
)
provider.get_stock_data.return_value = stock_data
return provider
@pytest.fixture
async def complete_vectorbt_engine(self, enhanced_stock_data_provider):
"""Create complete VectorBT engine with all strategies enabled."""
engine = VectorBTEngine(data_provider=enhanced_stock_data_provider)
return engine
async def test_all_15_strategies_integration(
self, complete_vectorbt_engine, benchmark_timer
):
"""Test all 15 strategies (9 traditional + 6 ML) in complete workflow."""
results = {}
failed_strategies = []
with benchmark_timer() as timer:
# Test traditional strategies
for strategy in TRADITIONAL_STRATEGIES:
try:
if strategy in STRATEGY_TEMPLATES:
parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
result = await complete_vectorbt_engine.run_backtest(
symbol="COMPREHENSIVE_TEST",
strategy_type=strategy,
parameters=parameters,
start_date="2022-01-01",
end_date="2023-12-31",
)
results[strategy] = result
# Validate basic result structure
assert "metrics" in result
assert "trades" in result
assert "equity_curve" in result
assert result["symbol"] == "COMPREHENSIVE_TEST"
logger.info(f"✓ {strategy} strategy executed successfully")
else:
logger.warning(f"Strategy {strategy} not found in templates")
except Exception as e:
failed_strategies.append(strategy)
logger.error(f"✗ {strategy} strategy failed: {str(e)}")
# Test ML strategies (mock implementation for integration test)
for strategy in ML_STRATEGIES:
try:
# Mock ML strategy execution
mock_ml_result = {
"symbol": "COMPREHENSIVE_TEST",
"strategy_type": strategy,
"metrics": {
"total_return": np.random.uniform(-0.2, 0.3),
"sharpe_ratio": np.random.uniform(0.5, 2.0),
"max_drawdown": np.random.uniform(-0.3, -0.05),
"total_trades": np.random.randint(10, 100),
},
"trades": [],
"equity_curve": np.random.cumsum(
np.random.normal(0.001, 0.02, 252)
).tolist(),
"ml_specific": {
"model_accuracy": np.random.uniform(0.55, 0.85),
"feature_importance": {
"momentum": 0.3,
"volatility": 0.25,
"volume": 0.45,
},
},
}
results[strategy] = mock_ml_result
logger.info(f"✓ {strategy} ML strategy simulated successfully")
except Exception as e:
failed_strategies.append(strategy)
logger.error(f"✗ {strategy} ML strategy failed: {str(e)}")
execution_time = timer.elapsed
# Validate overall results
successful_strategies = len(results)
total_strategies = len(ALL_STRATEGIES)
success_rate = successful_strategies / total_strategies
# Performance requirements
assert execution_time < 180.0 # Should complete within 3 minutes
assert success_rate >= 0.8 # At least 80% success rate
assert successful_strategies >= 12 # At least 12 strategies should work
# Log comprehensive results
logger.info(
f"Strategy Integration Test Results:\n"
f" • Total Strategies: {total_strategies}\n"
f" • Successful: {successful_strategies}\n"
f" • Failed: {len(failed_strategies)}\n"
f" • Success Rate: {success_rate:.1%}\n"
f" • Execution Time: {execution_time:.2f}s\n"
f" • Failed Strategies: {failed_strategies}"
)
return {
"total_strategies": total_strategies,
"successful_strategies": successful_strategies,
"failed_strategies": failed_strategies,
"success_rate": success_rate,
"execution_time": execution_time,
"results": results,
}
async def test_parallel_execution_capabilities(
self, complete_vectorbt_engine, benchmark_timer
):
"""Test parallel execution of multiple backtests."""
symbols = ["AAPL", "GOOGL", "MSFT", "TSLA", "AMZN", "META", "NFLX", "NVDA"]
strategies = ["sma_cross", "rsi", "macd", "bollinger"]
async def run_single_backtest(symbol, strategy):
"""Run a single backtest."""
try:
parameters = STRATEGY_TEMPLATES.get(strategy, {}).get("parameters", {})
result = await complete_vectorbt_engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
return {
"symbol": symbol,
"strategy": strategy,
"result": result,
"success": True,
}
except Exception as e:
return {
"symbol": symbol,
"strategy": strategy,
"error": str(e),
"success": False,
}
with benchmark_timer() as timer:
# Create all combinations
tasks = []
for symbol in symbols:
for strategy in strategies:
tasks.append(run_single_backtest(symbol, strategy))
# Execute in parallel with semaphore to control concurrency
semaphore = asyncio.Semaphore(8) # Max 8 concurrent executions
async def run_with_semaphore(task):
async with semaphore:
return await task
results = await asyncio.gather(
*[run_with_semaphore(task) for task in tasks], return_exceptions=True
)
execution_time = timer.elapsed
# Analyze results
total_executions = len(tasks)
successful_executions = sum(
1 for r in results if isinstance(r, dict) and r.get("success", False)
)
failed_executions = total_executions - successful_executions
# Performance assertions
assert execution_time < 300.0 # Should complete within 5 minutes
assert successful_executions >= total_executions * 0.7 # At least 70% success
# Calculate average execution time per backtest
avg_time_per_backtest = execution_time / total_executions
logger.info(
f"Parallel Execution Results:\n"
f" • Total Executions: {total_executions}\n"
f" • Successful: {successful_executions}\n"
f" • Failed: {failed_executions}\n"
f" • Success Rate: {successful_executions / total_executions:.1%}\n"
f" • Total Time: {execution_time:.2f}s\n"
f" • Avg Time/Backtest: {avg_time_per_backtest:.2f}s\n"
f" • Parallel Speedup: ~{total_executions * avg_time_per_backtest / execution_time:.1f}x"
)
return {
"total_executions": total_executions,
"successful_executions": successful_executions,
"execution_time": execution_time,
"avg_time_per_backtest": avg_time_per_backtest,
}
async def test_cache_behavior_and_optimization(self, complete_vectorbt_engine):
"""Test cache behavior and optimization in integrated workflow."""
symbol = "CACHE_TEST_SYMBOL"
strategy = "sma_cross"
parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
# First run - should populate cache
start_time = time.time()
result1 = await complete_vectorbt_engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
first_run_time = time.time() - start_time
# Second run - should use cache
start_time = time.time()
result2 = await complete_vectorbt_engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
second_run_time = time.time() - start_time
# Third run with different parameters - should not use cache
modified_parameters = {
**parameters,
"fast_period": parameters.get("fast_period", 10) + 5,
}
start_time = time.time()
await complete_vectorbt_engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=modified_parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
third_run_time = time.time() - start_time
# Validate results consistency (for cached runs)
assert result1["symbol"] == result2["symbol"]
assert result1["strategy_type"] == result2["strategy_type"]
# Cache effectiveness check (second run might be faster, but not guaranteed)
cache_speedup = first_run_time / second_run_time if second_run_time > 0 else 1.0
logger.info(
f"Cache Behavior Test Results:\n"
f" • First Run: {first_run_time:.3f}s\n"
f" • Second Run (cached): {second_run_time:.3f}s\n"
f" • Third Run (different params): {third_run_time:.3f}s\n"
f" • Cache Speedup: {cache_speedup:.2f}x\n"
)
return {
"first_run_time": first_run_time,
"second_run_time": second_run_time,
"third_run_time": third_run_time,
"cache_speedup": cache_speedup,
}
async def test_database_persistence_integration(
self, complete_vectorbt_engine, db_session
):
"""Test complete database persistence integration."""
# Generate test results
result = await complete_vectorbt_engine.run_backtest(
symbol="PERSISTENCE_TEST",
strategy_type="sma_cross",
parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
)
# Test persistence workflow
with BacktestPersistenceManager(session=db_session) as persistence:
# Save backtest result
backtest_id = persistence.save_backtest_result(
vectorbt_results=result,
execution_time=2.5,
notes="Integration test - complete persistence workflow",
)
# Validate saved data
assert backtest_id is not None
assert UUID(backtest_id) # Valid UUID
# Retrieve and validate
saved_result = persistence.get_backtest_by_id(backtest_id)
assert saved_result is not None
assert saved_result.symbol == "PERSISTENCE_TEST"
assert saved_result.strategy_type == "sma_cross"
assert saved_result.execution_time == 2.5
# Test batch operations
batch_results = []
for i in range(5):
batch_result = await complete_vectorbt_engine.run_backtest(
symbol=f"BATCH_TEST_{i}",
strategy_type="rsi",
parameters=STRATEGY_TEMPLATES["rsi"]["parameters"],
start_date="2023-06-01",
end_date="2023-12-31",
)
batch_results.append(batch_result)
# Save batch results
batch_ids = []
for i, batch_result in enumerate(batch_results):
batch_id = persistence.save_backtest_result(
vectorbt_results=batch_result,
execution_time=1.8 + i * 0.1,
notes=f"Batch test #{i + 1}",
)
batch_ids.append(batch_id)
# Query saved batch results
saved_batch = [persistence.get_backtest_by_id(bid) for bid in batch_ids]
assert all(saved is not None for saved in saved_batch)
assert len(saved_batch) == 5
# Test filtering and querying
rsi_results = persistence.get_backtests_by_strategy("rsi")
assert len(rsi_results) >= 5 # At least our batch results
logger.info("Database persistence test completed successfully")
return {"batch_ids": batch_ids, "single_id": backtest_id}
async def test_visualization_integration_complete(self, complete_vectorbt_engine):
"""Test complete visualization integration workflow."""
# Run backtest to get data for visualization
result = await complete_vectorbt_engine.run_backtest(
symbol="VIZ_TEST",
strategy_type="macd",
parameters=STRATEGY_TEMPLATES["macd"]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
)
# Test all visualization components
visualizations = {}
# 1. Equity curve visualization
equity_data = pd.Series(result["equity_curve"])
drawdown_data = pd.Series(result["drawdown_series"])
equity_chart = generate_equity_curve(
equity_data,
drawdown=drawdown_data,
title="Complete Integration Test - Equity Curve",
)
visualizations["equity_curve"] = equity_chart
# 2. Performance dashboard
dashboard_chart = generate_performance_dashboard(
result["metrics"], title="Complete Integration Test - Performance Dashboard"
)
visualizations["dashboard"] = dashboard_chart
# 3. Validate all visualizations
for viz_name, viz_data in visualizations.items():
assert isinstance(viz_data, str), f"{viz_name} should return string"
assert len(viz_data) > 100, f"{viz_name} should have substantial content"
# Try to decode as base64 (should be valid image)
try:
import base64
decoded = base64.b64decode(viz_data)
assert len(decoded) > 0, f"{viz_name} should have valid image data"
logger.info(f"✓ {viz_name} visualization generated successfully")
except Exception as e:
logger.error(f"✗ {viz_name} visualization failed: {e}")
raise
return visualizations
async def test_error_recovery_comprehensive(self, complete_vectorbt_engine):
"""Test comprehensive error recovery across the workflow."""
recovery_results = {}
# 1. Invalid symbol handling
try:
result = await complete_vectorbt_engine.run_backtest(
symbol="", # Empty symbol
strategy_type="sma_cross",
parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
)
recovery_results["empty_symbol"] = {"recovered": True, "result": result}
except Exception as e:
recovery_results["empty_symbol"] = {"recovered": False, "error": str(e)}
# 2. Invalid date range handling
try:
result = await complete_vectorbt_engine.run_backtest(
symbol="ERROR_TEST",
strategy_type="sma_cross",
parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
start_date="2025-01-01", # Future date
end_date="2025-12-31",
)
recovery_results["future_dates"] = {"recovered": True, "result": result}
except Exception as e:
recovery_results["future_dates"] = {"recovered": False, "error": str(e)}
# 3. Invalid strategy parameters
try:
invalid_params = {
"fast_period": -10,
"slow_period": -20,
} # Invalid negative values
result = await complete_vectorbt_engine.run_backtest(
symbol="ERROR_TEST",
strategy_type="sma_cross",
parameters=invalid_params,
start_date="2023-01-01",
end_date="2023-12-31",
)
recovery_results["invalid_params"] = {"recovered": True, "result": result}
except Exception as e:
recovery_results["invalid_params"] = {"recovered": False, "error": str(e)}
# 4. Unknown strategy handling
try:
result = await complete_vectorbt_engine.run_backtest(
symbol="ERROR_TEST",
strategy_type="nonexistent_strategy",
parameters={},
start_date="2023-01-01",
end_date="2023-12-31",
)
recovery_results["unknown_strategy"] = {"recovered": True, "result": result}
except Exception as e:
recovery_results["unknown_strategy"] = {"recovered": False, "error": str(e)}
# Analyze recovery effectiveness
total_tests = len(recovery_results)
recovered_tests = sum(
1 for r in recovery_results.values() if r.get("recovered", False)
)
recovery_rate = recovered_tests / total_tests if total_tests > 0 else 0
logger.info(
f"Error Recovery Test Results:\n"
f" • Total Error Scenarios: {total_tests}\n"
f" • Successfully Recovered: {recovered_tests}\n"
f" • Recovery Rate: {recovery_rate:.1%}\n"
)
for scenario, result in recovery_results.items():
status = "✓ RECOVERED" if result.get("recovered") else "✗ FAILED"
logger.info(f" • {scenario}: {status}")
return recovery_results
async def test_resource_management_comprehensive(self, complete_vectorbt_engine):
"""Test comprehensive resource management across workflow."""
import os
import psutil
process = psutil.Process(os.getpid())
# Baseline measurements
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
initial_threads = process.num_threads()
resource_snapshots = []
# Run multiple backtests while monitoring resources
for i in range(10):
await complete_vectorbt_engine.run_backtest(
symbol=f"RESOURCE_TEST_{i}",
strategy_type="sma_cross",
parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
)
# Take resource snapshot
current_memory = process.memory_info().rss / 1024 / 1024 # MB
current_threads = process.num_threads()
current_cpu = process.cpu_percent()
resource_snapshots.append(
{
"iteration": i + 1,
"memory_mb": current_memory,
"threads": current_threads,
"cpu_percent": current_cpu,
}
)
# Final measurements
final_memory = process.memory_info().rss / 1024 / 1024 # MB
final_threads = process.num_threads()
# Calculate resource growth
memory_growth = final_memory - initial_memory
thread_growth = final_threads - initial_threads
peak_memory = max(snapshot["memory_mb"] for snapshot in resource_snapshots)
avg_threads = sum(snapshot["threads"] for snapshot in resource_snapshots) / len(
resource_snapshots
)
# Resource management assertions
assert memory_growth < 500, (
f"Memory growth too high: {memory_growth:.1f}MB"
) # Max 500MB growth
assert thread_growth <= 10, (
f"Thread growth too high: {thread_growth}"
) # Max 10 additional threads
assert peak_memory < initial_memory + 1000, (
f"Peak memory too high: {peak_memory:.1f}MB"
) # Peak within 1GB of initial
logger.info(
f"Resource Management Test Results:\n"
f" • Initial Memory: {initial_memory:.1f}MB\n"
f" • Final Memory: {final_memory:.1f}MB\n"
f" • Memory Growth: {memory_growth:.1f}MB\n"
f" • Peak Memory: {peak_memory:.1f}MB\n"
f" • Initial Threads: {initial_threads}\n"
f" • Final Threads: {final_threads}\n"
f" • Thread Growth: {thread_growth}\n"
f" • Avg Threads: {avg_threads:.1f}"
)
return {
"memory_growth": memory_growth,
"thread_growth": thread_growth,
"peak_memory": peak_memory,
"resource_snapshots": resource_snapshots,
}
if __name__ == "__main__":
# Run advanced integration tests
pytest.main(
[
__file__,
"-v",
"--tb=short",
"--asyncio-mode=auto",
"--timeout=600", # 10 minute timeout for comprehensive tests
"-x", # Stop on first failure
"--durations=10", # Show 10 slowest tests
]
)
```
--------------------------------------------------------------------------------
/tests/test_tool_estimation_config.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive tests for ToolEstimationConfig.
This module tests the centralized tool cost estimation configuration that replaces
magic numbers scattered throughout the codebase. Tests cover:
- All tool-specific estimates
- Confidence levels and estimation basis
- Monitoring thresholds and alert conditions
- Edge cases and error handling
- Integration with server.py patterns
"""
from unittest.mock import patch
import pytest
from maverick_mcp.config.tool_estimation import (
EstimationBasis,
MonitoringThresholds,
ToolComplexity,
ToolEstimate,
ToolEstimationConfig,
get_tool_estimate,
get_tool_estimation_config,
should_alert_for_usage,
)
class TestToolEstimate:
"""Test ToolEstimate model validation and behavior."""
def test_valid_tool_estimate(self):
"""Test creating a valid ToolEstimate."""
estimate = ToolEstimate(
llm_calls=5,
total_tokens=8000,
confidence=0.8,
based_on=EstimationBasis.EMPIRICAL,
complexity=ToolComplexity.COMPLEX,
notes="Test estimate",
)
assert estimate.llm_calls == 5
assert estimate.total_tokens == 8000
assert estimate.confidence == 0.8
assert estimate.based_on == EstimationBasis.EMPIRICAL
assert estimate.complexity == ToolComplexity.COMPLEX
assert estimate.notes == "Test estimate"
def test_confidence_validation(self):
"""Test confidence level validation."""
# Valid confidence levels
for confidence in [0.0, 0.5, 1.0]:
estimate = ToolEstimate(
llm_calls=1,
total_tokens=100,
confidence=confidence,
based_on=EstimationBasis.EMPIRICAL,
complexity=ToolComplexity.SIMPLE,
)
assert estimate.confidence == confidence
# Invalid confidence levels - Pydantic ValidationError
from pydantic import ValidationError
with pytest.raises(ValidationError):
ToolEstimate(
llm_calls=1,
total_tokens=100,
confidence=-0.1,
based_on=EstimationBasis.EMPIRICAL,
complexity=ToolComplexity.SIMPLE,
)
with pytest.raises(ValidationError):
ToolEstimate(
llm_calls=1,
total_tokens=100,
confidence=1.1,
based_on=EstimationBasis.EMPIRICAL,
complexity=ToolComplexity.SIMPLE,
)
def test_negative_values_validation(self):
"""Test that negative values are not allowed."""
from pydantic import ValidationError
with pytest.raises(ValidationError):
ToolEstimate(
llm_calls=-1,
total_tokens=100,
confidence=0.8,
based_on=EstimationBasis.EMPIRICAL,
complexity=ToolComplexity.SIMPLE,
)
with pytest.raises(ValidationError):
ToolEstimate(
llm_calls=1,
total_tokens=-100,
confidence=0.8,
based_on=EstimationBasis.EMPIRICAL,
complexity=ToolComplexity.SIMPLE,
)
class TestMonitoringThresholds:
"""Test MonitoringThresholds model and validation."""
def test_default_thresholds(self):
"""Test default monitoring thresholds."""
thresholds = MonitoringThresholds()
assert thresholds.llm_calls_warning == 15
assert thresholds.llm_calls_critical == 25
assert thresholds.tokens_warning == 20000
assert thresholds.tokens_critical == 35000
assert thresholds.variance_warning == 0.5
assert thresholds.variance_critical == 1.0
def test_custom_thresholds(self):
"""Test custom monitoring thresholds."""
thresholds = MonitoringThresholds(
llm_calls_warning=10,
llm_calls_critical=20,
tokens_warning=15000,
tokens_critical=30000,
variance_warning=0.3,
variance_critical=0.8,
)
assert thresholds.llm_calls_warning == 10
assert thresholds.llm_calls_critical == 20
assert thresholds.tokens_warning == 15000
assert thresholds.tokens_critical == 30000
assert thresholds.variance_warning == 0.3
assert thresholds.variance_critical == 0.8
class TestToolEstimationConfig:
"""Test the main ToolEstimationConfig class."""
def test_default_configuration(self):
"""Test default configuration initialization."""
config = ToolEstimationConfig()
# Test default estimates by complexity
assert config.simple_default.complexity == ToolComplexity.SIMPLE
assert config.standard_default.complexity == ToolComplexity.STANDARD
assert config.complex_default.complexity == ToolComplexity.COMPLEX
assert config.premium_default.complexity == ToolComplexity.PREMIUM
# Test unknown tool fallback
assert config.unknown_tool_estimate.complexity == ToolComplexity.STANDARD
assert config.unknown_tool_estimate.confidence == 0.3
assert config.unknown_tool_estimate.based_on == EstimationBasis.CONSERVATIVE
def test_get_estimate_known_tools(self):
"""Test getting estimates for known tools."""
config = ToolEstimationConfig()
# Test simple tools
simple_tools = [
"get_stock_price",
"get_company_info",
"get_stock_info",
"calculate_sma",
"get_market_hours",
"get_chart_links",
"list_available_agents",
"clear_cache",
"get_cached_price_data",
"get_watchlist",
"generate_dev_token",
]
for tool in simple_tools:
estimate = config.get_estimate(tool)
assert estimate.complexity == ToolComplexity.SIMPLE
assert estimate.llm_calls <= 1 # Simple tools should have minimal LLM usage
assert estimate.confidence >= 0.8 # Should have high confidence
# Test standard tools
standard_tools = [
"get_rsi_analysis",
"get_macd_analysis",
"get_support_resistance",
"fetch_stock_data",
"get_maverick_stocks",
"get_news_sentiment",
"get_economic_calendar",
]
for tool in standard_tools:
estimate = config.get_estimate(tool)
assert estimate.complexity == ToolComplexity.STANDARD
assert 1 <= estimate.llm_calls <= 5
assert estimate.confidence >= 0.7
# Test complex tools
complex_tools = [
"get_full_technical_analysis",
"risk_adjusted_analysis",
"compare_tickers",
"portfolio_correlation_analysis",
"get_market_overview",
"get_all_screening_recommendations",
]
for tool in complex_tools:
estimate = config.get_estimate(tool)
assert estimate.complexity == ToolComplexity.COMPLEX
assert 4 <= estimate.llm_calls <= 8
assert estimate.confidence >= 0.7
# Test premium tools
premium_tools = [
"analyze_market_with_agent",
"get_agent_streaming_analysis",
"compare_personas_analysis",
]
for tool in premium_tools:
estimate = config.get_estimate(tool)
assert estimate.complexity == ToolComplexity.PREMIUM
assert estimate.llm_calls >= 8
assert estimate.total_tokens >= 10000
def test_get_estimate_unknown_tool(self):
"""Test getting estimate for unknown tools."""
config = ToolEstimationConfig()
estimate = config.get_estimate("unknown_tool_name")
assert estimate == config.unknown_tool_estimate
assert estimate.complexity == ToolComplexity.STANDARD
assert estimate.confidence == 0.3
assert estimate.based_on == EstimationBasis.CONSERVATIVE
def test_get_default_for_complexity(self):
"""Test getting default estimates by complexity."""
config = ToolEstimationConfig()
simple = config.get_default_for_complexity(ToolComplexity.SIMPLE)
assert simple == config.simple_default
standard = config.get_default_for_complexity(ToolComplexity.STANDARD)
assert standard == config.standard_default
complex_est = config.get_default_for_complexity(ToolComplexity.COMPLEX)
assert complex_est == config.complex_default
premium = config.get_default_for_complexity(ToolComplexity.PREMIUM)
assert premium == config.premium_default
def test_should_alert_critical_thresholds(self):
"""Test alert conditions for critical thresholds."""
config = ToolEstimationConfig()
# Test critical LLM calls threshold
should_alert, message = config.should_alert("test_tool", 30, 5000)
assert should_alert
assert "Critical: LLM calls (30) exceeded threshold (25)" in message
# Test critical token threshold
should_alert, message = config.should_alert("test_tool", 5, 40000)
assert should_alert
assert "Critical: Token usage (40000) exceeded threshold (35000)" in message
def test_should_alert_variance_thresholds(self):
"""Test alert conditions for variance thresholds."""
config = ToolEstimationConfig()
# Test tool with known estimate for variance calculation
# get_stock_price: llm_calls=0, total_tokens=200
# Test critical LLM variance (infinite variance since estimate is 0)
should_alert, message = config.should_alert("get_stock_price", 5, 200)
assert should_alert
assert "Critical: LLM call variance" in message
# Test critical token variance (5x the estimate = 400% variance)
should_alert, message = config.should_alert("get_stock_price", 0, 1000)
assert should_alert
assert "Critical: Token variance" in message
def test_should_alert_warning_thresholds(self):
"""Test alert conditions for warning thresholds."""
config = ToolEstimationConfig()
# Test warning LLM calls threshold (15-24 should trigger warning)
# Use unknown tool which has reasonable base estimates to avoid variance issues
should_alert, message = config.should_alert("unknown_tool", 18, 5000)
assert should_alert
assert (
"Warning" in message or "Critical" in message
) # May trigger critical due to variance
# Test warning token threshold with a tool that has known estimates
# get_rsi_analysis: llm_calls=2, total_tokens=3000
should_alert, message = config.should_alert("get_rsi_analysis", 2, 25000)
assert should_alert
assert (
"Warning" in message or "Critical" in message
) # High token variance may trigger critical
def test_should_alert_no_alert(self):
"""Test cases where no alert should be triggered."""
config = ToolEstimationConfig()
# Normal usage within expected ranges
should_alert, message = config.should_alert("get_stock_price", 0, 200)
assert not should_alert
assert message == ""
# Slightly above estimate but within acceptable variance
should_alert, message = config.should_alert("get_stock_price", 0, 250)
assert not should_alert
assert message == ""
def test_get_tools_by_complexity(self):
"""Test filtering tools by complexity category."""
config = ToolEstimationConfig()
simple_tools = config.get_tools_by_complexity(ToolComplexity.SIMPLE)
standard_tools = config.get_tools_by_complexity(ToolComplexity.STANDARD)
complex_tools = config.get_tools_by_complexity(ToolComplexity.COMPLEX)
premium_tools = config.get_tools_by_complexity(ToolComplexity.PREMIUM)
# Verify all tools are categorized
all_tools = simple_tools + standard_tools + complex_tools + premium_tools
assert len(all_tools) == len(config.tool_estimates)
# Verify no overlap between categories
assert len(set(all_tools)) == len(all_tools)
# Verify specific known tools are in correct categories
assert "get_stock_price" in simple_tools
assert "get_rsi_analysis" in standard_tools
assert "get_full_technical_analysis" in complex_tools
assert "analyze_market_with_agent" in premium_tools
def test_get_summary_stats(self):
"""Test summary statistics generation."""
config = ToolEstimationConfig()
stats = config.get_summary_stats()
# Verify structure
assert "total_tools" in stats
assert "by_complexity" in stats
assert "avg_llm_calls" in stats
assert "avg_tokens" in stats
assert "avg_confidence" in stats
assert "basis_distribution" in stats
# Verify content
assert stats["total_tools"] > 0
assert len(stats["by_complexity"]) == 4 # All complexity levels
assert stats["avg_llm_calls"] >= 0
assert stats["avg_tokens"] > 0
assert 0 <= stats["avg_confidence"] <= 1
# Verify complexity distribution adds up
complexity_sum = sum(stats["by_complexity"].values())
assert complexity_sum == stats["total_tools"]
class TestModuleFunctions:
"""Test module-level functions."""
def test_get_tool_estimation_config_singleton(self):
"""Test that get_tool_estimation_config returns a singleton."""
config1 = get_tool_estimation_config()
config2 = get_tool_estimation_config()
# Should return the same instance
assert config1 is config2
@patch("maverick_mcp.config.tool_estimation._config", None)
def test_get_tool_estimation_config_initialization(self):
"""Test that configuration is initialized correctly."""
config = get_tool_estimation_config()
assert isinstance(config, ToolEstimationConfig)
assert len(config.tool_estimates) > 0
def test_get_tool_estimate_function(self):
"""Test the get_tool_estimate convenience function."""
estimate = get_tool_estimate("get_stock_price")
assert isinstance(estimate, ToolEstimate)
assert estimate.complexity == ToolComplexity.SIMPLE
# Test unknown tool
unknown_estimate = get_tool_estimate("unknown_tool")
assert unknown_estimate.based_on == EstimationBasis.CONSERVATIVE
def test_should_alert_for_usage_function(self):
"""Test the should_alert_for_usage convenience function."""
should_alert, message = should_alert_for_usage("test_tool", 30, 5000)
assert isinstance(should_alert, bool)
assert isinstance(message, str)
# Should trigger alert for high LLM calls
assert should_alert
assert "Critical" in message
class TestMagicNumberReplacement:
"""Test that configuration correctly replaces magic numbers from server.py."""
def test_all_usage_tier_tools_have_estimates(self):
"""Test that all tools referenced in server.py have estimates."""
config = ToolEstimationConfig()
# These are tools that were using magic numbers in server.py
# Based on the TOOL usage tier mapping pattern
critical_tools = [
# Simple tools (baseline tier)
"get_stock_price",
"get_company_info",
"get_stock_info",
"calculate_sma",
"get_market_hours",
"get_chart_links",
# Standard tools (core analysis tier)
"get_rsi_analysis",
"get_macd_analysis",
"get_support_resistance",
"fetch_stock_data",
"get_maverick_stocks",
"get_news_sentiment",
# Complex tools (advanced analysis tier)
"get_full_technical_analysis",
"risk_adjusted_analysis",
"compare_tickers",
"portfolio_correlation_analysis",
"get_market_overview",
# Premium tools (orchestration tier)
"analyze_market_with_agent",
"get_agent_streaming_analysis",
"compare_personas_analysis",
]
for tool in critical_tools:
estimate = config.get_estimate(tool)
# Should not get the fallback estimate
assert estimate != config.unknown_tool_estimate, (
f"Tool {tool} missing specific estimate"
)
# Should have reasonable confidence
assert estimate.confidence > 0.5, f"Tool {tool} has low confidence estimate"
def test_estimates_align_with_usage_tiers(self):
"""Test that tool estimates align with usage complexity tiers."""
config = ToolEstimationConfig()
# Simple tools should require minimal resources
simple_tools = [
"get_stock_price",
"get_company_info",
"get_stock_info",
"calculate_sma",
"get_market_hours",
"get_chart_links",
]
for tool in simple_tools:
estimate = config.get_estimate(tool)
assert estimate.complexity == ToolComplexity.SIMPLE
assert estimate.llm_calls <= 1 # Should require minimal/no LLM calls
# Standard tools perform moderate analysis
standard_tools = [
"get_rsi_analysis",
"get_macd_analysis",
"get_support_resistance",
"fetch_stock_data",
"get_maverick_stocks",
]
for tool in standard_tools:
estimate = config.get_estimate(tool)
assert estimate.complexity == ToolComplexity.STANDARD
assert 1 <= estimate.llm_calls <= 5 # Moderate LLM usage
# Complex tools orchestrate heavier workloads
complex_tools = [
"get_full_technical_analysis",
"risk_adjusted_analysis",
"compare_tickers",
"portfolio_correlation_analysis",
]
for tool in complex_tools:
estimate = config.get_estimate(tool)
assert estimate.complexity == ToolComplexity.COMPLEX
assert 4 <= estimate.llm_calls <= 8 # Multiple LLM interactions
# Premium tools coordinate multi-stage workflows
premium_tools = [
"analyze_market_with_agent",
"get_agent_streaming_analysis",
"compare_personas_analysis",
]
for tool in premium_tools:
estimate = config.get_estimate(tool)
assert estimate.complexity == ToolComplexity.PREMIUM
assert estimate.llm_calls >= 8 # Extensive LLM coordination
def test_no_hardcoded_estimates_remain(self):
"""Test that estimates are data-driven, not hardcoded."""
config = ToolEstimationConfig()
# All tool estimates should have basis information
for tool_name, estimate in config.tool_estimates.items():
assert estimate.based_on in EstimationBasis
assert estimate.complexity in ToolComplexity
assert estimate.notes is not None, f"Tool {tool_name} missing notes"
# Empirical estimates should generally have reasonable confidence
if estimate.based_on == EstimationBasis.EMPIRICAL:
assert estimate.confidence >= 0.6, (
f"Empirical estimate for {tool_name} has very low confidence"
)
# Conservative estimates should have lower confidence
if estimate.based_on == EstimationBasis.CONSERVATIVE:
assert estimate.confidence <= 0.6, (
f"Conservative estimate for {tool_name} has unexpectedly high confidence"
)
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_empty_configuration(self):
"""Test behavior with empty tool estimates."""
config = ToolEstimationConfig(tool_estimates={})
# Should fall back to unknown tool estimate
estimate = config.get_estimate("any_tool")
assert estimate == config.unknown_tool_estimate
# Summary stats should handle empty case
stats = config.get_summary_stats()
assert stats == {}
def test_alert_with_zero_estimates(self):
"""Test alert calculation when estimates are zero."""
config = ToolEstimationConfig()
# Tool with zero LLM calls in estimate
should_alert, message = config.should_alert("get_stock_price", 1, 200)
# Should alert because variance is infinite (1 vs 0 expected)
assert should_alert
def test_variance_calculation_edge_cases(self):
"""Test variance calculation with edge cases."""
config = ToolEstimationConfig()
# Perfect match should not alert
should_alert, message = config.should_alert("get_rsi_analysis", 2, 3000)
# get_rsi_analysis has: llm_calls=2, total_tokens=3000
assert not should_alert
def test_performance_with_large_usage(self):
"""Test performance and behavior with extremely large usage values."""
config = ToolEstimationConfig()
# Very large values should still work
should_alert, message = config.should_alert("test_tool", 1000, 1000000)
assert should_alert
assert "Critical" in message
def test_custom_monitoring_thresholds(self):
"""Test configuration with custom monitoring thresholds."""
custom_monitoring = MonitoringThresholds(
llm_calls_warning=5,
llm_calls_critical=10,
tokens_warning=1000,
tokens_critical=5000,
variance_warning=0.1,
variance_critical=0.2,
)
config = ToolEstimationConfig(monitoring=custom_monitoring)
# Should use custom thresholds
# Test critical threshold first (easier to predict)
should_alert, message = config.should_alert("test_tool", 12, 500)
assert should_alert
assert "Critical" in message
# Test LLM calls warning threshold
should_alert, message = config.should_alert(
"test_tool", 6, 100
) # Lower tokens to avoid variance issues
assert should_alert
# May be warning or critical depending on variance calculation
class TestIntegrationPatterns:
"""Test patterns that match server.py integration."""
def test_low_confidence_logging_pattern(self):
"""Test identifying tools that need monitoring due to low confidence."""
config = ToolEstimationConfig()
low_confidence_tools = []
for tool_name, estimate in config.tool_estimates.items():
if estimate.confidence < 0.8:
low_confidence_tools.append(tool_name)
# These tools should be logged for monitoring in production
assert len(low_confidence_tools) > 0
# Verify these are typically more complex tools
for tool_name in low_confidence_tools:
estimate = config.get_estimate(tool_name)
# Low confidence tools should typically be complex, premium, or analytical standard tools
assert estimate.complexity in [
ToolComplexity.STANDARD,
ToolComplexity.COMPLEX,
ToolComplexity.PREMIUM,
], (
f"Tool {tool_name} with low confidence has unexpected complexity {estimate.complexity}"
)
def test_error_handling_fallback_pattern(self):
"""Test the error handling pattern used in server.py."""
config = ToolEstimationConfig()
# Simulate error case - should fall back to unknown tool estimate
try:
# This would be the pattern in server.py when get_tool_estimate fails
estimate = config.get_estimate("nonexistent_tool")
fallback_estimate = config.unknown_tool_estimate
# Verify fallback has conservative characteristics
assert fallback_estimate.based_on == EstimationBasis.CONSERVATIVE
assert fallback_estimate.confidence == 0.3
assert fallback_estimate.complexity == ToolComplexity.STANDARD
# Should be the same as what get_estimate returns for unknown tools
assert estimate == fallback_estimate
except Exception:
# If estimation fails entirely, should be able to use fallback
fallback = config.unknown_tool_estimate
assert fallback.llm_calls > 0
assert fallback.total_tokens > 0
def test_usage_logging_extra_fields(self):
"""Test that estimates provide all fields needed for logging."""
config = ToolEstimationConfig()
for _tool_name, estimate in config.tool_estimates.items():
# Verify all fields needed for server.py logging are present
assert hasattr(estimate, "confidence")
assert hasattr(estimate, "based_on")
assert hasattr(estimate, "complexity")
assert hasattr(estimate, "llm_calls")
assert hasattr(estimate, "total_tokens")
# Verify fields have appropriate types for logging
assert isinstance(estimate.confidence, float)
assert isinstance(estimate.based_on, EstimationBasis)
assert isinstance(estimate.complexity, ToolComplexity)
assert isinstance(estimate.llm_calls, int)
assert isinstance(estimate.total_tokens, int)
```
--------------------------------------------------------------------------------
/docs/PORTFOLIO_PERSONALIZATION_PLAN.md:
--------------------------------------------------------------------------------
```markdown
# PORTFOLIO PERSONALIZATION - EXECUTION PLAN
## 1. Big Picture / Goal
**Objective:** Transform MaverickMCP's portfolio analysis tools from stateless, repetitive-input operations into an intelligent, personalized AI financial assistant through persistent portfolio storage and context-aware tool integration.
**Architectural Goal:** Implement a two-phase system that (1) adds persistent portfolio storage with cost basis tracking using established DDD patterns, and (2) intelligently enhances existing tools to auto-detect user holdings and provide personalized analysis without breaking the stateless MCP tool contract.
**Success Criteria (Mandatory):**
- **Phase 1 Complete:** 4 new MCP tools (`add_portfolio_position`, `get_my_portfolio`, `remove_portfolio_position`, `clear_my_portfolio`) and 1 MCP resource (`portfolio://my-holdings`) fully functional
- **Database Integration:** SQLAlchemy models with proper cost basis averaging, Alembic migration creating tables without conflicts
- **Phase 2 Integration:** 3 existing tools enhanced (`risk_adjusted_analysis`, `portfolio_correlation_analysis`, `compare_tickers`) with automatic portfolio detection
- **AI Context Injection:** Portfolio resource provides live P&L, diversification metrics, and position details to AI agents automatically
- **Test Coverage:** 85%+ test coverage with unit, integration, and domain tests passing
- **Code Quality:** Zero linting errors (ruff), full type annotations (ty), all hooks passing
- **Documentation:** PORTFOLIO.md guide, updated tool docstrings, usage examples in Claude Desktop
**Financial Disclaimer:** All portfolio features include educational disclaimers. No investment recommendations. Local-first storage only. No tax advice provided.
## 2. To-Do List (High Level)
### Phase 1: Persistent Portfolio Storage Foundation (4-5 days)
- [ ] **Spike 1:** Research cost basis averaging algorithms and edge cases (FIFO, average cost)
- [ ] **Domain Entities:** Create `Portfolio` and `Position` domain entities with business logic
- [ ] **Database Models:** Implement `UserPortfolio` and `PortfolioPosition` SQLAlchemy models
- [ ] **Migration:** Create Alembic migration with proper indexes and constraints
- [ ] **MCP Tools:** Implement 4 portfolio management tools with validation
- [ ] **MCP Resource:** Implement `portfolio://my-holdings` with live P&L calculations
- [ ] **Unit Tests:** Comprehensive domain entity and cost basis tests
- [ ] **Integration Tests:** Database operation and transaction tests
### Phase 2: Intelligent Tool Integration (2-3 days)
- [ ] **Risk Analysis Enhancement:** Add position awareness to `risk_adjusted_analysis`
- [ ] **Correlation Enhancement:** Enable `portfolio_correlation_analysis` with no arguments
- [ ] **Comparison Enhancement:** Enable `compare_tickers` with optional portfolio auto-fill
- [ ] **Resource Enhancement:** Add live market data to portfolio resource
- [ ] **Integration Tests:** Cross-tool functionality validation
- [ ] **Documentation:** Update existing tool docstrings with new capabilities
### Phase 3: Polish & Documentation (1-2 days)
- [ ] **Manual Testing:** Claude Desktop end-to-end workflow validation
- [ ] **Error Handling:** Edge case coverage (partial sells, zero shares, invalid tickers)
- [ ] **Performance:** Query optimization, batch operations, caching strategy
- [ ] **Documentation:** Complete PORTFOLIO.md with examples and screenshots
- [ ] **Migration Testing:** Test upgrade/downgrade paths
## 3. Plan Details (Spikes & Features)
### Spike 1: Cost Basis Averaging Research
**Action:** Investigate cost basis calculation methods (FIFO, LIFO, average cost) and determine optimal approach for educational portfolio tracking.
**Steps:**
1. Research IRS cost basis methods and educational best practices
2. Analyze existing `PortfolioManager` tool (JSON-based, average cost) for patterns
3. Design algorithm for averaging purchases and handling partial sells
4. Create specification document for edge cases:
- Multiple purchases at different prices
- Partial position sales
- Zero/negative share handling
- Rounding and precision (financial data uses Numeric(12,4))
5. Benchmark performance for 100+ positions with 1000+ transactions
**Expected Outcome:** Clear specification for cost basis implementation using **average cost method** (simplest for educational use, matches existing PortfolioManager), with edge case handling documented.
**Decision Rationale:** Average cost is simpler than FIFO/LIFO, appropriate for educational context, and avoids tax accounting complexity.
---
### Feature A: Domain Entities (DDD Pattern)
**Goal:** Create pure business logic entities following MaverickMCP's established DDD patterns (similar to backtesting domain entities).
**Files to Create:**
- `maverick_mcp/domain/portfolio.py` - Core domain entities
- `maverick_mcp/domain/position.py` - Position value objects
**Domain Entity Design:**
```python
# maverick_mcp/domain/portfolio.py
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from typing import List, Optional
@dataclass
class Position:
"""Value object representing a single portfolio position."""
ticker: str
shares: Decimal # Use Decimal for precision
average_cost_basis: Decimal
total_cost: Decimal
purchase_date: datetime # Earliest purchase
notes: Optional[str] = None
def add_shares(self, shares: Decimal, price: Decimal, date: datetime) -> "Position":
"""Add shares with automatic cost basis averaging."""
new_total_shares = self.shares + shares
new_total_cost = self.total_cost + (shares * price)
new_avg_cost = new_total_cost / new_total_shares
return Position(
ticker=self.ticker,
shares=new_total_shares,
average_cost_basis=new_avg_cost,
total_cost=new_total_cost,
purchase_date=min(self.purchase_date, date),
notes=self.notes
)
def remove_shares(self, shares: Decimal) -> Optional["Position"]:
"""Remove shares, return None if position fully closed."""
if shares >= self.shares:
return None # Full position close
new_shares = self.shares - shares
new_total_cost = new_shares * self.average_cost_basis
return Position(
ticker=self.ticker,
shares=new_shares,
average_cost_basis=self.average_cost_basis,
total_cost=new_total_cost,
purchase_date=self.purchase_date,
notes=self.notes
)
def calculate_current_value(self, current_price: Decimal) -> dict:
"""Calculate live P&L metrics."""
current_value = self.shares * current_price
unrealized_pnl = current_value - self.total_cost
pnl_percentage = (unrealized_pnl / self.total_cost * 100) if self.total_cost else Decimal(0)
return {
"current_value": current_value,
"unrealized_pnl": unrealized_pnl,
"pnl_percentage": pnl_percentage
}
@dataclass
class Portfolio:
"""Aggregate root for user portfolio."""
portfolio_id: str # UUID
user_id: str # "default" for single-user
name: str
positions: List[Position]
created_at: datetime
updated_at: datetime
def add_position(self, ticker: str, shares: Decimal, price: Decimal,
date: datetime, notes: Optional[str] = None) -> None:
"""Add or update position with automatic averaging."""
# Find existing position
for i, pos in enumerate(self.positions):
if pos.ticker == ticker:
self.positions[i] = pos.add_shares(shares, price, date)
self.updated_at = datetime.now(UTC)
return
# Create new position
new_position = Position(
ticker=ticker,
shares=shares,
average_cost_basis=price,
total_cost=shares * price,
purchase_date=date,
notes=notes
)
self.positions.append(new_position)
self.updated_at = datetime.now(UTC)
def remove_position(self, ticker: str, shares: Optional[Decimal] = None) -> bool:
"""Remove position or partial shares."""
for i, pos in enumerate(self.positions):
if pos.ticker == ticker:
if shares is None or shares >= pos.shares:
# Full position removal
self.positions.pop(i)
else:
# Partial removal
updated_pos = pos.remove_shares(shares)
if updated_pos:
self.positions[i] = updated_pos
else:
self.positions.pop(i)
self.updated_at = datetime.now(UTC)
return True
return False
def get_position(self, ticker: str) -> Optional[Position]:
"""Get position by ticker."""
return next((pos for pos in self.positions if pos.ticker == ticker), None)
def get_total_invested(self) -> Decimal:
"""Calculate total capital invested."""
return sum(pos.total_cost for pos in self.positions)
def calculate_portfolio_metrics(self, current_prices: dict[str, Decimal]) -> dict:
"""Calculate comprehensive portfolio metrics."""
total_value = Decimal(0)
total_cost = Decimal(0)
position_details = []
for pos in self.positions:
current_price = current_prices.get(pos.ticker, pos.average_cost_basis)
metrics = pos.calculate_current_value(current_price)
total_value += metrics["current_value"]
total_cost += pos.total_cost
position_details.append({
"ticker": pos.ticker,
"shares": float(pos.shares),
"cost_basis": float(pos.average_cost_basis),
"current_price": float(current_price),
**{k: float(v) for k, v in metrics.items()}
})
total_pnl = total_value - total_cost
total_pnl_pct = (total_pnl / total_cost * 100) if total_cost else Decimal(0)
return {
"total_value": float(total_value),
"total_invested": float(total_cost),
"total_pnl": float(total_pnl),
"total_pnl_percentage": float(total_pnl_pct),
"position_count": len(self.positions),
"positions": position_details
}
```
**Testing Strategy:**
- Unit tests for cost basis averaging edge cases
- Property-based tests for arithmetic precision
- Edge case tests: zero shares, negative P&L, division by zero
---
### Feature B: Database Models (SQLAlchemy ORM)
**Goal:** Create persistent storage models following established patterns in `maverick_mcp/data/models.py`.
**Files to Modify:**
- `maverick_mcp/data/models.py` - Add new models (lines ~1700+)
**Model Design:**
```python
# Add to maverick_mcp/data/models.py
class UserPortfolio(TimestampMixin, Base):
"""
User portfolio for tracking investment holdings.
Follows personal-use design: single user_id="default"
"""
__tablename__ = "mcp_portfolios"
id = Column(Uuid, primary_key=True, default=uuid.uuid4)
user_id = Column(String(50), nullable=False, default="default", index=True)
name = Column(String(200), nullable=False, default="My Portfolio")
# Relationships
positions = relationship(
"PortfolioPosition",
back_populates="portfolio",
cascade="all, delete-orphan",
lazy="selectin" # Efficient loading
)
# Indexes for queries
__table_args__ = (
Index("idx_portfolio_user", "user_id"),
UniqueConstraint("user_id", "name", name="uq_user_portfolio_name"),
)
def __repr__(self):
return f"<UserPortfolio(id={self.id}, name='{self.name}', positions={len(self.positions)})>"
class PortfolioPosition(TimestampMixin, Base):
"""
Individual position within a portfolio with cost basis tracking.
"""
__tablename__ = "mcp_portfolio_positions"
id = Column(Uuid, primary_key=True, default=uuid.uuid4)
portfolio_id = Column(Uuid, ForeignKey("mcp_portfolios.id", ondelete="CASCADE"), nullable=False)
# Position details
ticker = Column(String(20), nullable=False, index=True)
shares = Column(Numeric(20, 8), nullable=False) # High precision for fractional shares
average_cost_basis = Column(Numeric(12, 4), nullable=False) # Financial precision
total_cost = Column(Numeric(20, 4), nullable=False) # Total capital invested
purchase_date = Column(DateTime(timezone=True), nullable=False) # Earliest purchase
notes = Column(Text, nullable=True) # Optional user notes
# Relationships
portfolio = relationship("UserPortfolio", back_populates="positions")
# Indexes for efficient queries
__table_args__ = (
Index("idx_position_portfolio", "portfolio_id"),
Index("idx_position_ticker", "ticker"),
Index("idx_position_portfolio_ticker", "portfolio_id", "ticker"),
UniqueConstraint("portfolio_id", "ticker", name="uq_portfolio_position_ticker"),
)
def __repr__(self):
return f"<PortfolioPosition(ticker='{self.ticker}', shares={self.shares}, cost_basis={self.average_cost_basis})>"
```
**Key Design Decisions:**
1. **Table Names:** `mcp_portfolios` and `mcp_portfolio_positions` (consistent with `mcp_*` pattern)
2. **user_id:** Default "default" for single-user personal use
3. **Numeric Precision:** Matches existing financial data patterns (12,4 for prices, 20,8 for shares)
4. **Cascade Delete:** Portfolio deletion removes all positions automatically
5. **Unique Constraint:** One position per ticker per portfolio
6. **Indexes:** Optimized for common queries (user lookup, ticker filtering)
---
### Feature C: Alembic Migration
**Goal:** Create database migration following established patterns without conflicts.
**File to Create:**
- `alembic/versions/014_add_portfolio_models.py`
**Migration Pattern:**
```python
"""Add portfolio and position models
Revision ID: 014_add_portfolio_models
Revises: 013_add_backtest_persistence_models
Create Date: 2025-11-01 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers
revision = '014_add_portfolio_models'
down_revision = '013_add_backtest_persistence_models'
branch_labels = None
depends_on = None
def upgrade():
"""Create portfolio management tables."""
# Create portfolios table
op.create_table(
'mcp_portfolios',
sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column('user_id', sa.String(50), nullable=False, server_default='default'),
sa.Column('name', sa.String(200), nullable=False, server_default='My Portfolio'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# Create indexes on portfolios
op.create_index('idx_portfolio_user', 'mcp_portfolios', ['user_id'])
op.create_unique_constraint('uq_user_portfolio_name', 'mcp_portfolios', ['user_id', 'name'])
# Create positions table
op.create_table(
'mcp_portfolio_positions',
sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column('portfolio_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('ticker', sa.String(20), nullable=False),
sa.Column('shares', sa.Numeric(20, 8), nullable=False),
sa.Column('average_cost_basis', sa.Numeric(12, 4), nullable=False),
sa.Column('total_cost', sa.Numeric(20, 4), nullable=False),
sa.Column('purchase_date', sa.DateTime(timezone=True), nullable=False),
sa.Column('notes', sa.Text, nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.ForeignKeyConstraint(['portfolio_id'], ['mcp_portfolios.id'], ondelete='CASCADE'),
)
# Create indexes on positions
op.create_index('idx_position_portfolio', 'mcp_portfolio_positions', ['portfolio_id'])
op.create_index('idx_position_ticker', 'mcp_portfolio_positions', ['ticker'])
op.create_index('idx_position_portfolio_ticker', 'mcp_portfolio_positions', ['portfolio_id', 'ticker'])
op.create_unique_constraint('uq_portfolio_position_ticker', 'mcp_portfolio_positions', ['portfolio_id', 'ticker'])
def downgrade():
"""Drop portfolio management tables."""
op.drop_table('mcp_portfolio_positions')
op.drop_table('mcp_portfolios')
```
**Testing:**
- Test upgrade: `alembic upgrade head`
- Test downgrade: `alembic downgrade -1`
- Verify indexes created: SQL query inspection
- Test with SQLite and PostgreSQL
---
### Feature D: MCP Tools Implementation
**Goal:** Implement 4 portfolio management tools following tool_registry.py pattern.
**Files to Create:**
- `maverick_mcp/api/routers/portfolio_management.py` - New tool implementations
- `maverick_mcp/api/services/portfolio_persistence_service.py` - Service layer
- `maverick_mcp/validation/portfolio_management.py` - Pydantic validation
**Service Layer Pattern:**
```python
# maverick_mcp/api/services/portfolio_persistence_service.py
class PortfolioPersistenceService(BaseService):
"""Service for portfolio CRUD operations."""
async def get_or_create_default_portfolio(self) -> UserPortfolio:
"""Get the default portfolio, create if doesn't exist."""
pass
async def add_position(self, ticker: str, shares: Decimal,
price: Decimal, date: datetime,
notes: Optional[str]) -> PortfolioPosition:
"""Add or update position with cost averaging."""
pass
async def get_portfolio_with_live_data(self) -> dict:
"""Fetch portfolio with current market prices."""
pass
async def remove_position(self, ticker: str,
shares: Optional[Decimal]) -> bool:
"""Remove position or partial shares."""
pass
async def clear_portfolio(self) -> bool:
"""Delete all positions."""
pass
```
**Tool Registration:**
```python
# Add to maverick_mcp/api/routers/tool_registry.py
def register_portfolio_management_tools(mcp: FastMCP) -> None:
"""Register portfolio management tools."""
from maverick_mcp.api.routers.portfolio_management import (
add_portfolio_position,
get_my_portfolio,
remove_portfolio_position,
clear_my_portfolio
)
mcp.tool(name="portfolio_add_position")(add_portfolio_position)
mcp.tool(name="portfolio_get_my_portfolio")(get_my_portfolio)
mcp.tool(name="portfolio_remove_position")(remove_portfolio_position)
mcp.tool(name="portfolio_clear")(clear_my_portfolio)
```
---
### Feature E: MCP Resource Implementation
**Goal:** Create `portfolio://my-holdings` resource for automatic AI context injection.
**File to Modify:**
- `maverick_mcp/api/server.py` - Add resource alongside existing health:// and dashboard:// resources
**Resource Implementation:**
```python
# Add to maverick_mcp/api/server.py (around line 823, near other resources)
@mcp.resource("portfolio://my-holdings")
def portfolio_holdings_resource() -> dict[str, Any]:
"""
Portfolio holdings resource for AI context injection.
Provides comprehensive portfolio context to AI agents including:
- Current positions with live P&L
- Portfolio metrics and diversification
- Sector exposure analysis
- Top/bottom performers
This resource is automatically available to AI agents during conversations,
enabling personalized analysis without requiring manual ticker input.
"""
# Implementation using service layer with async handling
pass
```
---
### Feature F: Phase 2 Tool Enhancements
**Goal:** Enhance existing tools to auto-detect portfolio holdings.
**Files to Modify:**
1. `maverick_mcp/api/routers/portfolio.py` - Enhance 3 existing tools
2. `maverick_mcp/validation/portfolio.py` - Update validation to allow optional parameters
**Enhancement Pattern:**
- Add optional parameters (tickers can be None)
- Check portfolio for holdings if no tickers provided
- Add position awareness to analysis results
- Maintain backward compatibility
---
## 4. Progress (Living Document Section)
| Date | Time | Item Completed / Status Update | Resulting Changes (LOC/Files) |
|:-----|:-----|:------------------------------|:------------------------------|
| 2025-11-01 | Start | Plan approved and documented | PORTFOLIO_PERSONALIZATION_PLAN.md created |
| TBD | TBD | Implementation begins | - |
_(This section will be updated during implementation)_
---
## 5. Surprises and Discoveries
_(Technical issues discovered during implementation will be documented here)_
**Anticipated Challenges:**
1. **MCP Resource Async Context:** Resources are sync functions but need async database calls - solved with event loop management (see existing health_resource pattern)
2. **Cost Basis Precision:** Financial calculations require Decimal precision, not floats - use Numeric(12,4) for prices, Numeric(20,8) for shares
3. **Portfolio Resource Performance:** Live price fetching could be slow - implement caching strategy, consider async batching
4. **Single User Assumption:** No user authentication means all operations use user_id="default" - acceptable for personal use
---
## 6. Decision Log
| Date | Decision | Rationale |
|:-----|:---------|:----------|
| 2025-11-01 | **Cost Basis Method: Average Cost** | Simplest for educational use, matches existing PortfolioManager, avoids tax accounting complexity |
| 2025-11-01 | **Table Names: mcp_portfolios, mcp_portfolio_positions** | Consistent with existing mcp_* naming convention for MCP-specific tables |
| 2025-11-01 | **User ID: "default" for all users** | Single-user personal-use design, consistent with auth disabled architecture |
| 2025-11-01 | **Numeric Precision: Numeric(12,4) for prices, Numeric(20,8) for shares** | Matches existing financial data patterns, supports fractional shares |
| 2025-11-01 | **Optional tickers parameter for Phase 2** | Enables "just works" UX while maintaining backward compatibility |
| 2025-11-01 | **MCP Resource for AI context** | Most elegant solution for automatic context injection without breaking tool contracts |
| 2025-11-01 | **Domain-Driven Design pattern** | Follows established MaverickMCP architecture, clean separation of concerns |
---
## 7. Implementation Phases
### Phase 1: Foundation (4-5 days)
**Files Created:** 8 new files
**Files Modified:** 3 existing files
**Estimated LOC:** ~2,500 lines
**Tests:** ~1,200 lines
### Phase 2: Integration (2-3 days)
**Files Modified:** 4 existing files
**Estimated LOC:** ~800 lines additional
**Tests:** ~600 lines additional
### Phase 3: Polish (1-2 days)
**Documentation:** PORTFOLIO.md (~300 lines)
**Performance:** Query optimization
**Testing:** Manual Claude Desktop validation
**Total Effort:** 7-10 days
**Total New Code:** ~3,500 lines (including tests)
**Total Tests:** ~1,800 lines
---
## 8. Risk Assessment
**Low Risk:**
- ✅ Follows established patterns
- ✅ No breaking changes to existing tools
- ✅ Optional Phase 2 enhancements
- ✅ Well-scoped feature
**Medium Risk:**
- ⚠️ MCP resource performance with live prices
- ⚠️ Migration compatibility (SQLite vs PostgreSQL)
- ⚠️ Edge cases in cost basis averaging
**Mitigation Strategies:**
1. **Performance:** Implement caching, batch price fetches, add timeout protection
2. **Migration:** Test with both SQLite and PostgreSQL, provide rollback path
3. **Edge Cases:** Comprehensive unit tests, property-based testing for arithmetic
---
## 9. Testing Strategy
**Unit Tests (~60% of test code):**
- Domain entity logic (Position, Portfolio)
- Cost basis averaging edge cases
- Numeric precision validation
- Business logic validation
**Integration Tests (~30% of test code):**
- Database CRUD operations
- Migration upgrade/downgrade
- Service layer with real database
- Cross-tool functionality
**Manual Tests (~10% of effort):**
- Claude Desktop end-to-end workflows
- Natural language interactions
- MCP resource visibility
- Tool integration scenarios
**Test Coverage Target:** 85%+
---
## 10. Success Metrics
**Functional Success:**
- [ ] All 4 new tools work in Claude Desktop
- [ ] Portfolio resource visible to AI agents
- [ ] Cost basis averaging accurate to 4 decimal places
- [ ] Migration works on SQLite and PostgreSQL
- [ ] 3 enhanced tools auto-detect portfolio
**Quality Success:**
- [ ] 85%+ test coverage
- [ ] Zero linting errors (ruff)
- [ ] Full type annotations (ty check passes)
- [ ] All pre-commit hooks pass
**UX Success:**
- [ ] "Analyze my portfolio" works without ticker input
- [ ] AI agents reference actual holdings in responses
- [ ] Natural language interactions feel seamless
- [ ] Error messages are clear and actionable
---
## 11. Related Documentation
- **Original Issue:** [#40 - Portfolio Personalization](https://github.com/wshobson/maverick-mcp/issues/40)
- **User Documentation:** `docs/PORTFOLIO.md` (to be created)
- **API Documentation:** Tool docstrings and MCP introspection
- **Testing Guide:** `tests/README.md` (to be updated)
---
This execution plan provides a comprehensive roadmap following the PLANS.md rubric structure. The implementation is well-scoped, follows established patterns, and delivers significant UX improvement while maintaining code quality and architectural integrity.
```