This is page 32 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_ml_strategies.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive tests for ML-enhanced trading strategies.
3 |
4 | Tests cover:
5 | - Adaptive Strategy parameter adjustment and online learning
6 | - OnlineLearningStrategy with streaming ML algorithms
7 | - HybridAdaptiveStrategy combining multiple approaches
8 | - Feature engineering and extraction for ML models
9 | - Model training, prediction, and confidence scoring
10 | - Performance tracking and adaptation mechanisms
11 | - Parameter boundary enforcement and constraints
12 | - Strategy performance under different market regimes
13 | - Memory usage and computational efficiency
14 | - Error handling and model recovery scenarios
15 | """
16 |
17 | import warnings
18 | from typing import Any
19 | from unittest.mock import Mock, patch
20 |
21 | import numpy as np
22 | import pandas as pd
23 | import pytest
24 |
25 | from maverick_mcp.backtesting.strategies.base import Strategy
26 | from maverick_mcp.backtesting.strategies.ml.adaptive import (
27 | AdaptiveStrategy,
28 | HybridAdaptiveStrategy,
29 | OnlineLearningStrategy,
30 | )
31 |
32 | warnings.filterwarnings("ignore", category=FutureWarning)
33 |
34 |
35 | class MockBaseStrategy(Strategy):
36 | """Mock base strategy for testing adaptive strategies."""
37 |
38 | def __init__(self, parameters: dict[str, Any] = None):
39 | super().__init__(parameters or {"window": 20, "threshold": 0.02})
40 | self._signal_pattern = "alternating" # alternating, bullish, bearish, random
41 |
42 | @property
43 | def name(self) -> str:
44 | return "MockStrategy"
45 |
46 | @property
47 | def description(self) -> str:
48 | return "Mock strategy for testing"
49 |
50 | def generate_signals(self, data: pd.DataFrame) -> tuple[pd.Series, pd.Series]:
51 | """Generate mock signals based on pattern."""
52 | entry_signals = pd.Series(False, index=data.index)
53 | exit_signals = pd.Series(False, index=data.index)
54 |
55 | window = self.parameters.get("window", 20)
56 | threshold = float(self.parameters.get("threshold", 0.02) or 0.0)
57 | step = max(5, int(round(10 * (1 + abs(threshold) * 10))))
58 |
59 | if self._signal_pattern == "alternating":
60 | # Alternate between entry and exit signals with threshold-adjusted cadence
61 | for i in range(window, len(data), step):
62 | if (i // step) % 2 == 0:
63 | entry_signals.iloc[i] = True
64 | else:
65 | exit_signals.iloc[i] = True
66 | elif self._signal_pattern == "bullish":
67 | # More entry signals than exit
68 | entry_indices = np.random.choice(
69 | range(window, len(data)),
70 | size=min(20, len(data) - window),
71 | replace=False,
72 | )
73 | entry_signals.iloc[entry_indices] = True
74 | elif self._signal_pattern == "bearish":
75 | # More exit signals than entry
76 | exit_indices = np.random.choice(
77 | range(window, len(data)),
78 | size=min(20, len(data) - window),
79 | replace=False,
80 | )
81 | exit_signals.iloc[exit_indices] = True
82 |
83 | return entry_signals, exit_signals
84 |
85 |
86 | class TestAdaptiveStrategy:
87 | """Test suite for AdaptiveStrategy class."""
88 |
89 | @pytest.fixture
90 | def sample_market_data(self):
91 | """Create sample market data for testing."""
92 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
93 |
94 | # Generate realistic price data with trends
95 | returns = np.random.normal(0.0005, 0.02, len(dates))
96 | # Add some trending periods
97 | returns[100:150] += 0.003 # Bull period
98 | returns[200:250] -= 0.002 # Bear period
99 |
100 | prices = 100 * np.cumprod(1 + returns)
101 | volumes = np.random.randint(1000000, 5000000, len(dates))
102 |
103 | data = pd.DataFrame(
104 | {
105 | "open": prices * np.random.uniform(0.98, 1.02, len(dates)),
106 | "high": prices * np.random.uniform(1.00, 1.05, len(dates)),
107 | "low": prices * np.random.uniform(0.95, 1.00, len(dates)),
108 | "close": prices,
109 | "volume": volumes,
110 | },
111 | index=dates,
112 | )
113 |
114 | # Ensure high >= close, open and low <= close, open
115 | data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
116 | data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
117 |
118 | return data
119 |
120 | @pytest.fixture
121 | def mock_base_strategy(self):
122 | """Create a mock base strategy."""
123 | return MockBaseStrategy({"window": 20, "threshold": 0.02})
124 |
125 | @pytest.fixture
126 | def adaptive_strategy(self, mock_base_strategy):
127 | """Create an adaptive strategy with mock base."""
128 | return AdaptiveStrategy(
129 | base_strategy=mock_base_strategy,
130 | adaptation_method="gradient",
131 | learning_rate=0.01,
132 | lookback_period=50,
133 | adaptation_frequency=10,
134 | )
135 |
136 | def test_adaptive_strategy_initialization(
137 | self, adaptive_strategy, mock_base_strategy
138 | ):
139 | """Test adaptive strategy initialization."""
140 | assert adaptive_strategy.base_strategy == mock_base_strategy
141 | assert adaptive_strategy.adaptation_method == "gradient"
142 | assert adaptive_strategy.learning_rate == 0.01
143 | assert adaptive_strategy.lookback_period == 50
144 | assert adaptive_strategy.adaptation_frequency == 10
145 |
146 | assert len(adaptive_strategy.performance_history) == 0
147 | assert len(adaptive_strategy.parameter_history) == 0
148 | assert adaptive_strategy.last_adaptation == 0
149 |
150 | # Test name and description
151 | assert "Adaptive" in adaptive_strategy.name
152 | assert "MockStrategy" in adaptive_strategy.name
153 | assert "gradient" in adaptive_strategy.description
154 |
155 | def test_performance_metric_calculation(self, adaptive_strategy):
156 | """Test performance metric calculation."""
157 | # Test with normal returns
158 | returns = pd.Series([0.01, 0.02, -0.01, 0.015, -0.005])
159 | performance = adaptive_strategy.calculate_performance_metric(returns)
160 |
161 | assert isinstance(performance, float)
162 | assert not np.isnan(performance)
163 |
164 | # Test with zero volatility
165 | constant_returns = pd.Series([0.01, 0.01, 0.01, 0.01])
166 | performance = adaptive_strategy.calculate_performance_metric(constant_returns)
167 | assert performance == 0.0
168 |
169 | # Test with empty returns
170 | empty_returns = pd.Series([])
171 | performance = adaptive_strategy.calculate_performance_metric(empty_returns)
172 | assert performance == 0.0
173 |
174 | def test_adaptable_parameters_default(self, adaptive_strategy):
175 | """Test default adaptable parameters configuration."""
176 | adaptable_params = adaptive_strategy.get_adaptable_parameters()
177 |
178 | expected_params = ["lookback_period", "threshold", "window", "period"]
179 | for param in expected_params:
180 | assert param in adaptable_params
181 | assert "min" in adaptable_params[param]
182 | assert "max" in adaptable_params[param]
183 | assert "step" in adaptable_params[param]
184 |
185 | def test_gradient_parameter_adaptation(self, adaptive_strategy):
186 | """Test gradient-based parameter adaptation."""
187 | # Set up initial parameters
188 | initial_window = adaptive_strategy.base_strategy.parameters["window"]
189 | initial_threshold = adaptive_strategy.base_strategy.parameters["threshold"]
190 |
191 | # Simulate positive performance gradient
192 | adaptive_strategy.adapt_parameters_gradient(0.5) # Positive gradient
193 |
194 | # Parameters should have changed
195 | new_window = adaptive_strategy.base_strategy.parameters["window"]
196 | new_threshold = adaptive_strategy.base_strategy.parameters["threshold"]
197 |
198 | # At least one parameter should have changed
199 | assert new_window != initial_window or new_threshold != initial_threshold
200 |
201 | # Parameters should be within bounds
202 | adaptable_params = adaptive_strategy.get_adaptable_parameters()
203 | if "window" in adaptable_params:
204 | assert new_window >= adaptable_params["window"]["min"]
205 | assert new_window <= adaptable_params["window"]["max"]
206 |
207 | def test_random_search_parameter_adaptation(self, adaptive_strategy):
208 | """Test random search parameter adaptation."""
209 | adaptive_strategy.adaptation_method = "random_search"
210 |
211 | # Apply random search adaptation
212 | adaptive_strategy.adapt_parameters_random_search()
213 |
214 | # Parameters should potentially have changed
215 | new_params = adaptive_strategy.base_strategy.parameters
216 |
217 | # At least check that the method runs without error
218 | assert isinstance(new_params, dict)
219 | assert "window" in new_params
220 | assert "threshold" in new_params
221 |
222 | def test_adaptive_signal_generation(self, adaptive_strategy, sample_market_data):
223 | """Test adaptive signal generation with parameter updates."""
224 | entry_signals, exit_signals = adaptive_strategy.generate_signals(
225 | sample_market_data
226 | )
227 |
228 | # Basic signal validation
229 | assert len(entry_signals) == len(sample_market_data)
230 | assert len(exit_signals) == len(sample_market_data)
231 | assert entry_signals.dtype == bool
232 | assert exit_signals.dtype == bool
233 |
234 | # Check that some adaptations occurred
235 | assert len(adaptive_strategy.performance_history) > 0
236 |
237 | # Check that parameter history was recorded
238 | if len(adaptive_strategy.parameter_history) > 0:
239 | assert isinstance(adaptive_strategy.parameter_history[0], dict)
240 |
241 | def test_adaptation_frequency_control(self, adaptive_strategy, sample_market_data):
242 | """Test that adaptation occurs at correct frequency."""
243 | # Set a specific adaptation frequency
244 | adaptive_strategy.adaptation_frequency = 30
245 |
246 | # Generate signals
247 | adaptive_strategy.generate_signals(sample_market_data)
248 |
249 | # Number of adaptations should be roughly len(data) / adaptation_frequency
250 | expected_adaptations = len(sample_market_data) // 30
251 | actual_adaptations = len(adaptive_strategy.performance_history)
252 |
253 | # Allow some variance due to lookback period requirements
254 | assert abs(actual_adaptations - expected_adaptations) <= 2
255 |
256 | def test_adaptation_history_tracking(self, adaptive_strategy, sample_market_data):
257 | """Test adaptation history tracking."""
258 | adaptive_strategy.generate_signals(sample_market_data)
259 |
260 | history = adaptive_strategy.get_adaptation_history()
261 |
262 | assert "performance_history" in history
263 | assert "parameter_history" in history
264 | assert "current_parameters" in history
265 | assert "original_parameters" in history
266 |
267 | assert len(history["performance_history"]) > 0
268 | assert isinstance(history["current_parameters"], dict)
269 | assert isinstance(history["original_parameters"], dict)
270 |
271 | def test_reset_to_original_parameters(self, adaptive_strategy, sample_market_data):
272 | """Test resetting strategy to original parameters."""
273 | # Store original parameters
274 | original_params = adaptive_strategy.base_strategy.parameters.copy()
275 |
276 | # Generate signals to trigger adaptations
277 | adaptive_strategy.generate_signals(sample_market_data)
278 |
279 | # Parameters should have changed
280 |
281 | # Reset to original
282 | adaptive_strategy.reset_to_original()
283 |
284 | # Should match original parameters
285 | assert adaptive_strategy.base_strategy.parameters == original_params
286 | assert len(adaptive_strategy.performance_history) == 0
287 | assert len(adaptive_strategy.parameter_history) == 0
288 | assert adaptive_strategy.last_adaptation == 0
289 |
290 | def test_adaptive_strategy_error_handling(self, adaptive_strategy):
291 | """Test error handling in adaptive strategy."""
292 | # Test with invalid data
293 | invalid_data = pd.DataFrame({"close": [np.nan, np.nan]})
294 |
295 | entry_signals, exit_signals = adaptive_strategy.generate_signals(invalid_data)
296 |
297 | # Should return valid series even with bad data
298 | assert isinstance(entry_signals, pd.Series)
299 | assert isinstance(exit_signals, pd.Series)
300 | assert len(entry_signals) == len(invalid_data)
301 |
302 |
303 | class TestOnlineLearningStrategy:
304 | """Test suite for OnlineLearningStrategy class."""
305 |
306 | @pytest.fixture
307 | def online_strategy(self):
308 | """Create an online learning strategy."""
309 | return OnlineLearningStrategy(
310 | model_type="sgd",
311 | update_frequency=10,
312 | feature_window=20,
313 | confidence_threshold=0.6,
314 | )
315 |
316 | @pytest.fixture
317 | def online_learning_strategy(self, online_strategy):
318 | """Alias for online_strategy fixture for backward compatibility."""
319 | return online_strategy
320 |
321 | @pytest.fixture
322 | def sample_market_data(self):
323 | """Create sample market data for testing."""
324 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
325 |
326 | # Generate realistic price data with trends
327 | returns = np.random.normal(0.0005, 0.02, len(dates))
328 | # Add some trending periods
329 | returns[100:150] += 0.003 # Bull period
330 | returns[200:250] -= 0.002 # Bear period
331 |
332 | prices = 100 * np.cumprod(1 + returns)
333 | volumes = np.random.randint(1000000, 5000000, len(dates))
334 |
335 | data = pd.DataFrame(
336 | {
337 | "open": prices * np.random.uniform(0.98, 1.02, len(dates)),
338 | "high": prices * np.random.uniform(1.00, 1.05, len(dates)),
339 | "low": prices * np.random.uniform(0.95, 1.00, len(dates)),
340 | "close": prices,
341 | "volume": volumes,
342 | },
343 | index=dates,
344 | )
345 |
346 | # Ensure high >= close, open and low <= close, open
347 | data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
348 | data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
349 |
350 | return data
351 |
352 | def test_online_learning_initialization(self, online_strategy):
353 | """Test online learning strategy initialization."""
354 | assert online_strategy.model_type == "sgd"
355 | assert online_strategy.update_frequency == 10
356 | assert online_strategy.feature_window == 20
357 | assert online_strategy.confidence_threshold == 0.6
358 |
359 | assert online_strategy.model is not None
360 | assert hasattr(online_strategy.model, "fit") # Should be sklearn model
361 | assert not online_strategy.is_trained
362 | assert len(online_strategy.training_buffer) == 0
363 |
364 | # Test name and description
365 | assert "OnlineLearning" in online_strategy.name
366 | assert "SGD" in online_strategy.name
367 | assert "streaming" in online_strategy.description
368 |
369 | def test_model_initialization_error(self):
370 | """Test model initialization with unsupported type."""
371 | with pytest.raises(ValueError, match="Unsupported model type"):
372 | OnlineLearningStrategy(model_type="unsupported_model")
373 |
374 | def test_feature_extraction(self, online_strategy, sample_market_data):
375 | """Test feature extraction from market data."""
376 | # Test with sufficient data
377 | features = online_strategy.extract_features(sample_market_data, 30)
378 |
379 | assert isinstance(features, np.ndarray)
380 | assert len(features) > 0
381 | assert not np.any(np.isnan(features))
382 |
383 | # Test with insufficient data
384 | features = online_strategy.extract_features(sample_market_data, 1)
385 | assert len(features) == 0
386 |
387 | def test_target_creation(self, online_learning_strategy, sample_market_data):
388 | """Test target variable creation."""
389 | # Test normal case
390 | target = online_learning_strategy.create_target(sample_market_data, 30)
391 | assert target in [0, 1, 2] # sell, hold, buy
392 |
393 | # Test edge case - near end of data
394 | target = online_learning_strategy.create_target(
395 | sample_market_data, len(sample_market_data) - 1
396 | )
397 | assert target == 1 # Should default to hold
398 |
399 | def test_model_update_mechanism(self, online_strategy, sample_market_data):
400 | """Test online model update mechanism."""
401 | # Simulate model updates
402 | online_strategy.update_model(sample_market_data, 50)
403 |
404 | # Should not update if frequency not met
405 | assert online_strategy.last_update == 0 # No update yet
406 |
407 | # Force update by meeting frequency requirement
408 | online_strategy.last_update = 40
409 | online_strategy.update_model(sample_market_data, 51)
410 |
411 | # Now should have updated
412 | assert online_strategy.last_update > 40
413 |
414 | def test_online_signal_generation(self, online_strategy, sample_market_data):
415 | """Test online learning signal generation."""
416 | entry_signals, exit_signals = online_strategy.generate_signals(
417 | sample_market_data
418 | )
419 |
420 | # Basic validation
421 | assert len(entry_signals) == len(sample_market_data)
422 | assert len(exit_signals) == len(sample_market_data)
423 | assert entry_signals.dtype == bool
424 | assert exit_signals.dtype == bool
425 |
426 | # Should eventually train the model
427 | assert online_strategy.is_trained
428 |
429 | def test_model_info_retrieval(self, online_strategy, sample_market_data):
430 | """Test model information retrieval."""
431 | # Initially untrained
432 | info = online_strategy.get_model_info()
433 |
434 | assert info["model_type"] == "sgd"
435 | assert not info["is_trained"]
436 | assert info["feature_window"] == 20
437 | assert info["update_frequency"] == 10
438 | assert info["confidence_threshold"] == 0.6
439 |
440 | # Train the model
441 | online_strategy.generate_signals(sample_market_data)
442 |
443 | # Get info after training
444 | trained_info = online_strategy.get_model_info()
445 | assert trained_info["is_trained"]
446 |
447 | # Should have coefficients if model supports them
448 | if (
449 | hasattr(online_strategy.model, "coef_")
450 | and online_strategy.model.coef_ is not None
451 | ):
452 | assert "model_coefficients" in trained_info
453 |
454 | def test_confidence_threshold_filtering(self, online_strategy, sample_market_data):
455 | """Test that signals are filtered by confidence threshold."""
456 | # Use very high confidence threshold
457 | high_confidence_strategy = OnlineLearningStrategy(confidence_threshold=0.95)
458 |
459 | entry_signals, exit_signals = high_confidence_strategy.generate_signals(
460 | sample_market_data
461 | )
462 |
463 | # With high confidence threshold, should have fewer signals
464 | assert entry_signals.sum() <= 5 # Very few signals expected
465 | assert exit_signals.sum() <= 5
466 |
467 | def test_online_strategy_error_handling(self, online_strategy):
468 | """Test error handling in online learning strategy."""
469 | # Test with empty data
470 | empty_data = pd.DataFrame(columns=["close", "volume"])
471 |
472 | entry_signals, exit_signals = online_strategy.generate_signals(empty_data)
473 |
474 | assert len(entry_signals) == 0
475 | assert len(exit_signals) == 0
476 |
477 |
478 | class TestHybridAdaptiveStrategy:
479 | """Test suite for HybridAdaptiveStrategy class."""
480 |
481 | @pytest.fixture
482 | def mock_base_strategy(self):
483 | """Create a mock base strategy."""
484 | return MockBaseStrategy({"window": 20, "threshold": 0.02})
485 |
486 | @pytest.fixture
487 | def sample_market_data(self):
488 | """Create sample market data for testing."""
489 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
490 |
491 | # Generate realistic price data with trends
492 | returns = np.random.normal(0.0005, 0.02, len(dates))
493 | # Add some trending periods
494 | returns[100:150] += 0.003 # Bull period
495 | returns[200:250] -= 0.002 # Bear period
496 |
497 | prices = 100 * np.cumprod(1 + returns)
498 | volumes = np.random.randint(1000000, 5000000, len(dates))
499 |
500 | data = pd.DataFrame(
501 | {
502 | "open": prices * np.random.uniform(0.98, 1.02, len(dates)),
503 | "high": prices * np.random.uniform(1.00, 1.05, len(dates)),
504 | "low": prices * np.random.uniform(0.95, 1.00, len(dates)),
505 | "close": prices,
506 | "volume": volumes,
507 | },
508 | index=dates,
509 | )
510 |
511 | # Ensure high >= close, open and low <= close, open
512 | data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
513 | data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
514 |
515 | return data
516 |
517 | @pytest.fixture
518 | def hybrid_strategy(self, mock_base_strategy):
519 | """Create a hybrid adaptive strategy."""
520 | return HybridAdaptiveStrategy(
521 | base_strategy=mock_base_strategy,
522 | online_learning_weight=0.3,
523 | adaptation_method="gradient",
524 | learning_rate=0.02,
525 | )
526 |
527 | def test_hybrid_strategy_initialization(self, hybrid_strategy, mock_base_strategy):
528 | """Test hybrid strategy initialization."""
529 | assert hybrid_strategy.base_strategy == mock_base_strategy
530 | assert hybrid_strategy.online_learning_weight == 0.3
531 | assert hybrid_strategy.online_strategy is not None
532 | assert isinstance(hybrid_strategy.online_strategy, OnlineLearningStrategy)
533 |
534 | # Test name and description
535 | assert "HybridAdaptive" in hybrid_strategy.name
536 | assert "MockStrategy" in hybrid_strategy.name
537 | assert "hybrid" in hybrid_strategy.description.lower()
538 |
539 | def test_hybrid_signal_generation(self, hybrid_strategy, sample_market_data):
540 | """Test hybrid signal generation combining both approaches."""
541 | entry_signals, exit_signals = hybrid_strategy.generate_signals(
542 | sample_market_data
543 | )
544 |
545 | # Basic validation
546 | assert len(entry_signals) == len(sample_market_data)
547 | assert len(exit_signals) == len(sample_market_data)
548 | assert entry_signals.dtype == bool
549 | assert exit_signals.dtype == bool
550 |
551 | # Should have some signals (combination of both strategies)
552 | total_signals = entry_signals.sum() + exit_signals.sum()
553 | assert total_signals > 0
554 |
555 | def test_signal_weighting_mechanism(self, hybrid_strategy, sample_market_data):
556 | """Test that signal weighting works correctly."""
557 | # Set base strategy to generate specific pattern
558 | hybrid_strategy.base_strategy._signal_pattern = "bullish"
559 |
560 | # Generate signals
561 | entry_signals, exit_signals = hybrid_strategy.generate_signals(
562 | sample_market_data
563 | )
564 |
565 | # With bullish base strategy, should have more entry signals
566 | assert entry_signals.sum() >= exit_signals.sum()
567 |
568 | def test_hybrid_info_retrieval(self, hybrid_strategy, sample_market_data):
569 | """Test hybrid strategy information retrieval."""
570 | # Generate some signals first
571 | hybrid_strategy.generate_signals(sample_market_data)
572 |
573 | hybrid_info = hybrid_strategy.get_hybrid_info()
574 |
575 | assert "adaptation_history" in hybrid_info
576 | assert "online_learning_info" in hybrid_info
577 | assert "online_learning_weight" in hybrid_info
578 | assert "base_weight" in hybrid_info
579 |
580 | assert hybrid_info["online_learning_weight"] == 0.3
581 | assert hybrid_info["base_weight"] == 0.7
582 |
583 | # Verify nested information structure
584 | assert "model_type" in hybrid_info["online_learning_info"]
585 | assert "performance_history" in hybrid_info["adaptation_history"]
586 |
587 | def test_different_weight_configurations(
588 | self, mock_base_strategy, sample_market_data
589 | ):
590 | """Test hybrid strategy with different weight configurations."""
591 | # Test heavy online learning weighting
592 | heavy_online = HybridAdaptiveStrategy(
593 | base_strategy=mock_base_strategy, online_learning_weight=0.8
594 | )
595 |
596 | entry1, exit1 = heavy_online.generate_signals(sample_market_data)
597 |
598 | # Test heavy base strategy weighting
599 | heavy_base = HybridAdaptiveStrategy(
600 | base_strategy=mock_base_strategy, online_learning_weight=0.2
601 | )
602 |
603 | entry2, exit2 = heavy_base.generate_signals(sample_market_data)
604 |
605 | # Both should generate valid signals
606 | assert len(entry1) == len(entry2) == len(sample_market_data)
607 | assert len(exit1) == len(exit2) == len(sample_market_data)
608 |
609 | # Different weights should potentially produce different signals
610 | # (though this is probabilistic and may not always be true)
611 | signal_diff1 = (entry1 != entry2).sum() + (exit1 != exit2).sum()
612 | assert signal_diff1 >= 0 # Allow for identical signals in edge cases
613 |
614 |
615 | class TestMLStrategiesPerformance:
616 | """Performance and benchmark tests for ML strategies."""
617 |
618 | @pytest.fixture
619 | def sample_market_data(self):
620 | """Create sample market data for testing."""
621 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
622 |
623 | # Generate realistic price data with trends
624 | returns = np.random.normal(0.0005, 0.02, len(dates))
625 | # Add some trending periods
626 | returns[100:150] += 0.003 # Bull period
627 | returns[200:250] -= 0.002 # Bear period
628 |
629 | prices = 100 * np.cumprod(1 + returns)
630 | volumes = np.random.randint(1000000, 5000000, len(dates))
631 |
632 | data = pd.DataFrame(
633 | {
634 | "open": prices * np.random.uniform(0.98, 1.02, len(dates)),
635 | "high": prices * np.random.uniform(1.00, 1.05, len(dates)),
636 | "low": prices * np.random.uniform(0.95, 1.00, len(dates)),
637 | "close": prices,
638 | "volume": volumes,
639 | },
640 | index=dates,
641 | )
642 |
643 | # Ensure high >= close, open and low <= close, open
644 | data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
645 | data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
646 |
647 | return data
648 |
649 | def test_strategy_computational_efficiency(
650 | self, sample_market_data, benchmark_timer
651 | ):
652 | """Test computational efficiency of ML strategies."""
653 | strategies = [
654 | AdaptiveStrategy(MockBaseStrategy(), adaptation_method="gradient"),
655 | OnlineLearningStrategy(model_type="sgd"),
656 | HybridAdaptiveStrategy(MockBaseStrategy()),
657 | ]
658 |
659 | for strategy in strategies:
660 | with benchmark_timer() as timer:
661 | entry_signals, exit_signals = strategy.generate_signals(
662 | sample_market_data
663 | )
664 |
665 | # Should complete within reasonable time
666 | assert timer.elapsed < 10.0 # < 10 seconds
667 | assert len(entry_signals) == len(sample_market_data)
668 | assert len(exit_signals) == len(sample_market_data)
669 |
670 | def test_memory_usage_scalability(self, benchmark_timer):
671 | """Test memory usage with large datasets."""
672 | import os
673 |
674 | import psutil
675 |
676 | process = psutil.Process(os.getpid())
677 | initial_memory = process.memory_info().rss
678 |
679 | # Create large dataset
680 | dates = pd.date_range(start="2020-01-01", end="2023-12-31", freq="D") # 4 years
681 | large_data = pd.DataFrame(
682 | {
683 | "open": 100 + np.random.normal(0, 10, len(dates)),
684 | "high": 105 + np.random.normal(0, 10, len(dates)),
685 | "low": 95 + np.random.normal(0, 10, len(dates)),
686 | "close": 100 + np.random.normal(0, 10, len(dates)),
687 | "volume": np.random.randint(1000000, 10000000, len(dates)),
688 | },
689 | index=dates,
690 | )
691 |
692 | # Test online learning strategy (most memory intensive)
693 | strategy = OnlineLearningStrategy()
694 | strategy.generate_signals(large_data)
695 |
696 | final_memory = process.memory_info().rss
697 | memory_growth = (final_memory - initial_memory) / 1024 / 1024 # MB
698 |
699 | # Memory growth should be reasonable (< 200MB for 4 years of data)
700 | assert memory_growth < 200
701 |
702 | def test_strategy_adaptation_effectiveness(self, sample_market_data):
703 | """Test that adaptive strategies actually improve over time."""
704 | base_strategy = MockBaseStrategy()
705 | adaptive_strategy = AdaptiveStrategy(
706 | base_strategy=base_strategy, adaptation_method="gradient"
707 | )
708 |
709 | # Generate initial signals and measure performance
710 | initial_entry_signals, initial_exit_signals = (
711 | adaptive_strategy.generate_signals(sample_market_data)
712 | )
713 | assert len(initial_entry_signals) == len(sample_market_data)
714 | assert len(initial_exit_signals) == len(sample_market_data)
715 | assert len(adaptive_strategy.performance_history) > 0
716 |
717 | # Reset and generate again (should have different adaptations)
718 | adaptive_strategy.reset_to_original()
719 | post_reset_entry, post_reset_exit = adaptive_strategy.generate_signals(
720 | sample_market_data
721 | )
722 | assert len(post_reset_entry) == len(sample_market_data)
723 | assert len(post_reset_exit) == len(sample_market_data)
724 |
725 | # Should have recorded performance metrics again
726 | assert len(adaptive_strategy.performance_history) > 0
727 | assert len(adaptive_strategy.parameter_history) > 0
728 |
729 | def test_concurrent_strategy_execution(self, sample_market_data):
730 | """Test concurrent execution of multiple ML strategies."""
731 | import queue
732 | import threading
733 |
734 | results_queue = queue.Queue()
735 | error_queue = queue.Queue()
736 |
737 | def run_strategy(strategy_id, strategy_class):
738 | try:
739 | if strategy_class == AdaptiveStrategy:
740 | strategy = AdaptiveStrategy(MockBaseStrategy())
741 | elif strategy_class == OnlineLearningStrategy:
742 | strategy = OnlineLearningStrategy()
743 | else:
744 | strategy = HybridAdaptiveStrategy(MockBaseStrategy())
745 |
746 | entry_signals, exit_signals = strategy.generate_signals(
747 | sample_market_data
748 | )
749 | results_queue.put((strategy_id, len(entry_signals), len(exit_signals)))
750 | except Exception as e:
751 | error_queue.put(f"Strategy {strategy_id}: {e}")
752 |
753 | # Run multiple strategies concurrently
754 | threads = []
755 | strategy_classes = [
756 | AdaptiveStrategy,
757 | OnlineLearningStrategy,
758 | HybridAdaptiveStrategy,
759 | ]
760 |
761 | for i, strategy_class in enumerate(strategy_classes):
762 | thread = threading.Thread(target=run_strategy, args=(i, strategy_class))
763 | threads.append(thread)
764 | thread.start()
765 |
766 | # Wait for completion
767 | for thread in threads:
768 | thread.join(timeout=30) # 30 second timeout
769 |
770 | # Check results
771 | assert error_queue.empty(), f"Errors: {list(error_queue.queue)}"
772 | assert results_queue.qsize() == 3
773 |
774 | # All should have processed the full dataset
775 | while not results_queue.empty():
776 | strategy_id, entry_len, exit_len = results_queue.get()
777 | assert entry_len == len(sample_market_data)
778 | assert exit_len == len(sample_market_data)
779 |
780 |
781 | class TestMLStrategiesErrorHandling:
782 | """Error handling and edge case tests for ML strategies."""
783 |
784 | @pytest.fixture
785 | def sample_market_data(self):
786 | """Create sample market data for testing."""
787 | dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
788 |
789 | # Generate realistic price data with trends
790 | returns = np.random.normal(0.0005, 0.02, len(dates))
791 | # Add some trending periods
792 | returns[100:150] += 0.003 # Bull period
793 | returns[200:250] -= 0.002 # Bear period
794 |
795 | prices = 100 * np.cumprod(1 + returns)
796 | volumes = np.random.randint(1000000, 5000000, len(dates))
797 |
798 | data = pd.DataFrame(
799 | {
800 | "open": prices * np.random.uniform(0.98, 1.02, len(dates)),
801 | "high": prices * np.random.uniform(1.00, 1.05, len(dates)),
802 | "low": prices * np.random.uniform(0.95, 1.00, len(dates)),
803 | "close": prices,
804 | "volume": volumes,
805 | },
806 | index=dates,
807 | )
808 |
809 | # Ensure high >= close, open and low <= close, open
810 | data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
811 | data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
812 |
813 | return data
814 |
815 | @pytest.fixture
816 | def mock_base_strategy(self):
817 | """Create a mock base strategy."""
818 | return MockBaseStrategy({"window": 20, "threshold": 0.02})
819 |
820 | def test_adaptive_strategy_with_failing_base(self, sample_market_data):
821 | """Test adaptive strategy when base strategy fails."""
822 | # Create a base strategy that fails
823 | failing_strategy = Mock(spec=Strategy)
824 | failing_strategy.parameters = {"window": 20}
825 | failing_strategy.generate_signals.side_effect = Exception(
826 | "Base strategy failed"
827 | )
828 |
829 | adaptive_strategy = AdaptiveStrategy(failing_strategy)
830 |
831 | # Should handle the error gracefully
832 | entry_signals, exit_signals = adaptive_strategy.generate_signals(
833 | sample_market_data
834 | )
835 |
836 | assert isinstance(entry_signals, pd.Series)
837 | assert isinstance(exit_signals, pd.Series)
838 | assert len(entry_signals) == len(sample_market_data)
839 |
840 | def test_online_learning_with_insufficient_data(self):
841 | """Test online learning strategy with insufficient training data."""
842 | # Very small dataset
843 | small_data = pd.DataFrame({"close": [100, 101], "volume": [1000, 1100]})
844 |
845 | strategy = OnlineLearningStrategy(feature_window=20) # Window larger than data
846 |
847 | entry_signals, exit_signals = strategy.generate_signals(small_data)
848 |
849 | # Should handle gracefully
850 | assert len(entry_signals) == len(small_data)
851 | assert len(exit_signals) == len(small_data)
852 | assert not strategy.is_trained # Insufficient data to train
853 |
854 | def test_model_prediction_failure_handling(self, sample_market_data):
855 | """Test handling of model prediction failures."""
856 | strategy = OnlineLearningStrategy()
857 |
858 | # Simulate model failure after training
859 | with patch.object(
860 | strategy.model, "predict", side_effect=Exception("Prediction failed")
861 | ):
862 | entry_signals, exit_signals = strategy.generate_signals(sample_market_data)
863 |
864 | # Should still return valid series
865 | assert isinstance(entry_signals, pd.Series)
866 | assert isinstance(exit_signals, pd.Series)
867 | assert len(entry_signals) == len(sample_market_data)
868 |
869 | def test_parameter_boundary_enforcement(self, mock_base_strategy):
870 | """Test that parameter adaptations respect boundaries."""
871 | adaptive_strategy = AdaptiveStrategy(mock_base_strategy)
872 |
873 | # Set extreme gradient that should be bounded
874 | large_gradient = 100.0
875 |
876 | # Store original parameter values
877 | original_window = mock_base_strategy.parameters["window"]
878 |
879 | # Apply extreme gradient
880 | adaptive_strategy.adapt_parameters_gradient(large_gradient)
881 |
882 | # Parameter should be bounded
883 | new_window = mock_base_strategy.parameters["window"]
884 | assert new_window != original_window
885 | adaptable_params = adaptive_strategy.get_adaptable_parameters()
886 |
887 | if "window" in adaptable_params:
888 | assert new_window >= adaptable_params["window"]["min"]
889 | assert new_window <= adaptable_params["window"]["max"]
890 |
891 | def test_strategy_state_consistency(self, mock_base_strategy, sample_market_data):
892 | """Test that strategy state remains consistent after errors."""
893 | adaptive_strategy = AdaptiveStrategy(mock_base_strategy)
894 |
895 | # Generate initial signals successfully
896 | initial_signals = adaptive_strategy.generate_signals(sample_market_data)
897 | assert isinstance(initial_signals, tuple)
898 | assert len(initial_signals) == 2
899 | initial_state = {
900 | "performance_history": len(adaptive_strategy.performance_history),
901 | "parameter_history": len(adaptive_strategy.parameter_history),
902 | "parameters": adaptive_strategy.base_strategy.parameters.copy(),
903 | }
904 |
905 | # Simulate error during signal generation
906 | with patch.object(
907 | mock_base_strategy,
908 | "generate_signals",
909 | side_effect=Exception("Signal generation failed"),
910 | ):
911 | error_signals = adaptive_strategy.generate_signals(sample_market_data)
912 |
913 | # State should remain consistent or be properly handled
914 | assert isinstance(error_signals, tuple)
915 | assert len(error_signals) == 2
916 | assert isinstance(error_signals[0], pd.Series)
917 | assert isinstance(error_signals[1], pd.Series)
918 | assert (
919 | len(adaptive_strategy.performance_history)
920 | == initial_state["performance_history"]
921 | )
922 | assert (
923 | len(adaptive_strategy.parameter_history)
924 | == initial_state["parameter_history"]
925 | )
926 | assert adaptive_strategy.base_strategy.parameters == initial_state["parameters"]
927 |
928 |
929 | if __name__ == "__main__":
930 | # Run tests with detailed output
931 | pytest.main([__file__, "-v", "--tb=short", "--asyncio-mode=auto"])
932 |
```
--------------------------------------------------------------------------------
/maverick_mcp/data/cache.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Cache utilities for Maverick-MCP.
3 | Implements Redis-based caching with fallback to in-memory caching.
4 | Now uses centralized Redis connection pooling for improved performance.
5 | Includes timezone handling, smart invalidation, and performance monitoring.
6 | """
7 |
8 | import asyncio
9 | import hashlib
10 | import json
11 | import logging
12 | import os
13 | import time
14 | import zlib
15 | from collections import defaultdict
16 | from collections.abc import Sequence
17 | from datetime import UTC, date, datetime
18 | from typing import Any, cast
19 |
20 | import msgpack
21 | import pandas as pd
22 | import redis
23 | from dotenv import load_dotenv
24 |
25 | from maverick_mcp.config.settings import get_settings
26 |
27 | # Import the new performance module for Redis connection pooling
28 |
29 | # Load environment variables
30 | load_dotenv()
31 |
32 | # Setup logging
33 | logging.basicConfig(
34 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
35 | )
36 | logger = logging.getLogger("maverick_mcp.cache")
37 |
38 | settings = get_settings()
39 |
40 | # Redis configuration (kept for backwards compatibility)
41 | REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
42 | REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
43 | REDIS_DB = int(os.getenv("REDIS_DB", "0"))
44 | REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", "")
45 | REDIS_SSL = os.getenv("REDIS_SSL", "False").lower() == "true"
46 |
47 | # Cache configuration
48 | CACHE_ENABLED = os.getenv("CACHE_ENABLED", "True").lower() == "true"
49 | CACHE_TTL_SECONDS = settings.performance.cache_ttl_seconds
50 | CACHE_VERSION = os.getenv("CACHE_VERSION", "v1")
51 |
52 | # Cache statistics
53 | CacheStatMap = defaultdict[str, float]
54 | _cache_stats: CacheStatMap = defaultdict(float)
55 | _cache_stats["hits"] = 0.0
56 | _cache_stats["misses"] = 0.0
57 | _cache_stats["sets"] = 0.0
58 | _cache_stats["errors"] = 0.0
59 | _cache_stats["serialization_time"] = 0.0
60 | _cache_stats["deserialization_time"] = 0.0
61 |
62 | # In-memory cache as fallback with memory management
63 | _memory_cache: dict[str, dict[str, Any]] = {}
64 | _memory_cache_max_size = 1000 # Will be updated to use config
65 |
66 | # Cache metadata for version tracking
67 | _cache_metadata: dict[str, dict[str, Any]] = {}
68 |
69 | # Memory monitoring
70 | _cache_memory_stats: dict[str, float] = {
71 | "memory_cache_bytes": 0.0,
72 | "redis_connection_count": 0.0,
73 | "large_object_count": 0.0,
74 | "compression_savings_bytes": 0.0,
75 | }
76 |
77 |
78 | def _dataframe_to_payload(df: pd.DataFrame) -> dict[str, Any]:
79 | """Convert a DataFrame to a JSON-serializable payload."""
80 |
81 | normalized = ensure_timezone_naive(df)
82 | json_payload = cast(
83 | str,
84 | normalized.to_json(orient="split", date_format="iso", default_handler=str),
85 | )
86 | payload = json.loads(json_payload)
87 | payload["index_type"] = (
88 | "datetime" if isinstance(normalized.index, pd.DatetimeIndex) else "other"
89 | )
90 | payload["index_name"] = normalized.index.name
91 | return payload
92 |
93 |
94 | def _payload_to_dataframe(payload: dict[str, Any]) -> pd.DataFrame:
95 | """Reconstruct a DataFrame from a serialized payload."""
96 |
97 | data = payload.get("data", {})
98 | columns = data.get("columns", [])
99 | frame = pd.DataFrame(data.get("data", []), columns=columns)
100 | index_values = data.get("index", [])
101 |
102 | if payload.get("index_type") == "datetime":
103 | index_values = pd.to_datetime(index_values)
104 | index = normalize_timezone(pd.DatetimeIndex(index_values))
105 | else:
106 | index = index_values
107 |
108 | frame.index = index
109 | frame.index.name = payload.get("index_name")
110 | return ensure_timezone_naive(frame)
111 |
112 |
113 | def _json_default(value: Any) -> Any:
114 | """JSON serializer for unsupported types."""
115 |
116 | if isinstance(value, datetime | date):
117 | return value.isoformat()
118 | if isinstance(value, pd.Timestamp):
119 | return value.isoformat()
120 | if isinstance(value, pd.Series):
121 | return value.tolist()
122 | if isinstance(value, set):
123 | return list(value)
124 | raise TypeError(f"Unsupported type {type(value)!r} for cache serialization")
125 |
126 |
127 | def _decode_json_payload(raw_data: str) -> Any:
128 | """Decode JSON payloads with DataFrame support."""
129 |
130 | payload = json.loads(raw_data)
131 | if isinstance(payload, dict) and payload.get("__cache_type__") == "dataframe":
132 | return _payload_to_dataframe(payload)
133 | if isinstance(payload, dict) and payload.get("__cache_type__") == "dict":
134 | result: dict[str, Any] = {}
135 | for key, value in payload.get("data", {}).items():
136 | if isinstance(value, dict) and value.get("__cache_type__") == "dataframe":
137 | result[key] = _payload_to_dataframe(value)
138 | else:
139 | result[key] = value
140 | return result
141 | return payload
142 |
143 |
144 | def normalize_timezone(index: pd.Index | Sequence[Any]) -> pd.DatetimeIndex:
145 | """Return a timezone-naive :class:`~pandas.DatetimeIndex` in UTC."""
146 |
147 | dt_index = index if isinstance(index, pd.DatetimeIndex) else pd.DatetimeIndex(index)
148 |
149 | if dt_index.tz is not None:
150 | dt_index = dt_index.tz_convert("UTC").tz_localize(None)
151 |
152 | return dt_index
153 |
154 |
155 | def ensure_timezone_naive(df: pd.DataFrame) -> pd.DataFrame:
156 | """Ensure DataFrame has timezone-naive datetime index.
157 |
158 | Args:
159 | df: DataFrame with potentially timezone-aware index
160 |
161 | Returns:
162 | DataFrame with timezone-naive index
163 | """
164 | if isinstance(df.index, pd.DatetimeIndex):
165 | df = df.copy()
166 | df.index = normalize_timezone(df.index)
167 | return df
168 |
169 |
170 | def get_cache_stats() -> dict[str, Any]:
171 | """Get current cache statistics with memory information.
172 |
173 | Returns:
174 | Dictionary containing cache performance metrics
175 | """
176 | stats: dict[str, float | int] = cast(dict[str, float | int], dict(_cache_stats))
177 |
178 | # Calculate hit rate
179 | total_requests = stats["hits"] + stats["misses"]
180 | hit_rate = (stats["hits"] / total_requests * 100) if total_requests > 0 else 0
181 |
182 | stats["hit_rate_percent"] = round(hit_rate, 2)
183 | stats["total_requests"] = total_requests
184 |
185 | # Memory cache stats
186 | stats["memory_cache_size"] = len(_memory_cache)
187 | stats["memory_cache_max_size"] = _memory_cache_max_size
188 |
189 | # Add memory statistics
190 | stats.update(_cache_memory_stats)
191 |
192 | # Calculate memory cache size in bytes
193 | memory_size_bytes = 0
194 | for entry in _memory_cache.values():
195 | if "data" in entry:
196 | try:
197 | if hasattr(entry["data"], "__sizeof__"):
198 | memory_size_bytes += entry["data"].__sizeof__()
199 | elif isinstance(entry["data"], str | bytes):
200 | memory_size_bytes += len(entry["data"])
201 | elif isinstance(entry["data"], pd.DataFrame):
202 | memory_size_bytes += entry["data"].memory_usage(deep=True).sum()
203 | except Exception:
204 | pass # Skip if size calculation fails
205 |
206 | stats["memory_cache_bytes"] = memory_size_bytes
207 | stats["memory_cache_mb"] = memory_size_bytes / (1024**2)
208 |
209 | return stats
210 |
211 |
212 | def reset_cache_stats() -> None:
213 | """Reset cache statistics."""
214 | global _cache_stats
215 | _cache_stats.clear()
216 | _cache_stats.update(
217 | {
218 | "hits": 0.0,
219 | "misses": 0.0,
220 | "sets": 0.0,
221 | "errors": 0.0,
222 | "serialization_time": 0.0,
223 | "deserialization_time": 0.0,
224 | }
225 | )
226 |
227 |
228 | def generate_cache_key(base_key: str, **kwargs) -> str:
229 | """Generate versioned cache key with consistent hashing.
230 |
231 | Args:
232 | base_key: Base cache key
233 | **kwargs: Additional parameters to include in key
234 |
235 | Returns:
236 | Versioned and hashed cache key
237 | """
238 | # Include cache version and sorted parameters
239 | key_parts = [CACHE_VERSION, base_key]
240 |
241 | # Sort kwargs for consistent key generation
242 | if kwargs:
243 | sorted_params = sorted(kwargs.items())
244 | param_str = ":".join(f"{k}={v}" for k, v in sorted_params)
245 | key_parts.append(param_str)
246 |
247 | full_key = ":".join(str(part) for part in key_parts)
248 |
249 | # Hash long keys to prevent Redis key length limits
250 | if len(full_key) > 250:
251 | key_hash = hashlib.md5(full_key.encode()).hexdigest()
252 | return f"{CACHE_VERSION}:hashed:{key_hash}"
253 |
254 | return full_key
255 |
256 |
257 | def _cleanup_expired_memory_cache():
258 | """Clean up expired entries from memory cache and enforce size limit with memory tracking."""
259 | current_time = time.time()
260 | bytes_freed = 0
261 |
262 | # Remove expired entries
263 | expired_keys = [
264 | k
265 | for k, v in _memory_cache.items()
266 | if "expiry" in v and v["expiry"] < current_time
267 | ]
268 | for k in expired_keys:
269 | entry = _memory_cache[k]
270 | if "data" in entry and isinstance(entry["data"], pd.DataFrame):
271 | bytes_freed += entry["data"].memory_usage(deep=True).sum()
272 | del _memory_cache[k]
273 |
274 | # Calculate current memory usage
275 | current_memory_bytes = 0
276 | for entry in _memory_cache.values():
277 | if "data" in entry and isinstance(entry["data"], pd.DataFrame):
278 | current_memory_bytes += entry["data"].memory_usage(deep=True).sum()
279 |
280 | # Enforce memory-based size limit (100MB default)
281 | memory_limit_bytes = 100 * 1024 * 1024 # 100MB
282 |
283 | # Enforce size limit - remove oldest entries if over limit
284 | if (
285 | len(_memory_cache) > _memory_cache_max_size
286 | or current_memory_bytes > memory_limit_bytes
287 | ):
288 | # Sort by expiry time (oldest first)
289 | sorted_items = sorted(
290 | _memory_cache.items(), key=lambda x: x[1].get("expiry", float("inf"))
291 | )
292 |
293 | # Calculate how many to remove
294 | num_to_remove = max(len(_memory_cache) - _memory_cache_max_size, 0)
295 |
296 | # Remove by memory if over memory limit
297 | if current_memory_bytes > memory_limit_bytes:
298 | removed_memory = 0
299 | for k, v in sorted_items:
300 | if "data" in v and isinstance(v["data"], pd.DataFrame):
301 | entry_size = v["data"].memory_usage(deep=True).sum()
302 | removed_memory += entry_size
303 | bytes_freed += entry_size
304 | del _memory_cache[k]
305 | num_to_remove = max(num_to_remove, 1)
306 |
307 | if removed_memory >= (current_memory_bytes - memory_limit_bytes):
308 | break
309 | else:
310 | # Remove by count
311 | for k, v in sorted_items[:num_to_remove]:
312 | if "data" in v and isinstance(v["data"], pd.DataFrame):
313 | bytes_freed += v["data"].memory_usage(deep=True).sum()
314 | del _memory_cache[k]
315 |
316 | if num_to_remove > 0:
317 | logger.debug(
318 | f"Removed {num_to_remove} entries from memory cache "
319 | f"(freed {bytes_freed / (1024**2):.2f}MB)"
320 | )
321 |
322 | # Update memory stats
323 | _cache_memory_stats["memory_cache_bytes"] = current_memory_bytes - bytes_freed
324 |
325 |
326 | # Global Redis connection pool - created once and reused
327 | _redis_pool: redis.ConnectionPool | None = None
328 |
329 |
330 | def _get_or_create_redis_pool() -> redis.ConnectionPool | None:
331 | """Create or return existing Redis connection pool."""
332 | global _redis_pool
333 |
334 | if _redis_pool is not None:
335 | return _redis_pool
336 |
337 | try:
338 | # Build connection pool parameters
339 | pool_params = {
340 | "host": REDIS_HOST,
341 | "port": REDIS_PORT,
342 | "db": REDIS_DB,
343 | "max_connections": settings.db.redis_max_connections,
344 | "retry_on_timeout": settings.db.redis_retry_on_timeout,
345 | "socket_timeout": settings.db.redis_socket_timeout,
346 | "socket_connect_timeout": settings.db.redis_socket_connect_timeout,
347 | "health_check_interval": 30, # Check connection health every 30 seconds
348 | }
349 |
350 | # Only add password if provided
351 | if REDIS_PASSWORD:
352 | pool_params["password"] = REDIS_PASSWORD
353 |
354 | # Only add SSL params if SSL is enabled
355 | if REDIS_SSL:
356 | pool_params["ssl"] = True
357 | pool_params["ssl_check_hostname"] = False
358 |
359 | _redis_pool = redis.ConnectionPool(**pool_params)
360 | logger.debug(
361 | f"Created Redis connection pool with {settings.db.redis_max_connections} max connections"
362 | )
363 | return _redis_pool
364 | except Exception as e:
365 | logger.warning(f"Failed to create Redis connection pool: {e}")
366 | return None
367 |
368 |
369 | def get_redis_client() -> redis.Redis | None:
370 | """
371 | Get a Redis client using the centralized connection pool.
372 |
373 | This function uses a singleton connection pool to avoid pool exhaustion
374 | and provides robust error handling with graceful fallback.
375 | """
376 | if not CACHE_ENABLED:
377 | return None
378 |
379 | try:
380 | # Get or create the connection pool
381 | pool = _get_or_create_redis_pool()
382 | if pool is None:
383 | return None
384 |
385 | # Create client using the shared pool
386 | client = redis.Redis(
387 | connection_pool=pool,
388 | decode_responses=False,
389 | )
390 |
391 | # Test connection with a timeout to avoid hanging
392 | client.ping()
393 | return client # type: ignore[no-any-return]
394 |
395 | except redis.ConnectionError as e:
396 | logger.warning(f"Redis connection failed: {e}. Using in-memory cache.")
397 | return None
398 | except redis.TimeoutError as e:
399 | logger.warning(f"Redis connection timeout: {e}. Using in-memory cache.")
400 | return None
401 | except Exception as e:
402 | # Handle the IndexError: pop from empty list and other unexpected errors
403 | logger.warning(f"Redis client error: {e}. Using in-memory cache.")
404 | # Reset the pool if we encounter unexpected errors
405 | global _redis_pool
406 | _redis_pool = None
407 | return None
408 |
409 |
410 | def _deserialize_cached_data(data: bytes, key: str) -> Any:
411 | """Deserialize cached data with multiple format support and timezone handling."""
412 | start_time = time.time()
413 |
414 | try:
415 | # Try msgpack with zlib compression first (most efficient for DataFrames)
416 | if data[:2] == b"\x78\x9c": # zlib magic bytes
417 | try:
418 | decompressed = zlib.decompress(data)
419 | # Try msgpack first
420 | try:
421 | result = msgpack.loads(decompressed, raw=False)
422 | # Handle DataFrame reconstruction with timezone normalization
423 | if isinstance(result, dict) and result.get("_type") == "dataframe":
424 | df = pd.DataFrame.from_dict(result["data"], orient="index")
425 |
426 | # Restore proper index
427 | if result.get("index_data"):
428 | if result.get("index_type") == "datetime":
429 | df.index = pd.to_datetime(result["index_data"])
430 | df.index = normalize_timezone(df.index)
431 | else:
432 | df.index = result["index_data"]
433 | elif result.get("index_type") == "datetime":
434 | df.index = pd.to_datetime(df.index)
435 | df.index = normalize_timezone(df.index)
436 |
437 | # Restore column order
438 | if result.get("columns"):
439 | df = df[result["columns"]]
440 |
441 | return df
442 | return result
443 | except Exception as e:
444 | logger.debug(f"Msgpack decompressed failed for {key}: {e}")
445 | try:
446 | return _decode_json_payload(decompressed.decode("utf-8"))
447 | except Exception as e2:
448 | logger.debug(f"JSON decompressed failed for {key}: {e2}")
449 | pass
450 | except Exception:
451 | pass
452 |
453 | # Try msgpack uncompressed
454 | try:
455 | result = msgpack.loads(data, raw=False)
456 | if isinstance(result, dict) and result.get("_type") == "dataframe":
457 | df = pd.DataFrame.from_dict(result["data"], orient="index")
458 |
459 | # Restore proper index
460 | if result.get("index_data"):
461 | if result.get("index_type") == "datetime":
462 | df.index = pd.to_datetime(result["index_data"])
463 | df.index = normalize_timezone(df.index)
464 | else:
465 | df.index = result["index_data"]
466 | elif result.get("index_type") == "datetime":
467 | df.index = pd.to_datetime(df.index)
468 | df.index = normalize_timezone(df.index)
469 |
470 | # Restore column order
471 | if result.get("columns"):
472 | df = df[result["columns"]]
473 |
474 | return df
475 | return result
476 | except Exception:
477 | pass
478 |
479 | # Fall back to JSON
480 | try:
481 | decoded = data.decode() if isinstance(data, bytes) else data
482 | return _decode_json_payload(decoded)
483 | except Exception:
484 | pass
485 |
486 | except Exception as e:
487 | _cache_stats["errors"] += 1
488 | logger.warning(f"Failed to deserialize cache data for key {key}: {e}")
489 | return None
490 | finally:
491 | _cache_stats["deserialization_time"] += time.time() - start_time
492 |
493 | _cache_stats["errors"] += 1
494 | logger.warning(f"Failed to deserialize cache data for key {key} - no format worked")
495 | return None
496 |
497 |
498 | def get_from_cache(key: str) -> Any | None:
499 | """
500 | Get data from the cache.
501 |
502 | Args:
503 | key: Cache key
504 |
505 | Returns:
506 | Cached data or None if not found
507 | """
508 | if not CACHE_ENABLED:
509 | return None
510 |
511 | # Try Redis first
512 | redis_client = get_redis_client()
513 | if redis_client:
514 | try:
515 | data = redis_client.get(key)
516 | if data:
517 | _cache_stats["hits"] += 1
518 | logger.debug(f"Cache hit for {key} (Redis)")
519 | result = _deserialize_cached_data(data, key) # type: ignore[arg-type]
520 | return result
521 | except Exception as e:
522 | _cache_stats["errors"] += 1
523 | logger.warning(f"Error reading from Redis cache: {e}")
524 |
525 | # Fall back to in-memory cache
526 | if key in _memory_cache:
527 | entry = _memory_cache[key]
528 | if "expiry" not in entry or entry["expiry"] > time.time():
529 | _cache_stats["hits"] += 1
530 | logger.debug(f"Cache hit for {key} (memory)")
531 | return entry["data"]
532 | else:
533 | # Clean up expired entry
534 | del _memory_cache[key]
535 |
536 | _cache_stats["misses"] += 1
537 | logger.debug(f"Cache miss for {key}")
538 | return None
539 |
540 |
541 | def _serialize_data(data: Any, key: str) -> bytes:
542 | """Serialize data efficiently based on type with optimized formats and memory tracking."""
543 | start_time = time.time()
544 | original_size = 0
545 | compressed_size = 0
546 |
547 | try:
548 | # Special handling for DataFrames - use msgpack with timezone normalization
549 | if isinstance(data, pd.DataFrame):
550 | original_size = data.memory_usage(deep=True).sum()
551 |
552 | # Track large objects
553 | if original_size > 10 * 1024 * 1024: # 10MB threshold
554 | _cache_memory_stats["large_object_count"] += 1
555 | logger.debug(
556 | f"Serializing large DataFrame for {key}: {original_size / (1024**2):.2f}MB"
557 | )
558 |
559 | # Ensure timezone-naive DataFrame
560 | df = ensure_timezone_naive(data)
561 |
562 | # Try msgpack first (most efficient for DataFrames)
563 | try:
564 | # Convert to msgpack-serializable format with proper index handling
565 | df_dict = {
566 | "_type": "dataframe",
567 | "data": df.to_dict("index"),
568 | "index_type": (
569 | "datetime"
570 | if isinstance(df.index, pd.DatetimeIndex)
571 | else "other"
572 | ),
573 | "columns": list(df.columns),
574 | "index_data": [str(idx) for idx in df.index],
575 | }
576 | msgpack_data = cast(bytes, msgpack.packb(df_dict))
577 | compressed = zlib.compress(msgpack_data, level=1)
578 | compressed_size = len(compressed)
579 |
580 | # Track compression savings
581 | if original_size > compressed_size:
582 | _cache_memory_stats["compression_savings_bytes"] += (
583 | original_size - compressed_size
584 | )
585 |
586 | return compressed
587 | except Exception as e:
588 | logger.debug(f"Msgpack DataFrame serialization failed for {key}: {e}")
589 | json_payload = {
590 | "__cache_type__": "dataframe",
591 | "data": _dataframe_to_payload(df),
592 | }
593 | compressed = zlib.compress(
594 | json.dumps(json_payload).encode("utf-8"), level=1
595 | )
596 | compressed_size = len(compressed)
597 |
598 | if original_size > compressed_size:
599 | _cache_memory_stats["compression_savings_bytes"] += (
600 | original_size - compressed_size
601 | )
602 |
603 | return compressed
604 |
605 | # For dictionaries with DataFrames (like backtest results)
606 | if isinstance(data, dict) and any(
607 | isinstance(v, pd.DataFrame) for v in data.values()
608 | ):
609 | # Ensure all DataFrames are timezone-naive
610 | processed_data = {}
611 | for k, v in data.items():
612 | if isinstance(v, pd.DataFrame):
613 | processed_data[k] = ensure_timezone_naive(v)
614 | else:
615 | processed_data[k] = v
616 |
617 | try:
618 | # Try msgpack for mixed dict with DataFrames
619 | serializable_data = {}
620 | for k, v in processed_data.items():
621 | if isinstance(v, pd.DataFrame):
622 | serializable_data[k] = {
623 | "_type": "dataframe",
624 | "data": v.to_dict("index"),
625 | "index_type": (
626 | "datetime"
627 | if isinstance(v.index, pd.DatetimeIndex)
628 | else "other"
629 | ),
630 | }
631 | else:
632 | serializable_data[k] = v
633 |
634 | msgpack_data = cast(bytes, msgpack.packb(serializable_data))
635 | compressed = zlib.compress(msgpack_data, level=1)
636 | return compressed
637 | except Exception:
638 | payload = {
639 | "__cache_type__": "dict",
640 | "data": {
641 | key: (
642 | {
643 | "__cache_type__": "dataframe",
644 | "data": _dataframe_to_payload(value),
645 | }
646 | if isinstance(value, pd.DataFrame)
647 | else value
648 | )
649 | for key, value in processed_data.items()
650 | },
651 | }
652 | compressed = zlib.compress(
653 | json.dumps(payload, default=_json_default).encode("utf-8"),
654 | level=1,
655 | )
656 | return compressed
657 |
658 | # For simple data types, try msgpack first (more efficient than JSON)
659 | if isinstance(data, dict | list | str | int | float | bool | type(None)):
660 | try:
661 | return cast(bytes, msgpack.packb(data))
662 | except Exception:
663 | # Fall back to JSON
664 | return json.dumps(data, default=_json_default).encode("utf-8")
665 |
666 | raise TypeError(f"Unsupported cache data type {type(data)!r} for key {key}")
667 |
668 | except TypeError as exc:
669 | _cache_stats["errors"] += 1
670 | logger.warning(f"Unsupported data type for cache key {key}: {exc}")
671 | raise
672 | except Exception as e:
673 | _cache_stats["errors"] += 1
674 | logger.warning(f"Failed to serialize data for key {key}: {e}")
675 | # Fall back to JSON string representation
676 | try:
677 | return json.dumps(str(data)).encode("utf-8")
678 | except Exception:
679 | return b"null"
680 | finally:
681 | _cache_stats["serialization_time"] += time.time() - start_time
682 |
683 |
684 | def save_to_cache(key: str, data: Any, ttl: int | None = None) -> bool:
685 | """
686 | Save data to the cache.
687 |
688 | Args:
689 | key: Cache key
690 | data: Data to cache
691 | ttl: Time-to-live in seconds (default: CACHE_TTL_SECONDS)
692 |
693 | Returns:
694 | True if saved successfully, False otherwise
695 | """
696 | if not CACHE_ENABLED:
697 | return False
698 |
699 | resolved_ttl = CACHE_TTL_SECONDS if ttl is None else ttl
700 |
701 | # Serialize data efficiently
702 | try:
703 | serialized_data = _serialize_data(data, key)
704 | except TypeError as exc:
705 | logger.warning(f"Skipping cache for {key}: {exc}")
706 | return False
707 |
708 | # Store cache metadata
709 | _cache_metadata[key] = {
710 | "created_at": datetime.now(UTC).isoformat(),
711 | "ttl": resolved_ttl,
712 | "size_bytes": len(serialized_data),
713 | "version": CACHE_VERSION,
714 | }
715 |
716 | success = False
717 |
718 | # Try Redis first
719 | redis_client = get_redis_client()
720 | if redis_client:
721 | try:
722 | redis_client.setex(key, resolved_ttl, serialized_data)
723 | logger.debug(f"Saved to Redis cache: {key}")
724 | success = True
725 | except Exception as e:
726 | _cache_stats["errors"] += 1
727 | logger.warning(f"Error saving to Redis cache: {e}")
728 |
729 | if not success:
730 | # Fall back to in-memory cache
731 | _memory_cache[key] = {"data": data, "expiry": time.time() + resolved_ttl}
732 | logger.debug(f"Saved to memory cache: {key}")
733 | success = True
734 |
735 | # Clean up memory cache if needed
736 | if len(_memory_cache) > _memory_cache_max_size:
737 | _cleanup_expired_memory_cache()
738 |
739 | if success:
740 | _cache_stats["sets"] += 1
741 |
742 | return success
743 |
744 |
745 | def cleanup_redis_pool() -> None:
746 | """Cleanup Redis connection pool."""
747 | global _redis_pool
748 | if _redis_pool:
749 | try:
750 | _redis_pool.disconnect()
751 | logger.debug("Redis connection pool disconnected")
752 | except Exception as e:
753 | logger.warning(f"Error disconnecting Redis pool: {e}")
754 | finally:
755 | _redis_pool = None
756 |
757 |
758 | def clear_cache(pattern: str | None = None) -> int:
759 | """
760 | Clear cache entries matching the pattern.
761 |
762 | Args:
763 | pattern: Pattern to match keys (e.g., "stock:*")
764 | If None, clears all cache
765 |
766 | Returns:
767 | Number of entries cleared
768 | """
769 | count = 0
770 |
771 | # Clear from Redis
772 | redis_client = get_redis_client()
773 | if redis_client:
774 | try:
775 | if pattern:
776 | keys = cast(list[bytes], redis_client.keys(pattern))
777 | if keys:
778 | delete_result = cast(int, redis_client.delete(*keys))
779 | count += delete_result
780 | else:
781 | flush_result = cast(int, redis_client.flushdb())
782 | count += flush_result
783 | logger.info(f"Cleared {count} entries from Redis cache")
784 | except Exception as e:
785 | logger.warning(f"Error clearing Redis cache: {e}")
786 |
787 | # Clear from memory cache
788 | if pattern:
789 | # Simple pattern matching for memory cache (only supports prefix*)
790 | if pattern.endswith("*"):
791 | prefix = pattern[:-1]
792 | memory_keys = [k for k in _memory_cache.keys() if k.startswith(prefix)]
793 | else:
794 | memory_keys = [k for k in _memory_cache.keys() if k == pattern]
795 |
796 | for k in memory_keys:
797 | del _memory_cache[k]
798 | count += len(memory_keys)
799 | else:
800 | count += len(_memory_cache)
801 | _memory_cache.clear()
802 |
803 | logger.info(f"Cleared {count} total cache entries")
804 | return count
805 |
806 |
807 | class CacheManager:
808 | """
809 | Enhanced cache manager with async support and additional methods.
810 |
811 | This manager now integrates with the centralized Redis connection pool
812 | for improved performance and resource management.
813 | """
814 |
815 | def __init__(self):
816 | """Initialize the cache manager."""
817 | self._redis_client = None
818 | self._initialized = False
819 | self._use_performance_redis = True # Flag to use new performance module
820 |
821 | def _ensure_client(self) -> redis.Redis | None:
822 | """Ensure Redis client is initialized with connection pooling."""
823 | if not self._initialized:
824 | # Always use the new robust connection pooling approach
825 | self._redis_client = get_redis_client()
826 | self._initialized = True
827 | return self._redis_client
828 |
829 | async def get(self, key: str) -> Any | None:
830 | """Async wrapper for get_from_cache."""
831 | return await asyncio.get_event_loop().run_in_executor(None, get_from_cache, key)
832 |
833 | async def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
834 | """Async wrapper for save_to_cache."""
835 | return await asyncio.get_event_loop().run_in_executor(
836 | None, save_to_cache, key, value, ttl
837 | )
838 |
839 | async def set_with_ttl(self, key: str, value: str, ttl: int) -> bool:
840 | """Set a value with specific TTL."""
841 | if not CACHE_ENABLED:
842 | return False
843 |
844 | client = self._ensure_client()
845 | if client:
846 | try:
847 | client.setex(key, ttl, value)
848 | return True
849 | except Exception as e:
850 | logger.warning(f"Error setting value with TTL: {e}")
851 |
852 | # Fallback to memory cache
853 | _memory_cache[key] = {"data": value, "expiry": time.time() + ttl}
854 | return True
855 |
856 | async def set_many_with_ttl(self, items: list[tuple[str, str, int]]) -> bool:
857 | """Set multiple values with TTL in a batch."""
858 | if not CACHE_ENABLED:
859 | return False
860 |
861 | client = self._ensure_client()
862 | if client:
863 | try:
864 | pipe = client.pipeline()
865 | for key, value, ttl in items:
866 | pipe.setex(key, ttl, value)
867 | pipe.execute()
868 | return True
869 | except Exception as e:
870 | logger.warning(f"Error in batch set with TTL: {e}")
871 |
872 | # Fallback to memory cache
873 | for key, value, ttl in items:
874 | _memory_cache[key] = {"data": value, "expiry": time.time() + ttl}
875 | return True
876 |
877 | async def get_many(self, keys: list[str]) -> dict[str, Any]:
878 | """Get multiple values at once using pipeline for better performance."""
879 | results: dict[str, Any] = {}
880 |
881 | if not CACHE_ENABLED:
882 | return results
883 |
884 | client = self._ensure_client()
885 | if client:
886 | try:
887 | # Use pipeline for better performance with multiple operations
888 | pipe = client.pipeline(transaction=False)
889 | for key in keys:
890 | pipe.get(key)
891 | values = pipe.execute()
892 |
893 | for key, value in zip(keys, values, strict=False): # type: ignore[arg-type]
894 | if value:
895 | decoded_value: Any
896 | if isinstance(value, bytes):
897 | decoded_value = value.decode()
898 | else:
899 | decoded_value = value
900 |
901 | if isinstance(decoded_value, str):
902 | try:
903 | # Try to decode JSON if it's stored as JSON
904 | results[key] = json.loads(decoded_value)
905 | continue
906 | except json.JSONDecodeError:
907 | pass
908 |
909 | # If not JSON or decoding fails, store as-is
910 | results[key] = decoded_value
911 | except Exception as e:
912 | logger.warning(f"Error in batch get: {e}")
913 |
914 | # Fallback to memory cache for missing keys
915 | for key in keys:
916 | if key not in results and key in _memory_cache:
917 | entry = _memory_cache[key]
918 | if "expiry" not in entry or entry["expiry"] > time.time():
919 | results[key] = entry["data"]
920 |
921 | return results
922 |
923 | async def delete(self, key: str) -> bool:
924 | """Delete a key from cache."""
925 | if not CACHE_ENABLED:
926 | return False
927 |
928 | deleted = False
929 | client = self._ensure_client()
930 | if client:
931 | try:
932 | deleted = bool(client.delete(key))
933 | except Exception as e:
934 | logger.warning(f"Error deleting key: {e}")
935 |
936 | # Also delete from memory cache
937 | if key in _memory_cache:
938 | del _memory_cache[key]
939 | deleted = True
940 |
941 | return deleted
942 |
943 | async def delete_pattern(self, pattern: str) -> int:
944 | """Delete all keys matching a pattern."""
945 | count = 0
946 |
947 | if not CACHE_ENABLED:
948 | return count
949 |
950 | client = self._ensure_client()
951 | if client:
952 | try:
953 | keys = cast(list[bytes], client.keys(pattern))
954 | if keys:
955 | delete_result = cast(int, client.delete(*keys))
956 | count += delete_result
957 | except Exception as e:
958 | logger.warning(f"Error deleting pattern: {e}")
959 |
960 | # Also delete from memory cache
961 | if pattern.endswith("*"):
962 | prefix = pattern[:-1]
963 | memory_keys = [k for k in _memory_cache.keys() if k.startswith(prefix)]
964 | for k in memory_keys:
965 | del _memory_cache[k]
966 | count += 1
967 |
968 | return count
969 |
970 | async def exists(self, key: str) -> bool:
971 | """Check if a key exists."""
972 | if not CACHE_ENABLED:
973 | return False
974 |
975 | client = self._ensure_client()
976 | if client:
977 | try:
978 | return bool(client.exists(key))
979 | except Exception as e:
980 | logger.warning(f"Error checking key existence: {e}")
981 |
982 | # Fallback to memory cache
983 | if key in _memory_cache:
984 | entry = _memory_cache[key]
985 | return "expiry" not in entry or entry["expiry"] > time.time()
986 |
987 | return False
988 |
989 | async def count_keys(self, pattern: str) -> int:
990 | """Count keys matching a pattern."""
991 | if not CACHE_ENABLED:
992 | return 0
993 |
994 | count = 0
995 | client = self._ensure_client()
996 | if client:
997 | try:
998 | cursor = 0
999 | while True:
1000 | cursor, keys = client.scan(cursor, match=pattern, count=1000) # type: ignore[misc]
1001 | count += len(keys)
1002 | if cursor == 0:
1003 | break
1004 | except Exception as e:
1005 | logger.warning(f"Error counting keys: {e}")
1006 |
1007 | # Add memory cache count
1008 | if pattern.endswith("*"):
1009 | prefix = pattern[:-1]
1010 | count += sum(1 for k in _memory_cache.keys() if k.startswith(prefix))
1011 |
1012 | return count
1013 |
1014 | async def batch_save(self, items: list[tuple[str, Any, int | None]]) -> int:
1015 | """
1016 | Save multiple items to cache using pipeline for better performance.
1017 |
1018 | Args:
1019 | items: List of tuples (key, data, ttl)
1020 |
1021 | Returns:
1022 | Number of items successfully saved
1023 | """
1024 | if not CACHE_ENABLED:
1025 | return 0
1026 |
1027 | saved_count = 0
1028 | client = self._ensure_client()
1029 |
1030 | if client:
1031 | try:
1032 | pipe = client.pipeline(transaction=False)
1033 |
1034 | for key, data, ttl in items:
1035 | if ttl is None:
1036 | ttl = CACHE_TTL_SECONDS
1037 |
1038 | # Convert data to JSON
1039 | json_data = json.dumps(data)
1040 | pipe.setex(key, ttl, json_data)
1041 |
1042 | results = pipe.execute()
1043 | saved_count = sum(1 for r in results if r)
1044 | logger.debug(f"Batch saved {saved_count} items to Redis cache")
1045 | except Exception as e:
1046 | logger.warning(f"Error in batch save to Redis: {e}")
1047 |
1048 | # Fallback to memory cache for failed items
1049 | if saved_count < len(items):
1050 | for key, data, ttl in items:
1051 | if ttl is None:
1052 | ttl = CACHE_TTL_SECONDS
1053 | _memory_cache[key] = {"data": data, "expiry": time.time() + ttl}
1054 | saved_count += 1
1055 |
1056 | return saved_count
1057 |
1058 | async def batch_delete(self, keys: list[str]) -> int:
1059 | """
1060 | Delete multiple keys from cache using pipeline for better performance.
1061 |
1062 | Args:
1063 | keys: List of keys to delete
1064 |
1065 | Returns:
1066 | Number of keys deleted
1067 | """
1068 | if not CACHE_ENABLED:
1069 | return 0
1070 |
1071 | deleted_count = 0
1072 | client = self._ensure_client()
1073 |
1074 | if client and keys:
1075 | try:
1076 | # Use single delete command for multiple keys
1077 | deleted_result = client.delete(*keys)
1078 | deleted_count = cast(int, deleted_result)
1079 | logger.debug(f"Batch deleted {deleted_count} keys from Redis cache")
1080 | except Exception as e:
1081 | logger.warning(f"Error in batch delete from Redis: {e}")
1082 |
1083 | # Also delete from memory cache
1084 | for key in keys:
1085 | if key in _memory_cache:
1086 | del _memory_cache[key]
1087 | deleted_count += 1
1088 |
1089 | return deleted_count
1090 |
1091 | async def batch_exists(self, keys: list[str]) -> dict[str, bool]:
1092 | """
1093 | Check existence of multiple keys using pipeline for better performance.
1094 |
1095 | Args:
1096 | keys: List of keys to check
1097 |
1098 | Returns:
1099 | Dictionary mapping key to existence boolean
1100 | """
1101 | results: dict[str, bool] = {}
1102 |
1103 | if not CACHE_ENABLED:
1104 | return dict.fromkeys(keys, False)
1105 |
1106 | client = self._ensure_client()
1107 |
1108 | if client:
1109 | try:
1110 | pipe = client.pipeline(transaction=False)
1111 | for key in keys:
1112 | pipe.exists(key)
1113 |
1114 | existence_results = pipe.execute()
1115 | for key, exists in zip(keys, existence_results, strict=False):
1116 | results[key] = bool(exists)
1117 | except Exception as e:
1118 | logger.warning(f"Error in batch exists check: {e}")
1119 |
1120 | # Check memory cache for missing keys
1121 | for key in keys:
1122 | if key not in results and key in _memory_cache:
1123 | entry = _memory_cache[key]
1124 | results[key] = "expiry" not in entry or entry["expiry"] > time.time()
1125 | elif key not in results:
1126 | results[key] = False
1127 |
1128 | return results
1129 |
1130 | async def batch_get_or_set(
1131 | self, items: list[tuple[str, Any, int | None]]
1132 | ) -> dict[str, Any]:
1133 | """
1134 | Get multiple values, setting missing ones atomically using pipeline.
1135 |
1136 | Args:
1137 | items: List of tuples (key, default_value, ttl)
1138 |
1139 | Returns:
1140 | Dictionary of key-value pairs
1141 | """
1142 | if not CACHE_ENABLED:
1143 | return {key: default for key, default, _ in items}
1144 |
1145 | results: dict[str, Any] = {}
1146 | keys = [item[0] for item in items]
1147 |
1148 | # First, try to get all values
1149 | existing = await self.get_many(keys)
1150 |
1151 | # Identify missing keys
1152 | missing_items = [item for item in items if item[0] not in existing]
1153 |
1154 | # Set missing values if any
1155 | if missing_items:
1156 | await self.batch_save(missing_items)
1157 |
1158 | # Add default values to results
1159 | for key, default_value, _ in missing_items:
1160 | results[key] = default_value
1161 |
1162 | # Add existing values to results
1163 | results.update(existing)
1164 |
1165 | return results
1166 |
```
--------------------------------------------------------------------------------
/tests/test_deep_research_parallel_execution.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive test suite for DeepResearchAgent parallel execution functionality.
3 |
4 | This test suite covers:
5 | - Parallel vs sequential execution modes
6 | - Subagent creation and orchestration
7 | - Task routing to specialized subagents
8 | - Parallel execution fallback mechanisms
9 | - Result synthesis from parallel tasks
10 | - Performance characteristics of parallel execution
11 | """
12 |
13 | import asyncio
14 | import time
15 | from datetime import datetime
16 | from unittest.mock import AsyncMock, Mock, patch
17 |
18 | import pytest
19 | from langchain_core.language_models import BaseChatModel
20 | from langchain_core.messages import AIMessage
21 | from langgraph.checkpoint.memory import MemorySaver
22 | from pydantic import ConfigDict
23 |
24 | from maverick_mcp.agents.deep_research import (
25 | BaseSubagent,
26 | CompetitiveResearchAgent,
27 | DeepResearchAgent,
28 | FundamentalResearchAgent,
29 | SentimentResearchAgent,
30 | TechnicalResearchAgent,
31 | )
32 | from maverick_mcp.utils.parallel_research import (
33 | ParallelResearchConfig,
34 | ResearchResult,
35 | ResearchTask,
36 | )
37 |
38 |
39 | class MockLLM(BaseChatModel):
40 | """Mock LLM for testing."""
41 |
42 | # Allow extra fields to be set on this Pydantic model
43 | model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
44 |
45 | def __init__(self, **kwargs):
46 | # Extract our custom fields before calling super()
47 | self._response_content = kwargs.pop("response_content", "Mock response")
48 | super().__init__(**kwargs)
49 |
50 | def _generate(self, messages, stop=None, **kwargs):
51 | # This method should not be called in async tests
52 | raise NotImplementedError("Use ainvoke for async tests")
53 |
54 | async def ainvoke(self, messages, config=None, **kwargs):
55 | """Mock async invocation."""
56 | await asyncio.sleep(0.01) # Simulate processing time
57 | return AIMessage(content=self._response_content)
58 |
59 | @property
60 | def _llm_type(self) -> str:
61 | return "mock_llm"
62 |
63 |
64 | class TestDeepResearchAgentParallelExecution:
65 | """Test DeepResearchAgent parallel execution capabilities."""
66 |
67 | @pytest.fixture
68 | def mock_llm(self):
69 | """Create mock LLM."""
70 | return MockLLM(
71 | response_content='{"KEY_INSIGHTS": ["Test insight"], "SENTIMENT": {"direction": "bullish", "confidence": 0.8}, "CREDIBILITY": 0.9}'
72 | )
73 |
74 | @pytest.fixture
75 | def parallel_config(self):
76 | """Create test parallel configuration."""
77 | return ParallelResearchConfig(
78 | max_concurrent_agents=3,
79 | timeout_per_agent=5,
80 | enable_fallbacks=True,
81 | rate_limit_delay=0.1,
82 | )
83 |
84 | @pytest.fixture
85 | def deep_research_agent(self, mock_llm, parallel_config):
86 | """Create DeepResearchAgent with parallel execution enabled."""
87 | return DeepResearchAgent(
88 | llm=mock_llm,
89 | persona="moderate",
90 | checkpointer=MemorySaver(),
91 | enable_parallel_execution=True,
92 | parallel_config=parallel_config,
93 | )
94 |
95 | @pytest.fixture
96 | def sequential_agent(self, mock_llm):
97 | """Create DeepResearchAgent with sequential execution."""
98 | return DeepResearchAgent(
99 | llm=mock_llm,
100 | persona="moderate",
101 | checkpointer=MemorySaver(),
102 | enable_parallel_execution=False,
103 | )
104 |
105 | def test_agent_initialization_parallel_enabled(self, deep_research_agent):
106 | """Test agent initialization with parallel execution enabled."""
107 | assert deep_research_agent.enable_parallel_execution is True
108 | assert deep_research_agent.parallel_config is not None
109 | assert deep_research_agent.parallel_orchestrator is not None
110 | assert deep_research_agent.task_distributor is not None
111 | assert deep_research_agent.parallel_config.max_concurrent_agents == 3
112 |
113 | def test_agent_initialization_sequential(self, sequential_agent):
114 | """Test agent initialization with sequential execution."""
115 | assert sequential_agent.enable_parallel_execution is False
116 | # These components should still be initialized for potential future use
117 | assert sequential_agent.parallel_orchestrator is not None
118 |
119 | @pytest.mark.asyncio
120 | async def test_parallel_execution_mode_selection(self, deep_research_agent):
121 | """Test parallel execution mode selection."""
122 | # Mock search providers to be available
123 | mock_provider = AsyncMock()
124 | deep_research_agent.search_providers = [mock_provider]
125 |
126 | with (
127 | patch.object(
128 | deep_research_agent,
129 | "_execute_parallel_research",
130 | new_callable=AsyncMock,
131 | ) as mock_parallel,
132 | patch.object(deep_research_agent.graph, "ainvoke") as mock_sequential,
133 | patch.object(
134 | deep_research_agent,
135 | "_ensure_search_providers_loaded",
136 | return_value=None,
137 | ),
138 | ):
139 | mock_parallel.return_value = {
140 | "status": "success",
141 | "execution_mode": "parallel",
142 | "agent_type": "deep_research",
143 | }
144 |
145 | # Test with parallel execution enabled (default)
146 | result = await deep_research_agent.research_comprehensive(
147 | topic="AAPL analysis", session_id="test_123"
148 | )
149 |
150 | # Should use parallel execution
151 | mock_parallel.assert_called_once()
152 | mock_sequential.assert_not_called()
153 | assert result["execution_mode"] == "parallel"
154 |
155 | @pytest.mark.asyncio
156 | async def test_sequential_execution_mode_selection(self, sequential_agent):
157 | """Test sequential execution mode selection."""
158 | # Mock search providers to be available
159 | mock_provider = AsyncMock()
160 | sequential_agent.search_providers = [mock_provider]
161 |
162 | with (
163 | patch.object(
164 | sequential_agent, "_execute_parallel_research"
165 | ) as mock_parallel,
166 | patch.object(sequential_agent.graph, "ainvoke") as mock_sequential,
167 | patch.object(
168 | sequential_agent, "_ensure_search_providers_loaded", return_value=None
169 | ),
170 | ):
171 | mock_sequential.return_value = {
172 | "status": "success",
173 | "persona": "moderate",
174 | "research_confidence": 0.8,
175 | }
176 |
177 | # Test with parallel execution disabled
178 | await sequential_agent.research_comprehensive(
179 | topic="AAPL analysis", session_id="test_123"
180 | )
181 |
182 | # Should use sequential execution
183 | mock_parallel.assert_not_called()
184 | mock_sequential.assert_called_once()
185 |
186 | @pytest.mark.asyncio
187 | async def test_parallel_execution_override(self, deep_research_agent):
188 | """Test overriding parallel execution at runtime."""
189 | # Mock search providers to be available
190 | mock_provider = AsyncMock()
191 | deep_research_agent.search_providers = [mock_provider]
192 |
193 | with (
194 | patch.object(
195 | deep_research_agent, "_execute_parallel_research"
196 | ) as mock_parallel,
197 | patch.object(deep_research_agent.graph, "ainvoke") as mock_sequential,
198 | patch.object(
199 | deep_research_agent,
200 | "_ensure_search_providers_loaded",
201 | return_value=None,
202 | ),
203 | ):
204 | mock_sequential.return_value = {"status": "success", "persona": "moderate"}
205 |
206 | # Override parallel execution to false
207 | await deep_research_agent.research_comprehensive(
208 | topic="AAPL analysis",
209 | session_id="test_123",
210 | use_parallel_execution=False,
211 | )
212 |
213 | # Should use sequential despite agent default
214 | mock_parallel.assert_not_called()
215 | mock_sequential.assert_called_once()
216 |
217 | @pytest.mark.asyncio
218 | async def test_parallel_execution_fallback(self, deep_research_agent):
219 | """Test fallback to sequential when parallel execution fails."""
220 | # Mock search providers to be available
221 | mock_provider = AsyncMock()
222 | deep_research_agent.search_providers = [mock_provider]
223 |
224 | with (
225 | patch.object(
226 | deep_research_agent,
227 | "_execute_parallel_research",
228 | new_callable=AsyncMock,
229 | ) as mock_parallel,
230 | patch.object(deep_research_agent.graph, "ainvoke") as mock_sequential,
231 | patch.object(
232 | deep_research_agent,
233 | "_ensure_search_providers_loaded",
234 | return_value=None,
235 | ),
236 | ):
237 | # Parallel execution fails
238 | mock_parallel.side_effect = RuntimeError("Parallel execution failed")
239 | mock_sequential.return_value = {
240 | "status": "success",
241 | "persona": "moderate",
242 | "research_confidence": 0.7,
243 | }
244 |
245 | result = await deep_research_agent.research_comprehensive(
246 | topic="AAPL analysis", session_id="test_123"
247 | )
248 |
249 | # Should attempt parallel then fall back to sequential
250 | mock_parallel.assert_called_once()
251 | mock_sequential.assert_called_once()
252 | assert result["status"] == "success"
253 |
254 | @pytest.mark.asyncio
255 | async def test_execute_parallel_research_task_distribution(
256 | self, deep_research_agent
257 | ):
258 | """Test parallel research task distribution."""
259 | with (
260 | patch.object(
261 | deep_research_agent.task_distributor, "distribute_research_tasks"
262 | ) as mock_distribute,
263 | patch.object(
264 | deep_research_agent.parallel_orchestrator,
265 | "execute_parallel_research",
266 | new_callable=AsyncMock,
267 | ) as mock_execute,
268 | ):
269 | # Mock task distribution
270 | mock_tasks = [
271 | ResearchTask(
272 | "test_123_fundamental", "fundamental", "AAPL", ["earnings"]
273 | ),
274 | ResearchTask("test_123_sentiment", "sentiment", "AAPL", ["news"]),
275 | ]
276 | mock_distribute.return_value = mock_tasks
277 |
278 | # Mock orchestrator execution
279 | mock_result = ResearchResult()
280 | mock_result.successful_tasks = 2
281 | mock_result.failed_tasks = 0
282 | mock_result.synthesis = {"confidence_score": 0.85}
283 | mock_execute.return_value = mock_result
284 |
285 | initial_state = {
286 | "persona": "moderate",
287 | "research_topic": "AAPL analysis",
288 | "session_id": "test_123",
289 | "focus_areas": ["earnings", "sentiment"],
290 | }
291 | await deep_research_agent._execute_parallel_research(
292 | topic="AAPL analysis",
293 | session_id="test_123",
294 | depth="standard",
295 | focus_areas=["earnings", "sentiment"],
296 | start_time=datetime.now(),
297 | initial_state=initial_state,
298 | )
299 |
300 | # Verify task distribution was called correctly
301 | mock_distribute.assert_called_once_with(
302 | topic="AAPL analysis",
303 | session_id="test_123",
304 | focus_areas=["earnings", "sentiment"],
305 | )
306 |
307 | # Verify orchestrator was called with distributed tasks
308 | mock_execute.assert_called_once()
309 | args, kwargs = mock_execute.call_args
310 | assert kwargs["tasks"] == mock_tasks
311 |
312 | @pytest.mark.asyncio
313 | async def test_subagent_task_routing(self, deep_research_agent):
314 | """Test routing tasks to appropriate subagents."""
315 | # Test fundamental routing
316 | fundamental_task = ResearchTask(
317 | "test_fundamental", "fundamental", "AAPL", ["earnings"]
318 | )
319 |
320 | with patch(
321 | "maverick_mcp.agents.deep_research.FundamentalResearchAgent"
322 | ) as mock_fundamental:
323 | mock_subagent = AsyncMock()
324 | mock_subagent.execute_research.return_value = {
325 | "research_type": "fundamental"
326 | }
327 | mock_fundamental.return_value = mock_subagent
328 |
329 | # This would normally be called by the orchestrator
330 | # We're testing the routing logic directly
331 | await deep_research_agent._execute_subagent_task(fundamental_task)
332 |
333 | mock_fundamental.assert_called_once_with(deep_research_agent)
334 | mock_subagent.execute_research.assert_called_once_with(fundamental_task)
335 |
336 | @pytest.mark.asyncio
337 | async def test_unknown_task_type_fallback(self, deep_research_agent):
338 | """Test fallback for unknown task types."""
339 | unknown_task = ResearchTask("test_unknown", "unknown_type", "AAPL", ["test"])
340 |
341 | with patch(
342 | "maverick_mcp.agents.deep_research.FundamentalResearchAgent"
343 | ) as mock_fundamental:
344 | mock_subagent = AsyncMock()
345 | mock_subagent.execute_research.return_value = {
346 | "research_type": "fundamental"
347 | }
348 | mock_fundamental.return_value = mock_subagent
349 |
350 | await deep_research_agent._execute_subagent_task(unknown_task)
351 |
352 | # Should fall back to fundamental analysis
353 | mock_fundamental.assert_called_once_with(deep_research_agent)
354 |
355 | @pytest.mark.asyncio
356 | async def test_parallel_result_synthesis(self, deep_research_agent, mock_llm):
357 | """Test synthesis of results from parallel tasks."""
358 | # Create mock task results
359 | task_results = {
360 | "test_123_fundamental": ResearchTask(
361 | "test_123_fundamental", "fundamental", "AAPL", ["earnings"]
362 | ),
363 | "test_123_sentiment": ResearchTask(
364 | "test_123_sentiment", "sentiment", "AAPL", ["news"]
365 | ),
366 | }
367 |
368 | # Set tasks as completed with results
369 | task_results["test_123_fundamental"].status = "completed"
370 | task_results["test_123_fundamental"].result = {
371 | "insights": ["Strong earnings growth"],
372 | "sentiment": {"direction": "bullish", "confidence": 0.8},
373 | "credibility_score": 0.9,
374 | }
375 |
376 | task_results["test_123_sentiment"].status = "completed"
377 | task_results["test_123_sentiment"].result = {
378 | "insights": ["Positive market sentiment"],
379 | "sentiment": {"direction": "bullish", "confidence": 0.7},
380 | "credibility_score": 0.8,
381 | }
382 |
383 | # Mock LLM synthesis response
384 | mock_llm._response_content = "Synthesized analysis showing strong bullish outlook based on fundamental and sentiment analysis"
385 |
386 | result = await deep_research_agent._synthesize_parallel_results(task_results)
387 |
388 | assert result is not None
389 | assert "synthesis" in result
390 | assert "key_insights" in result
391 | assert "overall_sentiment" in result
392 | assert len(result["key_insights"]) > 0
393 | assert result["overall_sentiment"]["direction"] == "bullish"
394 |
395 | @pytest.mark.asyncio
396 | async def test_synthesis_with_mixed_results(self, deep_research_agent):
397 | """Test synthesis with mixed successful and failed tasks."""
398 | task_results = {
399 | "test_123_fundamental": ResearchTask(
400 | "test_123_fundamental", "fundamental", "AAPL", ["earnings"]
401 | ),
402 | "test_123_technical": ResearchTask(
403 | "test_123_technical", "technical", "AAPL", ["charts"]
404 | ),
405 | "test_123_sentiment": ResearchTask(
406 | "test_123_sentiment", "sentiment", "AAPL", ["news"]
407 | ),
408 | }
409 |
410 | # One successful, one failed, one successful
411 | task_results["test_123_fundamental"].status = "completed"
412 | task_results["test_123_fundamental"].result = {
413 | "insights": ["Strong fundamentals"],
414 | "sentiment": {"direction": "bullish", "confidence": 0.8},
415 | }
416 |
417 | task_results["test_123_technical"].status = "failed"
418 | task_results["test_123_technical"].error = "Technical analysis failed"
419 |
420 | task_results["test_123_sentiment"].status = "completed"
421 | task_results["test_123_sentiment"].result = {
422 | "insights": ["Mixed sentiment"],
423 | "sentiment": {"direction": "neutral", "confidence": 0.6},
424 | }
425 |
426 | result = await deep_research_agent._synthesize_parallel_results(task_results)
427 |
428 | # Should handle mixed results gracefully
429 | assert result is not None
430 | assert len(result["key_insights"]) > 0
431 | assert "task_breakdown" in result
432 | assert result["task_breakdown"]["test_123_technical"]["status"] == "failed"
433 |
434 | @pytest.mark.asyncio
435 | async def test_synthesis_with_no_successful_results(self, deep_research_agent):
436 | """Test synthesis when all tasks fail."""
437 | task_results = {
438 | "test_123_fundamental": ResearchTask(
439 | "test_123_fundamental", "fundamental", "AAPL", ["earnings"]
440 | ),
441 | "test_123_sentiment": ResearchTask(
442 | "test_123_sentiment", "sentiment", "AAPL", ["news"]
443 | ),
444 | }
445 |
446 | # Both tasks failed
447 | task_results["test_123_fundamental"].status = "failed"
448 | task_results["test_123_fundamental"].error = "API timeout"
449 |
450 | task_results["test_123_sentiment"].status = "failed"
451 | task_results["test_123_sentiment"].error = "No data available"
452 |
453 | result = await deep_research_agent._synthesize_parallel_results(task_results)
454 |
455 | # Should handle gracefully
456 | assert result is not None
457 | assert result["confidence_score"] == 0.0
458 | assert "No research results available" in result["synthesis"]
459 |
460 | @pytest.mark.asyncio
461 | async def test_synthesis_llm_failure_fallback(self, deep_research_agent):
462 | """Test fallback when LLM synthesis fails."""
463 | task_results = {
464 | "test_123_fundamental": ResearchTask(
465 | "test_123_fundamental", "fundamental", "AAPL", ["earnings"]
466 | ),
467 | }
468 |
469 | task_results["test_123_fundamental"].status = "completed"
470 | task_results["test_123_fundamental"].result = {
471 | "insights": ["Good insights"],
472 | "sentiment": {"direction": "bullish", "confidence": 0.8},
473 | }
474 |
475 | # Mock LLM to fail
476 | with patch.object(
477 | deep_research_agent.llm, "ainvoke", side_effect=RuntimeError("LLM failed")
478 | ):
479 | result = await deep_research_agent._synthesize_parallel_results(
480 | task_results
481 | )
482 |
483 | # Should use fallback synthesis
484 | assert result is not None
485 | assert "fallback synthesis" in result["synthesis"].lower()
486 |
487 | @pytest.mark.asyncio
488 | async def test_format_parallel_research_response(self, deep_research_agent):
489 | """Test formatting of parallel research response."""
490 | # Create mock research result
491 | research_result = ResearchResult()
492 | research_result.successful_tasks = 2
493 | research_result.failed_tasks = 0
494 | research_result.total_execution_time = 1.5
495 | research_result.parallel_efficiency = 2.1
496 | research_result.synthesis = {
497 | "confidence_score": 0.85,
498 | "key_findings": ["Finding 1", "Finding 2"],
499 | }
500 |
501 | # Mock task results with sources
502 | task1 = ResearchTask(
503 | "test_123_fundamental", "fundamental", "AAPL", ["earnings"]
504 | )
505 | task1.status = "completed"
506 | task1.result = {
507 | "sources": [
508 | {
509 | "title": "AAPL Earnings Report",
510 | "url": "https://example.com/earnings",
511 | "credibility_score": 0.9,
512 | }
513 | ]
514 | }
515 | research_result.task_results = {"test_123_fundamental": task1}
516 |
517 | start_time = datetime.now()
518 | formatted_result = await deep_research_agent._format_parallel_research_response(
519 | research_result=research_result,
520 | topic="AAPL analysis",
521 | session_id="test_123",
522 | depth="standard",
523 | initial_state={"persona": "moderate"},
524 | start_time=start_time,
525 | )
526 |
527 | # Verify formatted response structure
528 | assert formatted_result["status"] == "success"
529 | assert formatted_result["agent_type"] == "deep_research"
530 | assert formatted_result["execution_mode"] == "parallel"
531 | assert formatted_result["research_topic"] == "AAPL analysis"
532 | assert formatted_result["confidence_score"] == 0.85
533 | assert "parallel_execution_stats" in formatted_result
534 | assert formatted_result["parallel_execution_stats"]["successful_tasks"] == 2
535 | assert len(formatted_result["citations"]) > 0
536 |
537 | @pytest.mark.asyncio
538 | async def test_aggregated_sentiment_calculation(self, deep_research_agent):
539 | """Test aggregation of sentiment from multiple sources."""
540 | sentiment_scores = [
541 | {"direction": "bullish", "confidence": 0.8},
542 | {"direction": "bullish", "confidence": 0.6},
543 | {"direction": "neutral", "confidence": 0.7},
544 | {"direction": "bearish", "confidence": 0.5},
545 | ]
546 |
547 | result = deep_research_agent._calculate_aggregated_sentiment(sentiment_scores)
548 |
549 | assert result is not None
550 | assert "direction" in result
551 | assert "confidence" in result
552 | assert "consensus" in result
553 | assert "source_count" in result
554 | assert result["source_count"] == 4
555 |
556 | @pytest.mark.asyncio
557 | async def test_parallel_recommendation_derivation(self, deep_research_agent):
558 | """Test derivation of investment recommendations from parallel analysis."""
559 | # Test strong bullish signal
560 | bullish_sentiment = {"direction": "bullish", "confidence": 0.9}
561 | recommendation = deep_research_agent._derive_parallel_recommendation(
562 | bullish_sentiment
563 | )
564 | assert "strong buy" in recommendation.lower() or "buy" in recommendation.lower()
565 |
566 | # Test bearish signal
567 | bearish_sentiment = {"direction": "bearish", "confidence": 0.8}
568 | recommendation = deep_research_agent._derive_parallel_recommendation(
569 | bearish_sentiment
570 | )
571 | assert (
572 | "caution" in recommendation.lower() or "negative" in recommendation.lower()
573 | )
574 |
575 | # Test neutral/mixed signals
576 | neutral_sentiment = {"direction": "neutral", "confidence": 0.5}
577 | recommendation = deep_research_agent._derive_parallel_recommendation(
578 | neutral_sentiment
579 | )
580 | assert "neutral" in recommendation.lower() or "mixed" in recommendation.lower()
581 |
582 |
583 | class TestSpecializedSubagents:
584 | """Test specialized research subagent functionality."""
585 |
586 | @pytest.fixture
587 | def mock_parent_agent(self):
588 | """Create mock parent DeepResearchAgent."""
589 | parent = Mock()
590 | parent.llm = MockLLM()
591 | parent.search_providers = []
592 | parent.content_analyzer = Mock()
593 | parent.persona = Mock()
594 | parent.persona.name = "moderate"
595 | parent._calculate_source_credibility = Mock(return_value=0.8)
596 | return parent
597 |
598 | def test_base_subagent_initialization(self, mock_parent_agent):
599 | """Test BaseSubagent initialization."""
600 | subagent = BaseSubagent(mock_parent_agent)
601 |
602 | assert subagent.parent == mock_parent_agent
603 | assert subagent.llm == mock_parent_agent.llm
604 | assert subagent.search_providers == mock_parent_agent.search_providers
605 | assert subagent.content_analyzer == mock_parent_agent.content_analyzer
606 | assert subagent.persona == mock_parent_agent.persona
607 |
608 | @pytest.mark.asyncio
609 | async def test_fundamental_research_agent(self, mock_parent_agent):
610 | """Test FundamentalResearchAgent execution."""
611 | # Mock content analyzer
612 | mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
613 | return_value={
614 | "insights": ["Strong earnings growth"],
615 | "sentiment": {"direction": "bullish", "confidence": 0.8},
616 | "risk_factors": ["Market volatility"],
617 | "opportunities": ["Dividend growth"],
618 | "credibility_score": 0.9,
619 | }
620 | )
621 |
622 | subagent = FundamentalResearchAgent(mock_parent_agent)
623 |
624 | # Mock search results
625 | with patch.object(subagent, "_perform_specialized_search") as mock_search:
626 | mock_search.return_value = [
627 | {
628 | "title": "AAPL Earnings Report",
629 | "url": "https://example.com/earnings",
630 | "content": "Apple reported strong quarterly earnings...",
631 | "credibility_score": 0.9,
632 | }
633 | ]
634 |
635 | task = ResearchTask(
636 | "fund_task", "fundamental", "AAPL analysis", ["earnings"]
637 | )
638 | result = await subagent.execute_research(task)
639 |
640 | assert result["research_type"] == "fundamental"
641 | assert len(result["insights"]) > 0
642 | assert "sentiment" in result
643 | assert result["sentiment"]["direction"] == "bullish"
644 | assert len(result["sources"]) > 0
645 |
646 | def test_fundamental_query_generation(self, mock_parent_agent):
647 | """Test fundamental analysis query generation."""
648 | subagent = FundamentalResearchAgent(mock_parent_agent)
649 | queries = subagent._generate_fundamental_queries("AAPL")
650 |
651 | assert len(queries) > 0
652 | assert any("earnings" in query.lower() for query in queries)
653 | assert any("revenue" in query.lower() for query in queries)
654 | assert any("valuation" in query.lower() for query in queries)
655 |
656 | @pytest.mark.asyncio
657 | async def test_technical_research_agent(self, mock_parent_agent):
658 | """Test TechnicalResearchAgent execution."""
659 | mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
660 | return_value={
661 | "insights": ["Bullish chart pattern"],
662 | "sentiment": {"direction": "bullish", "confidence": 0.7},
663 | "risk_factors": ["Support level break"],
664 | "opportunities": ["Breakout potential"],
665 | "credibility_score": 0.8,
666 | }
667 | )
668 |
669 | subagent = TechnicalResearchAgent(mock_parent_agent)
670 |
671 | with patch.object(subagent, "_perform_specialized_search") as mock_search:
672 | mock_search.return_value = [
673 | {
674 | "title": "AAPL Technical Analysis",
675 | "url": "https://example.com/technical",
676 | "content": "Apple stock showing strong technical indicators...",
677 | "credibility_score": 0.8,
678 | }
679 | ]
680 |
681 | task = ResearchTask("tech_task", "technical", "AAPL analysis", ["charts"])
682 | result = await subagent.execute_research(task)
683 |
684 | assert result["research_type"] == "technical"
685 | assert "price_action" in result["focus_areas"]
686 | assert "technical_indicators" in result["focus_areas"]
687 |
688 | def test_technical_query_generation(self, mock_parent_agent):
689 | """Test technical analysis query generation."""
690 | subagent = TechnicalResearchAgent(mock_parent_agent)
691 | queries = subagent._generate_technical_queries("AAPL")
692 |
693 | assert any("technical analysis" in query.lower() for query in queries)
694 | assert any("chart pattern" in query.lower() for query in queries)
695 | assert any(
696 | "rsi" in query.lower() or "macd" in query.lower() for query in queries
697 | )
698 |
699 | @pytest.mark.asyncio
700 | async def test_sentiment_research_agent(self, mock_parent_agent):
701 | """Test SentimentResearchAgent execution."""
702 | mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
703 | return_value={
704 | "insights": ["Positive analyst sentiment"],
705 | "sentiment": {"direction": "bullish", "confidence": 0.9},
706 | "risk_factors": ["Market sentiment shift"],
707 | "opportunities": ["Upgrade potential"],
708 | "credibility_score": 0.85,
709 | }
710 | )
711 |
712 | subagent = SentimentResearchAgent(mock_parent_agent)
713 |
714 | with patch.object(subagent, "_perform_specialized_search") as mock_search:
715 | mock_search.return_value = [
716 | {
717 | "title": "AAPL Analyst Upgrade",
718 | "url": "https://example.com/upgrade",
719 | "content": "Apple receives analyst upgrade...",
720 | "credibility_score": 0.85,
721 | }
722 | ]
723 |
724 | task = ResearchTask("sent_task", "sentiment", "AAPL analysis", ["news"])
725 | result = await subagent.execute_research(task)
726 |
727 | assert result["research_type"] == "sentiment"
728 | assert "market_sentiment" in result["focus_areas"]
729 | assert result["sentiment"]["confidence"] > 0.8
730 |
731 | @pytest.mark.asyncio
732 | async def test_competitive_research_agent(self, mock_parent_agent):
733 | """Test CompetitiveResearchAgent execution."""
734 | mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
735 | return_value={
736 | "insights": ["Strong competitive position"],
737 | "sentiment": {"direction": "bullish", "confidence": 0.7},
738 | "risk_factors": ["Increased competition"],
739 | "opportunities": ["Market expansion"],
740 | "credibility_score": 0.8,
741 | }
742 | )
743 |
744 | subagent = CompetitiveResearchAgent(mock_parent_agent)
745 |
746 | with patch.object(subagent, "_perform_specialized_search") as mock_search:
747 | mock_search.return_value = [
748 | {
749 | "title": "AAPL Market Share Analysis",
750 | "url": "https://example.com/marketshare",
751 | "content": "Apple maintains strong market position...",
752 | "credibility_score": 0.8,
753 | }
754 | ]
755 |
756 | task = ResearchTask(
757 | "comp_task", "competitive", "AAPL analysis", ["market_share"]
758 | )
759 | result = await subagent.execute_research(task)
760 |
761 | assert result["research_type"] == "competitive"
762 | assert "competitive_position" in result["focus_areas"]
763 | assert "market_share" in result["focus_areas"]
764 |
765 | @pytest.mark.asyncio
766 | async def test_subagent_search_deduplication(self, mock_parent_agent):
767 | """Test search result deduplication in subagents."""
768 | subagent = BaseSubagent(mock_parent_agent)
769 |
770 | # Mock search providers with duplicate results
771 | mock_provider1 = AsyncMock()
772 | mock_provider1.search.return_value = [
773 | {"url": "https://example.com/article1", "title": "Article 1"},
774 | {"url": "https://example.com/article2", "title": "Article 2"},
775 | ]
776 |
777 | mock_provider2 = AsyncMock()
778 | mock_provider2.search.return_value = [
779 | {"url": "https://example.com/article1", "title": "Article 1"}, # Duplicate
780 | {"url": "https://example.com/article3", "title": "Article 3"},
781 | ]
782 |
783 | subagent.search_providers = [mock_provider1, mock_provider2]
784 |
785 | results = await subagent._perform_specialized_search(
786 | "test topic", ["test query"], max_results=10
787 | )
788 |
789 | # Should deduplicate by URL
790 | urls = [result["url"] for result in results]
791 | assert len(urls) == len(set(urls)) # No duplicates
792 | assert len(results) == 3 # Should have 3 unique results
793 |
794 | @pytest.mark.asyncio
795 | async def test_subagent_search_error_handling(self, mock_parent_agent):
796 | """Test error handling in subagent search."""
797 | subagent = BaseSubagent(mock_parent_agent)
798 |
799 | # Mock provider that fails
800 | mock_provider = AsyncMock()
801 | mock_provider.search.side_effect = RuntimeError("Search failed")
802 | subagent.search_providers = [mock_provider]
803 |
804 | # Should handle errors gracefully and return empty results
805 | results = await subagent._perform_specialized_search(
806 | "test topic", ["test query"], max_results=10
807 | )
808 |
809 | assert results == [] # Should return empty list on error
810 |
811 | @pytest.mark.asyncio
812 | async def test_subagent_content_analysis_error_handling(self, mock_parent_agent):
813 | """Test content analysis error handling in subagents."""
814 | # Mock content analyzer that fails
815 | mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
816 | side_effect=RuntimeError("Analysis failed")
817 | )
818 |
819 | subagent = BaseSubagent(mock_parent_agent)
820 |
821 | search_results = [
822 | {
823 | "title": "Test Article",
824 | "url": "https://example.com/test",
825 | "content": "Test content",
826 | }
827 | ]
828 |
829 | # Should handle analysis errors gracefully
830 | results = await subagent._analyze_search_results(
831 | search_results, "test_analysis"
832 | )
833 |
834 | # Should return empty results when analysis fails
835 | assert results == []
836 |
837 |
838 | @pytest.mark.integration
839 | class TestDeepResearchParallelIntegration:
840 | """Integration tests for DeepResearchAgent parallel execution."""
841 |
842 | @pytest.fixture
843 | def integration_agent(self):
844 | """Create agent for integration testing."""
845 | llm = MockLLM(
846 | '{"KEY_INSIGHTS": ["Integration insight"], "SENTIMENT": {"direction": "bullish", "confidence": 0.8}}'
847 | )
848 |
849 | config = ParallelResearchConfig(
850 | max_concurrent_agents=2,
851 | timeout_per_agent=5,
852 | enable_fallbacks=True,
853 | rate_limit_delay=0.05,
854 | )
855 |
856 | return DeepResearchAgent(
857 | llm=llm,
858 | persona="moderate",
859 | enable_parallel_execution=True,
860 | parallel_config=config,
861 | )
862 |
863 | @pytest.mark.asyncio
864 | async def test_end_to_end_parallel_research(self, integration_agent):
865 | """Test complete end-to-end parallel research workflow."""
866 | # Mock the search providers and subagent execution
867 | with patch.object(integration_agent, "_execute_subagent_task") as mock_execute:
868 | mock_execute.return_value = {
869 | "research_type": "fundamental",
870 | "insights": ["Strong financial health", "Growing revenue"],
871 | "sentiment": {"direction": "bullish", "confidence": 0.8},
872 | "risk_factors": ["Market volatility"],
873 | "opportunities": ["Expansion potential"],
874 | "credibility_score": 0.85,
875 | "sources": [
876 | {
877 | "title": "Financial Report",
878 | "url": "https://example.com/report",
879 | "credibility_score": 0.9,
880 | }
881 | ],
882 | }
883 |
884 | start_time = time.time()
885 | result = await integration_agent.research_comprehensive(
886 | topic="Apple Inc comprehensive financial analysis",
887 | session_id="integration_test_123",
888 | depth="comprehensive",
889 | focus_areas=["fundamentals", "sentiment", "competitive"],
890 | )
891 | execution_time = time.time() - start_time
892 |
893 | # Verify result structure
894 | assert result["status"] == "success"
895 | assert result["agent_type"] == "deep_research"
896 | assert result["execution_mode"] == "parallel"
897 | assert (
898 | result["research_topic"] == "Apple Inc comprehensive financial analysis"
899 | )
900 | assert result["confidence_score"] > 0
901 | assert len(result["citations"]) > 0
902 | assert "parallel_execution_stats" in result
903 |
904 | # Verify performance characteristics
905 | assert execution_time < 10 # Should complete reasonably quickly
906 | assert result["execution_time_ms"] > 0
907 |
908 | # Verify parallel execution stats
909 | stats = result["parallel_execution_stats"]
910 | assert stats["total_tasks"] > 0
911 | assert stats["successful_tasks"] >= 0
912 | assert stats["parallel_efficiency"] > 0
913 |
914 | @pytest.mark.asyncio
915 | async def test_parallel_vs_sequential_performance(self, integration_agent):
916 | """Test performance comparison between parallel and sequential execution."""
917 | topic = "Microsoft Corp investment analysis"
918 | session_id = "perf_test_123"
919 |
920 | # Mock subagent execution with realistic delay
921 | async def mock_subagent_execution(task):
922 | await asyncio.sleep(0.1) # Simulate work
923 | return {
924 | "research_type": task.task_type,
925 | "insights": [f"Insight from {task.task_type}"],
926 | "sentiment": {"direction": "bullish", "confidence": 0.7},
927 | "credibility_score": 0.8,
928 | "sources": [],
929 | }
930 |
931 | with patch.object(
932 | integration_agent,
933 | "_execute_subagent_task",
934 | side_effect=mock_subagent_execution,
935 | ):
936 | # Test parallel execution
937 | start_parallel = time.time()
938 | parallel_result = await integration_agent.research_comprehensive(
939 | topic=topic, session_id=session_id, use_parallel_execution=True
940 | )
941 | time.time() - start_parallel
942 |
943 | # Test sequential execution
944 | start_sequential = time.time()
945 | sequential_result = await integration_agent.research_comprehensive(
946 | topic=topic,
947 | session_id=f"{session_id}_seq",
948 | use_parallel_execution=False,
949 | )
950 | time.time() - start_sequential
951 |
952 | # Verify both succeeded
953 | assert parallel_result["status"] == "success"
954 | assert sequential_result["status"] == "success"
955 |
956 | # Parallel should generally be faster (though not guaranteed in all test environments)
957 | # At minimum, parallel efficiency should be calculated
958 | if "parallel_execution_stats" in parallel_result:
959 | assert (
960 | parallel_result["parallel_execution_stats"]["parallel_efficiency"]
961 | > 0
962 | )
963 |
964 | @pytest.mark.asyncio
965 | async def test_research_quality_consistency(self, integration_agent):
966 | """Test that parallel and sequential execution produce consistent quality."""
967 | topic = "Tesla Inc strategic analysis"
968 |
969 | # Mock consistent subagent responses
970 | mock_response = {
971 | "research_type": "fundamental",
972 | "insights": ["Consistent insight 1", "Consistent insight 2"],
973 | "sentiment": {"direction": "bullish", "confidence": 0.75},
974 | "credibility_score": 0.8,
975 | "sources": [
976 | {
977 | "title": "Source",
978 | "url": "https://example.com",
979 | "credibility_score": 0.8,
980 | }
981 | ],
982 | }
983 |
984 | with patch.object(
985 | integration_agent, "_execute_subagent_task", return_value=mock_response
986 | ):
987 | parallel_result = await integration_agent.research_comprehensive(
988 | topic=topic,
989 | session_id="quality_test_parallel",
990 | use_parallel_execution=True,
991 | )
992 |
993 | sequential_result = await integration_agent.research_comprehensive(
994 | topic=topic,
995 | session_id="quality_test_sequential",
996 | use_parallel_execution=False,
997 | )
998 |
999 | # Both should succeed with reasonable confidence
1000 | assert parallel_result["status"] == "success"
1001 | assert sequential_result["status"] == "success"
1002 | assert parallel_result["confidence_score"] > 0.5
1003 | assert sequential_result["confidence_score"] > 0.5
1004 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/server.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | MaverickMCP Server Implementation - Simple Stock Analysis MCP Server.
3 |
4 | This module implements a simplified FastMCP server focused on stock analysis with:
5 | - No authentication required
6 | - No billing system
7 | - Core stock data and technical analysis functionality
8 | - Multi-transport support (stdio, SSE, streamable-http)
9 | """
10 |
11 | # Configure warnings filter BEFORE any other imports to suppress known deprecation warnings
12 | import warnings
13 |
14 | warnings.filterwarnings(
15 | "ignore",
16 | message="pkg_resources is deprecated as an API.*",
17 | category=UserWarning,
18 | module="pandas_ta.*",
19 | )
20 |
21 | warnings.filterwarnings(
22 | "ignore",
23 | message="'crypt' is deprecated and slated for removal.*",
24 | category=DeprecationWarning,
25 | module="passlib.*",
26 | )
27 |
28 | warnings.filterwarnings(
29 | "ignore",
30 | message=".*pydantic.* is deprecated.*",
31 | category=DeprecationWarning,
32 | module="langchain.*",
33 | )
34 |
35 | warnings.filterwarnings(
36 | "ignore",
37 | message=".*cookie.*deprecated.*",
38 | category=DeprecationWarning,
39 | module="starlette.*",
40 | )
41 |
42 | # Suppress Plotly/Kaleido deprecation warnings from library internals
43 | # These warnings come from the libraries themselves and can't be fixed at user level
44 | # Comprehensive suppression patterns for all known kaleido warnings
45 | kaleido_patterns = [
46 | r".*plotly\.io\.kaleido\.scope\..*is deprecated.*",
47 | r".*Use of plotly\.io\.kaleido\.scope\..*is deprecated.*",
48 | r".*default_format.*deprecated.*",
49 | r".*default_width.*deprecated.*",
50 | r".*default_height.*deprecated.*",
51 | r".*default_scale.*deprecated.*",
52 | r".*mathjax.*deprecated.*",
53 | r".*plotlyjs.*deprecated.*",
54 | ]
55 |
56 | for pattern in kaleido_patterns:
57 | warnings.filterwarnings(
58 | "ignore",
59 | category=DeprecationWarning,
60 | message=pattern,
61 | )
62 |
63 | # Also suppress by module to catch any we missed
64 | warnings.filterwarnings(
65 | "ignore",
66 | category=DeprecationWarning,
67 | module=r".*kaleido.*",
68 | )
69 |
70 | warnings.filterwarnings(
71 | "ignore",
72 | category=DeprecationWarning,
73 | module=r"plotly\.io\._kaleido",
74 | )
75 |
76 | # Suppress websockets deprecation warnings from uvicorn internals
77 | # These warnings come from uvicorn's use of deprecated websockets APIs and cannot be fixed at our level
78 | warnings.filterwarnings(
79 | "ignore",
80 | message=".*websockets.legacy is deprecated.*",
81 | category=DeprecationWarning,
82 | )
83 |
84 | warnings.filterwarnings(
85 | "ignore",
86 | message=".*websockets.server.WebSocketServerProtocol is deprecated.*",
87 | category=DeprecationWarning,
88 | )
89 |
90 | # Broad suppression for all websockets deprecation warnings from third-party libs
91 | warnings.filterwarnings(
92 | "ignore",
93 | category=DeprecationWarning,
94 | module="websockets.*",
95 | )
96 |
97 | warnings.filterwarnings(
98 | "ignore",
99 | category=DeprecationWarning,
100 | module="uvicorn.protocols.websockets.*",
101 | )
102 |
103 | # ruff: noqa: E402 - Imports after warnings config for proper deprecation warning suppression
104 | import argparse
105 | import json
106 | import sys
107 | import uuid
108 | from collections.abc import Awaitable, Callable
109 | from datetime import UTC, datetime
110 | from typing import TYPE_CHECKING, Any, Protocol, cast
111 |
112 | from fastapi import FastAPI
113 | from fastmcp import FastMCP
114 |
115 | # Import tool registry for direct registration
116 | # This avoids Claude Desktop's issue with mounted router tool names
117 | from maverick_mcp.api.routers.tool_registry import register_all_router_tools
118 | from maverick_mcp.config.settings import settings
119 | from maverick_mcp.data.models import get_db
120 | from maverick_mcp.data.performance import (
121 | cleanup_performance_systems,
122 | initialize_performance_systems,
123 | )
124 | from maverick_mcp.providers.market_data import MarketDataProvider
125 | from maverick_mcp.providers.stock_data import StockDataProvider
126 | from maverick_mcp.utils.logging import get_logger, setup_structured_logging
127 | from maverick_mcp.utils.monitoring import initialize_monitoring
128 | from maverick_mcp.utils.structured_logger import (
129 | get_logger_manager,
130 | setup_backtesting_logging,
131 | )
132 | from maverick_mcp.utils.tracing import initialize_tracing
133 |
134 | # Connection manager temporarily disabled for compatibility
135 | if TYPE_CHECKING: # pragma: no cover - import used for static typing only
136 | from maverick_mcp.infrastructure.connection_manager import MCPConnectionManager
137 |
138 | # Monkey-patch FastMCP's create_sse_app to register both /sse and /sse/ routes
139 | # This allows both paths to work without 307 redirects
140 | # Fixes the mcp-remote tool registration failure issue
141 | from fastmcp.server import http as fastmcp_http
142 | from starlette.middleware import Middleware
143 | from starlette.routing import BaseRoute, Route
144 |
145 | _original_create_sse_app = fastmcp_http.create_sse_app
146 |
147 |
148 | def _patched_create_sse_app(
149 | server: Any,
150 | message_path: str,
151 | sse_path: str,
152 | auth: Any | None = None,
153 | debug: bool = False,
154 | routes: list[BaseRoute] | None = None,
155 | middleware: list[Middleware] | None = None,
156 | ) -> Any:
157 | """Patched version of create_sse_app that registers both /sse and /sse/ paths.
158 |
159 | This prevents 307 redirects by registering both path variants explicitly,
160 | fixing tool registration failures with mcp-remote that occurred when clients
161 | used /sse instead of /sse/.
162 | """
163 | import sys
164 |
165 | print(
166 | f"🔧 Patched create_sse_app called with sse_path={sse_path}",
167 | file=sys.stderr,
168 | flush=True,
169 | )
170 |
171 | # Call the original create_sse_app function
172 | app = _original_create_sse_app(
173 | server=server,
174 | message_path=message_path,
175 | sse_path=sse_path,
176 | auth=auth,
177 | debug=debug,
178 | routes=routes,
179 | middleware=middleware,
180 | )
181 |
182 | # Register both path variants (with and without trailing slash)
183 |
184 | # Find the SSE endpoint handler from the existing routes
185 | sse_endpoint = None
186 | for route in app.router.routes:
187 | if isinstance(route, Route) and route.path == sse_path:
188 | sse_endpoint = route.endpoint
189 | break
190 |
191 | if sse_endpoint:
192 | # Determine the alternative path
193 | if sse_path.endswith("/"):
194 | alt_path = sse_path.rstrip("/") # Remove trailing slash
195 | else:
196 | alt_path = sse_path + "/" # Add trailing slash
197 |
198 | # Register the alternative path
199 | new_route = Route(
200 | alt_path,
201 | endpoint=sse_endpoint,
202 | methods=["GET"],
203 | )
204 | app.router.routes.insert(0, new_route)
205 | print(
206 | f"✅ Registered SSE routes: {sse_path} AND {alt_path}",
207 | file=sys.stderr,
208 | flush=True,
209 | )
210 | else:
211 | print(
212 | f"⚠️ Could not find SSE endpoint for {sse_path}",
213 | file=sys.stderr,
214 | flush=True,
215 | )
216 |
217 | return app
218 |
219 |
220 | # Apply the monkey-patch
221 | fastmcp_http.create_sse_app = _patched_create_sse_app
222 |
223 |
224 | class FastMCPProtocol(Protocol):
225 | """Protocol describing the FastMCP interface we rely upon."""
226 |
227 | fastapi_app: FastAPI | None
228 | dependencies: list[Any]
229 |
230 | def resource(
231 | self, uri: str
232 | ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
233 |
234 | def event(
235 | self, name: str
236 | ) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]: ...
237 |
238 | def prompt(
239 | self, name: str | None = None, *, description: str | None = None
240 | ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
241 |
242 | def tool(
243 | self, name: str | None = None, *, description: str | None = None
244 | ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
245 |
246 | def run(self, *args: Any, **kwargs: Any) -> None: ...
247 |
248 |
249 | _use_stderr = "--transport" in sys.argv and "stdio" in sys.argv
250 |
251 | # Setup enhanced structured logging for backtesting
252 | setup_backtesting_logging(
253 | log_level=settings.api.log_level.upper(),
254 | enable_debug=settings.api.debug,
255 | log_file="logs/maverick_mcp.log" if not _use_stderr else None,
256 | )
257 |
258 | # Also setup the original logging for compatibility
259 | setup_structured_logging(
260 | log_level=settings.api.log_level.upper(),
261 | log_format="json" if settings.api.debug else "text",
262 | use_stderr=_use_stderr,
263 | )
264 |
265 | logger = get_logger("maverick_mcp.server")
266 | logger_manager = get_logger_manager()
267 |
268 | # Initialize FastMCP with enhanced connection management
269 | _fastmcp_instance = FastMCP(
270 | name=settings.app_name,
271 | )
272 | _fastmcp_instance.dependencies = []
273 | mcp = cast(FastMCPProtocol, _fastmcp_instance)
274 |
275 | # Initialize connection manager for stability
276 | connection_manager: "MCPConnectionManager | None" = None
277 |
278 | # TEMPORARILY DISABLED: MCP logging middleware - was breaking SSE transport
279 | # TODO: Fix middleware to work properly with SSE transport
280 | # logger.info("Adding comprehensive MCP logging middleware...")
281 | # try:
282 | # from maverick_mcp.api.middleware.mcp_logging import add_mcp_logging_middleware
283 | #
284 | # # Add logging middleware with debug mode based on settings
285 | # include_payloads = settings.api.debug or settings.api.log_level.upper() == "DEBUG"
286 | # import logging as py_logging
287 | # add_mcp_logging_middleware(
288 | # mcp,
289 | # include_payloads=include_payloads,
290 | # max_payload_length=3000, # Larger payloads in debug mode
291 | # log_level=getattr(py_logging, settings.api.log_level.upper())
292 | # )
293 | # logger.info("✅ MCP logging middleware added successfully")
294 | #
295 | # # Add console notification
296 | # print("🔧 MCP Server Enhanced Logging Enabled")
297 | # print(" 📊 Tool calls will be logged with execution details")
298 | # print(" 🔍 Protocol messages will be tracked for debugging")
299 | # print(" ⏱️ Timeout detection and warnings active")
300 | # print()
301 | #
302 | # except Exception as e:
303 | # logger.warning(f"Failed to add MCP logging middleware: {e}")
304 | # print("⚠️ Warning: MCP logging middleware could not be added")
305 |
306 | # Initialize monitoring and observability systems
307 | logger.info("Initializing monitoring and observability systems...")
308 |
309 | # Initialize core monitoring
310 | initialize_monitoring()
311 |
312 | # Initialize distributed tracing
313 | initialize_tracing()
314 |
315 | # Initialize backtesting metrics collector
316 | logger.info("Initializing backtesting metrics system...")
317 | try:
318 | from maverick_mcp.monitoring.metrics import get_backtesting_metrics
319 |
320 | backtesting_collector = get_backtesting_metrics()
321 | logger.info("✅ Backtesting metrics system initialized successfully")
322 |
323 | # Log metrics system capabilities
324 | print("🎯 Enhanced Backtesting Metrics System Enabled")
325 | print(" 📊 Strategy performance tracking active")
326 | print(" 🔄 API rate limiting and failure monitoring enabled")
327 | print(" 💾 Resource usage monitoring configured")
328 | print(" 🚨 Anomaly detection and alerting ready")
329 | print(" 📈 Prometheus metrics available at /metrics")
330 | print()
331 |
332 | except Exception as e:
333 | logger.warning(f"Failed to initialize backtesting metrics: {e}")
334 | print("⚠️ Warning: Backtesting metrics system could not be initialized")
335 |
336 | logger.info("Monitoring and observability systems initialized")
337 |
338 | # ENHANCED CONNECTION MANAGEMENT: Register tools through connection manager
339 | # This ensures tools persist through connection cycles and prevents disappearing tools
340 | logger.info("Initializing enhanced connection management system...")
341 |
342 | # Import connection manager and SSE optimizer
343 | # Connection management imports disabled for compatibility
344 | # from maverick_mcp.infrastructure.connection_manager import initialize_connection_management
345 | # from maverick_mcp.infrastructure.sse_optimizer import apply_sse_optimizations
346 |
347 | # Register all tools from routers directly for basic functionality
348 | register_all_router_tools(_fastmcp_instance)
349 | logger.info("Tools registered successfully")
350 |
351 | # Register monitoring and health endpoints directly with FastMCP
352 | from maverick_mcp.api.routers.health_enhanced import router as health_router
353 | from maverick_mcp.api.routers.monitoring import router as monitoring_router
354 |
355 | # Add monitoring and health endpoints to the FastMCP app's FastAPI instance
356 | if hasattr(mcp, "fastapi_app") and mcp.fastapi_app:
357 | mcp.fastapi_app.include_router(monitoring_router, tags=["monitoring"])
358 | mcp.fastapi_app.include_router(health_router, tags=["health"])
359 | logger.info("Monitoring and health endpoints registered with FastAPI application")
360 |
361 | # Initialize enhanced health monitoring system
362 | logger.info("Initializing enhanced health monitoring system...")
363 | try:
364 | from maverick_mcp.monitoring.health_monitor import get_health_monitor
365 | from maverick_mcp.utils.circuit_breaker import initialize_all_circuit_breakers
366 |
367 | # Initialize circuit breakers for all external APIs
368 | circuit_breaker_success = initialize_all_circuit_breakers()
369 | if circuit_breaker_success:
370 | logger.info("✅ Circuit breakers initialized for all external APIs")
371 | print("🛡️ Enhanced Circuit Breaker Protection Enabled")
372 | print(" 🔄 yfinance, Tiingo, FRED, OpenRouter, Exa APIs protected")
373 | print(" 📊 Failure detection and automatic recovery active")
374 | print(" 🚨 Circuit breaker monitoring and alerting enabled")
375 | else:
376 | logger.warning("⚠️ Some circuit breakers failed to initialize")
377 |
378 | # Get health monitor (will be started later in async context)
379 | health_monitor = get_health_monitor()
380 | logger.info("✅ Health monitoring system prepared")
381 |
382 | print("🏥 Comprehensive Health Monitoring System Ready")
383 | print(" 📈 Real-time component health tracking")
384 | print(" 🔍 Database, cache, and external API monitoring")
385 | print(" 💾 Resource usage monitoring (CPU, memory, disk)")
386 | print(" 📊 Status dashboard with historical metrics")
387 | print(" 🚨 Automated alerting and recovery actions")
388 | print(
389 | " 🩺 Health endpoints: /health, /health/detailed, /health/ready, /health/live"
390 | )
391 | print()
392 |
393 | except Exception as e:
394 | logger.warning(f"Failed to initialize enhanced health monitoring: {e}")
395 | print("⚠️ Warning: Enhanced health monitoring could not be fully initialized")
396 |
397 |
398 | # Add enhanced health endpoint as a resource
399 | @mcp.resource("health://")
400 | def health_resource() -> dict[str, Any]:
401 | """
402 | Enhanced comprehensive health check endpoint.
403 |
404 | Provides detailed system health including:
405 | - Component status (database, cache, external APIs)
406 | - Circuit breaker states
407 | - Resource utilization
408 | - Performance metrics
409 |
410 | Financial Disclaimer: This health check is for system monitoring only and does not
411 | provide any investment or financial advice.
412 | """
413 | try:
414 | import asyncio
415 |
416 | from maverick_mcp.api.routers.health_enhanced import _get_detailed_health_status
417 |
418 | loop_policy = asyncio.get_event_loop_policy()
419 | try:
420 | previous_loop = loop_policy.get_event_loop()
421 | except RuntimeError:
422 | previous_loop = None
423 |
424 | loop = loop_policy.new_event_loop()
425 | try:
426 | asyncio.set_event_loop(loop)
427 | health_status = loop.run_until_complete(_get_detailed_health_status())
428 | finally:
429 | loop.close()
430 | if previous_loop is not None:
431 | asyncio.set_event_loop(previous_loop)
432 | else:
433 | asyncio.set_event_loop(None)
434 |
435 | # Add service-specific information
436 | health_status.update(
437 | {
438 | "service": settings.app_name,
439 | "version": "1.0.0",
440 | "mode": "backtesting_with_enhanced_monitoring",
441 | }
442 | )
443 |
444 | return health_status
445 |
446 | except Exception as e:
447 | logger.error(f"Health resource check failed: {e}")
448 | return {
449 | "status": "unhealthy",
450 | "service": settings.app_name,
451 | "version": "1.0.0",
452 | "error": str(e),
453 | "timestamp": datetime.now(UTC).isoformat(),
454 | }
455 |
456 |
457 | # Add status dashboard endpoint as a resource
458 | @mcp.resource("dashboard://")
459 | def status_dashboard_resource() -> dict[str, Any]:
460 | """
461 | Comprehensive status dashboard with real-time metrics.
462 |
463 | Provides aggregated health status, performance metrics, alerts,
464 | and historical trends for the backtesting system.
465 | """
466 | try:
467 | import asyncio
468 |
469 | from maverick_mcp.monitoring.status_dashboard import get_dashboard_data
470 |
471 | loop_policy = asyncio.get_event_loop_policy()
472 | try:
473 | previous_loop = loop_policy.get_event_loop()
474 | except RuntimeError:
475 | previous_loop = None
476 |
477 | loop = loop_policy.new_event_loop()
478 | try:
479 | asyncio.set_event_loop(loop)
480 | dashboard_data = loop.run_until_complete(get_dashboard_data())
481 | finally:
482 | loop.close()
483 | if previous_loop is not None:
484 | asyncio.set_event_loop(previous_loop)
485 | else:
486 | asyncio.set_event_loop(None)
487 |
488 | return dashboard_data
489 |
490 | except Exception as e:
491 | logger.error(f"Dashboard resource failed: {e}")
492 | return {
493 | "error": "Failed to generate dashboard",
494 | "message": str(e),
495 | "timestamp": datetime.now(UTC).isoformat(),
496 | }
497 |
498 |
499 | # Add performance dashboard endpoint as a resource (keep existing)
500 | @mcp.resource("performance://")
501 | def performance_dashboard() -> dict[str, Any]:
502 | """
503 | Performance metrics dashboard showing backtesting system health.
504 |
505 | Provides real-time performance metrics, resource usage, and operational statistics
506 | for the backtesting infrastructure.
507 | """
508 | try:
509 | dashboard_metrics = logger_manager.create_dashboard_metrics()
510 |
511 | # Add additional context
512 | dashboard_metrics.update(
513 | {
514 | "service": settings.app_name,
515 | "environment": settings.environment,
516 | "version": "1.0.0",
517 | "dashboard_type": "backtesting_performance",
518 | "generated_at": datetime.now(UTC).isoformat(),
519 | }
520 | )
521 |
522 | return dashboard_metrics
523 | except Exception as e:
524 | logger.error(f"Failed to generate performance dashboard: {e}", exc_info=True)
525 | return {
526 | "error": "Failed to generate performance dashboard",
527 | "message": str(e),
528 | "timestamp": datetime.now(UTC).isoformat(),
529 | }
530 |
531 |
532 | # Prompts for Trading and Investing
533 |
534 |
535 | @mcp.prompt()
536 | def technical_analysis(ticker: str, timeframe: str = "daily") -> str:
537 | """Generate a comprehensive technical analysis prompt for a stock."""
538 | return f"""Please perform a comprehensive technical analysis for {ticker} on the {timeframe} timeframe.
539 |
540 | Use the available tools to:
541 | 1. Fetch historical price data and current stock information
542 | 2. Generate a full technical analysis including:
543 | - Trend analysis (primary, secondary trends)
544 | - Support and resistance levels
545 | - Moving averages (SMA, EMA analysis)
546 | - Key indicators (RSI, MACD, Stochastic)
547 | - Volume analysis and patterns
548 | - Chart patterns identification
549 | 3. Create a technical chart visualization
550 | 4. Provide a short-term outlook
551 |
552 | Focus on:
553 | - Price action and volume confirmation
554 | - Convergence/divergence of indicators
555 | - Risk/reward setup quality
556 | - Key decision levels for traders
557 |
558 | Present findings in a structured format with clear entry/exit suggestions if applicable."""
559 |
560 |
561 | @mcp.prompt()
562 | def stock_screening_report(strategy: str = "momentum") -> str:
563 | """Generate a stock screening report based on different strategies."""
564 | strategies = {
565 | "momentum": "high momentum and relative strength",
566 | "value": "undervalued with strong fundamentals",
567 | "growth": "high growth potential",
568 | "quality": "strong balance sheets and consistent earnings",
569 | }
570 |
571 | strategy_desc = strategies.get(strategy.lower(), "balanced approach")
572 |
573 | return f"""Please generate a comprehensive stock screening report focused on {strategy_desc}.
574 |
575 | Use the screening tools to:
576 | 1. Retrieve Maverick bullish stocks (for momentum/growth strategies)
577 | 2. Get Maverick bearish stocks (for short opportunities)
578 | 3. Fetch trending stocks (for breakout setups)
579 | 4. Analyze the top candidates with technical indicators
580 |
581 | For each recommended stock:
582 | - Current technical setup and score
583 | - Key levels (support, resistance, stop loss)
584 | - Risk/reward analysis
585 | - Volume and momentum characteristics
586 | - Sector/industry context
587 |
588 | Organize results by:
589 | 1. Top picks (highest conviction)
590 | 2. Watch list (developing setups)
591 | 3. Avoid list (deteriorating technicals)
592 |
593 | Include market context and any relevant economic factors."""
594 |
595 |
596 | # Simplified portfolio and watchlist tools (no authentication required)
597 | @mcp.tool()
598 | async def get_user_portfolio_summary() -> dict[str, Any]:
599 | """
600 | Get basic portfolio summary and stock analysis capabilities.
601 |
602 | Returns available features and sample stock data.
603 | """
604 | return {
605 | "mode": "simple_stock_analysis",
606 | "features": {
607 | "stock_data": True,
608 | "technical_analysis": True,
609 | "market_screening": True,
610 | "portfolio_analysis": True,
611 | "real_time_quotes": True,
612 | },
613 | "sample_data": "Use get_watchlist() to see sample stock data",
614 | "usage": "All stock analysis tools are available without restrictions",
615 | "last_updated": datetime.now(UTC).isoformat(),
616 | }
617 |
618 |
619 | @mcp.tool()
620 | async def get_watchlist(limit: int = 20) -> dict[str, Any]:
621 | """
622 | Get sample watchlist with real-time stock data.
623 |
624 | Provides stock data for popular tickers to demonstrate functionality.
625 | """
626 | # Sample watchlist for demonstration
627 | watchlist_tickers = [
628 | "AAPL",
629 | "MSFT",
630 | "GOOGL",
631 | "AMZN",
632 | "TSLA",
633 | "META",
634 | "NVDA",
635 | "JPM",
636 | "V",
637 | "JNJ",
638 | "UNH",
639 | "PG",
640 | "HD",
641 | "MA",
642 | "DIS",
643 | ][:limit]
644 |
645 | import asyncio
646 |
647 | def _build_watchlist() -> dict[str, Any]:
648 | db_session = next(get_db())
649 | try:
650 | provider = StockDataProvider(db_session=db_session)
651 | watchlist_data: list[dict[str, Any]] = []
652 | for ticker in watchlist_tickers:
653 | try:
654 | info = provider.get_stock_info(ticker)
655 | current_price = info.get("currentPrice", 0)
656 | previous_close = info.get("previousClose", current_price)
657 | change = current_price - previous_close
658 | change_pct = (
659 | (change / previous_close * 100) if previous_close else 0
660 | )
661 |
662 | ticker_data = {
663 | "ticker": ticker,
664 | "name": info.get("longName", ticker),
665 | "current_price": round(current_price, 2),
666 | "change": round(change, 2),
667 | "change_percent": round(change_pct, 2),
668 | "volume": info.get("volume", 0),
669 | "market_cap": info.get("marketCap", 0),
670 | "bid": info.get("bid", 0),
671 | "ask": info.get("ask", 0),
672 | "bid_size": info.get("bidSize", 0),
673 | "ask_size": info.get("askSize", 0),
674 | "last_trade_time": datetime.now(UTC).isoformat(),
675 | }
676 |
677 | watchlist_data.append(ticker_data)
678 |
679 | except Exception as exc:
680 | logger.error(f"Error fetching data for {ticker}: {str(exc)}")
681 | continue
682 |
683 | return {
684 | "watchlist": watchlist_data,
685 | "count": len(watchlist_data),
686 | "mode": "simple_stock_analysis",
687 | "last_updated": datetime.now(UTC).isoformat(),
688 | }
689 | finally:
690 | db_session.close()
691 |
692 | return await asyncio.to_thread(_build_watchlist)
693 |
694 |
695 | # Market Overview Tools (full access)
696 | @mcp.tool()
697 | async def get_market_overview() -> dict[str, Any]:
698 | """
699 | Get comprehensive market overview including indices, sectors, and market breadth.
700 |
701 | Provides full market data without restrictions.
702 | """
703 | try:
704 | # Create market provider instance
705 | import asyncio
706 |
707 | provider = MarketDataProvider()
708 |
709 | indices, sectors, breadth = await asyncio.gather(
710 | provider.get_market_summary_async(),
711 | provider.get_sector_performance_async(),
712 | provider.get_market_overview_async(),
713 | )
714 |
715 | overview = {
716 | "indices": indices,
717 | "sectors": sectors,
718 | "market_breadth": breadth,
719 | "last_updated": datetime.now(UTC).isoformat(),
720 | "mode": "simple_stock_analysis",
721 | }
722 |
723 | vix_value = indices.get("current_price", 0)
724 | overview["volatility"] = {
725 | "vix": vix_value,
726 | "vix_change": indices.get("change_percent", 0),
727 | "fear_level": (
728 | "extreme"
729 | if vix_value > 30
730 | else (
731 | "high"
732 | if vix_value > 20
733 | else "moderate"
734 | if vix_value > 15
735 | else "low"
736 | )
737 | ),
738 | }
739 |
740 | return overview
741 |
742 | except Exception as e:
743 | logger.error(f"Error getting market overview: {str(e)}")
744 | return {"error": str(e), "status": "error"}
745 |
746 |
747 | @mcp.tool()
748 | async def get_economic_calendar(days_ahead: int = 7) -> dict[str, Any]:
749 | """
750 | Get upcoming economic events and indicators.
751 |
752 | Provides full access to economic calendar data.
753 | """
754 | try:
755 | # Get economic calendar events (placeholder implementation)
756 | events: list[
757 | dict[str, Any]
758 | ] = [] # macro_provider doesn't have get_economic_calendar method
759 |
760 | return {
761 | "events": events,
762 | "days_ahead": days_ahead,
763 | "event_count": len(events),
764 | "mode": "simple_stock_analysis",
765 | "last_updated": datetime.now(UTC).isoformat(),
766 | }
767 |
768 | except Exception as e:
769 | logger.error(f"Error getting economic calendar: {str(e)}")
770 | return {"error": str(e), "status": "error"}
771 |
772 |
773 | @mcp.tool()
774 | async def get_mcp_connection_status() -> dict[str, Any]:
775 | """
776 | Get current MCP connection status for debugging connection stability issues.
777 |
778 | Returns detailed information about active connections, tool registration status,
779 | and connection health metrics to help diagnose disappearing tools.
780 | """
781 | try:
782 | global connection_manager
783 | if connection_manager is None:
784 | return {
785 | "error": "Connection manager not initialized",
786 | "status": "error",
787 | "server_mode": "simple_stock_analysis",
788 | "timestamp": datetime.now(UTC).isoformat(),
789 | }
790 |
791 | # Get connection status from manager
792 | status = connection_manager.get_connection_status()
793 |
794 | # Add additional debugging info
795 | status.update(
796 | {
797 | "server_mode": "simple_stock_analysis",
798 | "mcp_server_name": settings.app_name,
799 | "transport_modes": ["stdio", "sse", "streamable-http"],
800 | "debugging_info": {
801 | "tools_should_be_visible": status["tools_registered"],
802 | "recommended_action": (
803 | "Tools are registered and should be visible"
804 | if status["tools_registered"]
805 | else "Tools not registered - check connection manager"
806 | ),
807 | },
808 | "timestamp": datetime.now(UTC).isoformat(),
809 | }
810 | )
811 |
812 | return status
813 |
814 | except Exception as e:
815 | logger.error(f"Error getting connection status: {str(e)}")
816 | return {
817 | "error": str(e),
818 | "status": "error",
819 | "timestamp": datetime.now(UTC).isoformat(),
820 | }
821 |
822 |
823 | # Resources (public access)
824 | @mcp.resource("stock://{ticker}")
825 | def stock_resource(ticker: str) -> Any:
826 | """Get the latest stock data for a given ticker"""
827 | db_session = next(get_db())
828 | try:
829 | provider = StockDataProvider(db_session=db_session)
830 | df = provider.get_stock_data(ticker)
831 | payload = cast(str, df.to_json(orient="split", date_format="iso"))
832 | return json.loads(payload)
833 | finally:
834 | db_session.close()
835 |
836 |
837 | @mcp.resource("stock://{ticker}/{start_date}/{end_date}")
838 | def stock_resource_with_dates(ticker: str, start_date: str, end_date: str) -> Any:
839 | """Get stock data for a given ticker and date range"""
840 | db_session = next(get_db())
841 | try:
842 | provider = StockDataProvider(db_session=db_session)
843 | df = provider.get_stock_data(ticker, start_date, end_date)
844 | payload = cast(str, df.to_json(orient="split", date_format="iso"))
845 | return json.loads(payload)
846 | finally:
847 | db_session.close()
848 |
849 |
850 | @mcp.resource("stock_info://{ticker}")
851 | def stock_info_resource(ticker: str) -> dict[str, Any]:
852 | """Get detailed information about a stock"""
853 | db_session = next(get_db())
854 | try:
855 | provider = StockDataProvider(db_session=db_session)
856 | info = provider.get_stock_info(ticker)
857 | # Convert any non-serializable objects to strings
858 | return {
859 | k: (
860 | str(v)
861 | if not isinstance(
862 | v, int | float | bool | str | list | dict | type(None)
863 | )
864 | else v
865 | )
866 | for k, v in info.items()
867 | }
868 | finally:
869 | db_session.close()
870 |
871 |
872 | @mcp.resource("portfolio://my-holdings")
873 | def portfolio_holdings_resource() -> dict[str, Any]:
874 | """
875 | Get your current portfolio holdings as an MCP resource.
876 |
877 | This resource provides AI-enriched context about your portfolio for Claude to use
878 | in conversations. It includes all positions with current prices and P&L calculations.
879 |
880 | Returns:
881 | Dictionary containing portfolio holdings with performance metrics
882 | """
883 | from maverick_mcp.api.routers.portfolio import get_my_portfolio
884 |
885 | try:
886 | # Get portfolio with current prices
887 | portfolio_data = get_my_portfolio(
888 | user_id="default",
889 | portfolio_name="My Portfolio",
890 | include_current_prices=True,
891 | )
892 |
893 | if portfolio_data.get("status") == "error":
894 | return {
895 | "error": portfolio_data.get("error", "Unknown error"),
896 | "uri": "portfolio://my-holdings",
897 | "description": "Error retrieving portfolio holdings",
898 | }
899 |
900 | # Add resource metadata
901 | portfolio_data["uri"] = "portfolio://my-holdings"
902 | portfolio_data["description"] = (
903 | "Your current stock portfolio with live prices and P&L"
904 | )
905 | portfolio_data["mimeType"] = "application/json"
906 |
907 | return portfolio_data
908 |
909 | except Exception as e:
910 | logger.error(f"Portfolio holdings resource failed: {e}")
911 | return {
912 | "error": str(e),
913 | "uri": "portfolio://my-holdings",
914 | "description": "Failed to retrieve portfolio holdings",
915 | }
916 |
917 |
918 | # Main execution block
919 | if __name__ == "__main__":
920 | import asyncio
921 |
922 | from maverick_mcp.config.validation import validate_environment
923 | from maverick_mcp.utils.shutdown import graceful_shutdown
924 |
925 | # Parse command line arguments
926 | parser = argparse.ArgumentParser(
927 | description=f"{settings.app_name} Simple Stock Analysis MCP Server"
928 | )
929 | parser.add_argument(
930 | "--transport",
931 | choices=["stdio", "sse", "streamable-http"],
932 | default="sse",
933 | help="Transport method to use (default: sse)",
934 | )
935 | parser.add_argument(
936 | "--port",
937 | type=int,
938 | default=settings.api.port,
939 | help=f"Port to run the server on (default: {settings.api.port})",
940 | )
941 | parser.add_argument(
942 | "--host",
943 | default=settings.api.host,
944 | help=f"Host to run the server on (default: {settings.api.host})",
945 | )
946 |
947 | args = parser.parse_args()
948 |
949 | # Reconfigure logging for stdio transport to use stderr
950 | if args.transport == "stdio":
951 | setup_structured_logging(
952 | log_level=settings.api.log_level.upper(),
953 | log_format="json" if settings.api.debug else "text",
954 | use_stderr=True,
955 | )
956 |
957 | # Validate environment before starting
958 | # For stdio transport, use lenient validation to support testing
959 | fail_on_validation_error = args.transport != "stdio"
960 | logger.info("Validating environment configuration...")
961 | validate_environment(fail_on_error=fail_on_validation_error)
962 |
963 | # Initialize performance systems and health monitoring
964 | async def init_systems():
965 | logger.info("Initializing performance optimization systems...")
966 | try:
967 | performance_status = await initialize_performance_systems()
968 | logger.info(f"Performance systems initialized: {performance_status}")
969 | except Exception as e:
970 | logger.error(f"Failed to initialize performance systems: {e}")
971 |
972 | # Initialize background health monitoring
973 | logger.info("Starting background health monitoring...")
974 | try:
975 | from maverick_mcp.monitoring.health_monitor import start_health_monitoring
976 |
977 | await start_health_monitoring()
978 | logger.info("✅ Background health monitoring started")
979 | except Exception as e:
980 | logger.error(f"Failed to start health monitoring: {e}")
981 |
982 | asyncio.run(init_systems())
983 |
984 | # Initialize connection management and transport optimizations
985 | async def init_connection_management():
986 | global connection_manager
987 |
988 | # Initialize connection manager (removed for linting)
989 | logger.info("Enhanced connection management system initialized")
990 |
991 | # Apply SSE transport optimizations (removed for linting)
992 | logger.info("SSE transport optimizations applied")
993 |
994 | # Add connection event handlers for monitoring
995 | @mcp.event("connection_opened")
996 | async def on_connection_open(session_id: str | None = None) -> str:
997 | """Handle new MCP connection with enhanced stability."""
998 | if connection_manager is None:
999 | fallback_session_id = session_id or str(uuid.uuid4())
1000 | logger.info(
1001 | "MCP connection opened without manager: %s", fallback_session_id[:8]
1002 | )
1003 | return fallback_session_id
1004 |
1005 | try:
1006 | actual_session_id = await connection_manager.handle_new_connection(
1007 | session_id
1008 | )
1009 | logger.info(f"MCP connection opened: {actual_session_id[:8]}")
1010 | return actual_session_id
1011 | except Exception as e:
1012 | logger.error(f"Failed to handle connection open: {e}")
1013 | raise
1014 |
1015 | @mcp.event("connection_closed")
1016 | async def on_connection_close(session_id: str) -> None:
1017 | """Handle MCP connection close with cleanup."""
1018 | if connection_manager is None:
1019 | logger.info(
1020 | "MCP connection close received without manager: %s", session_id[:8]
1021 | )
1022 | return
1023 |
1024 | try:
1025 | await connection_manager.handle_connection_close(session_id)
1026 | logger.info(f"MCP connection closed: {session_id[:8]}")
1027 | except Exception as e:
1028 | logger.error(f"Failed to handle connection close: {e}")
1029 |
1030 | @mcp.event("message_received")
1031 | async def on_message_received(session_id: str, message: dict[str, Any]) -> None:
1032 | """Update session activity on message received."""
1033 | if connection_manager is None:
1034 | logger.debug(
1035 | "Skipping session activity update; connection manager disabled."
1036 | )
1037 | return
1038 |
1039 | try:
1040 | await connection_manager.update_session_activity(session_id)
1041 | except Exception as e:
1042 | logger.error(f"Failed to update session activity: {e}")
1043 |
1044 | logger.info("Connection event handlers registered")
1045 |
1046 | # Connection management disabled for compatibility
1047 | # asyncio.run(init_connection_management())
1048 |
1049 | logger.info(f"Starting {settings.app_name} simple stock analysis server")
1050 |
1051 | # Add initialization delay for connection stability
1052 | import time
1053 |
1054 | logger.info("Adding startup delay for connection stability...")
1055 | time.sleep(3) # 3 second delay to ensure full initialization
1056 | logger.info("Startup delay completed, server ready for connections")
1057 |
1058 | # Use graceful shutdown handler
1059 | with graceful_shutdown(f"{settings.app_name}-{args.transport}") as shutdown_handler:
1060 | # Log startup configuration
1061 | logger.info(
1062 | "Server configuration",
1063 | extra={
1064 | "transport": args.transport,
1065 | "host": args.host,
1066 | "port": args.port,
1067 | "mode": "simple_stock_analysis",
1068 | "auth_enabled": False,
1069 | "debug_mode": settings.api.debug,
1070 | "environment": settings.environment,
1071 | },
1072 | )
1073 |
1074 | # Register performance systems cleanup
1075 | async def cleanup_performance():
1076 | """Cleanup performance optimization systems during shutdown."""
1077 | try:
1078 | await cleanup_performance_systems()
1079 | except Exception as e:
1080 | logger.error(f"Error cleaning up performance systems: {e}")
1081 |
1082 | shutdown_handler.register_cleanup(cleanup_performance)
1083 |
1084 | # Register health monitoring cleanup
1085 | async def cleanup_health_monitoring():
1086 | """Cleanup health monitoring during shutdown."""
1087 | try:
1088 | from maverick_mcp.monitoring.health_monitor import (
1089 | stop_health_monitoring,
1090 | )
1091 |
1092 | await stop_health_monitoring()
1093 | logger.info("Health monitoring stopped")
1094 | except Exception as e:
1095 | logger.error(f"Error stopping health monitoring: {e}")
1096 |
1097 | shutdown_handler.register_cleanup(cleanup_health_monitoring)
1098 |
1099 | # Register connection manager cleanup
1100 | async def cleanup_connection_manager():
1101 | """Cleanup connection manager during shutdown."""
1102 | try:
1103 | if connection_manager:
1104 | await connection_manager.shutdown()
1105 | logger.info("Connection manager shutdown complete")
1106 | except Exception as e:
1107 | logger.error(f"Error shutting down connection manager: {e}")
1108 |
1109 | shutdown_handler.register_cleanup(cleanup_connection_manager)
1110 |
1111 | # Register cache cleanup
1112 | def close_cache():
1113 | """Close Redis connections during shutdown."""
1114 | from maverick_mcp.data.cache import get_redis_client
1115 |
1116 | try:
1117 | redis_client = get_redis_client()
1118 | if redis_client:
1119 | logger.info("Closing Redis connections...")
1120 | redis_client.close()
1121 | logger.info("Redis connections closed")
1122 | except Exception as e:
1123 | logger.error(f"Error closing Redis: {e}")
1124 |
1125 | shutdown_handler.register_cleanup(close_cache)
1126 |
1127 | # Run with the appropriate transport
1128 | if args.transport == "stdio":
1129 | logger.info(f"Starting {settings.app_name} server with stdio transport")
1130 | mcp.run(
1131 | transport="stdio",
1132 | debug=settings.api.debug,
1133 | log_level=settings.api.log_level.upper(),
1134 | )
1135 | elif args.transport == "streamable-http":
1136 | logger.info(
1137 | f"Starting {settings.app_name} server with streamable-http transport on http://{args.host}:{args.port}"
1138 | )
1139 | mcp.run(
1140 | transport="streamable-http",
1141 | port=args.port,
1142 | host=args.host,
1143 | )
1144 | else: # sse
1145 | logger.info(
1146 | f"Starting {settings.app_name} server with SSE transport on http://{args.host}:{args.port}"
1147 | )
1148 | mcp.run(
1149 | transport="sse",
1150 | port=args.port,
1151 | host=args.host,
1152 | path="/sse", # No trailing slash - both /sse and /sse/ will work with the monkey-patch
1153 | )
1154 |
```