This is page 27 of 29. Use http://codebase.md/wshobson/maverick-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/data/test_portfolio_models.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive integration tests for portfolio database models and migration.
This module tests:
1. Migration upgrade and downgrade operations
2. SQLAlchemy model CRUD operations (Create, Read, Update, Delete)
3. Database constraints (unique constraints, foreign keys, cascade deletes)
4. Relationships between UserPortfolio and PortfolioPosition
5. Decimal field precision for financial data (Numeric(12,4) and Numeric(20,8))
6. Timezone-aware datetime fields
7. Index creation and query optimization
Test Coverage:
- Migration creates tables with correct schema
- Indexes are created properly for performance optimization
- Unique constraints work for both portfolio and position level
- Cascade delete removes positions when portfolio is deleted
- Decimal precision is maintained through round-trip database operations
- Relationships are properly loaded with selectin strategy
- Default values are applied correctly (user_id="default", name="My Portfolio")
- Timestamp mixin functionality (created_at, updated_at)
Test Markers:
- @pytest.mark.integration - Full database integration tests
"""
import uuid
from datetime import UTC, datetime, timedelta
from decimal import Decimal
import pytest
from sqlalchemy import exc, inspect
from sqlalchemy.orm import Session
from maverick_mcp.data.models import PortfolioPosition, UserPortfolio
pytestmark = pytest.mark.integration
# ============================================================================
# Migration Tests
# ============================================================================
class TestMigrationUpgrade:
"""Test suite for migration upgrade operations."""
def test_migration_creates_portfolios_table(self, db_session: Session):
"""Test that migration creates mcp_portfolios table."""
inspector = inspect(db_session.bind)
tables = inspector.get_table_names()
assert "mcp_portfolios" in tables
def test_migration_creates_positions_table(self, db_session: Session):
"""Test that migration creates mcp_portfolio_positions table."""
inspector = inspect(db_session.bind)
tables = inspector.get_table_names()
assert "mcp_portfolio_positions" in tables
def test_portfolios_table_has_correct_columns(self, db_session: Session):
"""Test that portfolios table has all required columns."""
inspector = inspect(db_session.bind)
columns = {col["name"] for col in inspector.get_columns("mcp_portfolios")}
required_columns = {"id", "user_id", "name", "created_at", "updated_at"}
assert required_columns.issubset(columns)
def test_positions_table_has_correct_columns(self, db_session: Session):
"""Test that positions table has all required columns."""
inspector = inspect(db_session.bind)
columns = {
col["name"] for col in inspector.get_columns("mcp_portfolio_positions")
}
required_columns = {
"id",
"portfolio_id",
"ticker",
"shares",
"average_cost_basis",
"total_cost",
"purchase_date",
"notes",
"created_at",
"updated_at",
}
assert required_columns.issubset(columns)
def test_portfolios_id_column_type(self, db_session: Session):
"""Test that portfolio id column is UUID type."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "id" in columns
# Column exists and is configured as primary key through Index and UniqueConstraint
def test_positions_foreign_key_constraint(self, db_session: Session):
"""Test that positions table has foreign key to portfolios."""
inspector = inspect(db_session.bind)
fks = inspector.get_foreign_keys("mcp_portfolio_positions")
assert len(fks) > 0
assert any(fk["constrained_columns"] == ["portfolio_id"] for fk in fks)
def test_migration_creates_portfolio_user_index(self, db_session: Session):
"""Test that migration creates index on portfolio user_id."""
inspector = inspect(db_session.bind)
indexes = {idx["name"] for idx in inspector.get_indexes("mcp_portfolios")}
assert "idx_portfolio_user" in indexes
def test_migration_creates_position_portfolio_index(self, db_session: Session):
"""Test that migration creates index on position portfolio_id."""
inspector = inspect(db_session.bind)
indexes = {
idx["name"] for idx in inspector.get_indexes("mcp_portfolio_positions")
}
assert "idx_position_portfolio" in indexes
def test_migration_creates_position_ticker_index(self, db_session: Session):
"""Test that migration creates index on position ticker."""
inspector = inspect(db_session.bind)
indexes = {
idx["name"] for idx in inspector.get_indexes("mcp_portfolio_positions")
}
assert "idx_position_ticker" in indexes
def test_migration_creates_position_composite_index(self, db_session: Session):
"""Test that migration creates composite index on portfolio_id and ticker."""
inspector = inspect(db_session.bind)
indexes = {
idx["name"] for idx in inspector.get_indexes("mcp_portfolio_positions")
}
assert "idx_position_portfolio_ticker" in indexes
def test_migration_creates_unique_portfolio_constraint(self, db_session: Session):
"""Test that migration creates unique constraint on user_id and name."""
inspector = inspect(db_session.bind)
constraints = inspector.get_unique_constraints("mcp_portfolios")
constraint_names = {c["name"] for c in constraints}
assert "uq_user_portfolio_name" in constraint_names
def test_migration_creates_unique_position_constraint(self, db_session: Session):
"""Test that migration creates unique constraint on portfolio_id and ticker."""
inspector = inspect(db_session.bind)
constraints = inspector.get_unique_constraints("mcp_portfolio_positions")
constraint_names = {c["name"] for c in constraints}
assert "uq_portfolio_position_ticker" in constraint_names
def test_portfolios_user_id_has_default(self, db_session: Session):
"""Test that user_id column exists and is not nullable."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "user_id" in columns
# Default is handled at model level, not server level
def test_portfolios_name_has_default(self, db_session: Session):
"""Test that name column exists and is not nullable."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "name" in columns
# Default is handled at model level, not server level
def test_portfolios_created_at_has_default(self, db_session: Session):
"""Test that created_at column exists for timestamp tracking."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "created_at" in columns
def test_portfolios_updated_at_has_default(self, db_session: Session):
"""Test that updated_at column exists for timestamp tracking."""
inspector = inspect(db_session.bind)
columns = {col["name"]: col for col in inspector.get_columns("mcp_portfolios")}
assert "updated_at" in columns
def test_positions_created_at_has_default(self, db_session: Session):
"""Test that position created_at column exists for timestamp tracking."""
inspector = inspect(db_session.bind)
columns = {
col["name"]: col for col in inspector.get_columns("mcp_portfolio_positions")
}
assert "created_at" in columns
def test_positions_updated_at_has_default(self, db_session: Session):
"""Test that position updated_at column exists for timestamp tracking."""
inspector = inspect(db_session.bind)
columns = {
col["name"]: col for col in inspector.get_columns("mcp_portfolio_positions")
}
assert "updated_at" in columns
# ============================================================================
# Model CRUD Operation Tests
# ============================================================================
class TestPortfolioModelCRUD:
"""Test suite for UserPortfolio CRUD operations."""
def test_create_portfolio_with_all_fields(self, db_session: Session):
"""Test creating a portfolio with all fields specified."""
portfolio = UserPortfolio(
id=uuid.uuid4(),
user_id="test_user",
name="Test Portfolio",
)
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved is not None
assert retrieved.user_id == "test_user"
assert retrieved.name == "Test Portfolio"
assert retrieved.created_at is not None
assert retrieved.updated_at is not None
def test_create_portfolio_with_defaults(self, db_session: Session):
"""Test that portfolio defaults are applied correctly."""
portfolio = UserPortfolio()
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.user_id == "default"
assert retrieved.name == "My Portfolio"
def test_read_portfolio_by_id(self, db_session: Session):
"""Test reading portfolio by ID."""
portfolio = UserPortfolio(user_id="user1", name="Portfolio 1")
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved is not None
assert retrieved.id == portfolio.id
def test_read_portfolio_by_user_and_name(self, db_session: Session):
"""Test reading portfolio by user_id and name."""
portfolio = UserPortfolio(user_id="user2", name="My Portfolio 2")
db_session.add(portfolio)
db_session.commit()
retrieved = (
db_session.query(UserPortfolio)
.filter_by(user_id="user2", name="My Portfolio 2")
.first()
)
assert retrieved is not None
assert retrieved.id == portfolio.id
def test_read_all_portfolios_for_user(self, db_session: Session):
"""Test reading all portfolios for a specific user."""
user_id = f"user_read_{uuid.uuid4()}"
portfolios = [
UserPortfolio(user_id=user_id, name=f"Portfolio {i}") for i in range(3)
]
db_session.add_all(portfolios)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(user_id=user_id).all()
assert len(retrieved) == 3
def test_update_portfolio_name(self, db_session: Session):
"""Test updating portfolio name."""
portfolio = UserPortfolio(user_id="user3", name="Original Name")
db_session.add(portfolio)
db_session.commit()
portfolio.name = "Updated Name"
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.name == "Updated Name"
def test_update_portfolio_user_id(self, db_session: Session):
"""Test updating portfolio user_id."""
portfolio = UserPortfolio(user_id="old_user", name="Portfolio")
db_session.add(portfolio)
db_session.commit()
portfolio.user_id = "new_user"
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.user_id == "new_user"
def test_delete_portfolio(self, db_session: Session):
"""Test deleting a portfolio."""
portfolio = UserPortfolio(user_id="user4", name="To Delete")
db_session.add(portfolio)
db_session.commit()
portfolio_id = portfolio.id
db_session.delete(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio_id).first()
assert retrieved is None
def test_portfolio_repr(self, db_session: Session):
"""Test portfolio string representation."""
portfolio = UserPortfolio(user_id="user5", name="Test Portfolio")
db_session.add(portfolio)
db_session.commit()
repr_str = repr(portfolio)
assert "UserPortfolio" in repr_str
assert "Test Portfolio" in repr_str
class TestPositionModelCRUD:
"""Test suite for PortfolioPosition CRUD operations."""
@pytest.fixture
def portfolio(self, db_session: Session) -> UserPortfolio:
"""Create a test portfolio."""
portfolio = UserPortfolio(
user_id="default", name=f"Test Portfolio {uuid.uuid4()}"
)
db_session.add(portfolio)
db_session.commit()
return portfolio
def test_create_position_with_all_fields(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test creating a position with all fields."""
position = PortfolioPosition(
id=uuid.uuid4(),
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
notes="Test position",
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved is not None
assert retrieved.ticker == "AAPL"
assert retrieved.notes == "Test position"
def test_create_position_without_notes(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test creating a position without notes."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.notes is None
def test_read_position_by_id(self, db_session: Session, portfolio: UserPortfolio):
"""Test reading position by ID."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="GOOG",
shares=Decimal("2.00000000"),
average_cost_basis=Decimal("2750.0000"),
total_cost=Decimal("5500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved is not None
assert retrieved.ticker == "GOOG"
def test_read_position_by_portfolio_and_ticker(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test reading position by portfolio_id and ticker."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TSLA",
shares=Decimal("1.00000000"),
average_cost_basis=Decimal("250.0000"),
total_cost=Decimal("250.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id, ticker="TSLA")
.first()
)
assert retrieved is not None
def test_read_all_positions_in_portfolio(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test reading all positions in a portfolio."""
positions_data = [
("AAPL", Decimal("10"), Decimal("150.0000")),
("MSFT", Decimal("5"), Decimal("380.0000")),
("GOOG", Decimal("2"), Decimal("2750.0000")),
]
for ticker, shares, price in positions_data:
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker=ticker,
shares=shares,
average_cost_basis=price,
total_cost=shares * price,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id)
.all()
)
assert len(retrieved) == 3
def test_update_position_shares(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test updating position shares."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
position.shares = Decimal("20.00000000")
position.average_cost_basis = Decimal("160.0000")
position.total_cost = Decimal("3200.0000")
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.shares == Decimal("20.00000000")
def test_update_position_cost_basis(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test updating position average cost basis."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
original_cost_basis = position.average_cost_basis
position.average_cost_basis = Decimal("390.0000")
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.average_cost_basis != original_cost_basis
assert retrieved.average_cost_basis == Decimal("390.0000")
def test_update_position_notes(self, db_session: Session, portfolio: UserPortfolio):
"""Test updating position notes."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="GOOG",
shares=Decimal("2.00000000"),
average_cost_basis=Decimal("2750.0000"),
total_cost=Decimal("5500.0000"),
purchase_date=datetime.now(UTC),
notes="Original notes",
)
db_session.add(position)
db_session.commit()
position.notes = "Updated notes"
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.notes == "Updated notes"
def test_delete_position(self, db_session: Session, portfolio: UserPortfolio):
"""Test deleting a position."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TSLA",
shares=Decimal("1.00000000"),
average_cost_basis=Decimal("250.0000"),
total_cost=Decimal("250.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
position_id = position.id
db_session.delete(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position_id).first()
)
assert retrieved is None
def test_position_repr(self, db_session: Session, portfolio: UserPortfolio):
"""Test position string representation."""
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="NVDA",
shares=Decimal("3.00000000"),
average_cost_basis=Decimal("900.0000"),
total_cost=Decimal("2700.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
repr_str = repr(position)
assert "PortfolioPosition" in repr_str
assert "NVDA" in repr_str
# ============================================================================
# Relationship Tests
# ============================================================================
class TestPortfolioPositionRelationships:
"""Test suite for relationships between UserPortfolio and PortfolioPosition."""
@pytest.fixture
def portfolio_with_positions(self, db_session: Session) -> UserPortfolio:
"""Create a portfolio with multiple positions."""
portfolio = UserPortfolio(
user_id="default", name=f"Relationship Test {uuid.uuid4()}"
)
db_session.add(portfolio)
db_session.commit()
positions = [
PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
),
PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
),
]
db_session.add_all(positions)
db_session.commit()
return portfolio
def test_portfolio_has_positions_relationship(
self, db_session: Session, portfolio_with_positions: UserPortfolio
):
"""Test that portfolio has positions relationship."""
portfolio = (
db_session.query(UserPortfolio)
.filter_by(id=portfolio_with_positions.id)
.first()
)
assert hasattr(portfolio, "positions")
assert isinstance(portfolio.positions, list)
def test_positions_eagerly_loaded_via_selectin(
self, db_session: Session, portfolio_with_positions: UserPortfolio
):
"""Test that positions are eagerly loaded (selectin strategy)."""
portfolio = (
db_session.query(UserPortfolio)
.filter_by(id=portfolio_with_positions.id)
.first()
)
assert len(portfolio.positions) == 2
assert {p.ticker for p in portfolio.positions} == {"AAPL", "MSFT"}
def test_position_has_portfolio_relationship(
self, db_session: Session, portfolio_with_positions: UserPortfolio
):
"""Test that position has back reference to portfolio."""
position = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio_with_positions.id)
.first()
)
assert position.portfolio is not None
assert position.portfolio.id == portfolio_with_positions.id
def test_position_portfolio_relationship_maintains_integrity(
self, db_session: Session, portfolio_with_positions: UserPortfolio
):
"""Test that position portfolio relationship maintains data integrity."""
position = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio_with_positions.id, ticker="AAPL")
.first()
)
assert position.portfolio.name == portfolio_with_positions.name
assert position.portfolio.user_id == portfolio_with_positions.user_id
def test_multiple_portfolios_have_separate_positions(self, db_session: Session):
"""Test that multiple portfolios have separate position lists."""
user_id = f"user_multi_{uuid.uuid4()}"
portfolio1 = UserPortfolio(user_id=user_id, name=f"Portfolio 1 {uuid.uuid4()}")
portfolio2 = UserPortfolio(user_id=user_id, name=f"Portfolio 2 {uuid.uuid4()}")
db_session.add_all([portfolio1, portfolio2])
db_session.commit()
position1 = PortfolioPosition(
portfolio_id=portfolio1.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
position2 = PortfolioPosition(
portfolio_id=portfolio2.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add_all([position1, position2])
db_session.commit()
p1 = db_session.query(UserPortfolio).filter_by(id=portfolio1.id).first()
p2 = db_session.query(UserPortfolio).filter_by(id=portfolio2.id).first()
assert len(p1.positions) == 1
assert len(p2.positions) == 1
assert p1.positions[0].ticker == "AAPL"
assert p2.positions[0].ticker == "MSFT"
# ============================================================================
# Constraint Tests
# ============================================================================
class TestDatabaseConstraints:
"""Test suite for database constraints enforcement."""
def test_unique_portfolio_name_constraint_enforced(self, db_session: Session):
"""Test that unique constraint on (user_id, name) is enforced."""
user_id = f"user_constraint_{uuid.uuid4()}"
name = f"Unique Portfolio {uuid.uuid4()}"
portfolio1 = UserPortfolio(user_id=user_id, name=name)
db_session.add(portfolio1)
db_session.commit()
# Try to create duplicate
portfolio2 = UserPortfolio(user_id=user_id, name=name)
db_session.add(portfolio2)
with pytest.raises(exc.IntegrityError):
db_session.commit()
def test_unique_position_ticker_constraint_enforced(self, db_session: Session):
"""Test that unique constraint on (portfolio_id, ticker) is enforced."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
position1 = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position1)
db_session.commit()
# Try to create duplicate ticker
position2 = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("160.0000"),
total_cost=Decimal("800.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position2)
with pytest.raises(exc.IntegrityError):
db_session.commit()
def test_foreign_key_constraint_enforced(self, db_session: Session):
"""Test that foreign key constraint is enforced."""
position = PortfolioPosition(
portfolio_id=uuid.uuid4(), # Non-existent portfolio
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
with pytest.raises(exc.IntegrityError):
db_session.commit()
def test_cascade_delete_removes_positions(self, db_session: Session):
"""Test that deleting a portfolio cascades delete to positions."""
portfolio = UserPortfolio(user_id="default", name=f"Delete Test {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
positions = [
PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
),
PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
),
]
db_session.add_all(positions)
db_session.commit()
portfolio_id = portfolio.id
db_session.delete(portfolio)
db_session.commit()
# Verify portfolio is deleted
p = db_session.query(UserPortfolio).filter_by(id=portfolio_id).first()
assert p is None
# Verify positions are also deleted
pos = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio_id)
.all()
)
assert len(pos) == 0
def test_cascade_delete_doesnt_affect_other_portfolios(self, db_session: Session):
"""Test that deleting one portfolio doesn't affect others."""
user_id = f"user_cascade_{uuid.uuid4()}"
portfolio1 = UserPortfolio(user_id=user_id, name=f"Portfolio 1 {uuid.uuid4()}")
portfolio2 = UserPortfolio(user_id=user_id, name=f"Portfolio 2 {uuid.uuid4()}")
db_session.add_all([portfolio1, portfolio2])
db_session.commit()
position = PortfolioPosition(
portfolio_id=portfolio1.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
db_session.delete(portfolio1)
db_session.commit()
# Portfolio2 should still exist
p2 = db_session.query(UserPortfolio).filter_by(id=portfolio2.id).first()
assert p2 is not None
# ============================================================================
# Decimal Precision Tests
# ============================================================================
class TestDecimalPrecision:
"""Test suite for Decimal field precision."""
@pytest.fixture
def portfolio(self, db_session: Session) -> UserPortfolio:
"""Create a test portfolio."""
portfolio = UserPortfolio(
user_id="default", name=f"Decimal Test {uuid.uuid4()}"
)
db_session.add(portfolio)
db_session.commit()
return portfolio
def test_shares_numeric_20_8_precision(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that shares maintains Numeric(20,8) precision."""
shares = Decimal("12345678901.12345678")
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST1",
shares=shares,
average_cost_basis=Decimal("100.0000"),
total_cost=Decimal("1234567890112.3456"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.shares == shares
def test_cost_basis_numeric_12_4_precision(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that average_cost_basis maintains Numeric(12,4) precision."""
cost_basis = Decimal("99999999.9999")
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST2",
shares=Decimal("100.00000000"),
average_cost_basis=cost_basis,
total_cost=Decimal("9999999999.9999"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.average_cost_basis == cost_basis
def test_total_cost_numeric_20_4_precision(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that total_cost maintains Numeric(20,4) precision."""
total_cost = Decimal("9999999999999999.9999")
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST3",
shares=Decimal("1000.00000000"),
average_cost_basis=Decimal("9999999.9999"),
total_cost=total_cost,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.total_cost == total_cost
def test_fractional_shares_precision(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that fractional shares with high precision are maintained.
Note: total_cost uses Numeric(20, 4), so values are truncated to 4 decimal places.
"""
shares = Decimal("0.33333333")
cost_basis = Decimal("2750.1234")
total_cost = Decimal("917.5041") # Truncated from 917.50413522 to 4 decimals
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST4",
shares=shares,
average_cost_basis=cost_basis,
total_cost=total_cost,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.shares == shares
assert retrieved.average_cost_basis == cost_basis
assert retrieved.total_cost == total_cost
def test_very_small_decimal_values(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test handling of very small Decimal values.
Note: total_cost uses Numeric(20, 4) precision, so values smaller than
0.0001 will be truncated. This is appropriate for stock trading.
"""
shares = Decimal("0.00000001")
cost_basis = Decimal("0.0001")
total_cost = Decimal("0.0000") # Rounds to 0.0000 due to Numeric(20, 4)
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="TEST5",
shares=shares,
average_cost_basis=cost_basis,
total_cost=total_cost,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.shares == shares
assert retrieved.average_cost_basis == cost_basis
# Total cost truncated to 4 decimal places as per Numeric(20, 4)
assert retrieved.total_cost == total_cost
def test_multiple_positions_precision_preserved(
self, db_session: Session, portfolio: UserPortfolio
):
"""Test that precision is maintained across multiple positions."""
test_data = [
(Decimal("1"), Decimal("100.00"), Decimal("100.00")),
(Decimal("1.5"), Decimal("200.5000"), Decimal("300.7500")),
(Decimal("0.33333333"), Decimal("2750.1234"), Decimal("917.5041")),
(Decimal("100"), Decimal("150.1234"), Decimal("15012.34")),
]
for i, (shares, cost_basis, total_cost) in enumerate(test_data):
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker=f"MULTI{i}",
shares=shares,
average_cost_basis=cost_basis,
total_cost=total_cost,
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
positions = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id)
.all()
)
assert len(positions) == 4
for i, (expected_shares, expected_cost, _expected_total) in enumerate(
test_data
):
position = next(p for p in positions if p.ticker == f"MULTI{i}")
assert position.shares == expected_shares
assert position.average_cost_basis == expected_cost
# ============================================================================
# Timestamp Tests
# ============================================================================
class TestTimestampMixin:
"""Test suite for TimestampMixin functionality."""
def test_portfolio_created_at_set_on_creation(self, db_session: Session):
"""Test that created_at is set when portfolio is created."""
before = datetime.now(UTC)
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
after = datetime.now(UTC)
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.created_at is not None
assert before <= retrieved.created_at <= after
def test_portfolio_updated_at_set_on_creation(self, db_session: Session):
"""Test that updated_at is set when portfolio is created."""
before = datetime.now(UTC)
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
after = datetime.now(UTC)
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.updated_at is not None
assert before <= retrieved.updated_at <= after
def test_position_created_at_set_on_creation(self, db_session: Session):
"""Test that created_at is set when position is created."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
before = datetime.now(UTC)
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
after = datetime.now(UTC)
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.created_at is not None
assert before <= retrieved.created_at <= after
def test_position_updated_at_set_on_creation(self, db_session: Session):
"""Test that updated_at is set when position is created."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
before = datetime.now(UTC)
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
after = datetime.now(UTC)
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.updated_at is not None
assert before <= retrieved.updated_at <= after
def test_created_at_does_not_change_on_update(self, db_session: Session):
"""Test that created_at remains unchanged when portfolio is updated."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
original_created_at = portfolio.created_at
import time
time.sleep(0.01)
portfolio.name = "Updated Name"
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.created_at == original_created_at
def test_timezone_aware_datetimes(self, db_session: Session):
"""Test that datetimes are timezone-aware."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.created_at.tzinfo is not None
assert retrieved.updated_at.tzinfo is not None
# ============================================================================
# Default Value Tests
# ============================================================================
class TestDefaultValues:
"""Test suite for default values in models."""
def test_portfolio_default_user_id(self, db_session: Session):
"""Test that portfolio has default user_id."""
portfolio = UserPortfolio(name="Custom Name")
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.user_id == "default"
def test_portfolio_default_name(self, db_session: Session):
"""Test that portfolio has default name."""
portfolio = UserPortfolio(user_id="custom_user")
db_session.add(portfolio)
db_session.commit()
retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
assert retrieved.name == "My Portfolio"
def test_position_default_notes(self, db_session: Session):
"""Test that position notes default to None."""
portfolio = UserPortfolio(user_id="default", name=f"Portfolio {uuid.uuid4()}")
db_session.add(portfolio)
db_session.commit()
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
retrieved = (
db_session.query(PortfolioPosition).filter_by(id=position.id).first()
)
assert retrieved.notes is None
# ============================================================================
# Integration Tests
# ============================================================================
class TestPortfolioIntegration:
"""End-to-end integration tests combining multiple operations."""
def test_complete_portfolio_workflow(self, db_session: Session):
"""Test complete workflow: create, read, update, delete."""
# Create portfolio
user_id = f"test_user_{uuid.uuid4()}"
portfolio_name = f"Integration Test {uuid.uuid4()}"
portfolio = UserPortfolio(user_id=user_id, name=portfolio_name)
db_session.add(portfolio)
db_session.commit()
# Add positions
position1 = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="AAPL",
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC) - timedelta(days=30),
notes="Initial purchase",
)
position2 = PortfolioPosition(
portfolio_id=portfolio.id,
ticker="MSFT",
shares=Decimal("5.00000000"),
average_cost_basis=Decimal("380.0000"),
total_cost=Decimal("1900.0000"),
purchase_date=datetime.now(UTC) - timedelta(days=15),
)
db_session.add_all([position1, position2])
db_session.commit()
# Read and verify
retrieved_portfolio = (
db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
)
assert retrieved_portfolio is not None
assert len(retrieved_portfolio.positions) == 2
# Update position
aapl_position = next(
p for p in retrieved_portfolio.positions if p.ticker == "AAPL"
)
original_shares = aapl_position.shares
aapl_position.shares = Decimal("20.00000000")
aapl_position.average_cost_basis = Decimal("160.0000")
aapl_position.total_cost = Decimal("3200.0000")
db_session.commit()
# Verify update
retrieved_position = (
db_session.query(PortfolioPosition).filter_by(id=aapl_position.id).first()
)
assert retrieved_position.shares == Decimal("20.00000000")
assert retrieved_position.shares != original_shares
# Delete one position
db_session.delete(aapl_position)
db_session.commit()
# Verify deletion
remaining_positions = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id)
.all()
)
assert len(remaining_positions) == 1
assert remaining_positions[0].ticker == "MSFT"
# Delete portfolio (cascade delete)
db_session.delete(retrieved_portfolio)
db_session.commit()
# Verify cascade delete
portfolio_check = (
db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
)
assert portfolio_check is None
positions_check = (
db_session.query(PortfolioPosition)
.filter_by(portfolio_id=portfolio.id)
.all()
)
assert len(positions_check) == 0
def test_complex_portfolio_with_multiple_users(self, db_session: Session):
"""Test complex scenario with multiple portfolios and users."""
user_ids = [f"user_{uuid.uuid4()}" for _ in range(3)]
portfolios = []
# Create portfolios for multiple users
for user_id in user_ids:
for i in range(2):
portfolio = UserPortfolio(
user_id=user_id, name=f"Portfolio {i} {uuid.uuid4()}"
)
db_session.add(portfolio)
portfolios.append(portfolio)
db_session.commit()
# Add positions to each portfolio
tickers = ["AAPL", "MSFT", "GOOG", "AMZN", "TSLA"]
for portfolio in portfolios:
for ticker in tickers[:3]: # Add 3 positions per portfolio
position = PortfolioPosition(
portfolio_id=portfolio.id,
ticker=ticker,
shares=Decimal("10.00000000"),
average_cost_basis=Decimal("150.0000"),
total_cost=Decimal("1500.0000"),
purchase_date=datetime.now(UTC),
)
db_session.add(position)
db_session.commit()
# Verify structure
for user_id in user_ids:
user_portfolios = (
db_session.query(UserPortfolio).filter_by(user_id=user_id).all()
)
assert len(user_portfolios) == 2
for portfolio in user_portfolios:
assert len(portfolio.positions) == 3
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/vectorbt_engine.py:
--------------------------------------------------------------------------------
```python
"""VectorBT backtesting engine implementation with memory management and structured logging."""
import gc
from typing import Any
import numpy as np
import pandas as pd
import vectorbt as vbt
from pandas import DataFrame, Series
from maverick_mcp.backtesting.batch_processing import BatchProcessingMixin
from maverick_mcp.data.cache import (
CacheManager,
ensure_timezone_naive,
generate_cache_key,
)
from maverick_mcp.providers.stock_data import EnhancedStockDataProvider
from maverick_mcp.utils.cache_warmer import CacheWarmer
from maverick_mcp.utils.data_chunking import DataChunker, optimize_dataframe_dtypes
from maverick_mcp.utils.memory_profiler import (
check_memory_leak,
cleanup_dataframes,
get_memory_stats,
memory_context,
profile_memory,
)
from maverick_mcp.utils.structured_logger import (
get_performance_logger,
get_structured_logger,
with_structured_logging,
)
logger = get_structured_logger(__name__)
performance_logger = get_performance_logger("vectorbt_engine")
class VectorBTEngine(BatchProcessingMixin):
"""High-performance backtesting engine using VectorBT with memory management."""
def __init__(
self,
data_provider: EnhancedStockDataProvider | None = None,
cache_service=None,
enable_memory_profiling: bool = True,
chunk_size_mb: float = 100.0,
):
"""Initialize VectorBT engine.
Args:
data_provider: Stock data provider instance
cache_service: Cache service for data persistence
enable_memory_profiling: Enable memory profiling and optimization
chunk_size_mb: Chunk size for large dataset processing
"""
self.data_provider = data_provider or EnhancedStockDataProvider()
self.cache = cache_service or CacheManager()
self.cache_warmer = CacheWarmer(
data_provider=self.data_provider, cache_manager=self.cache
)
# Memory management configuration
self.enable_memory_profiling = enable_memory_profiling
self.chunker = DataChunker(
chunk_size_mb=chunk_size_mb, optimize_chunks=True, auto_gc=True
)
# Configure VectorBT settings for optimal performance and memory usage
try:
vbt.settings.array_wrapper["freq"] = "D"
vbt.settings.caching["enabled"] = True # Enable VectorBT's internal caching
# Don't set whitelist to avoid cache condition issues
except (KeyError, Exception) as e:
logger.warning(f"Could not configure VectorBT settings: {e}")
logger.info(
f"VectorBT engine initialized with memory profiling: {enable_memory_profiling}"
)
# Initialize memory tracking
if self.enable_memory_profiling:
initial_stats = get_memory_stats()
logger.debug(f"Initial memory stats: {initial_stats}")
@with_structured_logging(
"get_historical_data", include_performance=True, log_params=True
)
@profile_memory(log_results=True, threshold_mb=50.0)
async def get_historical_data(
self, symbol: str, start_date: str, end_date: str, interval: str = "1d"
) -> DataFrame:
"""Fetch historical data for backtesting with memory optimization.
Args:
symbol: Stock symbol
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
interval: Data interval (1d, 1h, etc.)
Returns:
Memory-optimized DataFrame with OHLCV data
"""
# Generate versioned cache key
cache_key = generate_cache_key(
"backtest_data",
symbol=symbol,
start_date=start_date,
end_date=end_date,
interval=interval,
)
# Try cache first with improved deserialization
cached_data = await self.cache.get(cache_key)
if cached_data is not None:
if isinstance(cached_data, pd.DataFrame):
# Already a DataFrame - ensure timezone-naive
df = ensure_timezone_naive(cached_data)
else:
# Restore DataFrame from dict (legacy JSON cache)
df = pd.DataFrame.from_dict(cached_data, orient="index")
# Convert index back to datetime
df.index = pd.to_datetime(df.index)
df = ensure_timezone_naive(df)
# Ensure column names are lowercase
df.columns = [col.lower() for col in df.columns]
return df
# Fetch from provider - try async method first, fallback to sync
try:
data = await self._get_data_async(symbol, start_date, end_date, interval)
except AttributeError:
# Fallback to sync method if async not available
data = self.data_provider.get_stock_data(
symbol=symbol,
start_date=start_date,
end_date=end_date,
interval=interval,
)
if data is None or data.empty:
raise ValueError(f"No data available for {symbol}")
# Normalize column names to lowercase for consistency
data.columns = [col.lower() for col in data.columns]
# Ensure timezone-naive index and fix any timezone comparison issues
data = ensure_timezone_naive(data)
# Optimize DataFrame memory usage
if self.enable_memory_profiling:
data = optimize_dataframe_dtypes(data, aggressive=False)
logger.debug(f"Optimized {symbol} data memory usage")
# Cache with adaptive TTL - longer for older data
from datetime import datetime
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
days_old = (datetime.now() - end_dt).days
ttl = 86400 if days_old > 7 else 3600 # 24h for older data, 1h for recent
await self.cache.set(cache_key, data, ttl=ttl)
return data
async def _get_data_async(
self, symbol: str, start_date: str, end_date: str, interval: str
) -> DataFrame:
"""Get data using async method if available."""
if hasattr(self.data_provider, "get_stock_data_async"):
return await self.data_provider.get_stock_data_async(
symbol=symbol,
start_date=start_date,
end_date=end_date,
interval=interval,
)
else:
# Fallback to sync method
return self.data_provider.get_stock_data(
symbol=symbol,
start_date=start_date,
end_date=end_date,
interval=interval,
)
@with_structured_logging(
"run_backtest", include_performance=True, log_params=True, log_result=False
)
@profile_memory(log_results=True, threshold_mb=200.0)
async def run_backtest(
self,
symbol: str,
strategy_type: str,
parameters: dict[str, Any],
start_date: str,
end_date: str,
initial_capital: float = 10000.0,
fees: float = 0.001,
slippage: float = 0.001,
) -> dict[str, Any]:
"""Run a vectorized backtest with memory optimization.
Args:
symbol: Stock symbol
strategy_type: Type of strategy (sma_cross, rsi, macd, etc.)
parameters: Strategy parameters
start_date: Start date
end_date: End date
initial_capital: Starting capital
fees: Trading fees (percentage)
slippage: Slippage (percentage)
Returns:
Dictionary with backtest results
"""
with memory_context("backtest_execution"):
# Fetch data
data = await self.get_historical_data(symbol, start_date, end_date)
# Check for large datasets and warn
data_memory_mb = data.memory_usage(deep=True).sum() / (1024**2)
if data_memory_mb > 100:
logger.warning(f"Large dataset detected: {data_memory_mb:.2f}MB")
# Log business metrics
performance_logger.log_business_metric(
"dataset_size_mb",
data_memory_mb,
symbol=symbol,
date_range_days=(
pd.to_datetime(end_date) - pd.to_datetime(start_date)
).days,
)
# Generate signals based on strategy
entries, exits = self._generate_signals(data, strategy_type, parameters)
# Optimize memory usage - use efficient data types
with memory_context("data_optimization"):
close_prices = data["close"].astype(np.float32)
entries = entries.astype(bool)
exits = exits.astype(bool)
# Clean up original data to free memory
if self.enable_memory_profiling:
cleanup_dataframes(data)
del data # Explicit deletion
gc.collect() # Force garbage collection
# Run VectorBT portfolio simulation with memory optimizations
with memory_context("portfolio_simulation"):
portfolio = vbt.Portfolio.from_signals(
close=close_prices,
entries=entries,
exits=exits,
init_cash=initial_capital,
fees=fees,
slippage=slippage,
freq="D",
cash_sharing=False, # Disable cash sharing for single asset
call_seq="auto", # Optimize call sequence
group_by=False, # Disable grouping for memory efficiency
broadcast_kwargs={"wrapper_kwargs": {"freq": "D"}},
)
# Extract comprehensive metrics with memory tracking
with memory_context("results_extraction"):
metrics = self._extract_metrics(portfolio)
trades = self._extract_trades(portfolio)
# Get equity curve - convert to list for smaller cache size
equity_curve = {
str(k): float(v) for k, v in portfolio.value().to_dict().items()
}
drawdown_series = {
str(k): float(v) for k, v in portfolio.drawdown().to_dict().items()
}
# Clean up portfolio object to free memory
if self.enable_memory_profiling:
del portfolio
cleanup_dataframes(close_prices) if hasattr(
close_prices, "_mgr"
) else None
del close_prices, entries, exits
gc.collect()
# Add memory statistics to results if profiling enabled
result = {
"symbol": symbol,
"strategy": strategy_type,
"parameters": parameters,
"metrics": metrics,
"trades": trades,
"equity_curve": equity_curve,
"drawdown_series": drawdown_series,
"start_date": start_date,
"end_date": end_date,
"initial_capital": initial_capital,
}
if self.enable_memory_profiling:
result["memory_stats"] = get_memory_stats()
# Check for potential memory leaks
if check_memory_leak(threshold_mb=50.0):
logger.warning("Potential memory leak detected during backtesting")
# Log business metrics for backtesting results
performance_logger.log_business_metric(
"backtest_total_return",
metrics.get("total_return", 0),
symbol=symbol,
strategy=strategy_type,
trade_count=metrics.get("total_trades", 0),
)
performance_logger.log_business_metric(
"backtest_sharpe_ratio",
metrics.get("sharpe_ratio", 0),
symbol=symbol,
strategy=strategy_type,
)
return result
def _generate_signals(
self, data: DataFrame, strategy_type: str, parameters: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate entry and exit signals based on strategy.
Args:
data: Price data
strategy_type: Strategy type
parameters: Strategy parameters
Returns:
Tuple of (entry_signals, exit_signals)
"""
# Ensure we have the required price data
if "close" not in data.columns:
raise ValueError(
f"Missing 'close' column in price data. Available columns: {list(data.columns)}"
)
close = data["close"]
if strategy_type in ["sma_cross", "sma_crossover"]:
return self._sma_crossover_signals(close, parameters)
elif strategy_type == "rsi":
return self._rsi_signals(close, parameters)
elif strategy_type == "macd":
return self._macd_signals(close, parameters)
elif strategy_type == "bollinger":
return self._bollinger_bands_signals(close, parameters)
elif strategy_type == "momentum":
return self._momentum_signals(close, parameters)
elif strategy_type == "ema_cross":
return self._ema_crossover_signals(close, parameters)
elif strategy_type == "mean_reversion":
return self._mean_reversion_signals(close, parameters)
elif strategy_type == "breakout":
return self._breakout_signals(close, parameters)
elif strategy_type == "volume_momentum":
return self._volume_momentum_signals(data, parameters)
elif strategy_type == "online_learning":
return self._online_learning_signals(data, parameters)
elif strategy_type == "regime_aware":
return self._regime_aware_signals(data, parameters)
elif strategy_type == "ensemble":
return self._ensemble_signals(data, parameters)
else:
raise ValueError(f"Unknown strategy type: {strategy_type}")
def _sma_crossover_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate SMA crossover signals."""
# Support both parameter naming conventions
fast_period = params.get("fast_period", params.get("fast_window", 10))
slow_period = params.get("slow_period", params.get("slow_window", 20))
fast_sma = vbt.MA.run(close, fast_period, short_name="fast").ma.squeeze()
slow_sma = vbt.MA.run(close, slow_period, short_name="slow").ma.squeeze()
entries = (fast_sma > slow_sma) & (fast_sma.shift(1) <= slow_sma.shift(1))
exits = (fast_sma < slow_sma) & (fast_sma.shift(1) >= slow_sma.shift(1))
return entries, exits
def _rsi_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate RSI-based signals."""
period = params.get("period", 14)
oversold = params.get("oversold", 30)
overbought = params.get("overbought", 70)
rsi = vbt.RSI.run(close, period).rsi.squeeze()
entries = (rsi < oversold) & (rsi.shift(1) >= oversold)
exits = (rsi > overbought) & (rsi.shift(1) <= overbought)
return entries, exits
def _macd_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate MACD signals."""
fast_period = params.get("fast_period", 12)
slow_period = params.get("slow_period", 26)
signal_period = params.get("signal_period", 9)
macd = vbt.MACD.run(
close,
fast_window=fast_period,
slow_window=slow_period,
signal_window=signal_period,
)
macd_line = macd.macd.squeeze()
signal_line = macd.signal.squeeze()
entries = (macd_line > signal_line) & (
macd_line.shift(1) <= signal_line.shift(1)
)
exits = (macd_line < signal_line) & (macd_line.shift(1) >= signal_line.shift(1))
return entries, exits
def _bollinger_bands_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate Bollinger Bands signals."""
period = params.get("period", 20)
std_dev = params.get("std_dev", 2)
bb = vbt.BBANDS.run(close, window=period, alpha=std_dev)
upper = bb.upper.squeeze()
lower = bb.lower.squeeze()
# Buy when price touches lower band, sell when touches upper
entries = (close <= lower) & (close.shift(1) > lower.shift(1))
exits = (close >= upper) & (close.shift(1) < upper.shift(1))
return entries, exits
def _momentum_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate momentum-based signals."""
lookback = params.get("lookback", 20)
threshold = params.get("threshold", 0.05)
returns = close.pct_change(lookback)
entries = returns > threshold
exits = returns < -threshold
return entries, exits
def _ema_crossover_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate EMA crossover signals."""
fast_period = params.get("fast_period", 12)
slow_period = params.get("slow_period", 26)
fast_ema = vbt.MA.run(close, fast_period, ewm=True).ma.squeeze()
slow_ema = vbt.MA.run(close, slow_period, ewm=True).ma.squeeze()
entries = (fast_ema > slow_ema) & (fast_ema.shift(1) <= slow_ema.shift(1))
exits = (fast_ema < slow_ema) & (fast_ema.shift(1) >= slow_ema.shift(1))
return entries, exits
def _mean_reversion_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate mean reversion signals."""
ma_period = params.get("ma_period", 20)
entry_threshold = params.get("entry_threshold", 0.02)
exit_threshold = params.get("exit_threshold", 0.01)
ma = vbt.MA.run(close, ma_period).ma.squeeze()
# Avoid division by zero in deviation calculation
with np.errstate(divide="ignore", invalid="ignore"):
deviation = np.where(ma != 0, (close - ma) / ma, 0)
entries = deviation < -entry_threshold
exits = deviation > exit_threshold
return entries, exits
def _breakout_signals(
self, close: Series, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate channel breakout signals."""
lookback = params.get("lookback", 20)
exit_lookback = params.get("exit_lookback", 10)
upper_channel = close.rolling(lookback).max()
lower_channel = close.rolling(exit_lookback).min()
entries = close > upper_channel.shift(1)
exits = close < lower_channel.shift(1)
return entries, exits
def _volume_momentum_signals(
self, data: DataFrame, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate volume-weighted momentum signals."""
momentum_period = params.get("momentum_period", 20)
volume_period = params.get("volume_period", 20)
momentum_threshold = params.get("momentum_threshold", 0.05)
volume_multiplier = params.get("volume_multiplier", 1.5)
close = data["close"]
volume = data.get("volume")
if volume is None:
# Fallback to pure momentum if no volume data
returns = close.pct_change(momentum_period)
entries = returns > momentum_threshold
exits = returns < -momentum_threshold
return entries, exits
returns = close.pct_change(momentum_period)
avg_volume = volume.rolling(volume_period).mean()
volume_surge = volume > (avg_volume * volume_multiplier)
# Entry: positive momentum with volume surge
entries = (returns > momentum_threshold) & volume_surge
# Exit: negative momentum or volume dry up
exits = (returns < -momentum_threshold) | (volume < avg_volume * 0.8)
return entries, exits
def _extract_metrics(self, portfolio: vbt.Portfolio) -> dict[str, Any]:
"""Extract comprehensive metrics from portfolio."""
def safe_float_metric(metric_func, default=0.0):
"""Safely extract float metrics, handling None and NaN values."""
try:
value = metric_func()
if value is None or np.isnan(value) or np.isinf(value):
return default
return float(value)
except (ZeroDivisionError, ValueError, TypeError):
return default
return {
"total_return": safe_float_metric(portfolio.total_return),
"annual_return": safe_float_metric(portfolio.annualized_return),
"sharpe_ratio": safe_float_metric(portfolio.sharpe_ratio),
"sortino_ratio": safe_float_metric(portfolio.sortino_ratio),
"calmar_ratio": safe_float_metric(portfolio.calmar_ratio),
"max_drawdown": safe_float_metric(portfolio.max_drawdown),
"win_rate": safe_float_metric(lambda: portfolio.trades.win_rate()),
"profit_factor": safe_float_metric(
lambda: portfolio.trades.profit_factor()
),
"expectancy": safe_float_metric(lambda: portfolio.trades.expectancy()),
"total_trades": int(portfolio.trades.count()),
"winning_trades": int(portfolio.trades.winning.count())
if hasattr(portfolio.trades, "winning")
else 0,
"losing_trades": int(portfolio.trades.losing.count())
if hasattr(portfolio.trades, "losing")
else 0,
"avg_win": safe_float_metric(
lambda: portfolio.trades.winning.pnl.mean()
if hasattr(portfolio.trades, "winning")
and portfolio.trades.winning.count() > 0
else None
),
"avg_loss": safe_float_metric(
lambda: portfolio.trades.losing.pnl.mean()
if hasattr(portfolio.trades, "losing")
and portfolio.trades.losing.count() > 0
else None
),
"best_trade": safe_float_metric(
lambda: portfolio.trades.pnl.max()
if portfolio.trades.count() > 0
else None
),
"worst_trade": safe_float_metric(
lambda: portfolio.trades.pnl.min()
if portfolio.trades.count() > 0
else None
),
"avg_duration": safe_float_metric(lambda: portfolio.trades.duration.mean()),
"kelly_criterion": self._calculate_kelly(portfolio),
"recovery_factor": self._calculate_recovery_factor(portfolio),
"risk_reward_ratio": self._calculate_risk_reward(portfolio),
}
def _extract_trades(self, portfolio: vbt.Portfolio) -> list:
"""Extract trade records from portfolio."""
if portfolio.trades.count() == 0:
return []
trades = portfolio.trades.records_readable
# Vectorized operation for better performance
trade_list = [
{
"entry_date": str(trade.get("Entry Timestamp", "")),
"exit_date": str(trade.get("Exit Timestamp", "")),
"entry_price": float(trade.get("Avg Entry Price", 0)),
"exit_price": float(trade.get("Avg Exit Price", 0)),
"size": float(trade.get("Size", 0)),
"pnl": float(trade.get("PnL", 0)),
"return": float(trade.get("Return", 0)),
"duration": str(trade.get("Duration", "")),
}
for _, trade in trades.iterrows()
]
return trade_list
def _calculate_kelly(self, portfolio: vbt.Portfolio) -> float:
"""Calculate Kelly Criterion."""
if portfolio.trades.count() == 0:
return 0.0
try:
win_rate = portfolio.trades.win_rate()
if win_rate is None or np.isnan(win_rate):
return 0.0
avg_win = (
abs(portfolio.trades.winning.returns.mean() or 0)
if hasattr(portfolio.trades, "winning")
and portfolio.trades.winning.count() > 0
else 0
)
avg_loss = (
abs(portfolio.trades.losing.returns.mean() or 0)
if hasattr(portfolio.trades, "losing")
and portfolio.trades.losing.count() > 0
else 0
)
# Check for division by zero and invalid values
if avg_loss == 0 or avg_win == 0 or np.isnan(avg_win) or np.isnan(avg_loss):
return 0.0
# Calculate Kelly with safe division
with np.errstate(divide="ignore", invalid="ignore"):
kelly = (win_rate * avg_win - (1 - win_rate) * avg_loss) / avg_win
# Check if result is valid
if np.isnan(kelly) or np.isinf(kelly):
return 0.0
return float(
min(max(kelly, -1.0), 0.25)
) # Cap between -100% and 25% for safety
except (ZeroDivisionError, ValueError, TypeError):
return 0.0
def get_memory_report(self) -> dict[str, Any]:
"""Get comprehensive memory usage report."""
if not self.enable_memory_profiling:
return {"message": "Memory profiling disabled"}
return get_memory_stats()
def clear_memory_cache(self) -> None:
"""Clear internal memory caches and force garbage collection."""
if hasattr(vbt.settings, "caching"):
vbt.settings.caching.clear()
gc.collect()
logger.info("Memory cache cleared and garbage collection performed")
def optimize_for_memory(self, aggressive: bool = False) -> None:
"""Optimize VectorBT settings for memory efficiency.
Args:
aggressive: Use aggressive memory optimizations
"""
if aggressive:
# Aggressive memory settings
vbt.settings.caching["enabled"] = False # Disable caching
vbt.settings.array_wrapper["dtype"] = np.float32 # Use float32
logger.info("Applied aggressive memory optimizations")
else:
# Conservative memory settings
vbt.settings.caching["enabled"] = True
vbt.settings.caching["max_size"] = 100 # Limit cache size
logger.info("Applied conservative memory optimizations")
async def run_memory_efficient_backtest(
self,
symbol: str,
strategy_type: str,
parameters: dict[str, Any],
start_date: str,
end_date: str,
initial_capital: float = 10000.0,
fees: float = 0.001,
slippage: float = 0.001,
chunk_data: bool = False,
) -> dict[str, Any]:
"""Run backtest with maximum memory efficiency.
Args:
symbol: Stock symbol
strategy_type: Strategy type
parameters: Strategy parameters
start_date: Start date
end_date: End date
initial_capital: Starting capital
fees: Trading fees
slippage: Slippage
chunk_data: Whether to process data in chunks
Returns:
Backtest results with memory statistics
"""
# Temporarily optimize for memory
original_settings = {
"caching_enabled": vbt.settings.caching.get("enabled", True),
"array_dtype": vbt.settings.array_wrapper.get("dtype", np.float64),
}
try:
self.optimize_for_memory(aggressive=True)
if chunk_data:
# Use chunked processing for very large datasets
return await self._run_chunked_backtest(
symbol,
strategy_type,
parameters,
start_date,
end_date,
initial_capital,
fees,
slippage,
)
else:
return await self.run_backtest(
symbol,
strategy_type,
parameters,
start_date,
end_date,
initial_capital,
fees,
slippage,
)
finally:
# Restore original settings
vbt.settings.caching["enabled"] = original_settings["caching_enabled"]
vbt.settings.array_wrapper["dtype"] = original_settings["array_dtype"]
async def _run_chunked_backtest(
self,
symbol: str,
strategy_type: str,
parameters: dict[str, Any],
start_date: str,
end_date: str,
initial_capital: float,
fees: float,
slippage: float,
) -> dict[str, Any]:
"""Run backtest using data chunking for very large datasets."""
from datetime import datetime, timedelta
# Calculate date chunks (monthly)
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
results = []
current_capital = initial_capital
current_date = start_dt
while current_date < end_dt:
chunk_end = min(current_date + timedelta(days=90), end_dt) # 3-month chunks
chunk_start_str = current_date.strftime("%Y-%m-%d")
chunk_end_str = chunk_end.strftime("%Y-%m-%d")
logger.debug(f"Processing chunk: {chunk_start_str} to {chunk_end_str}")
# Run backtest for chunk
chunk_result = await self.run_backtest(
symbol,
strategy_type,
parameters,
chunk_start_str,
chunk_end_str,
current_capital,
fees,
slippage,
)
results.append(chunk_result)
# Update capital for next chunk
final_value = chunk_result.get("metrics", {}).get("total_return", 0)
current_capital = current_capital * (1 + final_value)
current_date = chunk_end
# Combine results
return self._combine_chunked_results(results, symbol, strategy_type, parameters)
def _combine_chunked_results(
self,
chunk_results: list[dict],
symbol: str,
strategy_type: str,
parameters: dict[str, Any],
) -> dict[str, Any]:
"""Combine results from chunked backtesting."""
if not chunk_results:
return {}
# Combine trades
all_trades = []
for chunk in chunk_results:
all_trades.extend(chunk.get("trades", []))
# Combine equity curves
combined_equity = {}
combined_drawdown = {}
for chunk in chunk_results:
combined_equity.update(chunk.get("equity_curve", {}))
combined_drawdown.update(chunk.get("drawdown_series", {}))
# Calculate combined metrics
total_return = 1.0
for chunk in chunk_results:
chunk_return = chunk.get("metrics", {}).get("total_return", 0)
total_return *= 1 + chunk_return
total_return -= 1.0
combined_metrics = {
"total_return": total_return,
"total_trades": len(all_trades),
"chunks_processed": len(chunk_results),
}
return {
"symbol": symbol,
"strategy": strategy_type,
"parameters": parameters,
"metrics": combined_metrics,
"trades": all_trades,
"equity_curve": combined_equity,
"drawdown_series": combined_drawdown,
"processing_method": "chunked",
"memory_stats": get_memory_stats()
if self.enable_memory_profiling
else None,
}
def _calculate_recovery_factor(self, portfolio: vbt.Portfolio) -> float:
"""Calculate recovery factor (total return / max drawdown)."""
try:
max_dd = portfolio.max_drawdown()
total_return = portfolio.total_return()
# Check for invalid values
if (
max_dd is None
or np.isnan(max_dd)
or max_dd == 0
or total_return is None
or np.isnan(total_return)
):
return 0.0
# Calculate with safe division
with np.errstate(divide="ignore", invalid="ignore"):
recovery_factor = total_return / abs(max_dd)
# Check if result is valid
if np.isnan(recovery_factor) or np.isinf(recovery_factor):
return 0.0
return float(recovery_factor)
except (ZeroDivisionError, ValueError, TypeError):
return 0.0
def _calculate_risk_reward(self, portfolio: vbt.Portfolio) -> float:
"""Calculate risk-reward ratio."""
if portfolio.trades.count() == 0:
return 0.0
try:
avg_win = (
abs(portfolio.trades.winning.pnl.mean() or 0)
if hasattr(portfolio.trades, "winning")
and portfolio.trades.winning.count() > 0
else 0
)
avg_loss = (
abs(portfolio.trades.losing.pnl.mean() or 0)
if hasattr(portfolio.trades, "losing")
and portfolio.trades.losing.count() > 0
else 0
)
# Check for division by zero and invalid values
if (
avg_loss == 0
or avg_win == 0
or np.isnan(avg_win)
or np.isnan(avg_loss)
or np.isinf(avg_win)
or np.isinf(avg_loss)
):
return 0.0
# Calculate with safe division
with np.errstate(divide="ignore", invalid="ignore"):
risk_reward = avg_win / avg_loss
# Check if result is valid
if np.isnan(risk_reward) or np.isinf(risk_reward):
return 0.0
return float(risk_reward)
except (ZeroDivisionError, ValueError, TypeError):
return 0.0
@with_structured_logging(
"optimize_parameters",
include_performance=True,
log_params=True,
log_result=False,
)
@profile_memory(log_results=True, threshold_mb=500.0)
async def optimize_parameters(
self,
symbol: str,
strategy_type: str,
param_grid: dict[str, list],
start_date: str,
end_date: str,
optimization_metric: str = "sharpe_ratio",
initial_capital: float = 10000.0,
top_n: int = 10,
use_chunking: bool = True,
) -> dict[str, Any]:
"""Optimize strategy parameters using memory-efficient grid search.
Args:
symbol: Stock symbol
strategy_type: Strategy type
param_grid: Parameter grid for optimization
start_date: Start date
end_date: End date
optimization_metric: Metric to optimize
initial_capital: Starting capital
top_n: Number of top results to return
use_chunking: Use chunking for large parameter grids
Returns:
Optimization results with best parameters
"""
with memory_context("parameter_optimization"):
# Fetch data once
data = await self.get_historical_data(symbol, start_date, end_date)
# Create parameter combinations
param_combos = vbt.utils.params.create_param_combs(param_grid)
total_combos = len(param_combos)
logger.info(
f"Optimizing {total_combos} parameter combinations for {symbol}"
)
# Pre-convert data for optimization with memory efficiency
close_prices = data["close"].astype(np.float32)
# Check if we should use chunking for large parameter grids
if use_chunking and total_combos > 100:
logger.info(f"Using chunked processing for {total_combos} combinations")
chunk_size = min(50, max(10, total_combos // 10)) # Adaptive chunk size
results = self._optimize_parameters_chunked(
data,
close_prices,
strategy_type,
param_combos,
optimization_metric,
initial_capital,
chunk_size,
)
else:
results = []
for i, params in enumerate(param_combos):
try:
with memory_context(f"param_combo_{i}"):
# Generate signals for this parameter set
entries, exits = self._generate_signals(
data, strategy_type, params
)
# Convert to boolean arrays for memory efficiency
entries = entries.astype(bool)
exits = exits.astype(bool)
# Run backtest with optimizations
portfolio = vbt.Portfolio.from_signals(
close=close_prices,
entries=entries,
exits=exits,
init_cash=initial_capital,
fees=0.001,
freq="D",
cash_sharing=False,
call_seq="auto",
group_by=False, # Memory optimization
)
# Get optimization metric
metric_value = self._get_metric_value(
portfolio, optimization_metric
)
results.append(
{
"parameters": params,
optimization_metric: metric_value,
"total_return": float(portfolio.total_return()),
"max_drawdown": float(portfolio.max_drawdown()),
"total_trades": int(portfolio.trades.count()),
}
)
# Clean up intermediate objects
del portfolio, entries, exits
if i % 20 == 0: # Periodic cleanup
gc.collect()
except Exception as e:
logger.debug(f"Skipping invalid parameter combination {i}: {e}")
continue
# Clean up data objects
if self.enable_memory_profiling:
cleanup_dataframes(data, close_prices) if hasattr(
data, "_mgr"
) else None
del data, close_prices
gc.collect()
# Sort by optimization metric
results.sort(key=lambda x: x[optimization_metric], reverse=True)
# Get top N results
top_results = results[:top_n]
result = {
"symbol": symbol,
"strategy": strategy_type,
"optimization_metric": optimization_metric,
"best_parameters": top_results[0]["parameters"] if top_results else {},
"best_metric_value": top_results[0][optimization_metric]
if top_results
else 0,
"top_results": top_results,
"total_combinations_tested": total_combos,
"valid_combinations": len(results),
}
if self.enable_memory_profiling:
result["memory_stats"] = get_memory_stats()
return result
def _optimize_parameters_chunked(
self,
data: DataFrame,
close_prices: Series,
strategy_type: str,
param_combos: list,
optimization_metric: str,
initial_capital: float,
chunk_size: int,
) -> list[dict]:
"""Optimize parameters using chunked processing for memory efficiency."""
results = []
total_chunks = len(param_combos) // chunk_size + (
1 if len(param_combos) % chunk_size else 0
)
for chunk_idx in range(0, len(param_combos), chunk_size):
chunk_params = param_combos[chunk_idx : chunk_idx + chunk_size]
logger.debug(
f"Processing chunk {chunk_idx // chunk_size + 1}/{total_chunks}"
)
with memory_context(f"param_chunk_{chunk_idx // chunk_size}"):
for _, params in enumerate(chunk_params):
try:
# Generate signals for this parameter set
entries, exits = self._generate_signals(
data, strategy_type, params
)
# Convert to boolean arrays for memory efficiency
entries = entries.astype(bool)
exits = exits.astype(bool)
# Run backtest with optimizations
portfolio = vbt.Portfolio.from_signals(
close=close_prices,
entries=entries,
exits=exits,
init_cash=initial_capital,
fees=0.001,
freq="D",
cash_sharing=False,
call_seq="auto",
group_by=False,
)
# Get optimization metric
metric_value = self._get_metric_value(
portfolio, optimization_metric
)
results.append(
{
"parameters": params,
optimization_metric: metric_value,
"total_return": float(portfolio.total_return()),
"max_drawdown": float(portfolio.max_drawdown()),
"total_trades": int(portfolio.trades.count()),
}
)
# Clean up intermediate objects
del portfolio, entries, exits
except Exception as e:
logger.debug(f"Skipping invalid parameter combination: {e}")
continue
# Force garbage collection after each chunk
gc.collect()
return results
def _get_metric_value(self, portfolio: vbt.Portfolio, metric_name: str) -> float:
"""Get specific metric value from portfolio."""
metric_map = {
"total_return": portfolio.total_return,
"sharpe_ratio": portfolio.sharpe_ratio,
"sortino_ratio": portfolio.sortino_ratio,
"calmar_ratio": portfolio.calmar_ratio,
"max_drawdown": lambda: -portfolio.max_drawdown(),
"win_rate": lambda: portfolio.trades.win_rate() or 0,
"profit_factor": lambda: portfolio.trades.profit_factor() or 0,
}
if metric_name not in metric_map:
raise ValueError(f"Unknown metric: {metric_name}")
try:
value = metric_map[metric_name]()
# Check for invalid values
if value is None or np.isnan(value) or np.isinf(value):
return 0.0
return float(value)
except (ZeroDivisionError, ValueError, TypeError):
return 0.0
def _online_learning_signals(
self, data: DataFrame, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate online learning ML strategy signals.
Simple implementation using momentum with adaptive thresholds.
"""
lookback = params.get("lookback", 20)
learning_rate = params.get("learning_rate", 0.01)
close = data["close"]
returns = close.pct_change(lookback)
# Adaptive threshold based on rolling statistics
rolling_mean = returns.rolling(window=lookback).mean()
rolling_std = returns.rolling(window=lookback).std()
# Dynamic entry/exit thresholds
entry_threshold = rolling_mean + learning_rate * rolling_std
exit_threshold = rolling_mean - learning_rate * rolling_std
# Generate signals
entries = returns > entry_threshold
exits = returns < exit_threshold
# Fill NaN values
entries = entries.fillna(False)
exits = exits.fillna(False)
return entries, exits
def _regime_aware_signals(
self, data: DataFrame, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate regime-aware strategy signals.
Detects market regime and applies appropriate strategy.
"""
regime_window = params.get("regime_window", 50)
threshold = params.get("threshold", 0.02)
close = data["close"]
# Calculate regime indicators
returns = close.pct_change()
volatility = returns.rolling(window=regime_window).std()
trend_strength = close.rolling(window=regime_window).apply(
lambda x: (x[-1] - x[0]) / x[0] if x[0] != 0 else 0
)
# Determine regime: trending vs ranging
is_trending = abs(trend_strength) > threshold
# Trend following signals
sma_short = close.rolling(window=regime_window // 2).mean()
sma_long = close.rolling(window=regime_window).mean()
trend_entries = (close > sma_long) & (sma_short > sma_long)
trend_exits = (close < sma_long) & (sma_short < sma_long)
# Mean reversion signals
bb_upper = sma_long + 2 * volatility
bb_lower = sma_long - 2 * volatility
reversion_entries = close < bb_lower
reversion_exits = close > bb_upper
# Combine based on regime
entries = (is_trending & trend_entries) | (~is_trending & reversion_entries)
exits = (is_trending & trend_exits) | (~is_trending & reversion_exits)
# Fill NaN values
entries = entries.fillna(False)
exits = exits.fillna(False)
return entries, exits
def _ensemble_signals(
self, data: DataFrame, params: dict[str, Any]
) -> tuple[Series, Series]:
"""Generate ensemble strategy signals.
Combines multiple strategies with voting.
"""
fast_period = params.get("fast_period", 10)
slow_period = params.get("slow_period", 20)
rsi_period = params.get("rsi_period", 14)
close = data["close"]
# Strategy 1: SMA Crossover
fast_sma = close.rolling(window=fast_period).mean()
slow_sma = close.rolling(window=slow_period).mean()
sma_entries = (fast_sma > slow_sma) & (fast_sma.shift(1) <= slow_sma.shift(1))
sma_exits = (fast_sma < slow_sma) & (fast_sma.shift(1) >= slow_sma.shift(1))
# Strategy 2: RSI
delta = close.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=rsi_period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=rsi_period).mean()
rs = gain / loss.replace(0, 1e-10)
rsi = 100 - (100 / (1 + rs))
rsi_entries = (rsi < 30) & (rsi.shift(1) >= 30)
rsi_exits = (rsi > 70) & (rsi.shift(1) <= 70)
# Strategy 3: Momentum
momentum = close.pct_change(20)
mom_entries = momentum > 0.05
mom_exits = momentum < -0.05
# Ensemble voting - at least 2 out of 3 strategies agree
entry_votes = (
sma_entries.astype(int) + rsi_entries.astype(int) + mom_entries.astype(int)
)
exit_votes = (
sma_exits.astype(int) + rsi_exits.astype(int) + mom_exits.astype(int)
)
entries = entry_votes >= 2
exits = exit_votes >= 2
# Fill NaN values
entries = entries.fillna(False)
exits = exits.fillna(False)
return entries, exits
```
--------------------------------------------------------------------------------
/tests/fixtures/orchestration_fixtures.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive test fixtures for orchestration testing.
Provides realistic mock data for LLM responses, API responses, market data,
and test scenarios for the SupervisorAgent and DeepResearchAgent orchestration system.
"""
import json
from datetime import datetime, timedelta
from typing import Any
from unittest.mock import MagicMock
import numpy as np
import pandas as pd
import pytest
from langchain_core.messages import AIMessage
# ==============================================================================
# MOCK LLM RESPONSES
# ==============================================================================
class MockLLMResponses:
"""Realistic LLM responses for various orchestration scenarios."""
@staticmethod
def query_classification_response(
category: str = "stock_investment_decision",
confidence: float = 0.85,
parallel_capable: bool = True,
) -> str:
"""Mock query classification response from LLM."""
routing_agents_map = {
"market_screening": ["market"],
"technical_analysis": ["technical"],
"stock_investment_decision": ["market", "technical"],
"portfolio_analysis": ["market", "technical"],
"deep_research": ["research"],
"company_research": ["research"],
"sentiment_analysis": ["research"],
"risk_assessment": ["market", "technical"],
}
return json.dumps(
{
"category": category,
"confidence": confidence,
"required_agents": routing_agents_map.get(category, ["market"]),
"complexity": "moderate" if confidence > 0.7 else "complex",
"estimated_execution_time_ms": 45000
if category == "deep_research"
else 30000,
"parallel_capable": parallel_capable,
"reasoning": f"Query classified as {category} based on content analysis and intent detection.",
}
)
@staticmethod
def result_synthesis_response(
persona: str = "moderate",
agents_used: list[str] = None,
confidence: float = 0.82,
) -> str:
"""Mock result synthesis response from LLM."""
if agents_used is None:
agents_used = ["market", "technical"]
persona_focused_content = {
"conservative": """
Based on comprehensive analysis from our specialist agents, AAPL presents a balanced investment opportunity
with strong fundamentals and reasonable risk profile. The market analysis indicates stable sector
positioning with consistent dividend growth, while technical indicators suggest a consolidation phase
with support at $170. For conservative investors, consider gradual position building with
strict stop-loss at $165 to preserve capital. The risk-adjusted return profile aligns well
with conservative portfolio objectives, offering both income stability and modest growth potential.
""",
"moderate": """
Our multi-agent analysis reveals AAPL as a compelling investment opportunity with balanced risk-reward
characteristics. Market screening identified strong fundamentals including 15% revenue growth and
expanding services segment. Technical analysis shows bullish momentum with RSI at 58 and MACD
trending positive. Entry points around $175-180 offer favorable risk-reward with targets at $200-210.
Position sizing of 3-5% of portfolio aligns with moderate risk tolerance while capitalizing on
the current uptrend momentum.
""",
"aggressive": """
Multi-agent analysis identifies AAPL as a high-conviction growth play with exceptional upside potential.
Market analysis reveals accelerating AI adoption driving hardware refresh cycles, while technical
indicators signal strong breakout momentum above $185 resistance. The confluence of fundamental
catalysts and technical setup supports aggressive position sizing up to 8-10% allocation.
Target price of $220+ represents 25% upside with momentum likely to continue through earnings season.
This represents a prime opportunity for growth-focused portfolios seeking alpha generation.
""",
}
return persona_focused_content.get(
persona, persona_focused_content["moderate"]
).strip()
@staticmethod
def content_analysis_response(
sentiment: str = "bullish", confidence: float = 0.75, credibility: float = 0.8
) -> str:
"""Mock content analysis response from LLM."""
return json.dumps(
{
"KEY_INSIGHTS": [
"Apple's Q4 earnings exceeded expectations with 15% revenue growth",
"Services segment continues to expand with 12% year-over-year growth",
"iPhone 15 sales showing strong adoption in key markets",
"Cash position remains robust at $165B supporting capital allocation",
"AI integration across product line driving next upgrade cycle",
],
"SENTIMENT": {"direction": sentiment, "confidence": confidence},
"RISK_FACTORS": [
"China market regulatory concerns persist",
"Supply chain dependencies in Taiwan and South Korea",
"Increasing competition in services market",
"Currency headwinds affecting international revenue",
],
"OPPORTUNITIES": [
"AI-powered device upgrade cycle beginning",
"Vision Pro market penetration expanding",
"Services recurring revenue model strengthening",
"Emerging markets iPhone adoption accelerating",
],
"CREDIBILITY": credibility,
"RELEVANCE": 0.9,
"SUMMARY": f"Comprehensive analysis suggests {sentiment} outlook for Apple with strong fundamentals and growth catalysts, though regulatory and competitive risks require monitoring.",
}
)
@staticmethod
def research_synthesis_response(persona: str = "moderate") -> str:
"""Mock research synthesis response for deep research agent."""
synthesis_by_persona = {
"conservative": """
## Executive Summary
Apple represents a stable, dividend-paying technology stock suitable for conservative portfolios seeking
balanced growth and income preservation.
## Key Findings
• Consistent dividend growth averaging 8% annually over past 5 years
• Strong balance sheet with $165B cash providing downside protection
• Services revenue provides recurring income stream growing at 12% annually
• P/E ratio of 28x reasonable for quality growth stock
• Beta of 1.1 indicates moderate volatility relative to market
• Debt-to-equity ratio of 0.3 shows conservative capital structure
• Free cash flow yield of 3.2% supports dividend sustainability
## Investment Implications for Conservative Investors
Apple's combination of dividend growth, balance sheet strength, and market leadership makes it suitable
for conservative portfolios. The company's pivot to services provides recurring revenue stability while
hardware sales offer moderate growth potential.
## Risk Considerations
Primary risks include China market exposure (19% of revenue), technology obsolescence, and regulatory
pressure on App Store policies. However, strong cash position provides significant downside protection.
## Recommended Actions
Consider 2-3% portfolio allocation with gradual accumulation on pullbacks below $170.
Appropriate stop-loss at $160 to limit downside risk.
""",
"moderate": """
## Executive Summary
Apple presents a balanced investment opportunity combining growth potential with quality fundamentals,
well-suited for diversified moderate-risk portfolios.
## Key Findings
• Revenue growth acceleration to 15% driven by AI-enhanced products
• Services segment margins expanding to 70%, improving overall profitability
• Strong competitive moats in ecosystem and brand loyalty
• Capital allocation balance between growth investment and shareholder returns
• Technical indicators suggesting continued uptrend momentum
• Valuation appears fair at current levels with room for multiple expansion
• Market leadership position in premium smartphone and tablet segments
## Investment Implications for Moderate Investors
Apple offers an attractive blend of stability and growth potential. The company's evolution toward
services provides recurring revenue while hardware innovation drives periodic upgrade cycles.
## Risk Considerations
Key risks include supply chain disruption, China regulatory issues, and increasing competition
in services. Currency headwinds may impact international revenue growth.
## Recommended Actions
Target 4-5% portfolio allocation with entry points between $175-185. Consider taking profits
above $210 and adding on weakness below $170.
""",
"aggressive": """
## Executive Summary
Apple stands at the forefront of the next technology revolution with AI integration across its ecosystem,
presenting significant alpha generation potential for growth-focused investors.
## Key Findings
• AI-driven product refresh cycle beginning with iPhone 15 Pro and Vision Pro launch
• Services revenue trajectory accelerating with 18% growth potential
• Market share expansion opportunities in emerging markets and enterprise
• Vision Pro early adoption exceeding expectations, validating spatial computing thesis
• Developer ecosystem strengthening with AI tools integration
• Operating leverage improving with services mix shift
• Stock momentum indicators showing bullish technical setup
## Investment Implications for Aggressive Investors
Apple represents a high-conviction growth play with multiple expansion catalysts. The convergence
of AI adoption, new product categories, and services growth creates exceptional upside potential.
## Risk Considerations
Execution risk on Vision Pro adoption, competitive response from Android ecosystem, and
regulatory pressure on App Store represent key downside risks requiring active monitoring.
## Recommended Actions
Consider aggressive 8-10% allocation with momentum-based entry above $185 resistance.
Target price $230+ over 12-month horizon with trailing stop at 15% to protect gains.
""",
}
return synthesis_by_persona.get(
persona, synthesis_by_persona["moderate"]
).strip()
# ==============================================================================
# MOCK EXA API RESPONSES
# ==============================================================================
class MockExaResponses:
"""Realistic Exa API responses for financial research."""
@staticmethod
def search_results_aapl() -> list[dict[str, Any]]:
"""Mock Exa search results for AAPL analysis."""
return [
{
"url": "https://www.bloomberg.com/news/articles/2024-01-15/apple-earnings-beat",
"title": "Apple Earnings Beat Expectations as iPhone Sales Surge",
"content": "Apple Inc. reported quarterly revenue of $119.6 billion, surpassing analyst expectations as iPhone 15 sales showed strong momentum in key markets. The technology giant's services segment grew 12% year-over-year to $23.1 billion, demonstrating the recurring revenue model's strength. CEO Tim Cook highlighted AI integration across the product lineup as a key driver for the next upgrade cycle. Gross margins expanded to 45.9% compared to 43.3% in the prior year period, reflecting improved mix and operational efficiency. The company's cash position remains robust at $165.1 billion, providing flexibility for strategic investments and shareholder returns. China revenue declined 2% due to competitive pressures, though management expressed optimism about long-term opportunities in the region.",
"summary": "Apple exceeded Q4 earnings expectations with strong iPhone 15 sales and services growth, while maintaining robust cash position and expanding margins despite China headwinds.",
"highlights": [
"iPhone 15 strong sales momentum",
"Services grew 12% year-over-year",
"$165.1B cash position",
],
"published_date": "2024-01-15T08:30:00Z",
"author": "Mark Gurman",
"score": 0.94,
"provider": "exa",
},
{
"url": "https://seekingalpha.com/article/4665432-apple-stock-analysis-ai-catalyst",
"title": "Apple Stock: AI Integration Could Drive Next Super Cycle",
"content": "Apple's integration of artificial intelligence across its ecosystem represents a potential catalyst for the next device super cycle. The company's on-device AI processing capabilities, enabled by the A17 Pro chip, position Apple uniquely in the mobile AI landscape. Industry analysts project AI-enhanced features could drive iPhone replacement cycles to accelerate from the current 3.5 years to approximately 2.8 years. The services ecosystem benefits significantly from AI integration, with enhanced Siri capabilities driving increased App Store engagement and subscription services adoption. Vision Pro early metrics suggest spatial computing adoption is tracking ahead of initial estimates, with developer interest surging 300% quarter-over-quarter. The convergence of AI, spatial computing, and services creates multiple revenue expansion vectors over the next 3-5 years.",
"summary": "AI integration across Apple's ecosystem could accelerate device replacement cycles and expand services revenue through enhanced user engagement.",
"highlights": [
"AI-driven replacement cycle acceleration",
"Vision Pro adoption tracking well",
"Services ecosystem AI benefits",
],
"published_date": "2024-01-14T14:20:00Z",
"author": "Tech Analyst Team",
"score": 0.87,
"provider": "exa",
},
{
"url": "https://www.morningstar.com/stocks/aapl-valuation-analysis",
"title": "Apple Valuation Analysis: Fair Value Assessment",
"content": "Our discounted cash flow analysis suggests Apple's fair value ranges between $185-195 per share, indicating the stock trades near intrinsic value at current levels. The company's transition toward higher-margin services revenue supports multiple expansion, though hardware cycle dependency introduces valuation volatility. Key valuation drivers include services attach rates (currently 85% of active devices), gross margin trajectory (target 47-48% long-term), and capital allocation efficiency. The dividend yield of 0.5% appears sustainable with strong free cash flow generation of $95+ billion annually. Compared to technology peers, Apple trades at a 15% premium to the sector median, justified by superior return on invested capital and cash generation capabilities.",
"summary": "DCF analysis places Apple's fair value at $185-195, with current valuation supported by services transition and strong cash generation.",
"highlights": [
"Fair value $185-195 range",
"Services driving multiple expansion",
"Strong free cash flow $95B+",
],
"published_date": "2024-01-13T11:45:00Z",
"author": "Sarah Chen",
"score": 0.91,
"provider": "exa",
},
{
"url": "https://www.reuters.com/technology/apple-china-challenges-2024-01-12",
"title": "Apple Faces Growing Competition in China Market",
"content": "Apple confronts intensifying competition in China as local brands gain market share and regulatory scrutiny increases. Huawei's Mate 60 Pro launch has resonated strongly with Chinese consumers, contributing to Apple's 2% revenue decline in Greater China for Q4. The Chinese government's restrictions on iPhone use in government agencies signal potential broader policy shifts. Despite challenges, Apple maintains premium market leadership with 47% share in smartphones priced above $600. Management highlighted ongoing investments in local partnerships and supply chain relationships to navigate the complex regulatory environment. The company's services revenue in China grew 8% despite hardware headwinds, demonstrating ecosystem stickiness among existing users.",
"summary": "Apple faces competitive and regulatory challenges in China, though maintains premium market leadership and growing services revenue.",
"highlights": [
"China revenue down 2%",
"Regulatory iPhone restrictions",
"Premium segment leadership maintained",
],
"published_date": "2024-01-12T16:15:00Z",
"author": "Reuters Technology Team",
"score": 0.89,
"provider": "exa",
},
]
@staticmethod
def search_results_market_sentiment() -> list[dict[str, Any]]:
"""Mock Exa results for market sentiment analysis."""
return [
{
"url": "https://www.cnbc.com/2024/01/16/market-outlook-tech-stocks",
"title": "Tech Stocks Rally on AI Optimism Despite Rate Concerns",
"content": "Technology stocks surged 2.3% as artificial intelligence momentum overcame Federal Reserve policy concerns. Investors rotated into AI-beneficiary names including Apple, Microsoft, and Nvidia following strong earnings guidance across the sector. The Technology Select Sector SPDR ETF (XLK) reached new 52-week highs despite 10-year Treasury yields hovering near 4.5%. Institutional flows show $12.8 billion net inflows to technology funds over the past month, the strongest since early 2023. Options activity indicates continued bullish sentiment with call volume exceeding puts by 1.8:1 across major tech names. Analyst upgrades accelerated with 67% of tech stocks carrying buy ratings versus 52% sector average.",
"summary": "Tech stocks rally on AI optimism with strong institutional inflows and bullish options activity despite interest rate headwinds.",
"highlights": [
"Tech sector +2.3%",
"$12.8B institutional inflows",
"Call/put ratio 1.8:1",
],
"published_date": "2024-01-16T09:45:00Z",
"author": "CNBC Markets Team",
"score": 0.92,
"provider": "exa",
},
{
"url": "https://finance.yahoo.com/news/vix-fear-greed-market-sentiment",
"title": "VIX Falls to Multi-Month Lows as Fear Subsides",
"content": "The VIX volatility index dropped to 13.8, the lowest level since November 2021, signaling reduced market anxiety and increased risk appetite among investors. The CNN Fear & Greed Index shifted to 'Greed' territory at 72, up from 'Neutral' just two weeks ago. Credit spreads tightened across investment-grade and high-yield markets, with IG spreads at 85 basis points versus 110 in December. Equity put/call ratios declined to 0.45, indicating overwhelming bullish positioning. Margin debt increased 8% month-over-month as investors leverage up for continued market gains.",
"summary": "Market sentiment indicators show reduced fear and increased greed with VIX at multi-month lows and bullish positioning accelerating.",
"highlights": [
"VIX at 13.8 multi-month low",
"Fear & Greed at 72",
"Margin debt up 8%",
],
"published_date": "2024-01-16T14:30:00Z",
"author": "Market Sentiment Team",
"score": 0.88,
"provider": "exa",
},
]
@staticmethod
def search_results_empty() -> list[dict[str, Any]]:
"""Mock empty Exa search results for testing edge cases."""
return []
@staticmethod
def search_results_low_quality() -> list[dict[str, Any]]:
"""Mock low-quality Exa search results for credibility testing."""
return [
{
"url": "https://sketchy-site.com/apple-prediction",
"title": "AAPL Will 100X - Trust Me Bro Analysis",
"content": "Apple stock is going to the moon because reasons. My uncle works at Apple and says they're releasing iPhones made of gold next year. This is not financial advice but also definitely is financial advice. Buy now or cry later. Diamond hands to the moon rockets.",
"summary": "Questionable analysis with unsubstantiated claims about Apple's prospects.",
"highlights": [
"Gold iPhones coming",
"100x returns predicted",
"Uncle insider info",
],
"published_date": "2024-01-16T23:59:00Z",
"author": "Random Internet User",
"score": 0.12,
"provider": "exa",
}
]
# ==============================================================================
# MOCK TAVILY API RESPONSES
# ==============================================================================
class MockTavilyResponses:
"""Realistic Tavily API responses for web search."""
@staticmethod
def search_results_aapl() -> dict[str, Any]:
"""Mock Tavily search response for AAPL analysis."""
return {
"query": "Apple stock analysis AAPL investment outlook",
"follow_up_questions": [
"What are Apple's main revenue drivers?",
"How does Apple compare to competitors?",
"What are the key risks for Apple stock?",
],
"answer": "Apple (AAPL) shows strong fundamentals with growing services revenue and AI integration opportunities, though faces competition in China and regulatory pressures.",
"results": [
{
"title": "Apple Stock Analysis: Strong Fundamentals Despite Headwinds",
"url": "https://www.fool.com/investing/2024/01/15/apple-stock-analysis",
"content": "Apple's latest quarter demonstrated the resilience of its business model, with services revenue hitting a new record and iPhone sales exceeding expectations. The company's focus on artificial intelligence integration across its product ecosystem positions it well for future growth cycles. However, investors should monitor China market dynamics and App Store regulatory challenges that could impact long-term growth trajectories.",
"raw_content": "Apple Inc. (AAPL) continues to demonstrate strong business fundamentals in its latest quarterly report, with services revenue reaching new records and iPhone sales beating analyst expectations across key markets. The technology giant has strategically positioned itself at the forefront of artificial intelligence integration, with on-device AI processing capabilities that differentiate its products from competitors. Looking ahead, the company's ecosystem approach and services transition provide multiple growth vectors, though challenges in China and regulatory pressures on App Store policies require careful monitoring. The stock's current valuation appears reasonable given the company's cash generation capabilities and market position.",
"published_date": "2024-01-15",
"score": 0.89,
},
{
"title": "Tech Sector Outlook: AI Revolution Drives Growth",
"url": "https://www.barrons.com/articles/tech-outlook-ai-growth",
"content": "The technology sector stands at the beginning of a multi-year artificial intelligence transformation that could reshape revenue models and competitive dynamics. Companies with strong AI integration capabilities, including Apple, Microsoft, and Google, are positioned to benefit from this shift. Apple's approach of on-device AI processing provides privacy advantages and reduces cloud infrastructure costs compared to competitors relying heavily on cloud-based AI services.",
"raw_content": "The technology sector is experiencing a fundamental transformation as artificial intelligence capabilities become central to product differentiation and user experience. Companies that can effectively integrate AI while maintaining user privacy and system performance are likely to capture disproportionate value creation over the next 3-5 years. Apple's strategy of combining custom silicon with on-device AI processing provides competitive advantages in both performance and privacy, potentially driving accelerated device replacement cycles and services engagement. This positions Apple favorably compared to competitors relying primarily on cloud-based AI infrastructure.",
"published_date": "2024-01-14",
"score": 0.85,
},
{
"title": "Investment Analysis: Apple's Services Transformation",
"url": "https://www.investopedia.com/apple-services-analysis",
"content": "Apple's transformation from a hardware-centric to services-enabled company continues to gain momentum, with services revenue now representing over 22% of total revenue and growing at double-digit rates. This shift toward recurring revenue streams improves business model predictability and supports higher valuation multiples. The company's services ecosystem benefits from its large installed base and strong customer loyalty metrics.",
"raw_content": "Apple Inc.'s strategic evolution toward a services-centric business model represents one of the most successful corporate transformations in technology sector history. The company has leveraged its installed base of over 2 billion active devices to create a thriving services ecosystem encompassing the App Store, Apple Music, iCloud, Apple Pay, and various subscription services. This services revenue now exceeds $85 billion annually and continues growing at rates exceeding 10% year-over-year, providing both revenue diversification and margin enhancement. The recurring nature of services revenue creates more predictable cash flows and justifies premium valuation multiples compared to pure hardware companies.",
"published_date": "2024-01-13",
"score": 0.91,
},
],
"response_time": 1.2,
}
@staticmethod
def search_results_market_sentiment() -> dict[str, Any]:
"""Mock Tavily search response for market sentiment analysis."""
return {
"query": "stock market sentiment investor mood analysis 2024",
"follow_up_questions": [
"What are current market sentiment indicators?",
"How do investors feel about tech stocks?",
"What factors are driving market optimism?",
],
"answer": "Current market sentiment shows cautious optimism with reduced volatility and increased risk appetite, driven by AI enthusiasm and strong corporate earnings despite interest rate concerns.",
"results": [
{
"title": "Market Sentiment Indicators Signal Bullish Mood",
"url": "https://www.marketwatch.com/story/market-sentiment-bullish",
"content": "Multiple sentiment indicators suggest investors have shifted from defensive to risk-on positioning as 2024 progresses. The VIX volatility index has declined to multi-month lows while institutional money flows accelerate into equities. Credit markets show tightening spreads and increased issuance activity, reflecting improved risk appetite across asset classes.",
"raw_content": "A comprehensive analysis of market sentiment indicators reveals a significant shift in investor psychology over the past month. The CBOE Volatility Index (VIX) has dropped below 14, its lowest level since late 2021, indicating reduced fear and increased complacency among options traders. Simultaneously, the American Association of Individual Investors (AAII) sentiment survey shows bullish respondents outnumbering bearish by a 2:1 margin, the widest spread since early 2023. Institutional flows data from EPFR shows $45 billion in net inflows to equity funds over the past four weeks, with technology and growth sectors receiving disproportionate allocation.",
"published_date": "2024-01-16",
"score": 0.93,
},
{
"title": "Investor Psychology: Fear of Missing Out Returns",
"url": "https://www.wsj.com/markets/stocks/fomo-returns-markets",
"content": "The fear of missing out (FOMO) mentality has returned to equity markets as investors chase performance and increase leverage. Margin debt has increased significantly while cash positions at major brokerages have declined to multi-year lows. This shift in behavior suggests sentiment has moved from cautious to optimistic, though some analysts warn of potential overextension.",
"raw_content": "Behavioral indicators suggest a fundamental shift in investor psychology from the cautious stance that characterized much of 2023 to a more aggressive, opportunity-seeking mindset. NYSE margin debt has increased 15% over the past two months, reaching $750 billion as investors leverage up to participate in market gains. Cash positions at major discount brokerages have declined to just 3.2% of assets, compared to 5.8% during peak uncertainty in October 2023. Options market activity shows call volume exceeding put volume by the widest margin in 18 months, with particular strength in technology and AI-related names.",
"published_date": "2024-01-15",
"score": 0.88,
},
],
"response_time": 1.4,
}
@staticmethod
def search_results_error() -> dict[str, Any]:
"""Mock Tavily error response for testing error handling."""
return {
"error": "rate_limit_exceeded",
"message": "API rate limit exceeded. Please try again later.",
"retry_after": 60,
}
# ==============================================================================
# MOCK MARKET DATA
# ==============================================================================
class MockMarketData:
"""Realistic market data for testing financial analysis."""
@staticmethod
def stock_price_history(
symbol: str = "AAPL", days: int = 100, current_price: float = 185.0
) -> pd.DataFrame:
"""Generate realistic stock price history."""
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
dates = pd.date_range(start=start_date, end=end_date, freq="D")
# Generate realistic price movement
np.random.seed(42) # Consistent data for testing
returns = np.random.normal(
0.0008, 0.02, len(dates)
) # ~0.2% daily return, 2% volatility
# Start with a base price and apply returns
base_price = current_price * 0.9 # Start 10% lower
prices = [base_price]
for return_val in returns[1:]:
next_price = prices[-1] * (1 + return_val)
prices.append(max(next_price, 50)) # Floor price at $50
# Create OHLCV data
data = []
for i, (date, close_price) in enumerate(zip(dates, prices, strict=False)):
# Generate realistic OHLC from close price
volatility = abs(np.random.normal(0, 0.015)) # Intraday volatility
high = close_price * (1 + volatility)
low = close_price * (1 - volatility)
# Determine open based on previous close with gap
if i == 0:
open_price = close_price
else:
gap = np.random.normal(0, 0.005) # Small gap
open_price = prices[i - 1] * (1 + gap)
# Ensure OHLC relationships are valid
high = max(high, open_price, close_price)
low = min(low, open_price, close_price)
# Generate volume
base_volume = 50_000_000 # Base volume
volume_multiplier = np.random.uniform(0.5, 2.0)
volume = int(base_volume * volume_multiplier)
data.append(
{
"Date": date,
"Open": round(open_price, 2),
"High": round(high, 2),
"Low": round(low, 2),
"Close": round(close_price, 2),
"Volume": volume,
}
)
df = pd.DataFrame(data)
df.set_index("Date", inplace=True)
return df
@staticmethod
def technical_indicators(symbol: str = "AAPL") -> dict[str, Any]:
"""Mock technical indicators for a stock."""
return {
"symbol": symbol,
"timestamp": datetime.now(),
"rsi": {
"value": 58.3,
"signal": "neutral",
"interpretation": "Neither overbought nor oversold",
},
"macd": {
"value": 2.15,
"signal_line": 1.89,
"histogram": 0.26,
"signal": "bullish",
"interpretation": "MACD above signal line suggests bullish momentum",
},
"bollinger_bands": {
"upper": 192.45,
"middle": 185.20,
"lower": 177.95,
"position": "middle",
"squeeze": False,
},
"moving_averages": {
"sma_20": 183.45,
"sma_50": 178.90,
"sma_200": 172.15,
"ema_12": 184.80,
"ema_26": 181.30,
},
"support_resistance": {
"support_levels": [175.00, 170.50, 165.25],
"resistance_levels": [190.00, 195.75, 200.50],
"current_level": "between_support_resistance",
},
"volume_analysis": {
"average_volume": 52_000_000,
"current_volume": 68_000_000,
"relative_volume": 1.31,
"volume_trend": "increasing",
},
}
@staticmethod
def market_overview() -> dict[str, Any]:
"""Mock market overview data."""
return {
"timestamp": datetime.now(),
"indices": {
"SPY": {"price": 485.30, "change": +2.15, "change_pct": +0.44},
"QQQ": {"price": 412.85, "change": +5.42, "change_pct": +1.33},
"IWM": {"price": 195.67, "change": -1.23, "change_pct": -0.62},
"VIX": {"price": 13.8, "change": -1.2, "change_pct": -8.0},
},
"sector_performance": {
"Technology": +1.85,
"Healthcare": +0.45,
"Financial Services": -0.32,
"Consumer Cyclical": +0.78,
"Industrials": -0.15,
"Energy": -1.22,
"Utilities": +0.33,
"Real Estate": +0.91,
"Materials": -0.67,
"Consumer Defensive": +0.12,
"Communication Services": +1.34,
},
"market_breadth": {
"advancers": 1845,
"decliners": 1230,
"unchanged": 125,
"new_highs": 89,
"new_lows": 12,
"up_volume": 8.2e9,
"down_volume": 4.1e9,
},
"sentiment_indicators": {
"fear_greed_index": 72,
"vix_level": "low",
"put_call_ratio": 0.45,
"margin_debt_trend": "increasing",
},
}
# ==============================================================================
# TEST QUERY EXAMPLES
# ==============================================================================
class TestQueries:
"""Realistic user queries for different classification categories."""
MARKET_SCREENING = [
"Find me momentum stocks in the technology sector with strong earnings growth",
"Screen for dividend-paying stocks with yields above 3% and consistent payout history",
"Show me small-cap stocks with high revenue growth and low debt levels",
"Find stocks breaking out of consolidation patterns with increasing volume",
"Screen for value stocks trading below book value with improving fundamentals",
]
COMPANY_RESEARCH = [
"Analyze Apple's competitive position in the smartphone market",
"Research Tesla's battery technology advantages and manufacturing scale",
"Provide comprehensive analysis of Microsoft's cloud computing strategy",
"Analyze Amazon's e-commerce margins and AWS growth potential",
"Research Nvidia's AI chip market dominance and competitive threats",
]
TECHNICAL_ANALYSIS = [
"Analyze AAPL's chart patterns and provide entry/exit recommendations",
"What do the technical indicators say about SPY's short-term direction?",
"Analyze TSLA's support and resistance levels for swing trading",
"Show me the RSI and MACD signals for QQQ",
"Identify chart patterns in the Nasdaq that suggest market direction",
]
SENTIMENT_ANALYSIS = [
"What's the current market sentiment around tech stocks?",
"Analyze investor sentiment toward electric vehicle companies",
"How are traders feeling about the Fed's interest rate policy?",
"What's the mood in crypto markets right now?",
"Analyze sentiment around bank stocks after recent earnings",
]
PORTFOLIO_ANALYSIS = [
"Optimize my portfolio allocation for moderate risk tolerance",
"Analyze the correlation between my holdings and suggest diversification",
"Review my portfolio for sector concentration risk",
"Suggest rebalancing strategy for my retirement portfolio",
"Analyze my portfolio's beta and suggest hedging strategies",
]
RISK_ASSESSMENT = [
"Calculate appropriate position size for AAPL given my $100k account",
"What's the maximum drawdown risk for a 60/40 portfolio?",
"Analyze the tail risk in my growth stock positions",
"Calculate VaR for my current portfolio allocation",
"Assess concentration risk in my tech-heavy portfolio",
]
@classmethod
def get_random_query(cls, category: str) -> str:
"""Get a random query from the specified category."""
queries_map = {
"market_screening": cls.MARKET_SCREENING,
"company_research": cls.COMPANY_RESEARCH,
"technical_analysis": cls.TECHNICAL_ANALYSIS,
"sentiment_analysis": cls.SENTIMENT_ANALYSIS,
"portfolio_analysis": cls.PORTFOLIO_ANALYSIS,
"risk_assessment": cls.RISK_ASSESSMENT,
}
queries = queries_map.get(category, cls.MARKET_SCREENING)
return np.random.choice(queries)
# ==============================================================================
# PERSONA-SPECIFIC FIXTURES
# ==============================================================================
class PersonaFixtures:
"""Persona-specific test data and responses."""
@staticmethod
def conservative_investor_data() -> dict[str, Any]:
"""Data for conservative investor persona testing."""
return {
"persona": "conservative",
"characteristics": [
"capital preservation",
"income generation",
"low volatility",
"dividend focus",
],
"risk_tolerance": 0.3,
"preferred_sectors": ["Utilities", "Consumer Defensive", "Healthcare"],
"analysis_focus": [
"dividend yield",
"debt levels",
"stability",
"downside protection",
],
"position_sizing": {
"max_single_position": 0.05, # 5% max
"stop_loss_multiplier": 1.5,
"target_volatility": 0.12,
},
"sample_recommendations": [
"Consider gradual position building with strict risk management",
"Focus on dividend-paying stocks with consistent payout history",
"Maintain defensive positioning until market clarity improves",
"Prioritize capital preservation over aggressive growth",
],
}
@staticmethod
def moderate_investor_data() -> dict[str, Any]:
"""Data for moderate investor persona testing."""
return {
"persona": "moderate",
"characteristics": [
"balanced growth",
"diversification",
"moderate risk",
"long-term focus",
],
"risk_tolerance": 0.6,
"preferred_sectors": [
"Technology",
"Healthcare",
"Financial Services",
"Industrials",
],
"analysis_focus": [
"risk-adjusted returns",
"diversification",
"growth potential",
"fundamentals",
],
"position_sizing": {
"max_single_position": 0.08, # 8% max
"stop_loss_multiplier": 2.0,
"target_volatility": 0.18,
},
"sample_recommendations": [
"Balance growth opportunities with risk management",
"Consider diversified allocation across sectors and market caps",
"Target 4-6% position sizing for high-conviction ideas",
"Monitor both technical and fundamental indicators",
],
}
@staticmethod
def aggressive_investor_data() -> dict[str, Any]:
"""Data for aggressive investor persona testing."""
return {
"persona": "aggressive",
"characteristics": [
"high growth",
"momentum",
"concentrated positions",
"active trading",
],
"risk_tolerance": 0.9,
"preferred_sectors": [
"Technology",
"Communication Services",
"Consumer Cyclical",
],
"analysis_focus": [
"growth potential",
"momentum",
"catalysts",
"alpha generation",
],
"position_sizing": {
"max_single_position": 0.15, # 15% max
"stop_loss_multiplier": 3.0,
"target_volatility": 0.25,
},
"sample_recommendations": [
"Consider concentrated positions in high-conviction names",
"Target momentum stocks with strong catalysts",
"Use 10-15% position sizing for best opportunities",
"Focus on alpha generation over risk management",
],
}
# ==============================================================================
# EDGE CASE AND ERROR FIXTURES
# ==============================================================================
class EdgeCaseFixtures:
"""Fixtures for testing edge cases and error conditions."""
@staticmethod
def api_failure_responses() -> dict[str, Any]:
"""Mock API failure responses for error handling testing."""
return {
"exa_rate_limit": {
"error": "rate_limit_exceeded",
"message": "You have exceeded your API rate limit",
"retry_after": 3600,
"status_code": 429,
},
"tavily_unauthorized": {
"error": "unauthorized",
"message": "Invalid API key provided",
"status_code": 401,
},
"llm_timeout": {
"error": "timeout",
"message": "Request timed out after 30 seconds",
"status_code": 408,
},
"network_error": {
"error": "network_error",
"message": "Unable to connect to external service",
"status_code": 503,
},
}
@staticmethod
def conflicting_agent_results() -> dict[str, dict[str, Any]]:
"""Mock conflicting results from different agents for synthesis testing."""
return {
"market": {
"recommendation": "BUY",
"confidence": 0.85,
"reasoning": "Strong fundamentals and sector rotation into technology",
"target_price": 210.0,
"sentiment": "bullish",
},
"technical": {
"recommendation": "SELL",
"confidence": 0.78,
"reasoning": "Bearish divergence in RSI and approaching strong resistance",
"target_price": 165.0,
"sentiment": "bearish",
},
"research": {
"recommendation": "HOLD",
"confidence": 0.72,
"reasoning": "Mixed signals from fundamental analysis and market conditions",
"target_price": 185.0,
"sentiment": "neutral",
},
}
@staticmethod
def incomplete_data() -> dict[str, Any]:
"""Mock incomplete or missing data scenarios."""
return {
"missing_price_data": {
"symbol": "AAPL",
"error": "Price data not available for requested timeframe",
"available_data": None,
},
"partial_search_results": {
"results_found": 2,
"results_expected": 10,
"provider_errors": ["exa_timeout", "tavily_rate_limit"],
"partial_data": True,
},
"llm_partial_response": {
"analysis": "Partial analysis completed before",
"truncated": True,
"completion_percentage": 0.6,
},
}
@staticmethod
def malformed_data() -> dict[str, Any]:
"""Mock malformed or invalid data for error testing."""
return {
"invalid_json": '{"analysis": "incomplete json"', # Missing closing brace
"wrong_schema": {
"unexpected_field": "value",
"missing_required_field": None,
},
"invalid_dates": {
"published_date": "not-a-date",
"timestamp": "invalid-timestamp",
},
"invalid_numbers": {"confidence": "not-a-number", "price": "invalid-price"},
}
# ==============================================================================
# PYTEST FIXTURES
# ==============================================================================
@pytest.fixture
def mock_llm_responses():
"""Fixture providing mock LLM responses."""
return MockLLMResponses()
@pytest.fixture
def mock_exa_responses():
"""Fixture providing mock Exa API responses."""
return MockExaResponses()
@pytest.fixture
def mock_tavily_responses():
"""Fixture providing mock Tavily API responses."""
return MockTavilyResponses()
@pytest.fixture
def mock_market_data():
"""Fixture providing mock market data."""
return MockMarketData()
@pytest.fixture
def test_queries():
"""Fixture providing test queries."""
return TestQueries()
@pytest.fixture
def persona_fixtures():
"""Fixture providing persona-specific data."""
return PersonaFixtures()
@pytest.fixture
def edge_case_fixtures():
"""Fixture providing edge case test data."""
return EdgeCaseFixtures()
@pytest.fixture(params=["conservative", "moderate", "aggressive"])
def investor_persona(request):
"""Parametrized fixture for testing across all investor personas."""
return request.param
@pytest.fixture(
params=[
"market_screening",
"company_research",
"technical_analysis",
"sentiment_analysis",
]
)
def query_category(request):
"""Parametrized fixture for testing across all query categories."""
return request.param
# ==============================================================================
# HELPER FUNCTIONS
# ==============================================================================
def create_mock_llm_with_responses(responses: list[str]) -> MagicMock:
"""Create a mock LLM that returns specific responses in order."""
mock_llm = MagicMock()
# Create AIMessage objects for each response
ai_messages = [AIMessage(content=response) for response in responses]
mock_llm.ainvoke.side_effect = ai_messages
return mock_llm
def create_mock_agent_result(
agent_type: str,
confidence: float = 0.8,
recommendation: str = "BUY",
additional_data: dict[str, Any] = None,
) -> dict[str, Any]:
"""Create a mock agent result with realistic structure."""
base_result = {
"status": "success",
"agent_type": agent_type,
"confidence_score": confidence,
"recommendation": recommendation,
"timestamp": datetime.now(),
"execution_time_ms": np.random.uniform(1000, 5000),
}
if additional_data:
base_result.update(additional_data)
return base_result
def create_realistic_stock_data(
symbol: str = "AAPL", price: float = 185.0, volume: int = 50_000_000
) -> dict[str, Any]:
"""Create realistic stock data for testing."""
return {
"symbol": symbol,
"current_price": price,
"volume": volume,
"market_cap": 2_850_000_000_000, # $2.85T for AAPL
"pe_ratio": 28.5,
"dividend_yield": 0.005,
"beta": 1.1,
"52_week_high": 198.23,
"52_week_low": 164.08,
"average_volume": 48_000_000,
"sector": "Technology",
"industry": "Consumer Electronics",
}
# Export main classes for easy importing
__all__ = [
"MockLLMResponses",
"MockExaResponses",
"MockTavilyResponses",
"MockMarketData",
"TestQueries",
"PersonaFixtures",
"EdgeCaseFixtures",
"create_mock_llm_with_responses",
"create_mock_agent_result",
"create_realistic_stock_data",
]
```
--------------------------------------------------------------------------------
/tests/test_exa_research_integration.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive test suite for ExaSearch integration with research agents.
This test suite validates the complete research agent architecture with ExaSearch provider,
including timeout handling, parallel execution, specialized subagents, and performance
benchmarking across all research depths and focus areas.
"""
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from exa_py import Exa
from maverick_mcp.agents.deep_research import (
RESEARCH_DEPTH_LEVELS,
CompetitiveResearchAgent,
ContentAnalyzer,
DeepResearchAgent,
ExaSearchProvider,
FundamentalResearchAgent,
SentimentResearchAgent,
TechnicalResearchAgent,
)
from maverick_mcp.api.routers.research import (
ResearchRequest,
comprehensive_research,
get_research_agent,
)
from maverick_mcp.exceptions import WebSearchError
from maverick_mcp.utils.parallel_research import (
ParallelResearchConfig,
ParallelResearchOrchestrator,
ResearchResult,
ResearchTask,
TaskDistributionEngine,
)
# Test Data Factories and Fixtures
@pytest.fixture
def mock_llm():
"""Mock LLM with realistic response patterns for research scenarios."""
llm = MagicMock()
llm.ainvoke = AsyncMock()
# Mock different response types for different research phases
def mock_response_content(messages):
"""Generate realistic mock responses based on message content."""
content = str(messages[-1].content).lower()
if "synthesis" in content:
return MagicMock(
content='{"synthesis": "Comprehensive analysis shows positive outlook", "confidence": 0.8}'
)
elif "analyze" in content or "financial" in content:
return MagicMock(
content='{"KEY_INSIGHTS": ["Strong earnings growth", "Market share expansion"], "SENTIMENT": {"direction": "bullish", "confidence": 0.75}, "RISK_FACTORS": ["Interest rate sensitivity"], "OPPORTUNITIES": ["Market expansion"], "CREDIBILITY": 0.8, "RELEVANCE": 0.9, "SUMMARY": "Positive financial outlook"}'
)
else:
return MagicMock(content="Analysis completed successfully")
llm.ainvoke.side_effect = lambda messages, **kwargs: mock_response_content(messages)
return llm
@pytest.fixture
def mock_exa_client():
"""Mock Exa client with realistic search responses."""
mock_client = MagicMock(spec=Exa)
def create_mock_result(title, text, url_suffix=""):
"""Create mock Exa result object."""
result = MagicMock()
result.url = f"https://example.com/{url_suffix}"
result.title = title
result.text = text
result.published_date = "2024-01-15T10:00:00Z"
result.score = 0.85
result.author = "Financial Analyst"
return result
def mock_search_and_contents(query, num_results=5, **kwargs):
"""Generate mock search results based on query content."""
response = MagicMock()
results = []
query_lower = query.lower()
if "aapl" in query_lower or "apple" in query_lower:
results.extend(
[
create_mock_result(
"Apple Q4 Earnings Beat Expectations",
"Apple reported strong quarterly earnings with iPhone sales growth of 15% and services revenue reaching new highs. The company's financial position remains robust with strong cash flow.",
"apple-earnings",
),
create_mock_result(
"Apple Stock Technical Analysis",
"Apple stock shows bullish technical patterns with support at $180 and resistance at $200. RSI indicates oversold conditions presenting buying opportunity.",
"apple-technical",
),
]
)
elif "sentiment" in query_lower:
results.extend(
[
create_mock_result(
"Market Sentiment Turns Positive",
"Investor sentiment shows improvement with increased confidence in tech sector. Analyst upgrades and positive earnings surprises drive optimism.",
"market-sentiment",
),
]
)
elif "competitive" in query_lower or "industry" in query_lower:
results.extend(
[
create_mock_result(
"Tech Industry Competitive Landscape",
"The technology sector shows fierce competition with market leaders maintaining strong positions. Innovation and market share battles intensify.",
"competitive-analysis",
),
]
)
else:
# Default financial research results
results.extend(
[
create_mock_result(
"Financial Market Analysis",
"Current market conditions show mixed signals with growth prospects balanced against economic uncertainties. Investors remain cautiously optimistic.",
"market-analysis",
),
create_mock_result(
"Investment Outlook 2024",
"Investment opportunities emerge in technology and healthcare sectors despite ongoing market volatility. Diversification remains key strategy.",
"investment-outlook",
),
]
)
# Limit results to requested number
response.results = results[:num_results]
return response
mock_client.search_and_contents.side_effect = mock_search_and_contents
return mock_client
@pytest.fixture
def sample_research_tasks():
"""Sample research tasks for parallel execution testing."""
return [
ResearchTask(
task_id="session_123_fundamental",
task_type="fundamental",
target_topic="AAPL financial analysis",
focus_areas=["earnings", "valuation", "growth"],
priority=8,
timeout=20,
),
ResearchTask(
task_id="session_123_technical",
task_type="technical",
target_topic="AAPL technical analysis",
focus_areas=["chart_patterns", "support_resistance"],
priority=7,
timeout=15,
),
ResearchTask(
task_id="session_123_sentiment",
task_type="sentiment",
target_topic="AAPL market sentiment",
focus_areas=["news_sentiment", "analyst_ratings"],
priority=6,
timeout=15,
),
]
@pytest.fixture
def mock_settings():
"""Mock settings with ExaSearch configuration."""
settings = MagicMock()
settings.research.exa_api_key = "test_exa_api_key"
settings.data_limits.max_parallel_agents = 4
settings.performance.search_timeout_failure_threshold = 12
settings.performance.search_circuit_breaker_failure_threshold = 8
settings.performance.search_circuit_breaker_recovery_timeout = 30
return settings
# ExaSearchProvider Tests
class TestExaSearchProvider:
"""Test ExaSearch provider integration and functionality."""
@pytest.mark.unit
def test_exa_provider_initialization(self):
"""Test ExaSearchProvider initialization."""
api_key = "test_api_key_123"
provider = ExaSearchProvider(api_key)
assert provider.api_key == api_key
assert provider._api_key_verified is True
assert provider.is_healthy() is True
assert provider._failure_count == 0
@pytest.mark.unit
def test_exa_provider_initialization_without_key(self):
"""Test ExaSearchProvider initialization without API key."""
provider = ExaSearchProvider("")
assert provider.api_key == ""
assert provider._api_key_verified is False
assert provider.is_healthy() is True # Still healthy, but searches will fail
@pytest.mark.unit
def test_timeout_calculation(self):
"""Test adaptive timeout calculation for different query complexities."""
provider = ExaSearchProvider("test_key")
# Simple query
timeout = provider._calculate_timeout("AAPL", None)
assert timeout >= 4.0 # Minimum for Exa reliability
# Complex query
complex_query = "comprehensive analysis of Apple Inc financial performance and market position with competitive analysis"
timeout_complex = provider._calculate_timeout(complex_query, None)
assert timeout_complex >= timeout
# Budget constrained query
timeout_budget = provider._calculate_timeout("AAPL", 8.0)
assert 4.0 <= timeout_budget <= 8.0
@pytest.mark.unit
def test_failure_recording_and_health_status(self):
"""Test failure recording and health status management."""
provider = ExaSearchProvider("test_key")
# Initially healthy
assert provider.is_healthy() is True
# Record several timeout failures
for _ in range(5):
provider._record_failure("timeout")
assert provider._failure_count == 5
assert provider.is_healthy() is True # Still healthy, threshold not reached
# Exceed timeout threshold (default 12)
for _ in range(8):
provider._record_failure("timeout")
assert provider._failure_count == 13
assert provider.is_healthy() is False # Now unhealthy
# Test recovery
provider._record_success()
assert provider.is_healthy() is True
assert provider._failure_count == 0
@pytest.mark.unit
@patch("exa_py.Exa")
async def test_exa_search_success(self, mock_exa_class, mock_exa_client):
"""Test successful ExaSearch operation."""
mock_exa_class.return_value = mock_exa_client
provider = ExaSearchProvider("test_key")
results = await provider.search("AAPL financial analysis", num_results=3)
assert len(results) >= 1
assert all("url" in result for result in results)
assert all("title" in result for result in results)
assert all("content" in result for result in results)
assert all(result["provider"] == "exa" for result in results)
@pytest.mark.unit
@patch("exa_py.Exa")
async def test_exa_search_timeout(self, mock_exa_class):
"""Test ExaSearch timeout handling."""
# Mock Exa client that takes too long
mock_client = MagicMock()
def slow_search(*args, **kwargs):
import time
time.sleep(10) # Simulate slow synchronous response
mock_client.search_and_contents.side_effect = slow_search
mock_exa_class.return_value = mock_client
provider = ExaSearchProvider("test_key")
with pytest.raises(WebSearchError, match="timed out"):
await provider.search("test query", timeout_budget=2.0)
# Check that failure was recorded
assert not provider.is_healthy() or provider._failure_count > 0
@pytest.mark.unit
@patch("exa_py.Exa")
async def test_exa_search_unhealthy_provider(self, mock_exa_class):
"""Test behavior when provider is marked as unhealthy."""
provider = ExaSearchProvider("test_key")
provider._is_healthy = False
with pytest.raises(WebSearchError, match="disabled due to repeated failures"):
await provider.search("test query")
# DeepResearchAgent Tests
class TestDeepResearchAgent:
"""Test DeepResearchAgent with ExaSearch integration."""
@pytest.mark.unit
@patch("maverick_mcp.agents.deep_research.get_cached_search_provider")
async def test_agent_initialization_with_exa(self, mock_provider, mock_llm):
"""Test DeepResearchAgent initialization with ExaSearch provider."""
mock_exa_provider = MagicMock(spec=ExaSearchProvider)
mock_provider.return_value = mock_exa_provider
agent = DeepResearchAgent(
llm=mock_llm,
persona="moderate",
exa_api_key="test_key",
research_depth="standard",
)
await agent.initialize()
assert agent.search_providers == [mock_exa_provider]
assert agent._search_providers_loaded is True
assert agent.default_depth == "standard"
@pytest.mark.unit
@patch("maverick_mcp.agents.deep_research.get_cached_search_provider")
async def test_agent_initialization_without_providers(
self, mock_provider, mock_llm
):
"""Test agent behavior when no search providers are available."""
mock_provider.return_value = None
agent = DeepResearchAgent(
llm=mock_llm,
persona="moderate",
exa_api_key=None,
)
await agent.initialize()
assert agent.search_providers == []
assert agent._search_providers_loaded is True
@pytest.mark.unit
@patch("maverick_mcp.agents.deep_research.get_cached_search_provider")
async def test_research_comprehensive_no_providers(self, mock_provider, mock_llm):
"""Test research behavior when no search providers are configured."""
mock_provider.return_value = None
agent = DeepResearchAgent(llm=mock_llm, exa_api_key=None)
result = await agent.research_comprehensive(
topic="AAPL analysis", session_id="test_session", depth="basic"
)
assert "error" in result
assert "no search providers configured" in result["error"]
assert result["topic"] == "AAPL analysis"
@pytest.mark.integration
@patch("maverick_mcp.agents.deep_research.get_cached_search_provider")
@patch("exa_py.Exa")
async def test_research_comprehensive_success(
self, mock_exa_class, mock_provider, mock_llm, mock_exa_client
):
"""Test successful comprehensive research with ExaSearch."""
# Setup mocks
mock_exa_provider = ExaSearchProvider("test_key")
mock_provider.return_value = mock_exa_provider
mock_exa_class.return_value = mock_exa_client
agent = DeepResearchAgent(
llm=mock_llm,
persona="moderate",
exa_api_key="test_key",
research_depth="basic",
)
# Execute research
result = await agent.research_comprehensive(
topic="AAPL financial analysis",
session_id="test_session_123",
depth="basic",
timeout_budget=15.0,
)
# Verify result structure
assert result["status"] == "success"
assert result["agent_type"] == "deep_research"
assert result["research_topic"] == "AAPL financial analysis"
assert result["research_depth"] == "basic"
assert "findings" in result
assert "confidence_score" in result
assert "execution_time_ms" in result
@pytest.mark.unit
def test_research_depth_levels(self):
"""Test research depth level configurations."""
assert "basic" in RESEARCH_DEPTH_LEVELS
assert "standard" in RESEARCH_DEPTH_LEVELS
assert "comprehensive" in RESEARCH_DEPTH_LEVELS
assert "exhaustive" in RESEARCH_DEPTH_LEVELS
# Verify basic level has minimal settings for speed
basic = RESEARCH_DEPTH_LEVELS["basic"]
assert basic["max_sources"] <= 5
assert basic["max_searches"] <= 2
assert basic["validation_required"] is False
# Verify exhaustive has maximum settings
exhaustive = RESEARCH_DEPTH_LEVELS["exhaustive"]
assert exhaustive["max_sources"] >= 10
assert exhaustive["validation_required"] is True
# Specialized Subagent Tests
class TestSpecializedSubagents:
"""Test specialized research subagents."""
@pytest.fixture
def mock_parent_agent(self, mock_llm):
"""Mock parent DeepResearchAgent for subagent testing."""
agent = MagicMock()
agent.llm = mock_llm
agent.search_providers = [MagicMock(spec=ExaSearchProvider)]
agent.content_analyzer = MagicMock(spec=ContentAnalyzer)
agent.persona = MagicMock()
agent.persona.name = "moderate"
agent._calculate_source_credibility = MagicMock(return_value=0.8)
return agent
@pytest.mark.unit
async def test_fundamental_research_agent(
self, mock_parent_agent, sample_research_tasks
):
"""Test FundamentalResearchAgent execution."""
task = sample_research_tasks[0] # fundamental task
agent = FundamentalResearchAgent(mock_parent_agent)
# Mock search results
mock_search_results = [
{
"title": "AAPL Q4 Earnings Report",
"url": "https://example.com/earnings",
"content": "Apple reported strong quarterly earnings with revenue growth of 12%...",
"published_date": "2024-01-15",
}
]
agent._perform_specialized_search = AsyncMock(return_value=mock_search_results)
agent._analyze_search_results = AsyncMock(
return_value=[
{
**mock_search_results[0],
"analysis": {
"insights": [
"Strong earnings growth",
"Revenue diversification",
],
"risk_factors": ["Market competition"],
"opportunities": ["Market expansion"],
"sentiment": {"direction": "bullish", "confidence": 0.8},
},
"credibility_score": 0.8,
}
]
)
result = await agent.execute_research(task)
assert result["research_type"] == "fundamental"
assert "insights" in result
assert "risk_factors" in result
assert "opportunities" in result
assert "sentiment" in result
assert "sources" in result
assert len(result["focus_areas"]) > 0
assert "earnings" in result["focus_areas"]
@pytest.mark.unit
async def test_technical_research_agent(
self, mock_parent_agent, sample_research_tasks
):
"""Test TechnicalResearchAgent execution."""
task = sample_research_tasks[1] # technical task
agent = TechnicalResearchAgent(mock_parent_agent)
# Mock search results with technical analysis
mock_search_results = [
{
"title": "AAPL Technical Analysis",
"url": "https://example.com/technical",
"content": "AAPL shows bullish chart patterns with support at $180 and resistance at $200...",
"published_date": "2024-01-15",
}
]
agent._perform_specialized_search = AsyncMock(return_value=mock_search_results)
agent._analyze_search_results = AsyncMock(
return_value=[
{
**mock_search_results[0],
"analysis": {
"insights": [
"Bullish breakout pattern",
"Strong support levels",
],
"risk_factors": ["Overbought conditions"],
"opportunities": ["Momentum continuation"],
"sentiment": {"direction": "bullish", "confidence": 0.7},
},
"credibility_score": 0.7,
}
]
)
result = await agent.execute_research(task)
assert result["research_type"] == "technical"
assert "price_action" in result["focus_areas"]
assert "chart_patterns" in result["focus_areas"]
@pytest.mark.unit
async def test_sentiment_research_agent(
self, mock_parent_agent, sample_research_tasks
):
"""Test SentimentResearchAgent execution."""
task = sample_research_tasks[2] # sentiment task
agent = SentimentResearchAgent(mock_parent_agent)
# Mock search results with sentiment data
mock_search_results = [
{
"title": "Apple Stock Sentiment Analysis",
"url": "https://example.com/sentiment",
"content": "Analyst sentiment remains positive on Apple with multiple upgrades...",
"published_date": "2024-01-15",
}
]
agent._perform_specialized_search = AsyncMock(return_value=mock_search_results)
agent._analyze_search_results = AsyncMock(
return_value=[
{
**mock_search_results[0],
"analysis": {
"insights": ["Positive analyst sentiment", "Upgrade momentum"],
"risk_factors": ["Market volatility concerns"],
"opportunities": ["Institutional accumulation"],
"sentiment": {"direction": "bullish", "confidence": 0.85},
},
"credibility_score": 0.9,
}
]
)
result = await agent.execute_research(task)
assert result["research_type"] == "sentiment"
assert "market_sentiment" in result["focus_areas"]
assert result["sentiment"]["direction"] == "bullish"
@pytest.mark.unit
async def test_competitive_research_agent(self, mock_parent_agent):
"""Test CompetitiveResearchAgent execution."""
task = ResearchTask(
task_id="test_competitive",
task_type="competitive",
target_topic="AAPL competitive analysis",
focus_areas=["competitive_position", "market_share"],
)
agent = CompetitiveResearchAgent(mock_parent_agent)
# Mock search results with competitive data
mock_search_results = [
{
"title": "Apple vs Samsung Market Share",
"url": "https://example.com/competitive",
"content": "Apple maintains strong competitive position in premium smartphone market...",
"published_date": "2024-01-15",
}
]
agent._perform_specialized_search = AsyncMock(return_value=mock_search_results)
agent._analyze_search_results = AsyncMock(
return_value=[
{
**mock_search_results[0],
"analysis": {
"insights": [
"Strong market position",
"Premium segment dominance",
],
"risk_factors": ["Android competition"],
"opportunities": ["Emerging markets"],
"sentiment": {"direction": "bullish", "confidence": 0.75},
},
"credibility_score": 0.8,
}
]
)
result = await agent.execute_research(task)
assert result["research_type"] == "competitive"
assert "competitive_position" in result["focus_areas"]
assert "industry_trends" in result["focus_areas"]
# Parallel Research Tests
class TestParallelResearchOrchestrator:
"""Test parallel research execution and orchestration."""
@pytest.mark.unit
def test_orchestrator_initialization(self):
"""Test ParallelResearchOrchestrator initialization."""
config = ParallelResearchConfig(max_concurrent_agents=6, timeout_per_agent=20)
orchestrator = ParallelResearchOrchestrator(config)
assert orchestrator.config.max_concurrent_agents == 6
assert orchestrator.config.timeout_per_agent == 20
assert orchestrator._semaphore._value == 6 # Semaphore initialized correctly
@pytest.mark.unit
async def test_task_preparation(self, sample_research_tasks):
"""Test task preparation and prioritization."""
orchestrator = ParallelResearchOrchestrator()
prepared_tasks = await orchestrator._prepare_tasks(sample_research_tasks)
# Should be sorted by priority (descending)
assert prepared_tasks[0].priority >= prepared_tasks[1].priority
# All tasks should have timeouts set
for task in prepared_tasks:
assert task.timeout is not None
assert task.status == "pending"
assert task.task_id in orchestrator.active_tasks
@pytest.mark.integration
async def test_parallel_execution_success(self, sample_research_tasks):
"""Test successful parallel execution of research tasks."""
orchestrator = ParallelResearchOrchestrator(
ParallelResearchConfig(max_concurrent_agents=3, timeout_per_agent=10)
)
# Mock research executor
async def mock_executor(task):
"""Mock research executor that simulates successful execution."""
await asyncio.sleep(0.1) # Simulate work
return {
"research_type": task.task_type,
"insights": [
f"{task.task_type} insight 1",
f"{task.task_type} insight 2",
],
"sentiment": {"direction": "bullish", "confidence": 0.8},
"sources": [
{"title": f"{task.task_type} source", "url": "https://example.com"}
],
}
# Mock synthesis callback
async def mock_synthesis(task_results):
return {
"synthesis": f"Synthesized results from {len(task_results)} tasks",
"confidence_score": 0.8,
}
result = await orchestrator.execute_parallel_research(
tasks=sample_research_tasks,
research_executor=mock_executor,
synthesis_callback=mock_synthesis,
)
assert isinstance(result, ResearchResult)
assert result.successful_tasks == len(sample_research_tasks)
assert result.failed_tasks == 0
assert result.parallel_efficiency > 1.0 # Should be faster than sequential
assert result.synthesis is not None
assert "synthesis" in result.synthesis
@pytest.mark.unit
async def test_parallel_execution_with_failures(self, sample_research_tasks):
"""Test parallel execution with some task failures."""
orchestrator = ParallelResearchOrchestrator()
# Mock research executor that fails for certain task types
async def mock_executor_with_failures(task):
if task.task_type == "technical":
raise TimeoutError("Task timed out")
elif task.task_type == "sentiment":
raise Exception("Network error")
else:
return {"research_type": task.task_type, "insights": ["Success"]}
result = await orchestrator.execute_parallel_research(
tasks=sample_research_tasks,
research_executor=mock_executor_with_failures,
)
assert result.successful_tasks == 1 # Only fundamental should succeed
assert result.failed_tasks == 2
# Check that failed tasks have error information
failed_tasks = [
task for task in result.task_results.values() if task.status == "failed"
]
assert len(failed_tasks) == 2
for task in failed_tasks:
assert task.error is not None
@pytest.mark.unit
async def test_circuit_breaker_integration(self, sample_research_tasks):
"""Test circuit breaker integration in parallel execution."""
orchestrator = ParallelResearchOrchestrator()
# Mock executor that consistently fails
failure_count = 0
async def failing_executor(task):
nonlocal failure_count
failure_count += 1
raise Exception(f"Failure {failure_count}")
result = await orchestrator.execute_parallel_research(
tasks=sample_research_tasks,
research_executor=failing_executor,
)
# All tasks should fail
assert result.failed_tasks == len(sample_research_tasks)
assert result.successful_tasks == 0
class TestTaskDistributionEngine:
"""Test intelligent task distribution for research topics."""
@pytest.mark.unit
def test_topic_relevance_analysis(self):
"""Test topic relevance analysis for different task types."""
engine = TaskDistributionEngine()
# Test financial topic
relevance = engine._analyze_topic_relevance(
"apple earnings financial performance",
focus_areas=["fundamentals", "financials"],
)
assert "fundamental" in relevance
assert "technical" in relevance
assert "sentiment" in relevance
assert "competitive" in relevance
# Fundamental should have highest relevance for earnings query
assert relevance["fundamental"] > relevance["technical"]
assert relevance["fundamental"] > relevance["competitive"]
@pytest.mark.unit
def test_task_distribution_basic(self):
"""Test basic task distribution for a research topic."""
engine = TaskDistributionEngine()
tasks = engine.distribute_research_tasks(
topic="AAPL financial analysis and market outlook",
session_id="test_session",
focus_areas=["fundamentals", "technical_analysis"],
)
assert len(tasks) > 0
# Should have variety of task types
task_types = {task.task_type for task in tasks}
assert "fundamental" in task_types # High relevance for financial analysis
# Tasks should be properly configured
for task in tasks:
assert task.session_id == "test_session"
assert task.target_topic == "AAPL financial analysis and market outlook"
assert task.priority > 0
assert len(task.focus_areas) > 0
@pytest.mark.unit
def test_task_distribution_fallback(self):
"""Test task distribution fallback when no relevant tasks found."""
engine = TaskDistributionEngine()
# Mock the relevance analysis to return very low scores
with patch.object(
engine,
"_analyze_topic_relevance",
return_value={
"fundamental": 0.1,
"technical": 0.1,
"sentiment": 0.1,
"competitive": 0.1,
},
):
tasks = engine.distribute_research_tasks(
topic="obscure topic with no clear relevance",
session_id="test_session",
)
# Should still create at least one task (fallback)
assert len(tasks) >= 1
# Fallback should be fundamental analysis
assert any(task.task_type == "fundamental" for task in tasks)
@pytest.mark.unit
def test_task_priority_assignment(self):
"""Test priority assignment based on relevance scores."""
engine = TaskDistributionEngine()
tasks = engine.distribute_research_tasks(
topic="AAPL fundamental analysis earnings valuation",
session_id="test_session",
)
# Find fundamental task (should have high priority)
fundamental_tasks = [t for t in tasks if t.task_type == "fundamental"]
if fundamental_tasks:
fundamental_task = fundamental_tasks[0]
assert fundamental_task.priority >= 7 # Should be high priority
# Timeout and Circuit Breaker Tests
class TestTimeoutAndCircuitBreaker:
"""Test timeout handling and circuit breaker patterns."""
@pytest.mark.unit
async def test_timeout_budget_allocation(self, mock_llm):
"""Test timeout budget allocation across research phases."""
agent = DeepResearchAgent(llm=mock_llm, exa_api_key="test_key")
# Test basic timeout allocation
timeout_budget = 20.0
result = await agent.research_comprehensive(
topic="test topic",
session_id="test_session",
depth="basic",
timeout_budget=timeout_budget,
)
# Should either complete or timeout gracefully
assert "status" in result or "error" in result
# If timeout occurred, should have appropriate error structure
if result.get("status") == "error" or "error" in result:
# Should be a timeout-related error for very short budget
assert (
"timeout" in str(result).lower()
or "search providers" in str(result).lower()
)
@pytest.mark.unit
def test_provider_health_monitoring(self):
"""Test search provider health monitoring and recovery."""
provider = ExaSearchProvider("test_key")
# Initially healthy
assert provider.is_healthy()
# Simulate multiple timeout failures
for _i in range(15): # Exceed default threshold of 12
provider._record_failure("timeout")
# Should be marked unhealthy
assert not provider.is_healthy()
# Recovery after success
provider._record_success()
assert provider.is_healthy()
assert provider._failure_count == 0
@pytest.mark.integration
@patch("maverick_mcp.agents.deep_research.get_cached_search_provider")
async def test_research_with_provider_failures(self, mock_provider, mock_llm):
"""Test research behavior when provider failures occur."""
# Create a provider that will fail
failing_provider = MagicMock(spec=ExaSearchProvider)
failing_provider.is_healthy.return_value = True
failing_provider.search = AsyncMock(side_effect=WebSearchError("Search failed"))
mock_provider.return_value = failing_provider
agent = DeepResearchAgent(llm=mock_llm, exa_api_key="test_key")
result = await agent.research_comprehensive(
topic="test topic",
session_id="test_session",
depth="basic",
)
# Should handle provider failure gracefully
assert "status" in result
# May succeed with fallback or fail gracefully
# Performance and Benchmarking Tests
class TestPerformanceBenchmarks:
"""Test performance across different research depths and configurations."""
@pytest.mark.slow
@pytest.mark.integration
@patch("maverick_mcp.agents.deep_research.get_cached_search_provider")
@patch("exa_py.Exa")
async def test_research_depth_performance(
self, mock_exa_class, mock_provider, mock_llm, mock_exa_client
):
"""Benchmark performance across different research depths."""
mock_provider.return_value = ExaSearchProvider("test_key")
mock_exa_class.return_value = mock_exa_client
performance_results = {}
for depth in ["basic", "standard", "comprehensive"]:
agent = DeepResearchAgent(
llm=mock_llm,
exa_api_key="test_key",
research_depth=depth,
)
start_time = time.time()
result = await agent.research_comprehensive(
topic="AAPL financial analysis",
session_id=f"perf_test_{depth}",
depth=depth,
timeout_budget=30.0,
)
execution_time = time.time() - start_time
performance_results[depth] = {
"execution_time": execution_time,
"success": result.get("status") == "success",
"sources_analyzed": result.get("sources_analyzed", 0),
}
# Verify performance characteristics
assert (
performance_results["basic"]["execution_time"]
<= performance_results["comprehensive"]["execution_time"]
)
# Basic should be fastest
if performance_results["basic"]["success"]:
assert (
performance_results["basic"]["execution_time"] < 15.0
) # Should be fast
@pytest.mark.slow
async def test_parallel_vs_sequential_performance(self, sample_research_tasks):
"""Compare parallel vs sequential execution performance."""
config = ParallelResearchConfig(max_concurrent_agents=4, timeout_per_agent=10)
orchestrator = ParallelResearchOrchestrator(config)
async def mock_executor(task):
await asyncio.sleep(1) # Simulate 1 second work per task
return {"research_type": task.task_type, "insights": ["Mock insight"]}
# Parallel execution
start_time = time.time()
parallel_result = await orchestrator.execute_parallel_research(
tasks=sample_research_tasks,
research_executor=mock_executor,
)
parallel_time = time.time() - start_time
# Sequential simulation
start_time = time.time()
for task in sample_research_tasks:
await mock_executor(task)
sequential_time = time.time() - start_time
# Parallel should be significantly faster
assert parallel_result.parallel_efficiency > 1.5 # At least 50% improvement
assert parallel_time < sequential_time * 0.7 # Should be at least 30% faster
@pytest.mark.unit
async def test_memory_usage_monitoring(self, sample_research_tasks):
"""Test memory usage stays reasonable during parallel execution."""
import os
import psutil
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
config = ParallelResearchConfig(max_concurrent_agents=4)
orchestrator = ParallelResearchOrchestrator(config)
async def mock_executor(task):
# Create some data but not excessive
data = {"results": ["data"] * 1000} # Small amount of data
await asyncio.sleep(0.1)
return data
await orchestrator.execute_parallel_research(
tasks=sample_research_tasks * 5, # More tasks to test scaling
research_executor=mock_executor,
)
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_growth = final_memory - initial_memory
# Memory growth should be reasonable (less than 100MB for test)
assert memory_growth < 100, f"Memory grew by {memory_growth:.1f}MB"
# MCP Integration Tests
class TestMCPIntegration:
"""Test MCP tool endpoints and research router integration."""
@pytest.mark.integration
@patch("maverick_mcp.api.routers.research.get_settings")
async def test_comprehensive_research_mcp_tool(self, mock_settings):
"""Test the comprehensive research MCP tool endpoint."""
mock_settings.return_value.research.exa_api_key = "test_key"
result = await comprehensive_research(
query="AAPL financial analysis",
persona="moderate",
research_scope="basic",
max_sources=5,
timeframe="1m",
)
# Should return structured response
assert isinstance(result, dict)
assert "success" in result
# If successful, should have proper structure
if result.get("success"):
assert "research_results" in result
assert "research_metadata" in result
assert "request_id" in result
assert "timestamp" in result
@pytest.mark.unit
@patch("maverick_mcp.api.routers.research.get_settings")
async def test_research_without_exa_key(self, mock_settings):
"""Test research behavior without ExaSearch API key."""
mock_settings.return_value.research.exa_api_key = None
result = await comprehensive_research(
query="test query",
persona="moderate",
research_scope="basic",
)
assert result["success"] is False
assert "Exa search provider not configured" in result["error"]
assert "setup_instructions" in result["details"]
@pytest.mark.unit
def test_research_request_validation(self):
"""Test ResearchRequest model validation."""
# Valid request
request = ResearchRequest(
query="AAPL analysis",
persona="moderate",
research_scope="standard",
max_sources=15,
timeframe="1m",
)
assert request.query == "AAPL analysis"
assert request.persona == "moderate"
assert request.research_scope == "standard"
assert request.max_sources == 15
assert request.timeframe == "1m"
# Test defaults
minimal_request = ResearchRequest(query="test")
assert minimal_request.persona == "moderate"
assert minimal_request.research_scope == "standard"
assert minimal_request.max_sources == 10
assert minimal_request.timeframe == "1m"
@pytest.mark.unit
def test_get_research_agent_optimization(self):
"""Test research agent creation with optimization parameters."""
# Test optimized agent creation
agent = get_research_agent(
query="complex financial analysis of multiple companies",
research_scope="comprehensive",
timeout_budget=25.0,
max_sources=20,
)
assert isinstance(agent, DeepResearchAgent)
assert agent.max_sources <= 20 # Should respect or optimize max sources
assert agent.default_depth in [
"basic",
"standard",
"comprehensive",
"exhaustive",
]
# Test standard agent creation
standard_agent = get_research_agent()
assert isinstance(standard_agent, DeepResearchAgent)
# Content Analysis Tests
class TestContentAnalyzer:
"""Test AI-powered content analysis functionality."""
@pytest.mark.unit
async def test_content_analysis_success(self, mock_llm):
"""Test successful content analysis."""
analyzer = ContentAnalyzer(mock_llm)
content = "Apple reported strong quarterly earnings with revenue growth of 12% and expanding market share in the services segment."
result = await analyzer.analyze_content(
content=content, persona="moderate", analysis_focus="financial"
)
assert "insights" in result
assert "sentiment" in result
assert "risk_factors" in result
assert "opportunities" in result
assert "credibility_score" in result
assert "relevance_score" in result
assert "summary" in result
assert "analysis_timestamp" in result
@pytest.mark.unit
async def test_content_analysis_fallback(self, mock_llm):
"""Test content analysis fallback when AI analysis fails."""
analyzer = ContentAnalyzer(mock_llm)
# Make LLM fail
mock_llm.ainvoke.side_effect = Exception("LLM error")
result = await analyzer.analyze_content(
content="Test content", persona="moderate"
)
# Should fall back to keyword-based analysis
assert result["fallback_used"] is True
assert "sentiment" in result
assert result["sentiment"]["direction"] in ["bullish", "bearish", "neutral"]
@pytest.mark.unit
async def test_batch_content_analysis(self, mock_llm):
"""Test batch content analysis functionality."""
analyzer = ContentAnalyzer(mock_llm)
content_items = [
("Apple shows strong growth", "source1"),
("Market conditions remain volatile", "source2"),
("Tech sector outlook positive", "source3"),
]
results = await analyzer.analyze_content_batch(
content_items=content_items, persona="moderate", analysis_focus="general"
)
assert len(results) == len(content_items)
for i, result in enumerate(results):
assert result["source_identifier"] == f"source{i + 1}"
assert result["batch_processed"] is True
assert "sentiment" in result
# Error Handling and Edge Cases
class TestErrorHandlingAndEdgeCases:
"""Test comprehensive error handling and edge cases."""
@pytest.mark.unit
async def test_empty_search_results(self, mock_llm):
"""Test behavior when search returns no results."""
provider = ExaSearchProvider("test_key")
with patch("exa_py.Exa") as mock_exa:
# Mock empty results
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.results = []
mock_client.search_and_contents.return_value = mock_response
mock_exa.return_value = mock_client
results = await provider.search("nonexistent topic", num_results=5)
assert results == []
@pytest.mark.unit
async def test_malformed_search_response(self, mock_llm):
"""Test handling of malformed search responses."""
provider = ExaSearchProvider("test_key")
with patch("exa_py.Exa") as mock_exa:
# Mock malformed response
mock_client = MagicMock()
mock_client.search_and_contents.side_effect = Exception(
"Invalid response format"
)
mock_exa.return_value = mock_client
with pytest.raises(WebSearchError):
await provider.search("test query")
@pytest.mark.unit
async def test_network_timeout_recovery(self):
"""Test network timeout recovery mechanisms."""
provider = ExaSearchProvider("test_key")
# Simulate multiple timeouts followed by success
with patch("exa_py.Exa") as mock_exa:
call_count = 0
async def mock_search_with_recovery(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 2:
raise TimeoutError("Network timeout")
else:
# Success on third try
mock_response = MagicMock()
mock_result = MagicMock()
mock_result.url = "https://example.com"
mock_result.title = "Test Result"
mock_result.text = "Test content"
mock_result.published_date = "2024-01-15"
mock_result.score = 0.8
mock_response.results = [mock_result]
return mock_response
mock_client = MagicMock()
mock_client.search_and_contents.side_effect = mock_search_with_recovery
mock_exa.return_value = mock_client
# First two calls should fail and record failures
for _ in range(2):
with pytest.raises(WebSearchError):
await provider.search("test query", timeout_budget=1.0)
# Provider should still be healthy (failures recorded but not exceeded threshold)
assert provider._failure_count == 2
# Third call should succeed and reset failure count
results = await provider.search("test query")
assert len(results) > 0
assert provider._failure_count == 0 # Reset on success
@pytest.mark.unit
async def test_concurrent_request_limits(self, sample_research_tasks):
"""Test that concurrent request limits are respected."""
config = ParallelResearchConfig(max_concurrent_agents=2) # Very low limit
orchestrator = ParallelResearchOrchestrator(config)
execution_times = []
async def tracking_executor(task):
start = time.time()
await asyncio.sleep(0.5) # Simulate work
end = time.time()
execution_times.append((start, end))
return {"result": "success"}
await orchestrator.execute_parallel_research(
tasks=sample_research_tasks, # 3 tasks
research_executor=tracking_executor,
)
# With max_concurrent_agents=2, the third task should start after one of the first two finishes
# This means there should be overlap but not all three running simultaneously
assert len(execution_times) == 3
# Sort by start time
execution_times.sort()
# The third task should start after the first task finishes
# (allowing for some timing tolerance)
third_start = execution_times[2][0]
first_end = execution_times[0][1]
# Third should start after first ends (with small tolerance for async timing)
assert third_start >= (first_end - 0.1)
# Integration Test Suite
class TestFullIntegrationScenarios:
"""End-to-end integration tests for complete research workflows."""
@pytest.mark.integration
@pytest.mark.slow
@patch("maverick_mcp.agents.deep_research.get_cached_search_provider")
@patch("exa_py.Exa")
async def test_complete_research_workflow(
self, mock_exa_class, mock_provider, mock_llm, mock_exa_client
):
"""Test complete research workflow from query to final report."""
# Setup comprehensive mocks
mock_provider.return_value = ExaSearchProvider("test_key")
mock_exa_class.return_value = mock_exa_client
agent = DeepResearchAgent(
llm=mock_llm,
persona="moderate",
exa_api_key="test_key",
research_depth="standard",
enable_parallel_execution=True,
)
# Execute complete research workflow
result = await agent.research_comprehensive(
topic="Apple Inc (AAPL) investment analysis with market sentiment and competitive position",
session_id="integration_test_session",
depth="standard",
focus_areas=["fundamentals", "market_sentiment", "competitive_analysis"],
timeframe="1m",
use_parallel_execution=True,
)
# Verify comprehensive result structure
if result.get("status") == "success":
assert "findings" in result
assert "confidence_score" in result
assert isinstance(result["confidence_score"], int | float)
assert 0.0 <= result["confidence_score"] <= 1.0
assert "citations" in result
assert "execution_time_ms" in result
# Check for parallel execution indicators
if "parallel_execution_stats" in result:
assert "successful_tasks" in result["parallel_execution_stats"]
assert "parallel_efficiency" in result["parallel_execution_stats"]
# Should handle both success and controlled failure scenarios
assert "status" in result or "error" in result
@pytest.mark.integration
async def test_multi_persona_consistency(self, mock_llm, mock_exa_client):
"""Test research consistency across different investor personas."""
personas = ["conservative", "moderate", "aggressive", "day_trader"]
results = {}
for persona in personas:
with (
patch(
"maverick_mcp.agents.deep_research.get_cached_search_provider"
) as mock_provider,
patch("exa_py.Exa") as mock_exa_class,
):
mock_provider.return_value = ExaSearchProvider("test_key")
mock_exa_class.return_value = mock_exa_client
agent = DeepResearchAgent(
llm=mock_llm,
persona=persona,
exa_api_key="test_key",
research_depth="basic",
)
result = await agent.research_comprehensive(
topic="AAPL investment outlook",
session_id=f"persona_test_{persona}",
depth="basic",
)
results[persona] = result
# All personas should provide valid responses
for persona, result in results.items():
assert isinstance(result, dict), f"Invalid result for {persona}"
# Should have some form of result (success or controlled failure)
assert "status" in result or "error" in result or "success" in result
if __name__ == "__main__":
# Run specific test categories based on markers
pytest.main(
[
__file__,
"-v",
"--tb=short",
"-m",
"unit", # Run unit tests by default
]
)
```