This is page 26 of 28. Use http://codebase.md/wshobson/maverick-mcp?page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .jules
│ └── bolt.md
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ ├── unit
│ │ └── test_stock_repository_adapter.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_deep_research_functional.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive functional tests for DeepResearchAgent.
This test suite focuses on testing the actual research functionality including:
## Web Search Integration Tests (TestWebSearchIntegration):
- Exa and Tavily search provider query formatting and result processing
- Provider fallback behavior when APIs fail
- Search result deduplication from multiple providers
- Social media filtering and content processing
## Research Synthesis Tests (TestResearchSynthesis):
- Persona-aware content analysis with different investment styles
- Complete research synthesis workflow from query to findings
- Iterative research refinement based on initial results
- Fact validation and source credibility scoring
## Persona-Based Research Tests (TestPersonaBasedResearch):
- Conservative persona focus on stability, dividends, and risk factors
- Aggressive persona exploration of growth opportunities and innovation
- Day trader persona emphasis on short-term catalysts and volatility
- Research depth differences between conservative and aggressive approaches
## Multi-Step Research Workflow Tests (TestMultiStepResearchWorkflow):
- End-to-end research workflow from initial query to final report
- Handling of insufficient or conflicting information scenarios
- Research focusing and refinement based on discovered gaps
- Citation generation and source attribution
## Research Method Specialization Tests (TestResearchMethodSpecialization):
- Sentiment analysis specialization with news and social signals
- Fundamental analysis focusing on financials and company data
- Competitive analysis examining market position and rivals
- Proper routing to specialized analysis based on focus areas
## Error Handling and Resilience Tests (TestErrorHandlingAndResilience):
- Graceful degradation when search providers are unavailable
- Content analysis fallback when LLM services fail
- Partial search failure handling with provider redundancy
- Circuit breaker behavior and timeout handling
## Research Quality and Validation Tests (TestResearchQualityAndValidation):
- Research confidence calculation based on source quality and diversity
- Source credibility scoring (government, financial sites vs. blogs)
- Source diversity assessment for balanced research
- Investment recommendation logic based on persona and findings
## Key Features Tested:
- **Realistic Mock Data**: Uses comprehensive financial article samples
- **Provider Integration**: Tests both Exa and Tavily search providers
- **LangGraph Workflows**: Tests complete research state machine
- **Persona Adaptation**: Validates different investor behavior patterns
- **Error Resilience**: Ensures system continues operating with degraded capabilities
- **Research Logic**: Tests actual synthesis and analysis rather than just API calls
All tests use realistic mock data and test the research logic rather than just API connectivity.
26 test cases cover the complete research pipeline from initial search to final recommendations.
"""
import json
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from maverick_mcp.agents.deep_research import (
PERSONA_RESEARCH_FOCUS,
RESEARCH_DEPTH_LEVELS,
ContentAnalyzer,
DeepResearchAgent,
ExaSearchProvider,
TavilySearchProvider,
)
from maverick_mcp.exceptions import WebSearchError
# Mock Data Fixtures
@pytest.fixture
def mock_llm():
"""Mock LLM with realistic responses for content analysis."""
llm = MagicMock()
llm.ainvoke = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
# Default response for content analysis
def mock_response(messages):
response = Mock()
response.content = json.dumps(
{
"KEY_INSIGHTS": [
"Strong revenue growth in cloud services",
"Market expansion in international segments",
"Increasing competitive pressure from rivals",
],
"SENTIMENT": {"direction": "bullish", "confidence": 0.75},
"RISK_FACTORS": [
"Regulatory scrutiny in international markets",
"Supply chain disruptions affecting hardware",
],
"OPPORTUNITIES": [
"AI integration driving new revenue streams",
"Subscription model improving recurring revenue",
],
"CREDIBILITY": 0.8,
"RELEVANCE": 0.9,
"SUMMARY": "Analysis shows strong fundamentals with growth opportunities despite some regulatory risks.",
}
)
return response
llm.ainvoke.side_effect = mock_response
return llm
@pytest.fixture
def comprehensive_search_results():
"""Comprehensive mock search results from multiple providers."""
return [
{
"url": "https://finance.yahoo.com/news/apple-earnings-q4-2024",
"title": "Apple Reports Strong Q4 2024 Earnings",
"content": """Apple Inc. reported quarterly earnings that beat Wall Street expectations,
driven by strong iPhone sales and growing services revenue. The company posted
revenue of $94.9 billion, up 6% year-over-year. CEO Tim Cook highlighted the
success of the iPhone 15 lineup and expressed optimism about AI integration
in future products. Services revenue reached $22.3 billion, representing
a 16% increase. The company also announced a 4% increase in quarterly dividend.""",
"published_date": "2024-01-25T10:30:00Z",
"score": 0.92,
"provider": "exa",
"author": "Financial Times Staff",
},
{
"url": "https://seekingalpha.com/article/apple-technical-analysis-2024",
"title": "Apple Stock Technical Analysis: Bullish Momentum Building",
"content": """Technical analysis of Apple stock shows bullish momentum building
with the stock breaking above key resistance at $190. Volume has been
increasing on up days, suggesting institutional accumulation. The RSI
is at 58, indicating room for further upside. Key support levels are
at $185 and $180. Price target for the next quarter is $210-$220 based
on chart patterns and momentum indicators.""",
"published_date": "2024-01-24T14:45:00Z",
"score": 0.85,
"provider": "exa",
"author": "Tech Analyst Pro",
},
{
"url": "https://reuters.com/apple-supply-chain-concerns",
"title": "Apple Faces Supply Chain Headwinds in 2024",
"content": """Apple is encountering supply chain challenges that could impact
production timelines for its upcoming product launches. Manufacturing
partners in Asia report delays due to component shortages, particularly
for advanced semiconductors. The company is working to diversify its
supplier base to reduce risks. Despite these challenges, analysts
remain optimistic about Apple's ability to meet demand through
strategic inventory management.""",
"published_date": "2024-01-23T08:15:00Z",
"score": 0.78,
"provider": "tavily",
"author": "Reuters Technology Team",
},
{
"url": "https://fool.com/apple-ai-strategy-competitive-advantage",
"title": "Apple's AI Strategy Could Be Its Next Competitive Moat",
"content": """Apple's approach to artificial intelligence differs significantly
from competitors, focusing on on-device processing and privacy protection.
The company's investment in AI chips and machine learning capabilities
positions it well for the next phase of mobile computing. Industry
experts predict Apple's AI integration will drive hardware upgrade
cycles and create new revenue opportunities in services. The privacy-first
approach could become a key differentiator in the market.""",
"published_date": "2024-01-22T16:20:00Z",
"score": 0.88,
"provider": "exa",
"author": "Investment Strategy Team",
},
{
"url": "https://barrons.com/apple-dividend-growth-analysis",
"title": "Apple's Dividend Growth Story Continues",
"content": """Apple has increased its dividend for the 12th consecutive year,
demonstrating strong cash flow generation and commitment to returning
capital to shareholders. The company's dividend yield of 0.5% may seem
modest, but the consistent growth rate of 7% annually makes it attractive
for income-focused investors. With over $162 billion in cash and
marketable securities, Apple has the financial flexibility to continue
rewarding shareholders while investing in growth initiatives.""",
"published_date": "2024-01-21T11:00:00Z",
"score": 0.82,
"provider": "tavily",
"author": "Dividend Analysis Team",
},
]
@pytest.fixture
def mock_research_agent(mock_llm):
"""Create a DeepResearchAgent with mocked dependencies."""
with (
patch("maverick_mcp.agents.deep_research.ExaSearchProvider") as mock_exa,
patch("maverick_mcp.agents.deep_research.TavilySearchProvider") as mock_tavily,
):
# Mock search providers
mock_exa_instance = Mock()
mock_tavily_instance = Mock()
mock_exa.return_value = mock_exa_instance
mock_tavily.return_value = mock_tavily_instance
agent = DeepResearchAgent(
llm=mock_llm,
persona="moderate",
exa_api_key="mock-key",
tavily_api_key="mock-key",
)
# Add mock providers to the agent for testing
agent.search_providers = [mock_exa_instance, mock_tavily_instance]
return agent
class TestWebSearchIntegration:
"""Test web search integration and result processing."""
@pytest.mark.asyncio
async def test_exa_search_provider_query_formatting(self):
"""Test that Exa search queries are properly formatted and sent."""
with patch("maverick_mcp.agents.deep_research.circuit_manager") as mock_circuit:
mock_circuit.get_or_create = AsyncMock()
mock_circuit_instance = AsyncMock()
mock_circuit.get_or_create.return_value = mock_circuit_instance
# Mock the Exa client response
mock_exa_response = Mock()
mock_exa_response.results = [
Mock(
url="https://example.com/test",
title="Test Article",
text="Test content for search",
summary="Test summary",
highlights=["key highlight"],
published_date="2024-01-25",
author="Test Author",
score=0.9,
)
]
with patch("exa_py.Exa") as mock_exa_client:
mock_client_instance = Mock()
mock_client_instance.search_and_contents.return_value = (
mock_exa_response
)
mock_exa_client.return_value = mock_client_instance
# Create actual provider (not mocked)
provider = ExaSearchProvider("test-api-key")
mock_circuit_instance.call.return_value = [
{
"url": "https://example.com/test",
"title": "Test Article",
"content": "Test content for search",
"summary": "Test summary",
"highlights": ["key highlight"],
"published_date": "2024-01-25",
"author": "Test Author",
"score": 0.9,
"provider": "exa",
}
]
# Test the search
results = await provider.search("AAPL stock analysis", num_results=5)
# Verify query was properly formatted
assert len(results) == 1
assert results[0]["url"] == "https://example.com/test"
assert results[0]["provider"] == "exa"
assert results[0]["score"] == 0.9
@pytest.mark.asyncio
async def test_tavily_search_result_processing(self):
"""Test Tavily search result processing and filtering."""
with patch("maverick_mcp.agents.deep_research.circuit_manager") as mock_circuit:
mock_circuit.get_or_create = AsyncMock()
mock_circuit_instance = AsyncMock()
mock_circuit.get_or_create.return_value = mock_circuit_instance
mock_tavily_response = {
"results": [
{
"url": "https://news.example.com/tech-news",
"title": "Tech News Article",
"content": "Content about technology trends",
"raw_content": "Extended raw content with more details",
"published_date": "2024-01-25",
"score": 0.85,
},
{
"url": "https://facebook.com/social-post", # Should be filtered out
"title": "Social Media Post",
"content": "Social media content",
"score": 0.7,
},
]
}
with patch("tavily.TavilyClient") as mock_tavily_client:
mock_client_instance = Mock()
mock_client_instance.search.return_value = mock_tavily_response
mock_tavily_client.return_value = mock_client_instance
provider = TavilySearchProvider("test-api-key")
mock_circuit_instance.call.return_value = [
{
"url": "https://news.example.com/tech-news",
"title": "Tech News Article",
"content": "Content about technology trends",
"raw_content": "Extended raw content with more details",
"published_date": "2024-01-25",
"score": 0.85,
"provider": "tavily",
}
]
results = await provider.search("tech trends analysis")
# Verify results are properly processed and social media filtered
assert len(results) == 1
assert results[0]["provider"] == "tavily"
assert "facebook.com" not in results[0]["url"]
@pytest.mark.asyncio
async def test_search_provider_fallback_behavior(self, mock_research_agent):
"""Test fallback behavior when search providers fail."""
# Mock the execute searches workflow step directly
with patch.object(mock_research_agent, "_execute_searches") as mock_execute:
# Mock first provider to fail, second to succeed
mock_research_agent.search_providers[0].search = AsyncMock(
side_effect=WebSearchError("Exa API rate limit exceeded")
)
mock_research_agent.search_providers[1].search = AsyncMock(
return_value=[
{
"url": "https://backup-source.com/article",
"title": "Backup Article",
"content": "Fallback content from secondary provider",
"provider": "tavily",
"score": 0.75,
}
]
)
# Mock successful execution with fallback results
mock_result = Mock()
mock_result.goto = "analyze_content"
mock_result.update = {
"search_results": [
{
"url": "https://backup-source.com/article",
"title": "Backup Article",
"content": "Fallback content from secondary provider",
"provider": "tavily",
"score": 0.75,
}
],
"research_status": "analyzing",
}
mock_execute.return_value = mock_result
# Test state for search execution
state = {"search_queries": ["AAPL analysis"], "research_depth": "standard"}
# Execute the search step
result = await mock_research_agent._execute_searches(state)
# Should handle provider failure gracefully
assert result.goto == "analyze_content"
assert len(result.update["search_results"]) > 0
@pytest.mark.asyncio
async def test_search_result_deduplication(self, comprehensive_search_results):
"""Test deduplication of search results from multiple providers."""
# Create search results with duplicates
duplicate_results = (
comprehensive_search_results
+ [
{
"url": "https://finance.yahoo.com/news/apple-earnings-q4-2024", # Duplicate URL
"title": "Apple Q4 Results (Duplicate)",
"content": "Duplicate content with different title",
"provider": "tavily",
"score": 0.7,
}
]
)
with patch.object(DeepResearchAgent, "_execute_searches") as mock_execute:
mock_execute.return_value = Mock()
DeepResearchAgent(llm=MagicMock(), persona="moderate")
# Test the deduplication logic directly
# Simulate search execution with duplicates
all_results = duplicate_results
unique_results = []
seen_urls = set()
depth_config = RESEARCH_DEPTH_LEVELS["standard"]
for result in all_results:
if (
result["url"] not in seen_urls
and len(unique_results) < depth_config["max_sources"]
):
unique_results.append(result)
seen_urls.add(result["url"])
# Verify deduplication worked
assert len(unique_results) == 5 # Should remove 1 duplicate
urls = [r["url"] for r in unique_results]
assert len(set(urls)) == len(urls) # All URLs should be unique
class TestResearchSynthesis:
"""Test research synthesis and iterative querying functionality."""
@pytest.mark.asyncio
async def test_content_analysis_with_persona_focus(
self, comprehensive_search_results
):
"""Test that content analysis adapts to persona focus areas."""
# Mock LLM with persona-specific responses
mock_llm = MagicMock()
def persona_aware_response(messages):
response = Mock()
# Check if content is about dividends for conservative persona
content = messages[1].content if len(messages) > 1 else ""
if "conservative" in content and "dividend" in content:
response.content = json.dumps(
{
"KEY_INSIGHTS": [
"Strong dividend yield provides stable income"
],
"SENTIMENT": {"direction": "bullish", "confidence": 0.7},
"RISK_FACTORS": ["Interest rate sensitivity"],
"OPPORTUNITIES": ["Consistent dividend growth"],
"CREDIBILITY": 0.85,
"RELEVANCE": 0.9,
"SUMMARY": "Dividend analysis shows strong income potential for conservative investors.",
}
)
else:
response.content = json.dumps(
{
"KEY_INSIGHTS": ["Growth opportunity in AI sector"],
"SENTIMENT": {"direction": "bullish", "confidence": 0.8},
"RISK_FACTORS": ["Market competition"],
"OPPORTUNITIES": ["Innovation leadership"],
"CREDIBILITY": 0.8,
"RELEVANCE": 0.85,
"SUMMARY": "Analysis shows strong growth opportunities through innovation.",
}
)
return response
mock_llm.ainvoke = AsyncMock(side_effect=persona_aware_response)
analyzer = ContentAnalyzer(mock_llm)
# Test conservative persona analysis with dividend content
conservative_result = await analyzer.analyze_content(
content=comprehensive_search_results[4]["content"], # Dividend article
persona="conservative",
)
# Verify conservative-focused analysis
assert conservative_result["relevance_score"] > 0.8
assert (
"dividend" in conservative_result["summary"].lower()
or "income" in conservative_result["summary"].lower()
)
# Test aggressive persona analysis with growth content
aggressive_result = await analyzer.analyze_content(
content=comprehensive_search_results[3]["content"], # AI strategy article
persona="aggressive",
)
# Verify aggressive-focused analysis
assert aggressive_result["relevance_score"] > 0.7
assert any(
keyword in aggressive_result["summary"].lower()
for keyword in ["growth", "opportunity", "innovation"]
)
@pytest.mark.asyncio
async def test_research_synthesis_workflow(
self, mock_research_agent, comprehensive_search_results
):
"""Test the complete research synthesis workflow."""
# Mock the workflow components using the actual graph structure
with patch.object(mock_research_agent, "graph") as mock_graph:
# Mock successful workflow execution with all required fields
mock_result = {
"research_topic": "AAPL",
"research_depth": "standard",
"search_queries": ["AAPL financial analysis", "Apple earnings 2024"],
"search_results": comprehensive_search_results,
"analyzed_content": [
{
**result,
"analysis": {
"insights": [
"Strong revenue growth",
"AI integration opportunity",
],
"sentiment": {"direction": "bullish", "confidence": 0.8},
"risk_factors": [
"Supply chain risks",
"Regulatory concerns",
],
"opportunities": ["AI monetization", "Services expansion"],
"credibility_score": 0.85,
"relevance_score": 0.9,
"summary": "Strong fundamentals with growth catalysts",
},
}
for result in comprehensive_search_results[:3]
],
"validated_sources": comprehensive_search_results[:3],
"research_findings": {
"synthesis": "Apple shows strong fundamentals with growth opportunities",
"key_insights": ["Revenue growth", "AI opportunities"],
"overall_sentiment": {"direction": "bullish", "confidence": 0.8},
"confidence_score": 0.82,
},
"citations": [
{"id": 1, "title": "Apple Earnings", "url": "https://example.com/1"}
],
"research_status": "completed",
"research_confidence": 0.82,
"execution_time_ms": 1500.0,
"persona": "moderate",
}
mock_graph.ainvoke = AsyncMock(return_value=mock_result)
# Execute research
result = await mock_research_agent.research_comprehensive(
topic="AAPL", session_id="test_synthesis", depth="standard"
)
# Verify synthesis was performed
assert result["status"] == "success"
assert "findings" in result
assert result["sources_analyzed"] > 0
@pytest.mark.asyncio
async def test_iterative_research_refinement(self, mock_research_agent):
"""Test iterative research with follow-up queries based on initial findings."""
# Mock initial research finding gaps
with patch.object(
mock_research_agent, "_generate_search_queries"
) as mock_queries:
# First iteration - general queries
mock_queries.return_value = [
"NVDA competitive analysis",
"NVIDIA market position 2024",
]
queries_first = await mock_research_agent._generate_search_queries(
topic="NVDA competitive position",
persona_focus=PERSONA_RESEARCH_FOCUS["moderate"],
depth_config=RESEARCH_DEPTH_LEVELS["standard"],
)
# Verify initial queries are broad
assert any("competitive" in q.lower() for q in queries_first)
assert any("NVDA" in q or "NVIDIA" in q for q in queries_first)
@pytest.mark.asyncio
async def test_fact_validation_and_source_credibility(self, mock_research_agent):
"""Test fact validation and source credibility scoring."""
# Test source credibility calculation
test_sources = [
{
"url": "https://sec.gov/filing/aapl-10k-2024",
"title": "Apple 10-K Filing",
"content": "Official SEC filing content",
"published_date": "2024-01-20T00:00:00Z",
"analysis": {"credibility_score": 0.9},
},
{
"url": "https://random-blog.com/apple-speculation",
"title": "Random Blog Post",
"content": "Speculative content with no sources",
"published_date": "2023-06-01T00:00:00Z", # Old content
"analysis": {"credibility_score": 0.3},
},
]
# Test credibility scoring
for source in test_sources:
credibility = mock_research_agent._calculate_source_credibility(source)
if "sec.gov" in source["url"]:
assert (
credibility >= 0.8
) # Government sources should be highly credible
elif "random-blog" in source["url"]:
assert credibility <= 0.6 # Random blogs should have lower credibility
class TestPersonaBasedResearch:
"""Test persona-based research behavior and adaptation."""
@pytest.mark.asyncio
async def test_conservative_persona_research_focus(self, mock_llm):
"""Test that conservative persona focuses on stability and risk factors."""
agent = DeepResearchAgent(llm=mock_llm, persona="conservative")
# Test search query generation for conservative persona
persona_focus = PERSONA_RESEARCH_FOCUS["conservative"]
depth_config = RESEARCH_DEPTH_LEVELS["standard"]
queries = await agent._generate_search_queries(
topic="AAPL", persona_focus=persona_focus, depth_config=depth_config
)
# Verify conservative-focused queries
query_text = " ".join(queries).lower()
assert any(
keyword in query_text for keyword in ["dividend", "stability", "risk"]
)
# Test that conservative persona performs more thorough fact-checking
assert persona_focus["risk_focus"] == "downside protection"
assert persona_focus["time_horizon"] == "long-term"
@pytest.mark.asyncio
async def test_aggressive_persona_research_behavior(self, mock_llm):
"""Test aggressive persona explores speculative opportunities."""
agent = DeepResearchAgent(llm=mock_llm, persona="aggressive")
persona_focus = PERSONA_RESEARCH_FOCUS["aggressive"]
# Test query generation for aggressive persona
queries = await agent._generate_search_queries(
topic="TSLA",
persona_focus=persona_focus,
depth_config=RESEARCH_DEPTH_LEVELS["standard"],
)
# Verify aggressive-focused queries
query_text = " ".join(queries).lower()
assert any(
keyword in query_text for keyword in ["growth", "momentum", "opportunity"]
)
# Verify aggressive characteristics
assert persona_focus["risk_focus"] == "upside potential"
assert "innovation" in persona_focus["keywords"]
@pytest.mark.asyncio
async def test_day_trader_persona_short_term_focus(self, mock_llm):
"""Test day trader persona focuses on short-term catalysts and volatility."""
DeepResearchAgent(llm=mock_llm, persona="day_trader")
persona_focus = PERSONA_RESEARCH_FOCUS["day_trader"]
# Test characteristics specific to day trader persona
assert persona_focus["time_horizon"] == "intraday to weekly"
assert "catalysts" in persona_focus["keywords"]
assert "volatility" in persona_focus["keywords"]
assert "earnings" in persona_focus["keywords"]
# Test sources preference
assert "breaking news" in persona_focus["sources"]
assert "social sentiment" in persona_focus["sources"]
@pytest.mark.asyncio
async def test_research_depth_differences_by_persona(self, mock_llm):
"""Test that conservative personas do more thorough research."""
conservative_agent = DeepResearchAgent(
llm=mock_llm, persona="conservative", default_depth="comprehensive"
)
aggressive_agent = DeepResearchAgent(
llm=mock_llm, persona="aggressive", default_depth="standard"
)
# Conservative should use more comprehensive depth by default
assert conservative_agent.default_depth == "comprehensive"
# Aggressive can use standard depth for faster decisions
assert aggressive_agent.default_depth == "standard"
# Test depth level configurations
comprehensive_config = RESEARCH_DEPTH_LEVELS["comprehensive"]
standard_config = RESEARCH_DEPTH_LEVELS["standard"]
assert comprehensive_config["max_sources"] > standard_config["max_sources"]
assert comprehensive_config["validation_required"]
class TestMultiStepResearchWorkflow:
"""Test complete multi-step research workflows."""
@pytest.mark.asyncio
async def test_complete_research_workflow_success(
self, mock_research_agent, comprehensive_search_results
):
"""Test complete research workflow from query to final report."""
# Mock all workflow steps
with patch.object(mock_research_agent, "graph") as mock_graph:
# Mock successful workflow execution
mock_result = {
"research_topic": "AAPL",
"research_depth": "standard",
"search_queries": ["AAPL analysis", "Apple earnings"],
"search_results": comprehensive_search_results,
"analyzed_content": [
{
**result,
"analysis": {
"insights": ["Strong performance"],
"sentiment": {"direction": "bullish", "confidence": 0.8},
"credibility_score": 0.85,
},
}
for result in comprehensive_search_results
],
"validated_sources": comprehensive_search_results[:3],
"research_findings": {
"synthesis": "Apple shows strong fundamentals with growth opportunities",
"key_insights": [
"Revenue growth",
"AI opportunities",
"Strong cash flow",
],
"overall_sentiment": {"direction": "bullish", "confidence": 0.8},
"confidence_score": 0.82,
},
"citations": [
{
"id": 1,
"title": "Apple Earnings",
"url": "https://example.com/1",
},
{
"id": 2,
"title": "Technical Analysis",
"url": "https://example.com/2",
},
],
"research_status": "completed",
"research_confidence": 0.82,
"execution_time_ms": 1500.0,
}
mock_graph.ainvoke = AsyncMock(return_value=mock_result)
# Execute complete research
result = await mock_research_agent.research_comprehensive(
topic="AAPL", session_id="workflow_test", depth="standard"
)
# Verify complete workflow
assert result["status"] == "success"
assert result["agent_type"] == "deep_research"
assert result["research_topic"] == "AAPL"
assert result["sources_analyzed"] == 3
assert result["confidence_score"] == 0.82
assert len(result["citations"]) == 2
@pytest.mark.asyncio
async def test_research_workflow_with_insufficient_information(
self, mock_research_agent
):
"""Test workflow handling when insufficient information is found."""
# Mock scenario with limited/poor quality results
with patch.object(mock_research_agent, "graph") as mock_graph:
mock_result = {
"research_topic": "OBSCURE_STOCK",
"research_depth": "standard",
"search_results": [], # No results found
"validated_sources": [],
"research_findings": {},
"research_confidence": 0.1, # Very low confidence
"research_status": "completed",
"execution_time_ms": 800.0,
}
mock_graph.ainvoke = AsyncMock(return_value=mock_result)
result = await mock_research_agent.research_comprehensive(
topic="OBSCURE_STOCK", session_id="insufficient_test"
)
# Should handle insufficient information gracefully
assert result["status"] == "success"
assert result["confidence_score"] == 0.1
assert result["sources_analyzed"] == 0
@pytest.mark.asyncio
async def test_research_with_conflicting_information(self, mock_research_agent):
"""Test handling of conflicting information from different sources."""
conflicting_sources = [
{
"url": "https://bull-analyst.com/buy-rating",
"title": "Strong Buy Rating for AAPL",
"analysis": {
"sentiment": {"direction": "bullish", "confidence": 0.9},
"credibility_score": 0.8,
},
},
{
"url": "https://bear-analyst.com/sell-rating",
"title": "Sell Rating for AAPL Due to Overvaluation",
"analysis": {
"sentiment": {"direction": "bearish", "confidence": 0.8},
"credibility_score": 0.7,
},
},
]
# Test overall sentiment calculation with conflicting sources
overall_sentiment = mock_research_agent._calculate_overall_sentiment(
conflicting_sources
)
# Should handle conflicts by providing consensus information
assert overall_sentiment["direction"] in ["bullish", "bearish", "neutral"]
assert "consensus" in overall_sentiment
assert overall_sentiment["source_count"] == 2
@pytest.mark.asyncio
async def test_research_focus_and_refinement(self, mock_research_agent):
"""Test research focusing and refinement based on initial findings."""
# Test different research focus areas
focus_areas = ["sentiment", "fundamental", "competitive"]
for focus in focus_areas:
route = mock_research_agent._route_specialized_analysis(
{"focus_areas": [focus]}
)
if focus == "sentiment":
assert route == "sentiment"
elif focus == "fundamental":
assert route == "fundamental"
elif focus == "competitive":
assert route == "competitive"
class TestResearchMethodSpecialization:
"""Test specialized research methods: sentiment, fundamental, competitive analysis."""
@pytest.mark.asyncio
async def test_sentiment_analysis_specialization(self, mock_research_agent):
"""Test sentiment analysis research method."""
test_state = {
"focus_areas": [
"sentiment",
"news",
], # Use keywords that match routing logic
"analyzed_content": [],
}
# Test sentiment analysis routing
route = mock_research_agent._route_specialized_analysis(test_state)
assert route == "sentiment"
# Test sentiment analysis execution (mocked)
with patch.object(mock_research_agent, "_analyze_content") as mock_analyze:
mock_analyze.return_value = Mock()
await mock_research_agent._sentiment_analysis(test_state)
mock_analyze.assert_called_once()
@pytest.mark.asyncio
async def test_fundamental_analysis_specialization(self, mock_research_agent):
"""Test fundamental analysis research method."""
test_state = {
"focus_areas": [
"fundamental",
"financial",
], # Use exact keywords from routing logic
"analyzed_content": [],
}
# Test fundamental analysis routing
route = mock_research_agent._route_specialized_analysis(test_state)
assert route == "fundamental"
# Test fundamental analysis execution
with patch.object(mock_research_agent, "_analyze_content") as mock_analyze:
mock_analyze.return_value = Mock()
await mock_research_agent._fundamental_analysis(test_state)
mock_analyze.assert_called_once()
@pytest.mark.asyncio
async def test_competitive_analysis_specialization(self, mock_research_agent):
"""Test competitive analysis research method."""
test_state = {
"focus_areas": [
"competitive",
"market",
], # Use exact keywords from routing logic
"analyzed_content": [],
}
# Test competitive analysis routing
route = mock_research_agent._route_specialized_analysis(test_state)
assert route == "competitive"
# Test competitive analysis execution
with patch.object(mock_research_agent, "_analyze_content") as mock_analyze:
mock_analyze.return_value = Mock()
await mock_research_agent._competitive_analysis(test_state)
mock_analyze.assert_called_once()
class TestErrorHandlingAndResilience:
"""Test error handling and system resilience."""
@pytest.mark.asyncio
async def test_research_agent_with_no_search_providers(self, mock_llm):
"""Test research agent behavior with no available search providers."""
# Create agent without search providers
agent = DeepResearchAgent(llm=mock_llm, persona="moderate")
# Should initialize successfully but with limited capabilities
assert len(agent.search_providers) == 0
# Research should still attempt to work but with limited results
result = await agent.research_comprehensive(
topic="TEST", session_id="no_providers_test"
)
# Should not crash, may return limited results
assert "status" in result
@pytest.mark.asyncio
async def test_content_analysis_fallback_on_llm_failure(
self, comprehensive_search_results
):
"""Test content analysis fallback when LLM fails."""
# Mock LLM that fails
failing_llm = MagicMock()
failing_llm.ainvoke = AsyncMock(
side_effect=Exception("LLM service unavailable")
)
analyzer = ContentAnalyzer(failing_llm)
# Should use fallback analysis
result = await analyzer.analyze_content(
content=comprehensive_search_results[0]["content"], persona="conservative"
)
# Verify fallback was used
assert result["fallback_used"]
assert result["sentiment"]["direction"] in ["bullish", "bearish", "neutral"]
assert 0 <= result["credibility_score"] <= 1
assert 0 <= result["relevance_score"] <= 1
@pytest.mark.asyncio
async def test_partial_search_failure_handling(self, mock_research_agent):
"""Test handling when some but not all search providers fail."""
# Test the actual search execution logic directly
mock_research_agent.search_providers[0].search = AsyncMock(
side_effect=WebSearchError("Provider 1 failed")
)
mock_research_agent.search_providers[1].search = AsyncMock(
return_value=[
{
"url": "https://working-provider.com/article",
"title": "Working Provider Article",
"content": "Content from working provider",
"provider": "working_provider",
"score": 0.8,
}
]
)
# Test the search execution directly
state = {"search_queries": ["test query"], "research_depth": "standard"}
result = await mock_research_agent._execute_searches(state)
# Should continue with working providers and return results
assert hasattr(result, "update")
assert "search_results" in result.update
# Should have at least the working provider results
assert (
len(result.update["search_results"]) >= 0
) # May be 0 if all fail, but should not crash
@pytest.mark.asyncio
async def test_research_timeout_and_circuit_breaker(self, mock_research_agent):
"""Test research timeout handling and circuit breaker behavior."""
# Test would require actual circuit breaker implementation
# This is a placeholder for circuit breaker testing
with patch(
"maverick_mcp.agents.circuit_breaker.circuit_manager"
) as mock_circuit:
mock_circuit.get_or_create = AsyncMock()
circuit_instance = AsyncMock()
mock_circuit.get_or_create.return_value = circuit_instance
# Mock circuit breaker open state
circuit_instance.call = AsyncMock(
side_effect=Exception("Circuit breaker open")
)
# Research should handle circuit breaker gracefully
# Implementation depends on actual circuit breaker behavior
pass
class TestResearchQualityAndValidation:
"""Test research quality assurance and validation mechanisms."""
def test_research_confidence_calculation(self, mock_research_agent):
"""Test research confidence calculation based on multiple factors."""
# Test with high-quality sources
high_quality_sources = [
{
"url": "https://sec.gov/filing1",
"credibility_score": 0.95,
"analysis": {"relevance_score": 0.9},
},
{
"url": "https://bloomberg.com/article1",
"credibility_score": 0.85,
"analysis": {"relevance_score": 0.8},
},
{
"url": "https://reuters.com/article2",
"credibility_score": 0.8,
"analysis": {"relevance_score": 0.85},
},
]
confidence = mock_research_agent._calculate_research_confidence(
high_quality_sources
)
assert confidence >= 0.65 # Should be reasonably high confidence
# Test with low-quality sources
low_quality_sources = [
{
"url": "https://random-blog.com/post1",
"credibility_score": 0.3,
"analysis": {"relevance_score": 0.4},
}
]
low_confidence = mock_research_agent._calculate_research_confidence(
low_quality_sources
)
assert low_confidence < 0.5 # Should be low confidence
def test_source_diversity_scoring(self, mock_research_agent):
"""Test source diversity calculation."""
diverse_sources = [
{"url": "https://sec.gov/filing"},
{"url": "https://bloomberg.com/news"},
{"url": "https://reuters.com/article"},
{"url": "https://wsj.com/story"},
{"url": "https://ft.com/content"},
]
confidence = mock_research_agent._calculate_research_confidence(diverse_sources)
# More diverse sources should contribute to higher confidence
assert confidence > 0.6
def test_investment_recommendation_logic(self, mock_research_agent):
"""Test investment recommendation based on research findings."""
# Test bullish scenario
bullish_sources = [
{
"analysis": {
"sentiment": {"direction": "bullish", "confidence": 0.9},
"credibility_score": 0.8,
}
}
]
recommendation = mock_research_agent._recommend_action(bullish_sources)
# Conservative persona should be more cautious
if mock_research_agent.persona.name.lower() == "conservative":
assert (
"gradual" in recommendation.lower()
or "risk management" in recommendation.lower()
)
else:
assert (
"consider" in recommendation.lower()
and "position" in recommendation.lower()
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/backtesting.py:
--------------------------------------------------------------------------------
```python
"""MCP router for VectorBT backtesting tools with structured logging."""
from typing import Any
import numpy as np
from fastmcp import Context
from maverick_mcp.backtesting import (
BacktestAnalyzer,
StrategyOptimizer,
VectorBTEngine,
)
from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES, StrategyParser
from maverick_mcp.backtesting.strategies.templates import (
get_strategy_info,
list_available_strategies,
)
from maverick_mcp.backtesting.visualization import (
generate_equity_curve,
generate_optimization_heatmap,
generate_performance_dashboard,
generate_trade_scatter,
)
from maverick_mcp.utils.debug_utils import debug_operation
from maverick_mcp.utils.logging import get_logger
from maverick_mcp.utils.structured_logger import (
CorrelationIDGenerator,
get_performance_logger,
with_structured_logging,
)
# Initialize performance logger for backtesting router
performance_logger = get_performance_logger("backtesting_router")
logger = get_logger(__name__)
def convert_numpy_types(obj: Any) -> Any:
"""Recursively convert numpy types to Python native types for JSON serialization.
Args:
obj: Any object that might contain numpy types
Returns:
Object with all numpy types converted to Python native types
"""
import pandas as pd
# Check for numpy integer types (more robust using numpy's type hierarchy)
if isinstance(obj, np.integer):
return int(obj)
# Check for numpy floating point types
elif isinstance(obj, np.floating):
return float(obj)
# Check for numpy boolean type
elif isinstance(obj, np.bool_ | bool) and hasattr(obj, "item"):
return bool(obj)
# Check for numpy complex types
elif isinstance(obj, np.complexfloating):
return complex(obj)
# Handle numpy arrays
elif isinstance(obj, np.ndarray):
return obj.tolist()
# Handle pandas Series
elif isinstance(obj, pd.Series):
return obj.tolist()
# Handle pandas DataFrame
elif isinstance(obj, pd.DataFrame):
return obj.to_dict("records")
# Handle NaN/None values
elif pd.isna(obj):
return None
# Handle other numpy scalars with .item() method
elif hasattr(obj, "item") and hasattr(obj, "dtype"):
try:
return obj.item()
except Exception:
return str(obj)
# Recursively handle dictionaries
elif isinstance(obj, dict):
return {key: convert_numpy_types(value) for key, value in obj.items()}
# Recursively handle lists and tuples
elif isinstance(obj, list | tuple):
return [convert_numpy_types(item) for item in obj]
# Try to handle custom objects with __dict__
elif hasattr(obj, "__dict__") and not isinstance(obj, type):
try:
return convert_numpy_types(obj.__dict__)
except Exception:
return str(obj)
else:
# Return as-is for regular Python types
return obj
def setup_backtesting_tools(mcp):
"""Set up VectorBT backtesting tools for MCP.
Args:
mcp: FastMCP instance
"""
@mcp.tool()
@with_structured_logging("run_backtest", include_performance=True, log_params=True)
@debug_operation("run_backtest", enable_profiling=True, symbol="backtest_symbol")
async def run_backtest(
ctx: Context,
symbol: str,
strategy: str = "sma_cross",
start_date: str | None = None,
end_date: str | None = None,
initial_capital: float = 10000.0,
fast_period: str | int | None = None,
slow_period: str | int | None = None,
period: str | int | None = None,
oversold: str | float | None = None,
overbought: str | float | None = None,
signal_period: str | int | None = None,
std_dev: str | float | None = None,
lookback: str | int | None = None,
threshold: str | float | None = None,
z_score_threshold: str | float | None = None,
breakout_factor: str | float | None = None,
) -> dict[str, Any]:
"""Run a VectorBT backtest with specified strategy and parameters.
Args:
symbol: Stock symbol to backtest
strategy: Strategy type (sma_cross, rsi, macd, bollinger, momentum, etc.)
start_date: Start date (YYYY-MM-DD), defaults to 1 year ago
end_date: End date (YYYY-MM-DD), defaults to today
initial_capital: Starting capital for backtest
Strategy-specific parameters passed as individual arguments (e.g., fast_period=10, slow_period=20)
Returns:
Comprehensive backtest results including metrics, trades, and analysis
Examples:
run_backtest("AAPL", "sma_cross", fast_period=10, slow_period=20)
run_backtest("TSLA", "rsi", period=14, oversold=30, overbought=70)
"""
from datetime import datetime, timedelta
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
# Convert string parameters to appropriate types
def convert_param(value, param_type):
"""Convert string parameter to appropriate type."""
if value is None:
return None
if isinstance(value, str):
try:
if param_type is int:
return int(value)
elif param_type is float:
return float(value)
except (ValueError, TypeError) as e:
raise ValueError(
f"Invalid {param_type.__name__} value: {value}"
) from e
return value
# Build parameters dict from provided arguments with type conversion
param_map = {
"fast_period": convert_param(fast_period, int),
"slow_period": convert_param(slow_period, int),
"period": convert_param(period, int),
"oversold": convert_param(oversold, float),
"overbought": convert_param(overbought, float),
"signal_period": convert_param(signal_period, int),
"std_dev": convert_param(std_dev, float),
"lookback": convert_param(lookback, int),
"threshold": convert_param(threshold, float),
"z_score_threshold": convert_param(z_score_threshold, float),
"breakout_factor": convert_param(breakout_factor, float),
}
# Get default parameters for strategy
if strategy in STRATEGY_TEMPLATES:
parameters = dict(STRATEGY_TEMPLATES[strategy]["parameters"])
# Override with provided non-None parameters
for param_name, param_value in param_map.items():
if param_value is not None:
parameters[param_name] = param_value
else:
# Use only provided parameters for unknown strategies
parameters = {k: v for k, v in param_map.items() if v is not None}
# Initialize engine
engine = VectorBTEngine()
# Run backtest
results = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date=start_date,
end_date=end_date,
initial_capital=initial_capital,
)
# Analyze results
analyzer = BacktestAnalyzer()
analysis = analyzer.analyze(results)
# Combine results and analysis
results["analysis"] = analysis
# Log business metrics
if results.get("metrics"):
metrics = results["metrics"]
performance_logger.log_business_metric(
"backtest_completion",
1,
symbol=symbol,
strategy=strategy,
total_return=metrics.get("total_return", 0),
sharpe_ratio=metrics.get("sharpe_ratio", 0),
max_drawdown=metrics.get("max_drawdown", 0),
total_trades=metrics.get("total_trades", 0),
)
# Set correlation context for downstream operations
CorrelationIDGenerator.set_correlation_id()
return results
@mcp.tool()
@with_structured_logging(
"optimize_strategy", include_performance=True, log_params=True
)
@debug_operation(
"optimize_strategy", enable_profiling=True, strategy="optimization_strategy"
)
async def optimize_strategy(
ctx: Context,
symbol: str,
strategy: str = "sma_cross",
start_date: str | None = None,
end_date: str | None = None,
optimization_metric: str = "sharpe_ratio",
optimization_level: str = "medium",
top_n: int = 10,
) -> dict[str, Any]:
"""Optimize strategy parameters using VectorBT grid search.
Args:
symbol: Stock symbol to optimize
strategy: Strategy type to optimize
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
optimization_metric: Metric to optimize (sharpe_ratio, total_return, win_rate, etc.)
optimization_level: Level of optimization (coarse, medium, fine)
top_n: Number of top results to return
Returns:
Optimization results with best parameters and performance metrics
"""
from datetime import datetime, timedelta
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365 * 2)).strftime("%Y-%m-%d")
# Initialize engine and optimizer
engine = VectorBTEngine()
optimizer = StrategyOptimizer(engine)
# Generate parameter grid
param_grid = optimizer.generate_param_grid(strategy, optimization_level)
# Run optimization
results = await engine.optimize_parameters(
symbol=symbol,
strategy_type=strategy,
param_grid=param_grid,
start_date=start_date,
end_date=end_date,
optimization_metric=optimization_metric,
top_n=top_n,
)
return results
@mcp.tool()
async def walk_forward_analysis(
ctx: Context,
symbol: str,
strategy: str = "sma_cross",
start_date: str | None = None,
end_date: str | None = None,
window_size: int = 252,
step_size: int = 63,
) -> dict[str, Any]:
"""Perform walk-forward analysis to test strategy robustness.
Args:
symbol: Stock symbol to analyze
strategy: Strategy type
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
window_size: Test window size in trading days (default: 1 year)
step_size: Step size for rolling window (default: 1 quarter)
Returns:
Walk-forward analysis results with out-of-sample performance
"""
from datetime import datetime, timedelta
# Default date range (3 years for walk-forward)
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365 * 3)).strftime("%Y-%m-%d")
# Initialize engine and optimizer
engine = VectorBTEngine()
optimizer = StrategyOptimizer(engine)
# Get default parameters
parameters = STRATEGY_TEMPLATES.get(strategy, {}).get("parameters", {})
# Run walk-forward analysis
results = await optimizer.walk_forward_analysis(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date=start_date,
end_date=end_date,
window_size=window_size,
step_size=step_size,
)
return results
@mcp.tool()
async def monte_carlo_simulation(
ctx: Context,
symbol: str,
strategy: str = "sma_cross",
start_date: str | None = None,
end_date: str | None = None,
num_simulations: int = 1000,
fast_period: str | int | None = None,
slow_period: str | int | None = None,
period: str | int | None = None,
) -> dict[str, Any]:
"""Run Monte Carlo simulation on backtest results.
Args:
symbol: Stock symbol
strategy: Strategy type
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
num_simulations: Number of Monte Carlo simulations
Strategy-specific parameters as individual arguments
Returns:
Monte Carlo simulation results with confidence intervals
"""
from datetime import datetime, timedelta
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
# Convert string parameters to appropriate types
def convert_param(value, param_type):
"""Convert string parameter to appropriate type."""
if value is None:
return None
if isinstance(value, str):
try:
if param_type is int:
return int(value)
elif param_type is float:
return float(value)
except (ValueError, TypeError) as e:
raise ValueError(
f"Invalid {param_type.__name__} value: {value}"
) from e
return value
# Build parameters dict from provided arguments with type conversion
param_map = {
"fast_period": convert_param(fast_period, int),
"slow_period": convert_param(slow_period, int),
"period": convert_param(period, int),
}
# Get parameters
if strategy in STRATEGY_TEMPLATES:
parameters = dict(STRATEGY_TEMPLATES[strategy]["parameters"])
# Override with provided non-None parameters
for param_name, param_value in param_map.items():
if param_value is not None:
parameters[param_name] = param_value
else:
# Use only provided parameters for unknown strategies
parameters = {k: v for k, v in param_map.items() if v is not None}
# Run backtest first
engine = VectorBTEngine()
backtest_results = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date=start_date,
end_date=end_date,
)
# Run Monte Carlo simulation
optimizer = StrategyOptimizer(engine)
mc_results = await optimizer.monte_carlo_simulation(
backtest_results=backtest_results,
num_simulations=num_simulations,
)
return mc_results
@mcp.tool()
async def compare_strategies(
ctx: Context,
symbol: str,
strategies: list[str] | str | None = None,
start_date: str | None = None,
end_date: str | None = None,
) -> dict[str, Any]:
"""Compare multiple strategies on the same symbol.
Args:
symbol: Stock symbol
strategies: List of strategy types to compare (defaults to all)
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
Returns:
Comparison results with rankings and analysis
"""
from datetime import datetime, timedelta
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
# Handle strategies as JSON string from some clients
if isinstance(strategies, str):
import json
try:
strategies = json.loads(strategies)
except json.JSONDecodeError:
# If it's not JSON, treat it as a single strategy
strategies = [strategies]
# Default to comparing top strategies
if not strategies:
strategies = ["sma_cross", "rsi", "macd", "bollinger", "momentum"]
# Run backtests for each strategy
engine = VectorBTEngine()
results_list = []
for strategy in strategies:
try:
# Get default parameters
parameters = STRATEGY_TEMPLATES.get(strategy, {}).get("parameters", {})
# Run backtest
results = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date=start_date,
end_date=end_date,
)
results_list.append(results)
except Exception:
# Skip failed strategies
continue
# Compare results
analyzer = BacktestAnalyzer()
comparison = analyzer.compare_strategies(results_list)
return comparison
@mcp.tool()
async def list_strategies(ctx: Context) -> dict[str, Any]:
"""List all available VectorBT strategies with descriptions.
Returns:
Dictionary of available strategies and their information
"""
strategies = {}
for strategy_type in list_available_strategies():
strategies[strategy_type] = get_strategy_info(strategy_type)
return {
"available_strategies": strategies,
"total_count": len(strategies),
"categories": {
"trend_following": ["sma_cross", "ema_cross", "macd", "breakout"],
"mean_reversion": ["rsi", "bollinger", "mean_reversion"],
"momentum": ["momentum", "volume_momentum"],
},
}
@mcp.tool()
async def parse_strategy(ctx: Context, description: str) -> dict[str, Any]:
"""Parse natural language strategy description into VectorBT parameters.
Args:
description: Natural language description of trading strategy
Returns:
Parsed strategy configuration with type and parameters
Examples:
"Buy when RSI is below 30 and sell when above 70"
"Use 10-day and 20-day moving average crossover"
"MACD strategy with standard parameters"
"""
parser = StrategyParser()
config = parser.parse_simple(description)
# Validate the parsed strategy
if parser.validate_strategy(config):
return {
"success": True,
"strategy": config,
"message": f"Successfully parsed as {config['strategy_type']} strategy",
}
else:
return {
"success": False,
"strategy": config,
"message": "Could not fully parse strategy, using defaults",
}
@mcp.tool()
async def backtest_portfolio(
ctx: Context,
symbols: list[str],
strategy: str = "sma_cross",
start_date: str | None = None,
end_date: str | None = None,
initial_capital: float = 10000.0,
position_size: float = 0.1,
fast_period: str | int | None = None,
slow_period: str | int | None = None,
period: str | int | None = None,
) -> dict[str, Any]:
"""Backtest a strategy across multiple symbols (portfolio).
Args:
symbols: List of stock symbols
strategy: Strategy type to apply
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
initial_capital: Starting capital
position_size: Position size per symbol (0.1 = 10%)
Strategy-specific parameters as individual arguments
Returns:
Portfolio backtest results with aggregate metrics
"""
from datetime import datetime, timedelta
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
# Convert string parameters to appropriate types
def convert_param(value, param_type):
"""Convert string parameter to appropriate type."""
if value is None:
return None
if isinstance(value, str):
try:
if param_type is int:
return int(value)
elif param_type is float:
return float(value)
except (ValueError, TypeError) as e:
raise ValueError(
f"Invalid {param_type.__name__} value: {value}"
) from e
return value
# Build parameters dict from provided arguments with type conversion
param_map = {
"fast_period": convert_param(fast_period, int),
"slow_period": convert_param(slow_period, int),
"period": convert_param(period, int),
}
# Get parameters
if strategy in STRATEGY_TEMPLATES:
parameters = dict(STRATEGY_TEMPLATES[strategy]["parameters"])
# Override with provided non-None parameters
for param_name, param_value in param_map.items():
if param_value is not None:
parameters[param_name] = param_value
else:
# Use only provided parameters for unknown strategies
parameters = {k: v for k, v in param_map.items() if v is not None}
# Run backtests for each symbol
engine = VectorBTEngine()
portfolio_results = []
capital_per_symbol = initial_capital * position_size
for symbol in symbols:
try:
results = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date=start_date,
end_date=end_date,
initial_capital=capital_per_symbol,
)
portfolio_results.append(results)
except Exception:
# Skip failed symbols
continue
if not portfolio_results:
return {"error": "No symbols could be backtested"}
# Aggregate portfolio metrics
total_return = sum(
r["metrics"]["total_return"] for r in portfolio_results
) / len(portfolio_results)
avg_sharpe = sum(r["metrics"]["sharpe_ratio"] for r in portfolio_results) / len(
portfolio_results
)
max_drawdown = max(r["metrics"]["max_drawdown"] for r in portfolio_results)
total_trades = sum(r["metrics"]["total_trades"] for r in portfolio_results)
return {
"portfolio_metrics": {
"symbols_tested": len(portfolio_results),
"total_return": total_return,
"average_sharpe": avg_sharpe,
"max_drawdown": max_drawdown,
"total_trades": total_trades,
},
"individual_results": portfolio_results,
"summary": f"Portfolio backtest of {len(portfolio_results)} symbols with {strategy} strategy",
}
@mcp.tool()
async def generate_backtest_charts(
ctx: Context,
symbol: str,
strategy: str = "sma_cross",
start_date: str | None = None,
end_date: str | None = None,
theme: str = "light",
) -> dict[str, str]:
"""Generate comprehensive charts for a backtest.
Args:
symbol: Stock symbol
strategy: Strategy type
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
theme: Chart theme (light or dark)
Returns:
Dictionary of base64-encoded chart images
"""
from datetime import datetime, timedelta
import pandas as pd
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
# Run backtest
engine = VectorBTEngine()
# Get default parameters for the strategy
from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
parameters = STRATEGY_TEMPLATES.get(strategy, {}).get("parameters", {})
results = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date=start_date,
end_date=end_date,
)
# Prepare data for charts
equity_curve_data = results["equity_curve"]
drawdown_data = results["drawdown_series"]
# Convert to pandas Series for charting
returns = pd.Series(equity_curve_data)
drawdown = pd.Series(drawdown_data)
trades = pd.DataFrame(results["trades"])
# Generate charts
charts = {
"equity_curve": generate_equity_curve(
returns, drawdown, f"{symbol} {strategy} Equity Curve", theme
),
"trade_scatter": generate_trade_scatter(
returns, trades, f"{symbol} {strategy} Trades", theme
),
"performance_dashboard": generate_performance_dashboard(
results["metrics"], f"{symbol} {strategy} Performance", theme
),
}
return charts
@mcp.tool()
async def generate_optimization_charts(
ctx: Context,
symbol: str,
strategy: str = "sma_cross",
start_date: str | None = None,
end_date: str | None = None,
theme: str = "light",
) -> dict[str, str]:
"""Generate chart for strategy parameter optimization.
Args:
symbol: Stock symbol
strategy: Strategy type
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
theme: Chart theme (light or dark)
Returns:
Dictionary of base64-encoded chart images
"""
from datetime import datetime, timedelta
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
# Run optimization
engine = VectorBTEngine()
optimizer = StrategyOptimizer(engine)
param_grid = optimizer.generate_param_grid(strategy, "medium")
# Create optimization results dictionary for heatmap
optimization_results = {}
for param_set, results in param_grid.items():
optimization_results[str(param_set)] = {
"performance": results.get("total_return", 0)
}
# Generate optimization heatmap
heatmap = generate_optimization_heatmap(
optimization_results, f"{symbol} {strategy} Parameter Optimization", theme
)
return {"optimization_heatmap": heatmap}
# ============ ML-ENHANCED STRATEGY TOOLS ============
@mcp.tool()
async def run_ml_strategy_backtest(
ctx: Context,
symbol: str,
strategy_type: str = "ml_predictor",
start_date: str | None = None,
end_date: str | None = None,
initial_capital: float = 10000.0,
train_ratio: float = 0.8,
model_type: str = "random_forest",
n_estimators: int = 100,
max_depth: int | None = None,
learning_rate: float = 0.01,
adaptation_method: str = "gradient",
) -> dict[str, Any]:
"""Run backtest using ML-enhanced strategies.
Args:
symbol: Stock symbol to backtest
strategy_type: ML strategy type (ml_predictor, adaptive, ensemble, regime_aware)
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
initial_capital: Initial capital amount
train_ratio: Ratio of data for training (0.0-1.0)
Strategy-specific parameters passed as individual arguments
Returns:
Backtest results with ML-specific metrics
"""
from datetime import datetime, timedelta
from maverick_mcp.backtesting.strategies.ml import (
AdaptiveStrategy,
MLPredictor,
RegimeAwareStrategy,
StrategyEnsemble,
)
from maverick_mcp.backtesting.strategies.templates import (
SimpleMovingAverageStrategy,
)
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=730)).strftime(
"%Y-%m-%d"
) # 2 years for ML
# Get historical data
engine = VectorBTEngine()
data = await engine.get_historical_data(symbol, start_date, end_date)
# Enhanced data validation for ML strategies
min_total_data = 200 # Minimum total data points for ML strategies
if len(data) < min_total_data:
return {
"error": f"Insufficient data for ML strategy: {len(data)} < {min_total_data} required"
}
# Split data for training/testing
split_idx = int(len(data) * train_ratio)
train_data = data.iloc[:split_idx]
test_data = data.iloc[split_idx:]
# Validate split data sizes
min_train_data = 100
min_test_data = 50
if len(train_data) < min_train_data:
return {
"error": f"Insufficient training data: {len(train_data)} < {min_train_data} required"
}
if len(test_data) < min_test_data:
return {
"error": f"Insufficient test data: {len(test_data)} < {min_test_data} required"
}
logger.info(
f"ML backtest data split: {len(train_data)} training, {len(test_data)} testing samples"
)
try:
# Create ML strategy based on type
if strategy_type == "ml_predictor":
ml_strategy = MLPredictor(
model_type=model_type,
n_estimators=n_estimators,
max_depth=max_depth,
)
# Train the model
training_metrics = ml_strategy.train(train_data)
elif strategy_type == "adaptive" or strategy_type == "online_learning":
# online_learning is an alias for adaptive strategy
base_strategy = SimpleMovingAverageStrategy()
ml_strategy = AdaptiveStrategy(
base_strategy,
learning_rate=learning_rate,
adaptation_method=adaptation_method,
)
training_metrics = {
"adaptation_method": adaptation_method,
"strategy_alias": strategy_type,
}
elif strategy_type == "ensemble":
# Create ensemble with basic strategies
base_strategies = [
SimpleMovingAverageStrategy({"fast_period": 10, "slow_period": 20}),
SimpleMovingAverageStrategy({"fast_period": 5, "slow_period": 15}),
]
ml_strategy = StrategyEnsemble(base_strategies)
training_metrics = {"ensemble_size": len(base_strategies)}
elif strategy_type == "regime_aware":
base_strategies = {
0: SimpleMovingAverageStrategy(
{"fast_period": 5, "slow_period": 20}
), # Bear
1: SimpleMovingAverageStrategy(
{"fast_period": 10, "slow_period": 30}
), # Sideways
2: SimpleMovingAverageStrategy(
{"fast_period": 20, "slow_period": 50}
), # Bull
}
ml_strategy = RegimeAwareStrategy(base_strategies)
# Fit regime detector
ml_strategy.fit_regime_detector(train_data)
training_metrics = {"n_regimes": len(base_strategies)}
else:
return {"error": f"Unsupported ML strategy type: {strategy_type}"}
# Generate signals on test data
entry_signals, exit_signals = ml_strategy.generate_signals(test_data)
# Run backtest analysis on test period
analyzer = BacktestAnalyzer()
backtest_results = await analyzer.run_vectorbt_backtest(
data=test_data,
entry_signals=entry_signals,
exit_signals=exit_signals,
initial_capital=initial_capital,
)
# Add ML-specific metrics
ml_metrics = {
"strategy_type": strategy_type,
"training_period": len(train_data),
"testing_period": len(test_data),
"train_test_split": train_ratio,
"training_metrics": training_metrics,
}
# Add strategy-specific analysis
if hasattr(ml_strategy, "get_feature_importance"):
ml_metrics["feature_importance"] = ml_strategy.get_feature_importance()
if hasattr(ml_strategy, "get_regime_analysis"):
ml_metrics["regime_analysis"] = ml_strategy.get_regime_analysis()
if hasattr(ml_strategy, "get_strategy_weights"):
ml_metrics["strategy_weights"] = ml_strategy.get_strategy_weights()
backtest_results["ml_metrics"] = ml_metrics
# Convert all numpy types before returning
return convert_numpy_types(backtest_results)
except Exception as e:
return {"error": f"ML backtest failed: {str(e)}"}
@mcp.tool()
async def train_ml_predictor(
ctx: Context,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
model_type: str = "random_forest",
target_periods: int = 5,
return_threshold: float = 0.02,
n_estimators: int = 100,
max_depth: int | None = None,
min_samples_split: int = 2,
) -> dict[str, Any]:
"""Train an ML predictor model for trading signals.
Args:
symbol: Stock symbol to train on
start_date: Start date for training data
end_date: End date for training data
model_type: ML model type (random_forest)
target_periods: Forward periods for target variable
return_threshold: Return threshold for signal classification
n_estimators, max_depth, min_samples_split: Model-specific parameters
Returns:
Training results and model metrics
"""
from datetime import datetime, timedelta
from maverick_mcp.backtesting.strategies.ml import MLPredictor
# Default date range (2 years for good ML training)
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=730)).strftime("%Y-%m-%d")
try:
# Get training data
engine = VectorBTEngine()
data = await engine.get_historical_data(symbol, start_date, end_date)
if len(data) < 200:
return {
"error": "Insufficient data for ML training (minimum 200 data points)"
}
# Create and train ML predictor
ml_predictor = MLPredictor(
model_type=model_type,
n_estimators=n_estimators,
max_depth=max_depth,
min_samples_split=min_samples_split,
)
training_metrics = ml_predictor.train(
data=data,
target_periods=target_periods,
return_threshold=return_threshold,
)
# Create model parameters dictionary
model_params = {
"n_estimators": n_estimators,
"max_depth": max_depth,
"min_samples_split": min_samples_split,
}
# Add training details
training_results = {
"symbol": symbol,
"model_type": model_type,
"training_period": f"{start_date} to {end_date}",
"data_points": len(data),
"target_periods": target_periods,
"return_threshold": return_threshold,
"model_parameters": model_params,
"training_metrics": training_metrics,
}
# Convert all numpy types before returning
return convert_numpy_types(training_results)
except Exception as e:
return {"error": f"ML training failed: {str(e)}"}
@mcp.tool()
async def analyze_market_regimes(
ctx: Context,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
method: str = "hmm",
n_regimes: int = 3,
lookback_period: int = 50,
) -> dict[str, Any]:
"""Analyze market regimes for a stock using ML methods.
Args:
symbol: Stock symbol to analyze
start_date: Start date for analysis
end_date: End date for analysis
method: Detection method (hmm, kmeans, threshold)
n_regimes: Number of regimes to detect
lookback_period: Lookback period for regime detection
Returns:
Market regime analysis results
"""
from datetime import datetime, timedelta
from maverick_mcp.backtesting.strategies.ml.regime_aware import (
MarketRegimeDetector,
)
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
try:
# Get historical data
engine = VectorBTEngine()
data = await engine.get_historical_data(symbol, start_date, end_date)
if len(data) < lookback_period + 50:
return {
"error": f"Insufficient data for regime analysis (minimum {lookback_period + 50} data points)"
}
# Create regime detector and analyze
regime_detector = MarketRegimeDetector(
method=method, n_regimes=n_regimes, lookback_period=lookback_period
)
# Fit regime detector
regime_detector.fit_regimes(data)
# Analyze regimes over time
regime_history = []
regime_probabilities = []
for i in range(lookback_period, len(data)):
window_data = data.iloc[i - lookback_period : i + 1]
current_regime = regime_detector.detect_current_regime(window_data)
regime_probs = regime_detector.get_regime_probabilities(window_data)
regime_history.append(
{
"date": data.index[i].strftime("%Y-%m-%d"),
"regime": int(current_regime),
"probabilities": regime_probs.tolist(),
}
)
regime_probabilities.append(regime_probs)
# Calculate regime statistics
regimes = [r["regime"] for r in regime_history]
regime_counts = {i: regimes.count(i) for i in range(n_regimes)}
regime_percentages = {
k: (v / len(regimes)) * 100 for k, v in regime_counts.items()
}
# Calculate average regime durations
regime_durations = {i: [] for i in range(n_regimes)}
current_regime = regimes[0]
duration = 1
for regime in regimes[1:]:
if regime == current_regime:
duration += 1
else:
regime_durations[current_regime].append(duration)
current_regime = regime
duration = 1
regime_durations[current_regime].append(duration)
avg_durations = {
k: np.mean(v) if v else 0 for k, v in regime_durations.items()
}
analysis_results = {
"symbol": symbol,
"analysis_period": f"{start_date} to {end_date}",
"method": method,
"n_regimes": n_regimes,
"regime_names": {
0: "Bear/Declining",
1: "Sideways/Uncertain",
2: "Bull/Trending",
},
"current_regime": regimes[-1] if regimes else 1,
"regime_counts": regime_counts,
"regime_percentages": regime_percentages,
"average_regime_durations": avg_durations,
"recent_regime_history": regime_history[-20:], # Last 20 periods
"total_regime_switches": len(
[i for i in range(1, len(regimes)) if regimes[i] != regimes[i - 1]]
),
}
return analysis_results
except Exception as e:
return {"error": f"Regime analysis failed: {str(e)}"}
@mcp.tool()
async def create_strategy_ensemble(
ctx: Context,
symbols: list[str],
base_strategies: list[str] | None = None,
weighting_method: str = "performance",
start_date: str | None = None,
end_date: str | None = None,
initial_capital: float = 10000.0,
) -> dict[str, Any]:
"""Create and backtest a strategy ensemble across multiple symbols.
Args:
symbols: List of stock symbols
base_strategies: List of base strategy names to ensemble
weighting_method: Weighting method (performance, equal, volatility)
start_date: Start date for backtesting
end_date: End date for backtesting
initial_capital: Initial capital per symbol
Returns:
Ensemble backtest results with strategy weights
"""
from datetime import datetime, timedelta
from maverick_mcp.backtesting.strategies.ml import StrategyEnsemble
from maverick_mcp.backtesting.strategies.templates import (
SimpleMovingAverageStrategy,
)
# Default strategies if none provided
if base_strategies is None:
base_strategies = ["sma_cross", "rsi", "macd"]
# Default date range
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
try:
# Create base strategy instances
strategy_instances = []
for strategy_name in base_strategies:
if strategy_name == "sma_cross":
strategy_instances.append(SimpleMovingAverageStrategy())
elif strategy_name == "rsi":
# Create RSI-based SMA strategy with different parameters
strategy_instances.append(
SimpleMovingAverageStrategy(
{"fast_period": 14, "slow_period": 28}
)
)
elif strategy_name == "macd":
# Create MACD-like SMA strategy with MACD default periods
strategy_instances.append(
SimpleMovingAverageStrategy(
{"fast_period": 12, "slow_period": 26}
)
)
# Add more strategies as needed
if not strategy_instances:
return {"error": "No valid base strategies provided"}
# Create ensemble strategy
ensemble = StrategyEnsemble(
strategies=strategy_instances, weighting_method=weighting_method
)
# Run ensemble backtest on multiple symbols
ensemble_results = []
total_return = 0
total_trades = 0
for symbol in symbols[:5]: # Limit to 5 symbols for performance
try:
# Get data and run backtest
engine = VectorBTEngine()
data = await engine.get_historical_data(
symbol, start_date, end_date
)
if len(data) < 100:
continue
# Generate ensemble signals
entry_signals, exit_signals = ensemble.generate_signals(data)
# Run backtest
analyzer = BacktestAnalyzer()
results = await analyzer.run_vectorbt_backtest(
data=data,
entry_signals=entry_signals,
exit_signals=exit_signals,
initial_capital=initial_capital,
)
# Add ensemble-specific metrics
results["ensemble_metrics"] = {
"strategy_weights": ensemble.get_strategy_weights(),
"strategy_performance": ensemble.get_strategy_performance(),
}
ensemble_results.append({"symbol": symbol, "results": results})
total_return += results["metrics"]["total_return"]
total_trades += results["metrics"]["total_trades"]
except Exception:
continue
if not ensemble_results:
return {"error": "No symbols could be processed"}
# Calculate aggregate metrics
avg_return = total_return / len(ensemble_results)
avg_trades = total_trades / len(ensemble_results)
# Convert all numpy types before returning
return convert_numpy_types(
{
"ensemble_summary": {
"symbols_tested": len(ensemble_results),
"base_strategies": base_strategies,
"weighting_method": weighting_method,
"average_return": avg_return,
"total_trades": total_trades,
"average_trades_per_symbol": avg_trades,
},
"individual_results": ensemble_results,
"final_strategy_weights": ensemble.get_strategy_weights(),
"strategy_performance_analysis": ensemble.get_strategy_performance(),
}
)
except Exception as e:
return {"error": f"Ensemble creation failed: {str(e)}"}
```
--------------------------------------------------------------------------------
/tests/data/test_portfolio_models.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive integration tests for portfolio database models and migration.
This module tests:
1. Migration upgrade and downgrade operations
2. SQLAlchemy model CRUD operations (Create, Read, Update, Delete)
3. Database constraints (unique constraints, foreign keys, cascade deletes)
4. Relationships between UserPortfolio and PortfolioPosition
5. Decimal field precision for financial data (Numeric(12,4) and Numeric(20,8))
6. Timezone-aware datetime fields
7. Index creation and query optimization
Test Coverage:
- Migration creates tables with correct schema
- Indexes are created properly for performance optimization
- Unique constraints work for both portfolio and position level
- Cascade delete removes positions when portfolio is deleted
- Decimal precision is maintained through round-trip database operations
- Relationships are properly loaded with selectin strategy
- Default values are applied correctly (user_id="default", name="My Portfolio")
- Timestamp mixin functionality (created_at, updated_at)
Test Markers:
- @pytest.mark.integration - Full database integration tests
"""
import uuid
from datetime import UTC, datetime, timedelta
from decimal import Decimal
import pytest
from sqlalchemy import exc, inspect
from sqlalchemy.orm import Session
from maverick_mcp.data.models import PortfolioPosition, UserPortfolio
pytestmark = pytest.mark.integration
# ============================================================================
# Migration Tests
# ============================================================================
class TestMigrationUpgrade:
"""Test suite for migration upgrade operations."""
def test_migration_creates_portfolios_table(self, db_session: Session):
"""Test that migration creates mcp_portfolios table."""
inspector = inspect(db_session.bind)
tables = inspector.get_table_names()
assert "mcp_portfolios" in tables
def test_migration_creates_positions_table(self, db_session: Session):
"""Test that migration creates mcp_portfolio_positions table."""
inspector = inspect(db_session.bind)
tables = inspector.get_table_names()
assert "mcp_portfolio_positions" in tables
def test_portfolios_table_has_correct_columns(self, db_session: Session):
"""Test that portfolios table has all required columns."""
inspector = inspect(db_session.bind)
columns = {col["name"] for col in inspector.get_columns("mcp_portfolios")}
required_columns = {"id", "user_id", "name", "created_at", "updated_at"}
assert required_columns.issubset(columns)
def test_positions_table_has_correct_columns(self, db_session: Session):
"""Test that positions table has all required columns."""
inspector = inspect(db_session.bind)
columns = {
col["name"] for col in inspector.get_columns("mcp_portfolio_positions")
}
required_columns = {
"id",
"portfolio_id",
"ticker",
"shares",
"average_cost_basis",
"total_cost",
"purchase_date",
"notes",
"created_at",
"updated_at",
}
assert required_columns.issubset(columns)
def test_portfolios_id_column_type(self, db_session: Session):
"""Test that portfolio id column is UUID type."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "id" in columns
# Column exists and is configured as primary key through Index and UniqueConstraint
def test_positions_foreign_key_constraint(self, db_session: Session):
"""Test that positions table has foreign key to portfolios."""
inspector = inspect(db_session.bind)
fks = inspector.get_foreign_keys("mcp_portfolio_positions")
assert len(fks) > 0
assert any(fk["constrained_columns"] == ["portfolio_id"] for fk in fks)
def test_migration_creates_portfolio_user_index(self, db_session: Session):
"""Test that migration creates index on portfolio user_id."""
inspector = inspect(db_session.bind)
indexes = {idx["name"] for idx in inspector.get_indexes("mcp_portfolios")}
assert "idx_portfolio_user" in indexes
def test_migration_creates_position_portfolio_index(self, db_session: Session):
"""Test that migration creates index on position portfolio_id."""
inspector = inspect(db_session.bind)
indexes = {
idx["name"] for idx in inspector.get_indexes("mcp_portfolio_positions")
}
assert "idx_position_portfolio" in indexes
def test_migration_creates_position_ticker_index(self, db_session: Session):
"""Test that migration creates index on position ticker."""
inspector = inspect(db_session.bind)
indexes = {
idx["name"] for idx in inspector.get_indexes("mcp_portfolio_positions")
}
assert "idx_position_ticker" in indexes
def test_migration_creates_position_composite_index(self, db_session: Session):
"""Test that migration creates composite index on portfolio_id and ticker."""
inspector = inspect(db_session.bind)
indexes = {
idx["name"] for idx in inspector.get_indexes("mcp_portfolio_positions")
}
assert "idx_position_portfolio_ticker" in indexes
def test_migration_creates_unique_portfolio_constraint(self, db_session: Session):
"""Test that migration creates unique constraint on user_id and name."""
inspector = inspect(db_session.bind)
constraints = inspector.get_unique_constraints("mcp_portfolios")
constraint_names = {c["name"] for c in constraints}
assert "uq_user_portfolio_name" in constraint_names
def test_migration_creates_unique_position_constraint(self, db_session: Session):
"""Test that migration creates unique constraint on portfolio_id and ticker."""
inspector = inspect(db_session.bind)
constraints = inspector.get_unique_constraints("mcp_portfolio_positions")
constraint_names = {c["name"] for c in constraints}
assert "uq_portfolio_position_ticker" in constraint_names
def test_portfolios_user_id_has_default(self, db_session: Session):
"""Test that user_id column exists and is not nullable."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "user_id" in columns
# Default is handled at model level, not server level
def test_portfolios_name_has_default(self, db_session: Session):
"""Test that name column exists and is not nullable."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "name" in columns
# Default is handled at model level, not server level
def test_portfolios_created_at_has_default(self, db_session: Session):
"""Test that created_at column exists for timestamp tracking."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "created_at" in columns
def test_portfolios_updated_at_has_default(self, db_session: Session):
"""Test that updated_at column exists for timestamp tracking."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "updated_at" in columns
def test_positions_created_at_has_default(self, db_session: Session):
"""Test that position created_at column exists for timestamp tracking."""
inspector = inspect(db_session.bind)
columns = {
col["name"]: col for col in inspector.get_columns("mcp_portfolio_positions")
}
assert "created_at" in columns
def test_positions_updated_at_has_default(self, db_session: Session):
"""Test that position updated_at column exists for timestamp tracking."""
inspector = inspect(db_session.bind)
columns = {
col["name"]: col for col in inspector.get_columns("mcp_portfolio_positions")
}
assert "updated_at" in columns
# ============================================================================
# Model CRUD Operation Tests
# ============================================================================
class TestPortfolioModelCRUD:
"""Test suite for UserPortfolio CRUD operations."""
def test_create_portfolio_with_all_fields(self, db_session: Session):
"""Test creating a portfolio with all fields specified."""
portfolio = UserPortfolio(
id=uuid.uuid4(),
user_id="test_user",
name="Test Portfolio",
)
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved is not None
assert retrieved.user_id == "test_user"
assert retrieved.name == "Test Portfolio"
assert retrieved.created_at is not None
assert retrieved.updated_at is not None
def test_create_portfolio_with_defaults(self, db_session: Session):
"""Test that portfolio defaults are applied correctly."""
portfolio = UserPortfolio()
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.user_id == "default"
assert retrieved.name == "My Portfolio"
def test_read_portfolio_by_id(self, db_session: Session):
"""Test reading portfolio by ID."""
portfolio = UserPortfolio(user_id="user1", name="Portfolio 1")
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved is not None
assert retrieved.id == portfolio.id
def test_read_portfolio_by_user_and_name(self, db_session: Session):
"""Test reading portfolio by user_id and name."""
portfolio = UserPortfolio(user_id="user2", name="My Portfolio 2")
db_session.add(portfolio)
db_session.commit()
retrieved = (
db_session.query(UserPortfolio)
.filter_by(user_id="user2", name="My Portfolio 2")
.first()
)
assert retrieved is not None
assert retrieved.id == portfolio.id
def test_read_all_portfolios_for_user(self, db_session: Session):
"""Test reading all portfolios for a specific user."""
user_id = f"user_read_{uuid.uuid4()}"
portfolios = [
UserPortfolio(user_id=user_id, name=f"Portfolio {i}") for i in range(3)
]
db_session.add_all(portfolios)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(user_id=user_id).all()
assert len(retrieved) == 3
def test_update_portfolio_name(self, db_session: Session):
"""Test updating portfolio name."""
portfolio = UserPortfolio(user_id="user3", name="Original Name")
db_session.add(portfolio)
db_session.commit()
portfolio.name = "Updated Name"
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.name == "Updated Name"
def test_update_portfolio_user_id(self, db_session: Session):
"""Test updating portfolio user_id."""
portfolio = UserPortfolio(user_id="old_user", name="Portfolio")
db_session.add(portfolio)
db_session.commit()
portfolio.user_id = "new_user"
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.user_id == "new_user"
def test_delete_portfolio(self, db_session: Session):
"""Test deleting a portfolio."""
portfolio = UserPortfolio(user_id="user4", name="To Delete")
db_session.add(portfolio)
db_session.commit()
portfolio_id = portfolio.id
db_session.delete(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio_id).first()
assert retrieved is None
def test_portfolio_repr(self, db_session: Session):
"""Test portfolio string representation."""
portfolio = UserPortfolio(user_id="user5", name="Test Portfolio")
db_session.add(portfolio)
db_session.commit()
repr_str = repr(portfolio)
assert "UserPortfolio" in repr_str
assert "Test Portfolio" in repr_str
class TestPositionModelCRUD:
"""Test suite for PortfolioPosition CRUD operations."""
@pytest.fixture
def portfolio(self, db_session: Session) -> UserPortfolio:
"""Create a test portfolio."""
portfolio = UserPortfolio(
user_id="default", name=f"Test Portfolio {uuid.uuid4()}"
)
db_session.add(portfolio)
db_session.commit()
return portfolio
def test_create_position_with_all_fields(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test creating a position with all fields."""
position = PortfolioPosition(
id=uuid.uuid4(),
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
notes="Test position",
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved is not None
assert retrieved.ticker == "AAPL"
assert retrieved.notes == "Test position"
def test_create_position_without_notes(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test creating a position without notes."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.notes is None
def test_read_position_by_id(self, db_session: Session, portfolio: UserPortfolio):
"""Test reading position by ID."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="GOOG",
shares=Decimal("2.00000000"),
average_cost_basis=Decimal("2750.0000"),
total_cost=Decimal("5500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved is not None
assert retrieved.ticker == "GOOG"
def test_read_position_by_portfolio_and_ticker(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test reading position by portfolio_id and ticker."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TSLA",
shares=Decimal("1.00000000"),
average_cost_basis=Decimal("250.0000"),
total_cost=Decimal("250.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id, ticker="TSLA")
.first()
)
assert retrieved is not None
def test_read_all_positions_in_portfolio(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test reading all positions in a portfolio."""
positions_data = [
("AAPL", Decimal("10"), Decimal("150.0000")),
("MSFT", Decimal("5"), Decimal("380.0000")),
("GOOG", Decimal("2"), Decimal("2750.0000")),
]
for ticker, shares, price in positions_data:
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker=ticker,
shares=shares,
average_cost_basis=price,
total_cost=shares * price,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id)
.all()
)
assert len(retrieved) == 3
def test_update_position_shares(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test updating position shares."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
position.shares = Decimal("20.00000000")
position.average_cost_basis = Decimal("160.0000")
position.total_cost = Decimal("3200.0000")
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.shares == Decimal("20.00000000")
def test_update_position_cost_basis(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test updating position average cost basis."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
original_cost_basis = position.average_cost_basis
position.average_cost_basis = Decimal("390.0000")
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.average_cost_basis != original_cost_basis
assert retrieved.average_cost_basis == Decimal("390.0000")
def test_update_position_notes(self, db_session: Session, portfolio: UserPortfolio):
"""Test updating position notes."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="GOOG",
shares=Decimal("2.00000000"),
average_cost_basis=Decimal("2750.0000"),
total_cost=Decimal("5500.0000"),
purchase_date=datetime.now(UTC),
notes="Original notes",
)
db_session.add(position)
db_session.commit()
position.notes = "Updated notes"
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.notes == "Updated notes"
def test_delete_position(self, db_session: Session, portfolio: UserPortfolio):
"""Test deleting a position."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TSLA",
shares=Decimal("1.00000000"),
average_cost_basis=Decimal("250.0000"),
total_cost=Decimal("250.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
position_id = position.id
db_session.delete(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position_id).first()
)
assert retrieved is None
def test_position_repr(self, db_session: Session, portfolio: UserPortfolio):
"""Test position string representation."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="NVDA",
shares=Decimal("3.00000000"),
average_cost_basis=Decimal("900.0000"),
total_cost=Decimal("2700.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
repr_str = repr(position)
assert "PortfolioPosition" in repr_str
assert "NVDA" in repr_str
# ============================================================================
# Relationship Tests
# ============================================================================
class TestPortfolioPositionRelationships:
"""Test suite for relationships between UserPortfolio and PortfolioPosition."""
@pytest.fixture
def portfolio_with_positions(self, db_session: Session) -> UserPortfolio:
"""Create a portfolio with multiple positions."""
portfolio = UserPortfolio(
user_id="default", name=f"Relationship Test {uuid.uuid4()}"
)
db_session.add(portfolio)
db_session.commit()
positions = [
PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
),
PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
),
]
db_session.add_all(positions)
db_session.commit()
return portfolio
def test_portfolio_has_positions_relationship(
self, db_session: Session, portfolio_with_positions: UserPortfolio
):
"""Test that portfolio has positions relationship."""
portfolio = (
db_session.query(UserPortfolio)
.filter_by(id=portfolio_with_positions.id)
.first()
)
assert hasattr(portfolio, "positions")
assert isinstance(portfolio.positions, list)
def test_positions_eagerly_loaded_via_selectin(
self, db_session: Session, portfolio_with_positions: UserPortfolio
):
"""Test that positions are eagerly loaded (selectin strategy)."""
portfolio = (
db_session.query(UserPortfolio)
.filter_by(id=portfolio_with_positions.id)
.first()
)
assert len(portfolio.positions) == 2
assert {p.ticker for p in portfolio.positions} == {"AAPL", "MSFT"}
def test_position_has_portfolio_relationship(
self, db_session: Session, portfolio_with_positions: UserPortfolio
):
"""Test that position has back reference to portfolio."""
position = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio_with_positions.id)
.first()
)
assert position.portfolio is not None
assert position.portfolio.id == portfolio_with_positions.id
def test_position_portfolio_relationship_maintains_integrity(
self, db_session: Session, portfolio_with_positions: UserPortfolio
):
"""Test that position portfolio relationship maintains data integrity."""
position = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio_with_positions.id, ticker="AAPL")
.first()
)
assert position.portfolio.name == portfolio_with_positions.name
assert position.portfolio.user_id == portfolio_with_positions.user_id
def test_multiple_portfolios_have_separate_positions(self, db_session: Session):
"""Test that multiple portfolios have separate position lists."""
user_id = f"user_multi_{uuid.uuid4()}"
portfolio1 = UserPortfolio(user_id=user_id, name=f"Portfolio 1 {uuid.uuid4()}")
portfolio2 = UserPortfolio(user_id=user_id, name=f"Portfolio 2 {uuid.uuid4()}")
db_session.add_all([portfolio1, portfolio2])
db_session.commit()
position1 = PortfolioPosition(
portfolio_id=portfolio1.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
position2 = PortfolioPosition(
portfolio_id=portfolio2.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add_all([position1, position2])
db_session.commit()
p1 = db_session.query(UserPortfolio).filter_by(id=portfolio1.id).first()
p2 = db_session.query(UserPortfolio).filter_by(id=portfolio2.id).first()
assert len(p1.positions) == 1
assert len(p2.positions) == 1
assert p1.positions[0].ticker == "AAPL"
assert p2.positions[0].ticker == "MSFT"
# ============================================================================
# Constraint Tests
# ============================================================================
class TestDatabaseConstraints:
"""Test suite for database constraints enforcement."""
def test_unique_portfolio_name_constraint_enforced(self, db_session: Session):
"""Test that unique constraint on (user_id, name) is enforced."""
user_id = f"user_constraint_{uuid.uuid4()}"
name = f"Unique Portfolio {uuid.uuid4()}"
portfolio1 = UserPortfolio(user_id=user_id, name=name)
db_session.add(portfolio1)
db_session.commit()
# Try to create duplicate
portfolio2 = UserPortfolio(user_id=user_id, name=name)
db_session.add(portfolio2)
with pytest.raises(exc.IntegrityError):
db_session.commit()
def test_unique_position_ticker_constraint_enforced(self, db_session: Session):
"""Test that unique constraint on (portfolio_id, ticker) is enforced."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
position1 = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position1)
db_session.commit()
# Try to create duplicate ticker
position2 = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("160.0000"),
total_cost=Decimal("800.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position2)
with pytest.raises(exc.IntegrityError):
db_session.commit()
def test_foreign_key_constraint_enforced(self, db_session: Session):
"""Test that foreign key constraint is enforced."""
position = PortfolioPosition(
portfolio_id=uuid.uuid4(), # Non-existent portfolio
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
with pytest.raises(exc.IntegrityError):
db_session.commit()
def test_cascade_delete_removes_positions(self, db_session: Session):
"""Test that deleting a portfolio cascades delete to positions."""
portfolio = UserPortfolio(user_id="default", name=f"Delete Test {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
positions = [
PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
),
PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
),
]
db_session.add_all(positions)
db_session.commit()
portfolio_id = portfolio.id
db_session.delete(portfolio)
db_session.commit()
# Verify portfolio is deleted
p = db_session.query(UserPortfolio).filter_by(id=portfolio_id).first()
assert p is None
# Verify positions are also deleted
pos = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio_id)
.all()
)
assert len(pos) == 0
def test_cascade_delete_doesnt_affect_other_portfolios(self, db_session: Session):
"""Test that deleting one portfolio doesn't affect others."""
user_id = f"user_cascade_{uuid.uuid4()}"
portfolio1 = UserPortfolio(user_id=user_id, name=f"Portfolio 1 {uuid.uuid4()}")
portfolio2 = UserPortfolio(user_id=user_id, name=f"Portfolio 2 {uuid.uuid4()}")
db_session.add_all([portfolio1, portfolio2])
db_session.commit()
position = PortfolioPosition(
portfolio_id=portfolio1.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
db_session.delete(portfolio1)
db_session.commit()
# Portfolio2 should still exist
p2 = db_session.query(UserPortfolio).filter_by(id=portfolio2.id).first()
assert p2 is not None
# ============================================================================
# Decimal Precision Tests
# ============================================================================
class TestDecimalPrecision:
"""Test suite for Decimal field precision."""
@pytest.fixture
def portfolio(self, db_session: Session) -> UserPortfolio:
"""Create a test portfolio."""
portfolio = UserPortfolio(
user_id="default", name=f"Decimal Test {uuid.uuid4()}"
)
db_session.add(portfolio)
db_session.commit()
return portfolio
def test_shares_numeric_20_8_precision(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that shares maintains Numeric(20,8) precision."""
shares = Decimal("12345678901.12345678")
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST1",
shares=shares,
average_cost_basis=Decimal("100.0000"),
total_cost=Decimal("1234567890112.3456"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.shares == shares
def test_cost_basis_numeric_12_4_precision(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that average_cost_basis maintains Numeric(12,4) precision."""
cost_basis = Decimal("99999999.9999")
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST2",
shares=Decimal("100.00000000"),
average_cost_basis=cost_basis,
total_cost=Decimal("9999999999.9999"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.average_cost_basis == cost_basis
def test_total_cost_numeric_20_4_precision(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that total_cost maintains Numeric(20,4) precision."""
total_cost = Decimal("9999999999999999.9999")
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST3",
shares=Decimal("1000.00000000"),
average_cost_basis=Decimal("9999999.9999"),
total_cost=total_cost,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.total_cost == total_cost
def test_fractional_shares_precision(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that fractional shares with high precision are maintained.
Note: total_cost uses Numeric(20, 4), so values are truncated to 4 decimal places.
"""
shares = Decimal("0.33333333")
cost_basis = Decimal("2750.1234")
total_cost = Decimal("917.5041") # Truncated from 917.50413522 to 4 decimals
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST4",
shares=shares,
average_cost_basis=cost_basis,
total_cost=total_cost,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.shares == shares
assert retrieved.average_cost_basis == cost_basis
assert retrieved.total_cost == total_cost
def test_very_small_decimal_values(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test handling of very small Decimal values.
Note: total_cost uses Numeric(20, 4) precision, so values smaller than
0.0001 will be truncated. This is appropriate for stock trading.
"""
shares = Decimal("0.00000001")
cost_basis = Decimal("0.0001")
total_cost = Decimal("0.0000") # Rounds to 0.0000 due to Numeric(20, 4)
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST5",
shares=shares,
average_cost_basis=cost_basis,
total_cost=total_cost,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.shares == shares
assert retrieved.average_cost_basis == cost_basis
# Total cost truncated to 4 decimal places as per Numeric(20, 4)
assert retrieved.total_cost == total_cost
def test_multiple_positions_precision_preserved(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that precision is maintained across multiple positions."""
test_data = [
(Decimal("1"), Decimal("100.00"), Decimal("100.00")),
(Decimal("1.5"), Decimal("200.5000"), Decimal("300.7500")),
(Decimal("0.33333333"), Decimal("2750.1234"), Decimal("917.5041")),
(Decimal("100"), Decimal("150.1234"), Decimal("15012.34")),
]
for i, (shares, cost_basis, total_cost) in enumerate(test_data):
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker=f"MULTI{i}",
shares=shares,
average_cost_basis=cost_basis,
total_cost=total_cost,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
positions = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id)
.all()
)
assert len(positions) == 4
for i, (expected_shares, expected_cost, _expected_total) in enumerate(
test_data
):
position = next(p for p in positions if p.ticker == f"MULTI{i}")
assert position.shares == expected_shares
assert position.average_cost_basis == expected_cost
# ============================================================================
# Timestamp Tests
# ============================================================================
class TestTimestampMixin:
"""Test suite for TimestampMixin functionality."""
def test_portfolio_created_at_set_on_creation(self, db_session: Session):
"""Test that created_at is set when portfolio is created."""
before = datetime.now(UTC)
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
after = datetime.now(UTC)
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.created_at is not None
assert before <= retrieved.created_at <= after
def test_portfolio_updated_at_set_on_creation(self, db_session: Session):
"""Test that updated_at is set when portfolio is created."""
before = datetime.now(UTC)
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
after = datetime.now(UTC)
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.updated_at is not None
assert before <= retrieved.updated_at <= after
def test_position_created_at_set_on_creation(self, db_session: Session):
"""Test that created_at is set when position is created."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
before = datetime.now(UTC)
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
after = datetime.now(UTC)
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.created_at is not None
assert before <= retrieved.created_at <= after
def test_position_updated_at_set_on_creation(self, db_session: Session):
"""Test that updated_at is set when position is created."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
before = datetime.now(UTC)
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
after = datetime.now(UTC)
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.updated_at is not None
assert before <= retrieved.updated_at <= after
def test_created_at_does_not_change_on_update(self, db_session: Session):
"""Test that created_at remains unchanged when portfolio is updated."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
original_created_at = portfolio.created_at
import time
time.sleep(0.01)
portfolio.name = "Updated Name"
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.created_at == original_created_at
def test_timezone_aware_datetimes(self, db_session: Session):
"""Test that datetimes are timezone-aware."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.created_at.tzinfo is not None
assert retrieved.updated_at.tzinfo is not None
# ============================================================================
# Default Value Tests
# ============================================================================
class TestDefaultValues:
"""Test suite for default values in models."""
def test_portfolio_default_user_id(self, db_session: Session):
"""Test that portfolio has default user_id."""
portfolio = UserPortfolio(name="Custom Name")
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.user_id == "default"
def test_portfolio_default_name(self, db_session: Session):
"""Test that portfolio has default name."""
portfolio = UserPortfolio(user_id="custom_user")
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.name == "My Portfolio"
def test_position_default_notes(self, db_session: Session):
"""Test that position notes default to None."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.notes is None
# ============================================================================
# Integration Tests
# ============================================================================
class TestPortfolioIntegration:
"""End-to-end integration tests combining multiple operations."""
def test_complete_portfolio_workflow(self, db_session: Session):
"""Test complete workflow: create, read, update, delete."""
# Create portfolio
user_id = f"test_user_{uuid.uuid4()}"
portfolio_name = f"Integration Test {uuid.uuid4()}"
portfolio = UserPortfolio(user_id=user_id, name=portfolio_name)
db_session.add(portfolio)
db_session.commit()
# Add positions
position1 = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC) - timedelta(days=30),
notes="Initial purchase",
)
position2 = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC) - timedelta(days=15),
)
db_session.add_all([position1, position2])
db_session.commit()
# Read and verify
retrieved_portfolio = (
db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
)
assert retrieved_portfolio is not None
assert len(retrieved_portfolio.positions) == 2
# Update position
aapl_position = next(
p for p in retrieved_portfolio.positions if p.ticker == "AAPL"
)
original_shares = aapl_position.shares
aapl_position.shares = Decimal("20.00000000")
aapl_position.average_cost_basis = Decimal("160.0000")
aapl_position.total_cost = Decimal("3200.0000")
db_session.commit()
# Verify update
retrieved_position = (
db_session.query(PortfolioPosition).filter_by(id=aapl_position.id).first()
)
assert retrieved_position.shares == Decimal("20.00000000")
assert retrieved_position.shares != original_shares
# Delete one position
db_session.delete(aapl_position)
db_session.commit()
# Verify deletion
remaining_positions = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id)
.all()
)
assert len(remaining_positions) == 1
assert remaining_positions[0].ticker == "MSFT"
# Delete portfolio (cascade delete)
db_session.delete(retrieved_portfolio)
db_session.commit()
# Verify cascade delete
portfolio_check = (
db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
)
assert portfolio_check is None
positions_check = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id)
.all()
)
assert len(positions_check) == 0
def test_complex_portfolio_with_multiple_users(self, db_session: Session):
"""Test complex scenario with multiple portfolios and users."""
user_ids = [f"user_{uuid.uuid4()}" for _ in range(3)]
portfolios = []
# Create portfolios for multiple users
for user_id in user_ids:
for i in range(2):
portfolio = UserPortfolio(
user_id=user_id, name=f"Portfolio {i} {uuid.uuid4()}"
)
db_session.add(portfolio)
portfolios.append(portfolio)
db_session.commit()
# Add positions to each portfolio
tickers = ["AAPL", "MSFT", "GOOG", "AMZN", "TSLA"]
for portfolio in portfolios:
for ticker in tickers[:3]: # Add 3 positions per portfolio
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker=ticker,
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
# Verify structure
for user_id in user_ids:
user_portfolios = (
db_session.query(UserPortfolio).filter_by(user_id=user_id).all()
)
assert len(user_portfolios) == 2
for portfolio in user_portfolios:
assert len(portfolio.positions) == 3
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/vectorbt_engine.py:
--------------------------------------------------------------------------------
```python
"""VectorBT backtesting engine implementation with memory management and structured logging."""
import gc
from typing import Any
import numpy as np
import pandas as pd
import vectorbt as vbt
from pandas import DataFrame, Series
from maverick_mcp.backtesting.batch_processing import BatchProcessingMixin
from maverick_mcp.data.cache import (
CacheManager,
ensure_timezone_naive,
generate_cache_key,
)
from maverick_mcp.providers.stock_data import EnhancedStockDataProvider
from maverick_mcp.utils.cache_warmer import CacheWarmer
from maverick_mcp.utils.data_chunking import DataChunker, optimize_dataframe_dtypes
from maverick_mcp.utils.memory_profiler import (
check_memory_leak,
cleanup_dataframes,
get_memory_stats,
memory_context,
profile_memory,
)
from maverick_mcp.utils.structured_logger import (
get_performance_logger,
get_structured_logger,
with_structured_logging,
)
logger = get_structured_logger(__name__)
performance_logger = get_performance_logger("vectorbt_engine")
class VectorBTEngine(BatchProcessingMixin):
"""High-performance backtesting engine using VectorBT with memory management."""
def __init__(
self,
data_provider: EnhancedStockDataProvider | None = None,
cache_service=None,
enable_memory_profiling: bool = True,
chunk_size_mb: float = 100.0,
):
"""Initialize VectorBT engine.
Args:
data_provider: Stock data provider instance
cache_service: Cache service for data persistence
enable_memory_profiling: Enable memory profiling and optimization
chunk_size_mb: Chunk size for large dataset processing
"""
self.data_provider = data_provider or EnhancedStockDataProvider()
self.cache = cache_service or CacheManager()
self.cache_warmer = CacheWarmer(
data_provider=self.data_provider, cache_manager=self.cache
)
# Memory management configuration
self.enable_memory_profiling = enable_memory_profiling
self.chunker = DataChunker(
chunk_size_mb=chunk_size_mb, optimize_chunks=True, auto_gc=True
)
# Configure VectorBT settings for optimal performance and memory usage
try:
vbt.settings.array_wrapper["freq"] = "D"
vbt.settings.caching["enabled"] = True # Enable VectorBT's internal caching
# Don't set whitelist to avoid cache condition issues
except (KeyError, Exception) as e:
logger.warning(f"Could not configure VectorBT settings: {e}")
logger.info(
f"VectorBT engine initialized with memory profiling: {enable_memory_profiling}"
)
# Initialize memory tracking
if self.enable_memory_profiling:
initial_stats = get_memory_stats()
logger.debug(f"Initial memory stats: {initial_stats}")
@with_structured_logging(
"get_historical_data", include_performance=True, log_params=True
)
@profile_memory(log_results=True, threshold_mb=50.0)
async def get_historical_data(
self, symbol: str, start_date: str, end_date: str, interval: str = "1d"
) -> DataFrame:
"""Fetch historical data for backtesting with memory optimization.
Args:
symbol: Stock symbol
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
interval: Data interval (1d, 1h, etc.)
Returns:
Memory-optimized DataFrame with OHLCV data
"""
# Generate versioned cache key
cache_key = generate_cache_key(
"backtest_data",
symbol=symbol,
start_date=start_date,
end_date=end_date,
interval=interval,
)
# Try cache first with improved deserialization
cached_data = await self.cache.get(cache_key)
if cached_data is not None:
if isinstance(cached_data, pd.DataFrame):
# Already a DataFrame - ensure timezone-naive
df = ensure_timezone_naive(cached_data)
else:
# Restore DataFrame from dict (legacy JSON cache)
df = pd.DataFrame.from_dict(cached_data, orient="index")
# Convert index back to datetime
df.index = pd.to_datetime(df.index)
df = ensure_timezone_naive(df)
# Ensure column names are lowercase
df.columns = [col.lower() for col in df.columns]
return df
# Fetch from provider - try async method first, fallback to sync
try:
data = await self._get_data_async(symbol, start_date, end_date, interval)
except AttributeError:
# Fallback to sync method if async not available
data = self.data_provider.get_stock_data(
symbol=symbol,
start_date=start_date,
end_date=end_date,
interval=interval,
)
if data is None or data.empty:
raise ValueError(f"No data available for {symbol}")
# Normalize column names to lowercase for consistency
data.columns = [col.lower() for col in data.columns]
# Ensure timezone-naive index and fix any timezone comparison issues
data = ensure_timezone_naive(data)
# Optimize DataFrame memory usage
if self.enable_memory_profiling:
data = optimize_dataframe_dtypes(data, aggressive=False)
logger.debug(f"Optimized {symbol} data memory usage")
# Cache with adaptive TTL - longer for older data
from datetime import datetime
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
days_old = (datetime.now() - end_dt).days
ttl = 86400 if days_old > 7 else 3600 # 24h for older data, 1h for recent
await self.cache.set(cache_key, data, ttl=ttl)
return data
async def _get_data_async(
self, symbol: str, start_date: str, end_date: str, interval: str
) -> DataFrame:
"""Get data using async method if available."""
if hasattr(self.data_provider, "get_stock_data_async"):
return await self.data_provider.get_stock_data_async(
symbol=symbol,
start_date=start_date,
end_date=end_date,
interval=interval,
)
else:
# Fallback to sync method
return self.data_provider.get_stock_data(
symbol=symbol,
start_date=start_date,
end_date=end_date,
interval=interval,
)
@with_structured_logging(
"run_backtest", include_performance=True, log_params=True, log_result=False
)
@profile_memory(log_results=True, threshold_mb=200.0)
async def run_backtest(
self,
symbol: str,
strategy_type: str,
parameters: dict[str, Any],
start_date: str,
end_date: str,
initial_capital: float = 10000.0,
fees: float = 0.001,
slippage: float = 0.001,
) -> dict[str, Any]:
"""Run a vectorized backtest with memory optimization.
Args:
symbol: Stock symbol
strategy_type: Type of strategy (sma_cross, rsi, macd, etc.)
parameters: Strategy parameters
start_date: Start date
end_date: End date
initial_capital: Starting capital
fees: Trading fees (percentage)
slippage: Slippage (percentage)
Returns:
Dictionary with backtest results
"""
with memory_context("backtest_execution"):
# Fetch data
data = await self.get_historical_data(symbol, start_date, end_date)
# Check for large datasets and warn
data_memory_mb = data.memory_usage(deep=True).sum() / (1024**2)
if data_memory_mb > 100:
logger.warning(f"Large dataset detected: {data_memory_mb:.2f}MB")
# Log business metrics
performance_logger.log_business_metric(
"dataset_size_mb",
data_memory_mb,
symbol=symbol,
date_range_days=(
pd.to_datetime(end_date) - pd.to_datetime(start_date)
).days,
)
# Generate signals based on strategy
entries, exits = self._generate_signals(data, strategy_type, parameters)
# Optimize memory usage - use efficient data types
with memory_context("data_optimization"):
close_prices = data["close"].astype(np.float32)
entries = entries.astype(bool)
exits = exits.astype(bool)
# Clean up original data to free memory
if self.enable_memory_profiling:
cleanup_dataframes(data)
del data # Explicit deletion
gc.collect() # Force garbage collection
# Run VectorBT portfolio simulation with memory optimizations
with memory_context("portfolio_simulation"):
portfolio = vbt.Portfolio.from_signals(
close=close_prices,
entries=entries,
exits=exits,
init_cash=initial_capital,
fees=fees,
slippage=slippage,
freq="D",
cash_sharing=False, # Disable cash sharing for single asset
call_seq="auto", # Optimize call sequence
group_by=False, # Disable grouping for memory efficiency
broadcast_kwargs={"wrapper_kwargs": {"freq": "D"}},
)
# Extract comprehensive metrics with memory tracking
with memory_context("results_extraction"):
metrics = self._extract_metrics(portfolio)
trades = self._extract_trades(portfolio)
# Get equity curve - convert to list for smaller cache size
equity_curve = {
str(k): float(v) for k, v in portfolio.value().to_dict().items()
}
drawdown_series = {
str(k): float(v) for k, v in portfolio.drawdown().to_dict().items()
}
# Clean up portfolio object to free memory
if self.enable_memory_profiling:
del portfolio
cleanup_dataframes(close_prices) if hasattr(
close_prices, "_mgr"
) else None
del close_prices, entries, exits
gc.collect()
# Add memory statistics to results if profiling enabled
result = {
"symbol": symbol,
"strategy": strategy_type,
"parameters": parameters,
"metrics": metrics,
"trades": trades,
"equity_curve": equity_curve,
"drawdown_series": drawdown_series,
"start_date": start_date,
"end_date": end_date,
"initial_capital": initial_capital,
}
if self.enable_memory_profiling:
result["memory_stats"] = get_memory_stats()
# Check for potential memory leaks
if check_memory_leak(threshold_mb=50.0):
logger.warning("Potential memory leak detected during backtesting")
# Log business metrics for backtesting results
performance_logger.log_business_metric(
"backtest_total_return",
metrics.get("total_return", 0),
symbol=symbol,
strategy=strategy_type,
trade_count=metrics.get("total_trades", 0),
)
performance_logger.log_business_metric(
"backtest_sharpe_ratio",
metrics.get("sharpe_ratio", 0),
symbol=symbol,
strategy=strategy_type,
)
return result
def _generate_signals(
self, data: DataFrame, strategy_type: str, parameters: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate entry and exit signals based on strategy.
Args:
data: Price data
strategy_type: Strategy type
parameters: Strategy parameters
Returns:
Tuple of (entry_signals, exit_signals)
"""
# Ensure we have the required price data
if "close" not in data.columns:
raise ValueError(
f"Missing 'close' column in price data. Available columns: {list(data.columns)}"
)
close = data["close"]
if strategy_type in ["sma_cross", "sma_crossover"]:
return self._sma_crossover_signals(close, parameters)
elif strategy_type == "rsi":
return self._rsi_signals(close, parameters)
elif strategy_type == "macd":
return self._macd_signals(close, parameters)
elif strategy_type == "bollinger":
return self._bollinger_bands_signals(close, parameters)
elif strategy_type == "momentum":
return self._momentum_signals(close, parameters)
elif strategy_type == "ema_cross":
return self._ema_crossover_signals(close, parameters)
elif strategy_type == "mean_reversion":
return self._mean_reversion_signals(close, parameters)
elif strategy_type == "breakout":
return self._breakout_signals(close, parameters)
elif strategy_type == "volume_momentum":
return self._volume_momentum_signals(data, parameters)
elif strategy_type == "online_learning":
return self._online_learning_signals(data, parameters)
elif strategy_type == "regime_aware":
return self._regime_aware_signals(data, parameters)
elif strategy_type == "ensemble":
return self._ensemble_signals(data, parameters)
else:
raise ValueError(f"Unknown strategy type: {strategy_type}")
def _sma_crossover_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate SMA crossover signals."""
# Support both parameter naming conventions
fast_period = params.get("fast_period", params.get("fast_window", 10))
slow_period = params.get("slow_period", params.get("slow_window", 20))
fast_sma = vbt.MA.run(close, fast_period, short_name="fast").ma.squeeze()
slow_sma = vbt.MA.run(close, slow_period, short_name="slow").ma.squeeze()
entries = (fast_sma > slow_sma) & (fast_sma.shift(1) <= slow_sma.shift(1))
exits = (fast_sma < slow_sma) & (fast_sma.shift(1) >= slow_sma.shift(1))
return entries, exits
def _rsi_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate RSI-based signals."""
period = params.get("period", 14)
oversold = params.get("oversold", 30)
overbought = params.get("overbought", 70)
rsi = vbt.RSI.run(close, period).rsi.squeeze()
entries = (rsi < oversold) & (rsi.shift(1) >= oversold)
exits = (rsi > overbought) & (rsi.shift(1) <= overbought)
return entries, exits
def _macd_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate MACD signals."""
fast_period = params.get("fast_period", 12)
slow_period = params.get("slow_period", 26)
signal_period = params.get("signal_period", 9)
macd = vbt.MACD.run(
close,
fast_window=fast_period,
slow_window=slow_period,
signal_window=signal_period,
)
macd_line = macd.macd.squeeze()
signal_line = macd.signal.squeeze()
entries = (macd_line > signal_line) & (
macd_line.shift(1) <= signal_line.shift(1)
)
exits = (macd_line < signal_line) & (macd_line.shift(1) >= signal_line.shift(1))
return entries, exits
def _bollinger_bands_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate Bollinger Bands signals."""
period = params.get("period", 20)
std_dev = params.get("std_dev", 2)
bb = vbt.BBANDS.run(close, window=period, alpha=std_dev)
upper = bb.upper.squeeze()
lower = bb.lower.squeeze()
# Buy when price touches lower band, sell when touches upper
entries = (close <= lower) & (close.shift(1) > lower.shift(1))
exits = (close >= upper) & (close.shift(1) < upper.shift(1))
return entries, exits
def _momentum_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate momentum-based signals."""
lookback = params.get("lookback", 20)
threshold = params.get("threshold", 0.05)
returns = close.pct_change(lookback)
entries = returns > threshold
exits = returns < -threshold
return entries, exits
def _ema_crossover_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate EMA crossover signals."""
fast_period = params.get("fast_period", 12)
slow_period = params.get("slow_period", 26)
fast_ema = vbt.MA.run(close, fast_period, ewm=True).ma.squeeze()
slow_ema = vbt.MA.run(close, slow_period, ewm=True).ma.squeeze()
entries = (fast_ema > slow_ema) & (fast_ema.shift(1) <= slow_ema.shift(1))
exits = (fast_ema < slow_ema) & (fast_ema.shift(1) >= slow_ema.shift(1))
return entries, exits
def _mean_reversion_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate mean reversion signals."""
ma_period = params.get("ma_period", 20)
entry_threshold = params.get("entry_threshold", 0.02)
exit_threshold = params.get("exit_threshold", 0.01)
ma = vbt.MA.run(close, ma_period).ma.squeeze()
# Avoid division by zero in deviation calculation
with np.errstate(divide="ignore", invalid="ignore"):
deviation = np.where(ma != 0, (close - ma) / ma, 0)
entries = deviation < -entry_threshold
exits = deviation > exit_threshold
return entries, exits
def _breakout_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate channel breakout signals."""
lookback = params.get("lookback", 20)
exit_lookback = params.get("exit_lookback", 10)
upper_channel = close.rolling(lookback).max()
lower_channel = close.rolling(exit_lookback).min()
entries = close > upper_channel.shift(1)
exits = close < lower_channel.shift(1)
return entries, exits
def _volume_momentum_signals(
self, data: DataFrame, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate volume-weighted momentum signals."""
momentum_period = params.get("momentum_period", 20)
volume_period = params.get("volume_period", 20)
momentum_threshold = params.get("momentum_threshold", 0.05)
volume_multiplier = params.get("volume_multiplier", 1.5)
close = data["close"]
volume = data.get("volume")
if volume is None:
# Fallback to pure momentum if no volume data
returns = close.pct_change(momentum_period)
entries = returns > momentum_threshold
exits = returns < -momentum_threshold
return entries, exits
returns = close.pct_change(momentum_period)
avg_volume = volume.rolling(volume_period).mean()
volume_surge = volume > (avg_volume * volume_multiplier)
# Entry: positive momentum with volume surge
entries = (returns > momentum_threshold) & volume_surge
# Exit: negative momentum or volume dry up
exits = (returns < -momentum_threshold) | (volume < avg_volume * 0.8)
return entries, exits
def _extract_metrics(self, portfolio: vbt.Portfolio) -> dict[str, Any]:
"""Extract comprehensive metrics from portfolio."""
def safe_float_metric(metric_func, default=0.0):
"""Safely extract float metrics, handling None and NaN values."""
try:
value = metric_func()
if value is None or np.isnan(value) or np.isinf(value):
return default
return float(value)
except (ZeroDivisionError, ValueError, TypeError):
return default
return {
"total_return": safe_float_metric(portfolio.total_return),
"annual_return": safe_float_metric(portfolio.annualized_return),
"sharpe_ratio": safe_float_metric(portfolio.sharpe_ratio),
"sortino_ratio": safe_float_metric(portfolio.sortino_ratio),
"calmar_ratio": safe_float_metric(portfolio.calmar_ratio),
"max_drawdown": safe_float_metric(portfolio.max_drawdown),
"win_rate": safe_float_metric(lambda: portfolio.trades.win_rate()),
"profit_factor": safe_float_metric(
lambda: portfolio.trades.profit_factor()
),
"expectancy": safe_float_metric(lambda: portfolio.trades.expectancy()),
"total_trades": int(portfolio.trades.count()),
"winning_trades": int(portfolio.trades.winning.count())
if hasattr(portfolio.trades, "winning")
else 0,
"losing_trades": int(portfolio.trades.losing.count())
if hasattr(portfolio.trades, "losing")
else 0,
"avg_win": safe_float_metric(
lambda: portfolio.trades.winning.pnl.mean()
if hasattr(portfolio.trades, "winning")
and portfolio.trades.winning.count() > 0
else None
),
"avg_loss": safe_float_metric(
lambda: portfolio.trades.losing.pnl.mean()
if hasattr(portfolio.trades, "losing")
and portfolio.trades.losing.count() > 0
else None
),
"best_trade": safe_float_metric(
lambda: portfolio.trades.pnl.max()
if portfolio.trades.count() > 0
else None
),
"worst_trade": safe_float_metric(
lambda: portfolio.trades.pnl.min()
if portfolio.trades.count() > 0
else None
),
"avg_duration": safe_float_metric(lambda: portfolio.trades.duration.mean()),
"kelly_criterion": self._calculate_kelly(portfolio),
"recovery_factor": self._calculate_recovery_factor(portfolio),
"risk_reward_ratio": self._calculate_risk_reward(portfolio),
}
def _extract_trades(self, portfolio: vbt.Portfolio) -> list:
"""Extract trade records from portfolio."""
if portfolio.trades.count() == 0:
return []
trades = portfolio.trades.records_readable
# Vectorized operation for better performance
trade_list = [
{
"entry_date": str(trade.get("Entry Timestamp", "")),
"exit_date": str(trade.get("Exit Timestamp", "")),
"entry_price": float(trade.get("Avg Entry Price", 0)),
"exit_price": float(trade.get("Avg Exit Price", 0)),
"size": float(trade.get("Size", 0)),
"pnl": float(trade.get("PnL", 0)),
"return": float(trade.get("Return", 0)),
"duration": str(trade.get("Duration", "")),
}
for _, trade in trades.iterrows()
]
return trade_list
def _calculate_kelly(self, portfolio: vbt.Portfolio) -> float:
"""Calculate Kelly Criterion."""
if portfolio.trades.count() == 0:
return 0.0
try:
win_rate = portfolio.trades.win_rate()
if win_rate is None or np.isnan(win_rate):
return 0.0
avg_win = (
abs(portfolio.trades.winning.returns.mean() or 0)
if hasattr(portfolio.trades, "winning")
and portfolio.trades.winning.count() > 0
else 0
)
avg_loss = (
abs(portfolio.trades.losing.returns.mean() or 0)
if hasattr(portfolio.trades, "losing")
and portfolio.trades.losing.count() > 0
else 0
)
# Check for division by zero and invalid values
if avg_loss == 0 or avg_win == 0 or np.isnan(avg_win) or np.isnan(avg_loss):
return 0.0
# Calculate Kelly with safe division
with np.errstate(divide="ignore", invalid="ignore"):
kelly = (win_rate * avg_win - (1 - win_rate) * avg_loss) / avg_win
# Check if result is valid
if np.isnan(kelly) or np.isinf(kelly):
return 0.0
return float(
min(max(kelly, -1.0), 0.25)
) # Cap between -100% and 25% for safety
except (ZeroDivisionError, ValueError, TypeError):
return 0.0
def get_memory_report(self) -> dict[str, Any]:
"""Get comprehensive memory usage report."""
if not self.enable_memory_profiling:
return {"message": "Memory profiling disabled"}
return get_memory_stats()
def clear_memory_cache(self) -> None:
"""Clear internal memory caches and force garbage collection."""
if hasattr(vbt.settings, "caching"):
vbt.settings.caching.clear()
gc.collect()
logger.info("Memory cache cleared and garbage collection performed")
def optimize_for_memory(self, aggressive: bool = False) -> None:
"""Optimize VectorBT settings for memory efficiency.
Args:
aggressive: Use aggressive memory optimizations
"""
if aggressive:
# Aggressive memory settings
vbt.settings.caching["enabled"] = False # Disable caching
vbt.settings.array_wrapper["dtype"] = np.float32 # Use float32
logger.info("Applied aggressive memory optimizations")
else:
# Conservative memory settings
vbt.settings.caching["enabled"] = True
vbt.settings.caching["max_size"] = 100 # Limit cache size
logger.info("Applied conservative memory optimizations")
async def run_memory_efficient_backtest(
self,
symbol: str,
strategy_type: str,
parameters: dict[str, Any],
start_date: str,
end_date: str,
initial_capital: float = 10000.0,
fees: float = 0.001,
slippage: float = 0.001,
chunk_data: bool = False,
) -> dict[str, Any]:
"""Run backtest with maximum memory efficiency.
Args:
symbol: Stock symbol
strategy_type: Strategy type
parameters: Strategy parameters
start_date: Start date
end_date: End date
initial_capital: Starting capital
fees: Trading fees
slippage: Slippage
chunk_data: Whether to process data in chunks
Returns:
Backtest results with memory statistics
"""
# Temporarily optimize for memory
original_settings = {
"caching_enabled": vbt.settings.caching.get("enabled", True),
"array_dtype": vbt.settings.array_wrapper.get("dtype", np.float64),
}
try:
self.optimize_for_memory(aggressive=True)
if chunk_data:
# Use chunked processing for very large datasets
return await self._run_chunked_backtest(
symbol,
strategy_type,
parameters,
start_date,
end_date,
initial_capital,
fees,
slippage,
)
else:
return await self.run_backtest(
symbol,
strategy_type,
parameters,
start_date,
end_date,
initial_capital,
fees,
slippage,
)
finally:
# Restore original settings
vbt.settings.caching["enabled"] = original_settings["caching_enabled"]
vbt.settings.array_wrapper["dtype"] = original_settings["array_dtype"]
async def _run_chunked_backtest(
self,
symbol: str,
strategy_type: str,
parameters: dict[str, Any],
start_date: str,
end_date: str,
initial_capital: float,
fees: float,
slippage: float,
) -> dict[str, Any]:
"""Run backtest using data chunking for very large datasets."""
from datetime import datetime, timedelta
# Calculate date chunks (monthly)
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
results = []
current_capital = initial_capital
current_date = start_dt
while current_date < end_dt:
chunk_end = min(current_date + timedelta(days=90), end_dt) # 3-month chunks
chunk_start_str = current_date.strftime("%Y-%m-%d")
chunk_end_str = chunk_end.strftime("%Y-%m-%d")
logger.debug(f"Processing chunk: {chunk_start_str} to {chunk_end_str}")
# Run backtest for chunk
chunk_result = await self.run_backtest(
symbol,
strategy_type,
parameters,
chunk_start_str,
chunk_end_str,
current_capital,
fees,
slippage,
)
results.append(chunk_result)
# Update capital for next chunk
final_value = chunk_result.get("metrics", {}).get("total_return", 0)
current_capital = current_capital * (1 + final_value)
current_date = chunk_end
# Combine results
return self._combine_chunked_results(results, symbol, strategy_type, parameters)
def _combine_chunked_results(
self,
chunk_results: list[dict],
symbol: str,
strategy_type: str,
parameters: dict[str, Any],
) -> dict[str, Any]:
"""Combine results from chunked backtesting."""
if not chunk_results:
return {}
# Combine trades
all_trades = []
for chunk in chunk_results:
all_trades.extend(chunk.get("trades", []))
# Combine equity curves
combined_equity = {}
combined_drawdown = {}
for chunk in chunk_results:
combined_equity.update(chunk.get("equity_curve", {}))
combined_drawdown.update(chunk.get("drawdown_series", {}))
# Calculate combined metrics
total_return = 1.0
for chunk in chunk_results:
chunk_return = chunk.get("metrics", {}).get("total_return", 0)
total_return *= 1 + chunk_return
total_return -= 1.0
combined_metrics = {
"total_return": total_return,
"total_trades": len(all_trades),
"chunks_processed": len(chunk_results),
}
return {
"symbol": symbol,
"strategy": strategy_type,
"parameters": parameters,
"metrics": combined_metrics,
"trades": all_trades,
"equity_curve": combined_equity,
"drawdown_series": combined_drawdown,
"processing_method": "chunked",
"memory_stats": get_memory_stats()
if self.enable_memory_profiling
else None,
}
def _calculate_recovery_factor(self, portfolio: vbt.Portfolio) -> float:
"""Calculate recovery factor (total return / max drawdown)."""
try:
max_dd = portfolio.max_drawdown()
total_return = portfolio.total_return()
# Check for invalid values
if (
max_dd is None
or np.isnan(max_dd)
or max_dd == 0
or total_return is None
or np.isnan(total_return)
):
return 0.0
# Calculate with safe division
with np.errstate(divide="ignore", invalid="ignore"):
recovery_factor = total_return / abs(max_dd)
# Check if result is valid
if np.isnan(recovery_factor) or np.isinf(recovery_factor):
return 0.0
return float(recovery_factor)
except (ZeroDivisionError, ValueError, TypeError):
return 0.0
def _calculate_risk_reward(self, portfolio: vbt.Portfolio) -> float:
"""Calculate risk-reward ratio."""
if portfolio.trades.count() == 0:
return 0.0
try:
avg_win = (
abs(portfolio.trades.winning.pnl.mean() or 0)
if hasattr(portfolio.trades, "winning")
and portfolio.trades.winning.count() > 0
else 0
)
avg_loss = (
abs(portfolio.trades.losing.pnl.mean() or 0)
if hasattr(portfolio.trades, "losing")
and portfolio.trades.losing.count() > 0
else 0
)
# Check for division by zero and invalid values
if (
avg_loss == 0
or avg_win == 0
or np.isnan(avg_win)
or np.isnan(avg_loss)
or np.isinf(avg_win)
or np.isinf(avg_loss)
):
return 0.0
# Calculate with safe division
with np.errstate(divide="ignore", invalid="ignore"):
risk_reward = avg_win / avg_loss
# Check if result is valid
if np.isnan(risk_reward) or np.isinf(risk_reward):
return 0.0
return float(risk_reward)
except (ZeroDivisionError, ValueError, TypeError):
return 0.0
@with_structured_logging(
"optimize_parameters",
include_performance=True,
log_params=True,
log_result=False,
)
@profile_memory(log_results=True, threshold_mb=500.0)
async def optimize_parameters(
self,
symbol: str,
strategy_type: str,
param_grid: dict[str, list],
start_date: str,
end_date: str,
optimization_metric: str = "sharpe_ratio",
initial_capital: float = 10000.0,
top_n: int = 10,
use_chunking: bool = True,
) -> dict[str, Any]:
"""Optimize strategy parameters using memory-efficient grid search.
Args:
symbol: Stock symbol
strategy_type: Strategy type
param_grid: Parameter grid for optimization
start_date: Start date
end_date: End date
optimization_metric: Metric to optimize
initial_capital: Starting capital
top_n: Number of top results to return
use_chunking: Use chunking for large parameter grids
Returns:
Optimization results with best parameters
"""
with memory_context("parameter_optimization"):
# Fetch data once
data = await self.get_historical_data(symbol, start_date, end_date)
# Create parameter combinations
param_combos = vbt.utils.params.create_param_combs(param_grid)
total_combos = len(param_combos)
logger.info(
f"Optimizing {total_combos} parameter combinations for {symbol}"
)
# Pre-convert data for optimization with memory efficiency
close_prices = data["close"].astype(np.float32)
# Check if we should use chunking for large parameter grids
if use_chunking and total_combos > 100:
logger.info(f"Using chunked processing for {total_combos} combinations")
chunk_size = min(50, max(10, total_combos // 10)) # Adaptive chunk size
results = self._optimize_parameters_chunked(
data,
close_prices,
strategy_type,
param_combos,
optimization_metric,
initial_capital,
chunk_size,
)
else:
results = []
for i, params in enumerate(param_combos):
try:
with memory_context(f"param_combo_{i}"):
# Generate signals for this parameter set
entries, exits = self._generate_signals(
data, strategy_type, params
)
# Convert to boolean arrays for memory efficiency
entries = entries.astype(bool)
exits = exits.astype(bool)
# Run backtest with optimizations
portfolio = vbt.Portfolio.from_signals(
close=close_prices,
entries=entries,
exits=exits,
init_cash=initial_capital,
fees=0.001,
freq="D",
cash_sharing=False,
call_seq="auto",
group_by=False, # Memory optimization
)
# Get optimization metric
metric_value = self._get_metric_value(
portfolio, optimization_metric
)
results.append(
{
"parameters": params,
optimization_metric: metric_value,
"total_return": float(portfolio.total_return()),
"max_drawdown": float(portfolio.max_drawdown()),
"total_trades": int(portfolio.trades.count()),
}
)
# Clean up intermediate objects
del portfolio, entries, exits
if i % 20 == 0: # Periodic cleanup
gc.collect()
except Exception as e:
logger.debug(f"Skipping invalid parameter combination {i}: {e}")
continue
# Clean up data objects
if self.enable_memory_profiling:
cleanup_dataframes(data, close_prices) if hasattr(
data, "_mgr"
) else None
del data, close_prices
gc.collect()
# Sort by optimization metric
results.sort(key=lambda x: x[optimization_metric], reverse=True)
# Get top N results
top_results = results[:top_n]
result = {
"symbol": symbol,
"strategy": strategy_type,
"optimization_metric": optimization_metric,
"best_parameters": top_results[0]["parameters"] if top_results else {},
"best_metric_value": top_results[0][optimization_metric]
if top_results
else 0,
"top_results": top_results,
"total_combinations_tested": total_combos,
"valid_combinations": len(results),
}
if self.enable_memory_profiling:
result["memory_stats"] = get_memory_stats()
return result
def _optimize_parameters_chunked(
self,
data: DataFrame,
close_prices: Series,
strategy_type: str,
param_combos: list,
optimization_metric: str,
initial_capital: float,
chunk_size: int,
) -> list[dict]:
"""Optimize parameters using chunked processing for memory efficiency."""
results = []
total_chunks = len(param_combos) // chunk_size + (
1 if len(param_combos) % chunk_size else 0
)
for chunk_idx in range(0, len(param_combos), chunk_size):
chunk_params = param_combos[chunk_idx : chunk_idx + chunk_size]
logger.debug(
f"Processing chunk {chunk_idx // chunk_size + 1}/{total_chunks}"
)
with memory_context(f"param_chunk_{chunk_idx // chunk_size}"):
for _, params in enumerate(chunk_params):
try:
# Generate signals for this parameter set
entries, exits = self._generate_signals(
data, strategy_type, params
)
# Convert to boolean arrays for memory efficiency
entries = entries.astype(bool)
exits = exits.astype(bool)
# Run backtest with optimizations
portfolio = vbt.Portfolio.from_signals(
close=close_prices,
entries=entries,
exits=exits,
init_cash=initial_capital,
fees=0.001,
freq="D",
cash_sharing=False,
call_seq="auto",
group_by=False,
)
# Get optimization metric
metric_value = self._get_metric_value(
portfolio, optimization_metric
)
results.append(
{
"parameters": params,
optimization_metric: metric_value,
"total_return": float(portfolio.total_return()),
"max_drawdown": float(portfolio.max_drawdown()),
"total_trades": int(portfolio.trades.count()),
}
)
# Clean up intermediate objects
del portfolio, entries, exits
except Exception as e:
logger.debug(f"Skipping invalid parameter combination: {e}")
continue
# Force garbage collection after each chunk
gc.collect()
return results
def _get_metric_value(self, portfolio: vbt.Portfolio, metric_name: str) -> float:
"""Get specific metric value from portfolio."""
metric_map = {
"total_return": portfolio.total_return,
"sharpe_ratio": portfolio.sharpe_ratio,
"sortino_ratio": portfolio.sortino_ratio,
"calmar_ratio": portfolio.calmar_ratio,
"max_drawdown": lambda: -portfolio.max_drawdown(),
"win_rate": lambda: portfolio.trades.win_rate() or 0,
"profit_factor": lambda: portfolio.trades.profit_factor() or 0,
}
if metric_name not in metric_map:
raise ValueError(f"Unknown metric: {metric_name}")
try:
value = metric_map[metric_name]()
# Check for invalid values
if value is None or np.isnan(value) or np.isinf(value):
return 0.0
return float(value)
except (ZeroDivisionError, ValueError, TypeError):
return 0.0
def _online_learning_signals(
self, data: DataFrame, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate online learning ML strategy signals.
Simple implementation using momentum with adaptive thresholds.
"""
lookback = params.get("lookback", 20)
learning_rate = params.get("learning_rate", 0.01)
close = data["close"]
returns = close.pct_change(lookback)
# Adaptive threshold based on rolling statistics
rolling_mean = returns.rolling(window=lookback).mean()
rolling_std = returns.rolling(window=lookback).std()
# Dynamic entry/exit thresholds
entry_threshold = rolling_mean + learning_rate * rolling_std
exit_threshold = rolling_mean - learning_rate * rolling_std
# Generate signals
entries = returns > entry_threshold
exits = returns < exit_threshold
# Fill NaN values
entries = entries.fillna(False)
exits = exits.fillna(False)
return entries, exits
def _regime_aware_signals(
self, data: DataFrame, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate regime-aware strategy signals.
Detects market regime and applies appropriate strategy.
"""
regime_window = params.get("regime_window", 50)
threshold = params.get("threshold", 0.02)
close = data["close"]
# Calculate regime indicators
returns = close.pct_change()
volatility = returns.rolling(window=regime_window).std()
trend_strength = close.rolling(window=regime_window).apply(
lambda x: (x[-1] - x[0]) / x[0] if x[0] != 0 else 0
)
# Determine regime: trending vs ranging
is_trending = abs(trend_strength) > threshold
# Trend following signals
sma_short = close.rolling(window=regime_window // 2).mean()
sma_long = close.rolling(window=regime_window).mean()
trend_entries = (close > sma_long) & (sma_short > sma_long)
trend_exits = (close < sma_long) & (sma_short < sma_long)
# Mean reversion signals
bb_upper = sma_long + 2 * volatility
bb_lower = sma_long - 2 * volatility
reversion_entries = close < bb_lower
reversion_exits = close > bb_upper
# Combine based on regime
entries = (is_trending & trend_entries) | (~is_trending & reversion_entries)
exits = (is_trending & trend_exits) | (~is_trending & reversion_exits)
# Fill NaN values
entries = entries.fillna(False)
exits = exits.fillna(False)
return entries, exits
def _ensemble_signals(
self, data: DataFrame, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate ensemble strategy signals.
Combines multiple strategies with voting.
"""
fast_period = params.get("fast_period", 10)
slow_period = params.get("slow_period", 20)
rsi_period = params.get("rsi_period", 14)
close = data["close"]
# Strategy 1: SMA Crossover
fast_sma = close.rolling(window=fast_period).mean()
slow_sma = close.rolling(window=slow_period).mean()
sma_entries = (fast_sma > slow_sma) & (fast_sma.shift(1) <= slow_sma.shift(1))
sma_exits = (fast_sma < slow_sma) & (fast_sma.shift(1) >= slow_sma.shift(1))
# Strategy 2: RSI
delta = close.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=rsi_period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=rsi_period).mean()
rs = gain / loss.replace(0, 1e-10)
rsi = 100 - (100 / (1 + rs))
rsi_entries = (rsi < 30) & (rsi.shift(1) >= 30)
rsi_exits = (rsi > 70) & (rsi.shift(1) <= 70)
# Strategy 3: Momentum
momentum = close.pct_change(20)
mom_entries = momentum > 0.05
mom_exits = momentum < -0.05
# Ensemble voting - at least 2 out of 3 strategies agree
entry_votes = (
sma_entries.astype(int) + rsi_entries.astype(int) + mom_entries.astype(int)
)
exit_votes = (
sma_exits.astype(int) + rsi_exits.astype(int) + mom_exits.astype(int)
)
entries = entry_votes >= 2
exits = exit_votes >= 2
# Fill NaN values
entries = entries.fillna(False)
exits = exits.fillna(False)
return entries, exits
```