This is page 24 of 29. Use http://codebase.md/wshobson/maverick-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_parallel_research_integration.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive integration tests for parallel research functionality.
This test suite covers:
- End-to-end parallel research workflows
- Integration between all parallel research components
- Performance characteristics under realistic conditions
- Error scenarios and recovery mechanisms
- Logging integration across all components
- Resource usage and scalability testing
"""
import asyncio
import time
from datetime import datetime
from typing import Any
from unittest.mock import Mock, patch
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langgraph.checkpoint.memory import MemorySaver
from maverick_mcp.agents.deep_research import DeepResearchAgent
from maverick_mcp.utils.parallel_research import (
ParallelResearchConfig,
)
class MockSearchProvider:
"""Mock search provider for integration testing."""
def __init__(self, provider_name: str, fail_rate: float = 0.0):
self.provider_name = provider_name
self.fail_rate = fail_rate
self.call_count = 0
async def search(self, query: str, num_results: int = 10) -> list[dict[str, Any]]:
"""Mock search with configurable failure rate."""
self.call_count += 1
# Simulate failures based on fail_rate
import random
if random.random() < self.fail_rate:
raise RuntimeError(f"{self.provider_name} search failed")
await asyncio.sleep(0.02) # Simulate network latency
# Generate mock search results
results = []
for i in range(min(num_results, 3)): # Return up to 3 results
results.append(
{
"url": f"https://{self.provider_name.lower()}.example.com/article_{i}_{self.call_count}",
"title": f"{query} - Article {i + 1} from {self.provider_name}",
"content": f"This is detailed content about {query} from {self.provider_name}. "
f"It contains valuable insights and analysis relevant to the research topic. "
f"Provider: {self.provider_name}, Call: {self.call_count}",
"published_date": datetime.now().isoformat(),
"author": f"Expert Analyst {i + 1}",
"score": 0.8 - (i * 0.1),
"provider": self.provider_name.lower(),
}
)
return results
class MockContentAnalyzer:
"""Mock content analyzer for integration testing."""
def __init__(self, analysis_delay: float = 0.01):
self.analysis_delay = analysis_delay
self.analysis_count = 0
async def analyze_content(
self, content: str, persona: str, analysis_focus: str = "general"
) -> dict[str, Any]:
"""Mock content analysis."""
self.analysis_count += 1
await asyncio.sleep(self.analysis_delay)
# Generate realistic analysis based on content keywords
insights = []
risk_factors = []
opportunities = []
content_lower = content.lower()
if "earnings" in content_lower or "revenue" in content_lower:
insights.append("Strong earnings performance indicated")
opportunities.append("Potential for continued revenue growth")
if "technical" in content_lower or "chart" in content_lower:
insights.append("Technical indicators suggest trend continuation")
risk_factors.append("Support level break could trigger selling")
if "sentiment" in content_lower or "analyst" in content_lower:
insights.append("Market sentiment appears positive")
opportunities.append("Analyst upgrades possible")
if "competitive" in content_lower or "market share" in content_lower:
insights.append("Competitive position remains strong")
risk_factors.append("Increased competitive pressure in market")
# Default insights if no specific keywords found
if not insights:
insights = [
f"General analysis insight {self.analysis_count} for {persona} investor"
]
sentiment_mapping = {
"conservative": {"direction": "neutral", "confidence": 0.6},
"moderate": {"direction": "bullish", "confidence": 0.7},
"aggressive": {"direction": "bullish", "confidence": 0.8},
}
return {
"insights": insights,
"sentiment": sentiment_mapping.get(
persona, {"direction": "neutral", "confidence": 0.5}
),
"risk_factors": risk_factors or ["Standard market risks apply"],
"opportunities": opportunities or ["Monitor for opportunities"],
"credibility_score": 0.8,
"relevance_score": 0.75,
"summary": f"Analysis for {persona} investor from {analysis_focus} perspective",
"analysis_timestamp": datetime.now(),
}
class MockLLM(BaseChatModel):
"""Mock LLM for integration testing."""
def __init__(self, response_delay: float = 0.05, fail_rate: float = 0.0):
super().__init__()
self.response_delay = response_delay
self.fail_rate = fail_rate
self.invocation_count = 0
async def ainvoke(self, messages, config=None, **kwargs):
"""Mock async LLM invocation."""
self.invocation_count += 1
# Simulate failures
import random
if random.random() < self.fail_rate:
raise RuntimeError("LLM service unavailable")
await asyncio.sleep(self.response_delay)
# Generate contextual response based on message content
message_content = str(messages[-1].content).lower()
if "synthesis" in message_content:
response = """
Based on the comprehensive research from multiple specialized agents, this analysis provides
a well-rounded view of the investment opportunity. The fundamental analysis shows strong
financial metrics, while sentiment analysis indicates positive market reception. Technical
analysis suggests favorable entry points, and competitive analysis reveals sustainable
advantages. Overall, this presents a compelling investment case for the specified investor persona.
"""
else:
response = '{"KEY_INSIGHTS": ["AI-generated insight"], "SENTIMENT": {"direction": "bullish", "confidence": 0.75}, "CREDIBILITY": 0.8}'
return AIMessage(content=response)
def _generate(self, messages, stop=None, **kwargs):
raise NotImplementedError("Use ainvoke for async tests")
@property
def _llm_type(self) -> str:
return "mock_llm"
@pytest.mark.integration
class TestParallelResearchEndToEnd:
"""Test complete end-to-end parallel research workflows."""
@pytest.fixture
def integration_config(self):
"""Configuration for integration testing."""
return ParallelResearchConfig(
max_concurrent_agents=3,
timeout_per_agent=10,
enable_fallbacks=True,
rate_limit_delay=0.1,
)
@pytest.fixture
def mock_search_providers(self):
"""Create mock search providers."""
return [
MockSearchProvider("Exa", fail_rate=0.1),
MockSearchProvider("Tavily", fail_rate=0.1),
]
@pytest.fixture
def integration_agent(self, integration_config, mock_search_providers):
"""Create DeepResearchAgent for integration testing."""
llm = MockLLM(response_delay=0.05, fail_rate=0.05)
agent = DeepResearchAgent(
llm=llm,
persona="moderate",
checkpointer=MemorySaver(),
enable_parallel_execution=True,
parallel_config=integration_config,
)
# Replace search providers with mocks
agent.search_providers = mock_search_providers
# Replace content analyzer with mock
agent.content_analyzer = MockContentAnalyzer(analysis_delay=0.02)
return agent
@pytest.mark.asyncio
async def test_complete_parallel_research_workflow(self, integration_agent):
"""Test complete parallel research workflow from start to finish."""
start_time = time.time()
result = await integration_agent.research_comprehensive(
topic="Apple Inc comprehensive investment analysis for Q4 2024",
session_id="integration_test_001",
depth="comprehensive",
focus_areas=["fundamentals", "technical_analysis", "market_sentiment"],
timeframe="30d",
)
execution_time = time.time() - start_time
# Verify successful execution
assert result["status"] == "success"
assert result["agent_type"] == "deep_research"
assert result["execution_mode"] == "parallel"
assert (
result["research_topic"]
== "Apple Inc comprehensive investment analysis for Q4 2024"
)
# Verify research quality
assert result["confidence_score"] > 0.5
assert result["sources_analyzed"] > 0
assert len(result["citations"]) > 0
# Verify parallel execution stats
assert "parallel_execution_stats" in result
stats = result["parallel_execution_stats"]
assert stats["total_tasks"] > 0
assert stats["successful_tasks"] >= 0
assert stats["parallel_efficiency"] > 0
# Verify findings structure
assert "findings" in result
findings = result["findings"]
assert "synthesis" in findings
assert "confidence_score" in findings
# Verify performance characteristics
assert execution_time < 15 # Should complete within reasonable time
assert result["execution_time_ms"] > 0
@pytest.mark.asyncio
async def test_parallel_vs_sequential_performance_comparison(
self, integration_agent
):
"""Compare parallel vs sequential execution performance."""
topic = "Tesla Inc strategic analysis and market position"
# Test parallel execution
start_parallel = time.time()
parallel_result = await integration_agent.research_comprehensive(
topic=topic,
session_id="perf_test_parallel",
use_parallel_execution=True,
depth="standard",
)
parallel_time = time.time() - start_parallel
# Test sequential execution
start_sequential = time.time()
sequential_result = await integration_agent.research_comprehensive(
topic=topic,
session_id="perf_test_sequential",
use_parallel_execution=False,
depth="standard",
)
sequential_time = time.time() - start_sequential
# Both should succeed
assert parallel_result["status"] == "success"
assert sequential_result["status"] == "success"
# Verify execution modes
assert parallel_result["execution_mode"] == "parallel"
# Sequential won't have execution_mode in result
# Parallel should show efficiency metrics
if "parallel_execution_stats" in parallel_result:
stats = parallel_result["parallel_execution_stats"]
assert stats["parallel_efficiency"] > 0
# If multiple tasks were executed in parallel, should show some efficiency gain
if stats["total_tasks"] > 1:
print(
f"Parallel time: {parallel_time:.3f}s, Sequential time: {sequential_time:.3f}s"
)
print(f"Parallel efficiency: {stats['parallel_efficiency']:.2f}x")
@pytest.mark.asyncio
async def test_error_resilience_and_fallback(self, integration_config):
"""Test error resilience and fallback mechanisms."""
# Create agent with higher failure rates to test resilience
failing_llm = MockLLM(fail_rate=0.3) # 30% failure rate
failing_providers = [
MockSearchProvider("FailingProvider", fail_rate=0.5) # 50% failure rate
]
agent = DeepResearchAgent(
llm=failing_llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=integration_config,
)
agent.search_providers = failing_providers
agent.content_analyzer = MockContentAnalyzer()
# Should handle failures gracefully
result = await agent.research_comprehensive(
topic="Resilience test analysis",
session_id="resilience_test_001",
depth="basic",
)
# Should complete successfully despite some failures
assert result["status"] == "success"
# May have lower confidence due to failures
assert result["confidence_score"] >= 0.0
# Should have some parallel execution stats even if tasks failed
if result.get("execution_mode") == "parallel":
stats = result["parallel_execution_stats"]
# Total tasks should be >= failed tasks (some tasks were attempted)
assert stats["total_tasks"] >= stats.get("failed_tasks", 0)
@pytest.mark.asyncio
async def test_different_research_depths(self, integration_agent):
"""Test parallel research with different depth configurations."""
depths_to_test = ["basic", "standard", "comprehensive"]
results = {}
for depth in depths_to_test:
result = await integration_agent.research_comprehensive(
topic=f"Microsoft Corp analysis - {depth} depth",
session_id=f"depth_test_{depth}",
depth=depth,
use_parallel_execution=True,
)
results[depth] = result
# All should succeed
assert result["status"] == "success"
assert result["research_depth"] == depth
# Comprehensive should generally have more sources and higher confidence
if all(r["status"] == "success" for r in results.values()):
basic_sources = results["basic"]["sources_analyzed"]
comprehensive_sources = results["comprehensive"]["sources_analyzed"]
# More comprehensive research should analyze more sources (when successful)
if basic_sources > 0 and comprehensive_sources > 0:
assert comprehensive_sources >= basic_sources
@pytest.mark.asyncio
async def test_persona_specific_research(self, integration_config):
"""Test parallel research with different investor personas."""
personas_to_test = ["conservative", "moderate", "aggressive"]
topic = "Amazon Inc investment opportunity analysis"
for persona in personas_to_test:
llm = MockLLM(response_delay=0.03)
agent = DeepResearchAgent(
llm=llm,
persona=persona,
enable_parallel_execution=True,
parallel_config=integration_config,
)
# Mock components
agent.search_providers = [MockSearchProvider("TestProvider")]
agent.content_analyzer = MockContentAnalyzer()
result = await agent.research_comprehensive(
topic=topic,
session_id=f"persona_test_{persona}",
use_parallel_execution=True,
)
assert result["status"] == "success"
assert result["persona"] == persona
# Should have findings tailored to persona
assert "findings" in result
@pytest.mark.asyncio
async def test_concurrent_research_sessions(self, integration_agent):
"""Test multiple concurrent research sessions."""
topics = [
"Google Alphabet strategic analysis",
"Meta Platforms competitive position",
"Netflix content strategy evaluation",
]
# Run multiple research sessions concurrently
tasks = [
integration_agent.research_comprehensive(
topic=topic,
session_id=f"concurrent_test_{i}",
use_parallel_execution=True,
depth="standard",
)
for i, topic in enumerate(topics)
]
start_time = time.time()
results = await asyncio.gather(*tasks, return_exceptions=True)
execution_time = time.time() - start_time
# All should succeed (or be exceptions we can handle)
successful_results = [
r for r in results if isinstance(r, dict) and r.get("status") == "success"
]
assert (
len(successful_results) >= len(topics) // 2
) # At least half should succeed
# Should complete in reasonable time despite concurrency
assert execution_time < 30
# Verify each result has proper session isolation
for _i, result in enumerate(successful_results):
if "findings" in result:
# Each should have distinct research content
assert result["research_topic"] in topics
@pytest.mark.integration
class TestParallelResearchScalability:
"""Test scalability characteristics of parallel research."""
@pytest.fixture
def scalability_config(self):
"""Configuration for scalability testing."""
return ParallelResearchConfig(
max_concurrent_agents=4,
timeout_per_agent=8,
enable_fallbacks=True,
rate_limit_delay=0.05,
)
@pytest.mark.asyncio
async def test_agent_limit_enforcement(self, scalability_config):
"""Test that concurrent agent limits are properly enforced."""
llm = MockLLM(response_delay=0.1) # Slower to see concurrency effects
agent = DeepResearchAgent(
llm=llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=scalability_config,
)
# Mock components with tracking
call_tracker = {"max_concurrent": 0, "current_concurrent": 0}
class TrackingProvider(MockSearchProvider):
async def search(self, query: str, num_results: int = 10):
call_tracker["current_concurrent"] += 1
call_tracker["max_concurrent"] = max(
call_tracker["max_concurrent"], call_tracker["current_concurrent"]
)
try:
return await super().search(query, num_results)
finally:
call_tracker["current_concurrent"] -= 1
agent.search_providers = [TrackingProvider("Tracker")]
agent.content_analyzer = MockContentAnalyzer()
result = await agent.research_comprehensive(
topic="Scalability test with many potential subtasks",
session_id="scalability_test_001",
focus_areas=[
"fundamentals",
"technical",
"sentiment",
"competitive",
"extra1",
"extra2",
],
use_parallel_execution=True,
)
assert result["status"] == "success"
# Should not exceed configured max concurrent agents
assert (
call_tracker["max_concurrent"] <= scalability_config.max_concurrent_agents
)
@pytest.mark.asyncio
async def test_memory_usage_under_load(self, scalability_config):
"""Test memory usage characteristics under load."""
import gc
import os
import psutil
# Get initial memory usage
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
llm = MockLLM(response_delay=0.02)
agent = DeepResearchAgent(
llm=llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=scalability_config,
)
agent.search_providers = [MockSearchProvider("MemoryTest")]
agent.content_analyzer = MockContentAnalyzer(analysis_delay=0.01)
# Perform multiple research operations
for i in range(10): # 10 operations to test memory accumulation
result = await agent.research_comprehensive(
topic=f"Memory test analysis {i}",
session_id=f"memory_test_{i}",
use_parallel_execution=True,
depth="basic",
)
assert result["status"] == "success"
# Force garbage collection
gc.collect()
# Check final memory usage
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_growth = final_memory - initial_memory
# Memory growth should be reasonable (not indicative of leaks)
assert memory_growth < 100 # Less than 100MB growth for 10 operations
@pytest.mark.asyncio
async def test_large_scale_task_distribution(self, scalability_config):
"""Test task distribution with many potential research areas."""
llm = MockLLM()
agent = DeepResearchAgent(
llm=llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=scalability_config,
)
agent.search_providers = [MockSearchProvider("LargeScale")]
agent.content_analyzer = MockContentAnalyzer()
# Test with many focus areas (more than max concurrent agents)
many_focus_areas = [
"earnings",
"revenue",
"profit_margins",
"debt_analysis",
"cash_flow",
"technical_indicators",
"chart_patterns",
"support_levels",
"momentum",
"analyst_ratings",
"news_sentiment",
"social_sentiment",
"institutional_sentiment",
"market_share",
"competitive_position",
"industry_trends",
"regulatory_environment",
]
result = await agent.research_comprehensive(
topic="Large scale comprehensive analysis with many research dimensions",
session_id="large_scale_test_001",
focus_areas=many_focus_areas,
use_parallel_execution=True,
depth="comprehensive",
)
assert result["status"] == "success"
# Should handle large number of focus areas efficiently
if "parallel_execution_stats" in result:
stats = result["parallel_execution_stats"]
# Should not create more tasks than max concurrent agents allows
assert stats["total_tasks"] <= scalability_config.max_concurrent_agents
# Should achieve some parallel efficiency
if stats["successful_tasks"] > 1:
assert stats["parallel_efficiency"] > 1.0
@pytest.mark.integration
class TestParallelResearchLoggingIntegration:
"""Test integration of logging throughout parallel research workflow."""
@pytest.fixture
def logged_agent(self):
"""Create agent with comprehensive logging."""
llm = MockLLM(response_delay=0.02)
config = ParallelResearchConfig(
max_concurrent_agents=2,
timeout_per_agent=5,
enable_fallbacks=True,
rate_limit_delay=0.05,
)
agent = DeepResearchAgent(
llm=llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=config,
)
agent.search_providers = [MockSearchProvider("LoggedProvider")]
agent.content_analyzer = MockContentAnalyzer()
return agent
@pytest.mark.asyncio
async def test_comprehensive_logging_workflow(self, logged_agent):
"""Test that comprehensive logging occurs throughout workflow."""
with patch(
"maverick_mcp.utils.orchestration_logging.get_orchestration_logger"
) as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
result = await logged_agent.research_comprehensive(
topic="Comprehensive logging test analysis",
session_id="logging_test_001",
use_parallel_execution=True,
)
assert result["status"] == "success"
# Should have multiple logging calls
assert mock_logger.info.call_count >= 10 # Multiple stages should log
# Verify different types of log messages occurred
all_log_calls = [call[0][0] for call in mock_logger.info.call_args_list]
" ".join(all_log_calls)
# Should contain various logging elements
assert any("RESEARCH_START" in call for call in all_log_calls)
assert any("PARALLEL" in call for call in all_log_calls)
@pytest.mark.asyncio
async def test_error_logging_integration(self, logged_agent):
"""Test error logging integration in parallel workflow."""
# Create a scenario that will cause some errors
failing_llm = MockLLM(fail_rate=0.5) # High failure rate
logged_agent.llm = failing_llm
with patch(
"maverick_mcp.utils.orchestration_logging.get_orchestration_logger"
) as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
# This may succeed or fail, but should log appropriately
try:
result = await logged_agent.research_comprehensive(
topic="Error logging test",
session_id="error_logging_test_001",
use_parallel_execution=True,
)
# If it succeeds, should still have logged errors from failed components
assert result["status"] == "success" or result["status"] == "error"
except Exception:
# If it fails completely, that's also acceptable for this test
pass
# Should have some error or warning logs due to high failure rate
has_error_logs = (
mock_logger.error.call_count > 0 or mock_logger.warning.call_count > 0
)
assert has_error_logs
@pytest.mark.asyncio
async def test_performance_metrics_logging(self, logged_agent):
"""Test that performance metrics are properly logged."""
with patch(
"maverick_mcp.utils.orchestration_logging.log_performance_metrics"
) as mock_perf_log:
result = await logged_agent.research_comprehensive(
topic="Performance metrics test",
session_id="perf_metrics_test_001",
use_parallel_execution=True,
)
assert result["status"] == "success"
# Should have logged performance metrics
assert mock_perf_log.call_count >= 1
# Verify metrics content
perf_call = mock_perf_log.call_args_list[0]
perf_call[0][0]
metrics = perf_call[0][1]
assert isinstance(metrics, dict)
# Should contain relevant performance metrics
expected_metrics = [
"total_tasks",
"successful_tasks",
"failed_tasks",
"parallel_efficiency",
]
assert any(metric in metrics for metric in expected_metrics)
@pytest.mark.integration
class TestParallelResearchErrorRecovery:
"""Test error recovery and resilience in parallel research."""
@pytest.mark.asyncio
async def test_partial_failure_recovery(self):
"""Test recovery when some parallel tasks fail."""
config = ParallelResearchConfig(
max_concurrent_agents=3,
timeout_per_agent=5,
enable_fallbacks=True,
rate_limit_delay=0.05,
)
# Create agent with mixed success/failure providers
llm = MockLLM(response_delay=0.03)
agent = DeepResearchAgent(
llm=llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=config,
)
# Mix of failing and working providers
agent.search_providers = [
MockSearchProvider("WorkingProvider", fail_rate=0.0),
MockSearchProvider("FailingProvider", fail_rate=0.8), # 80% failure rate
]
agent.content_analyzer = MockContentAnalyzer()
result = await agent.research_comprehensive(
topic="Partial failure recovery test",
session_id="partial_failure_test_001",
use_parallel_execution=True,
)
# Should complete successfully despite some failures
assert result["status"] == "success"
# Should have parallel execution stats showing mixed results
if "parallel_execution_stats" in result:
stats = result["parallel_execution_stats"]
# Should have attempted multiple tasks
assert stats["total_tasks"] >= 1
# May have some failures but should have some successes
if stats["total_tasks"] > 1:
assert (
stats["successful_tasks"] + stats["failed_tasks"]
== stats["total_tasks"]
)
@pytest.mark.asyncio
async def test_complete_failure_fallback(self):
"""Test fallback to sequential when parallel execution completely fails."""
config = ParallelResearchConfig(
max_concurrent_agents=2,
timeout_per_agent=3,
enable_fallbacks=True,
)
# Create agent that will fail in parallel mode
failing_llm = MockLLM(fail_rate=0.9) # Very high failure rate
agent = DeepResearchAgent(
llm=failing_llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=config,
)
agent.search_providers = [MockSearchProvider("FailingProvider", fail_rate=0.9)]
agent.content_analyzer = MockContentAnalyzer()
# Mock the sequential execution to succeed
with patch.object(agent.graph, "ainvoke") as mock_sequential:
mock_sequential.return_value = {
"status": "success",
"persona": "moderate",
"research_confidence": 0.6,
"research_findings": {"synthesis": "Fallback analysis"},
}
result = await agent.research_comprehensive(
topic="Complete failure fallback test",
session_id="complete_failure_test_001",
use_parallel_execution=True,
)
# Should fall back to sequential and succeed
assert result["status"] == "success"
# Sequential execution should have been called due to parallel failure
mock_sequential.assert_called_once()
@pytest.mark.asyncio
async def test_timeout_handling_in_parallel_execution(self):
"""Test handling of timeouts in parallel execution."""
config = ParallelResearchConfig(
max_concurrent_agents=2,
timeout_per_agent=1, # Very short timeout
enable_fallbacks=True,
)
# Create components with delays longer than timeout
slow_llm = MockLLM(response_delay=2.0) # Slower than timeout
agent = DeepResearchAgent(
llm=slow_llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=config,
)
agent.search_providers = [MockSearchProvider("SlowProvider")]
agent.content_analyzer = MockContentAnalyzer(analysis_delay=0.5)
# Should handle timeouts gracefully
result = await agent.research_comprehensive(
topic="Timeout handling test",
session_id="timeout_test_001",
use_parallel_execution=True,
)
# Should complete with some status (success or error)
assert result["status"] in ["success", "error"]
# If parallel execution stats are available, should show timeout effects
if (
"parallel_execution_stats" in result
and result["parallel_execution_stats"]["total_tasks"] > 0
):
stats = result["parallel_execution_stats"]
# Timeouts should result in failed tasks
assert stats["failed_tasks"] >= 0
@pytest.mark.integration
class TestParallelResearchDataFlow:
"""Test data flow and consistency in parallel research."""
@pytest.mark.asyncio
async def test_data_consistency_across_parallel_tasks(self):
"""Test that data remains consistent across parallel task execution."""
config = ParallelResearchConfig(
max_concurrent_agents=3,
timeout_per_agent=5,
)
llm = MockLLM()
agent = DeepResearchAgent(
llm=llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=config,
)
# Create providers that return consistent data
consistent_provider = MockSearchProvider("ConsistentProvider")
agent.search_providers = [consistent_provider]
agent.content_analyzer = MockContentAnalyzer()
result = await agent.research_comprehensive(
topic="Data consistency test for Apple Inc",
session_id="consistency_test_001",
use_parallel_execution=True,
)
assert result["status"] == "success"
# Verify data structure consistency
assert "research_topic" in result
assert "confidence_score" in result
assert "citations" in result
assert isinstance(result["citations"], list)
# If parallel execution occurred, verify stats structure
if "parallel_execution_stats" in result:
stats = result["parallel_execution_stats"]
required_stats = [
"total_tasks",
"successful_tasks",
"failed_tasks",
"parallel_efficiency",
]
for stat in required_stats:
assert stat in stats
assert isinstance(stats[stat], int | float)
@pytest.mark.asyncio
async def test_citation_aggregation_across_tasks(self):
"""Test that citations are properly aggregated from parallel tasks."""
config = ParallelResearchConfig(max_concurrent_agents=2)
llm = MockLLM()
agent = DeepResearchAgent(
llm=llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=config,
)
# Multiple providers to generate multiple sources
agent.search_providers = [
MockSearchProvider("Provider1"),
MockSearchProvider("Provider2"),
]
agent.content_analyzer = MockContentAnalyzer()
result = await agent.research_comprehensive(
topic="Citation aggregation test",
session_id="citation_test_001",
use_parallel_execution=True,
)
assert result["status"] == "success"
# Should have citations from multiple sources
citations = result.get("citations", [])
if len(citations) > 0:
# Citations should have required fields
for citation in citations:
assert "id" in citation
assert "title" in citation
assert "url" in citation
assert "credibility_score" in citation
# Should have unique citation IDs
citation_ids = [c["id"] for c in citations]
assert len(citation_ids) == len(set(citation_ids))
@pytest.mark.asyncio
async def test_research_quality_metrics(self):
"""Test research quality metrics in parallel execution."""
config = ParallelResearchConfig(max_concurrent_agents=2)
llm = MockLLM()
agent = DeepResearchAgent(
llm=llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=config,
)
agent.search_providers = [MockSearchProvider("QualityProvider")]
agent.content_analyzer = MockContentAnalyzer()
result = await agent.research_comprehensive(
topic="Research quality metrics test",
session_id="quality_test_001",
use_parallel_execution=True,
)
assert result["status"] == "success"
# Verify quality metrics
assert "confidence_score" in result
assert 0.0 <= result["confidence_score"] <= 1.0
assert "sources_analyzed" in result
assert isinstance(result["sources_analyzed"], int)
assert result["sources_analyzed"] >= 0
if "source_diversity" in result:
assert 0.0 <= result["source_diversity"] <= 1.0
```
--------------------------------------------------------------------------------
/tests/test_ml_strategies.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive tests for ML-enhanced trading strategies.
Tests cover:
- Adaptive Strategy parameter adjustment and online learning
- OnlineLearningStrategy with streaming ML algorithms
- HybridAdaptiveStrategy combining multiple approaches
- Feature engineering and extraction for ML models
- Model training, prediction, and confidence scoring
- Performance tracking and adaptation mechanisms
- Parameter boundary enforcement and constraints
- Strategy performance under different market regimes
- Memory usage and computational efficiency
- Error handling and model recovery scenarios
"""
import warnings
from typing import Any
from unittest.mock import Mock, patch
import numpy as np
import pandas as pd
import pytest
from maverick_mcp.backtesting.strategies.base import Strategy
from maverick_mcp.backtesting.strategies.ml.adaptive import (
AdaptiveStrategy,
HybridAdaptiveStrategy,
OnlineLearningStrategy,
)
warnings.filterwarnings("ignore", category=FutureWarning)
class MockBaseStrategy(Strategy):
"""Mock base strategy for testing adaptive strategies."""
def __init__(self, parameters: dict[str, Any] = None):
super().__init__(parameters or {"window": 20, "threshold": 0.02})
self._signal_pattern = "alternating" # alternating, bullish, bearish, random
@property
def name(self) -> str:
return "MockStrategy"
@property
def description(self) -> str:
return "Mock strategy for testing"
def generate_signals(self, data: pd.DataFrame) -> tuple[pd.Series, pd.Series]:
"""Generate mock signals based on pattern."""
entry_signals = pd.Series(False, index=data.index)
exit_signals = pd.Series(False, index=data.index)
window = self.parameters.get("window", 20)
threshold = float(self.parameters.get("threshold", 0.02) or 0.0)
step = max(5, int(round(10 * (1 + abs(threshold) * 10))))
if self._signal_pattern == "alternating":
# Alternate between entry and exit signals with threshold-adjusted cadence
for i in range(window, len(data), step):
if (i // step) % 2 == 0:
entry_signals.iloc[i] = True
else:
exit_signals.iloc[i] = True
elif self._signal_pattern == "bullish":
# More entry signals than exit
entry_indices = np.random.choice(
range(window, len(data)),
size=min(20, len(data) - window),
replace=False,
)
entry_signals.iloc[entry_indices] = True
elif self._signal_pattern == "bearish":
# More exit signals than entry
exit_indices = np.random.choice(
range(window, len(data)),
size=min(20, len(data) - window),
replace=False,
)
exit_signals.iloc[exit_indices] = True
return entry_signals, exit_signals
class TestAdaptiveStrategy:
"""Test suite for AdaptiveStrategy class."""
@pytest.fixture
def sample_market_data(self):
"""Create sample market data for testing."""
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
# Generate realistic price data with trends
returns = np.random.normal(0.0005, 0.02, len(dates))
# Add some trending periods
returns[100:150] += 0.003 # Bull period
returns[200:250] -= 0.002 # Bear period
prices = 100 * np.cumprod(1 + returns)
volumes = np.random.randint(1000000, 5000000, len(dates))
data = pd.DataFrame(
{
"open": prices * np.random.uniform(0.98, 1.02, len(dates)),
"high": prices * np.random.uniform(1.00, 1.05, len(dates)),
"low": prices * np.random.uniform(0.95, 1.00, len(dates)),
"close": prices,
"volume": volumes,
},
index=dates,
)
# Ensure high >= close, open and low <= close, open
data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
return data
@pytest.fixture
def mock_base_strategy(self):
"""Create a mock base strategy."""
return MockBaseStrategy({"window": 20, "threshold": 0.02})
@pytest.fixture
def adaptive_strategy(self, mock_base_strategy):
"""Create an adaptive strategy with mock base."""
return AdaptiveStrategy(
base_strategy=mock_base_strategy,
adaptation_method="gradient",
learning_rate=0.01,
lookback_period=50,
adaptation_frequency=10,
)
def test_adaptive_strategy_initialization(
self, adaptive_strategy, mock_base_strategy
):
"""Test adaptive strategy initialization."""
assert adaptive_strategy.base_strategy == mock_base_strategy
assert adaptive_strategy.adaptation_method == "gradient"
assert adaptive_strategy.learning_rate == 0.01
assert adaptive_strategy.lookback_period == 50
assert adaptive_strategy.adaptation_frequency == 10
assert len(adaptive_strategy.performance_history) == 0
assert len(adaptive_strategy.parameter_history) == 0
assert adaptive_strategy.last_adaptation == 0
# Test name and description
assert "Adaptive" in adaptive_strategy.name
assert "MockStrategy" in adaptive_strategy.name
assert "gradient" in adaptive_strategy.description
def test_performance_metric_calculation(self, adaptive_strategy):
"""Test performance metric calculation."""
# Test with normal returns
returns = pd.Series([0.01, 0.02, -0.01, 0.015, -0.005])
performance = adaptive_strategy.calculate_performance_metric(returns)
assert isinstance(performance, float)
assert not np.isnan(performance)
# Test with zero volatility
constant_returns = pd.Series([0.01, 0.01, 0.01, 0.01])
performance = adaptive_strategy.calculate_performance_metric(constant_returns)
assert performance == 0.0
# Test with empty returns
empty_returns = pd.Series([])
performance = adaptive_strategy.calculate_performance_metric(empty_returns)
assert performance == 0.0
def test_adaptable_parameters_default(self, adaptive_strategy):
"""Test default adaptable parameters configuration."""
adaptable_params = adaptive_strategy.get_adaptable_parameters()
expected_params = ["lookback_period", "threshold", "window", "period"]
for param in expected_params:
assert param in adaptable_params
assert "min" in adaptable_params[param]
assert "max" in adaptable_params[param]
assert "step" in adaptable_params[param]
def test_gradient_parameter_adaptation(self, adaptive_strategy):
"""Test gradient-based parameter adaptation."""
# Set up initial parameters
initial_window = adaptive_strategy.base_strategy.parameters["window"]
initial_threshold = adaptive_strategy.base_strategy.parameters["threshold"]
# Simulate positive performance gradient
adaptive_strategy.adapt_parameters_gradient(0.5) # Positive gradient
# Parameters should have changed
new_window = adaptive_strategy.base_strategy.parameters["window"]
new_threshold = adaptive_strategy.base_strategy.parameters["threshold"]
# At least one parameter should have changed
assert new_window != initial_window or new_threshold != initial_threshold
# Parameters should be within bounds
adaptable_params = adaptive_strategy.get_adaptable_parameters()
if "window" in adaptable_params:
assert new_window >= adaptable_params["window"]["min"]
assert new_window <= adaptable_params["window"]["max"]
def test_random_search_parameter_adaptation(self, adaptive_strategy):
"""Test random search parameter adaptation."""
adaptive_strategy.adaptation_method = "random_search"
# Apply random search adaptation
adaptive_strategy.adapt_parameters_random_search()
# Parameters should potentially have changed
new_params = adaptive_strategy.base_strategy.parameters
# At least check that the method runs without error
assert isinstance(new_params, dict)
assert "window" in new_params
assert "threshold" in new_params
def test_adaptive_signal_generation(self, adaptive_strategy, sample_market_data):
"""Test adaptive signal generation with parameter updates."""
entry_signals, exit_signals = adaptive_strategy.generate_signals(
sample_market_data
)
# Basic signal validation
assert len(entry_signals) == len(sample_market_data)
assert len(exit_signals) == len(sample_market_data)
assert entry_signals.dtype == bool
assert exit_signals.dtype == bool
# Check that some adaptations occurred
assert len(adaptive_strategy.performance_history) > 0
# Check that parameter history was recorded
if len(adaptive_strategy.parameter_history) > 0:
assert isinstance(adaptive_strategy.parameter_history[0], dict)
def test_adaptation_frequency_control(self, adaptive_strategy, sample_market_data):
"""Test that adaptation occurs at correct frequency."""
# Set a specific adaptation frequency
adaptive_strategy.adaptation_frequency = 30
# Generate signals
adaptive_strategy.generate_signals(sample_market_data)
# Number of adaptations should be roughly len(data) / adaptation_frequency
expected_adaptations = len(sample_market_data) // 30
actual_adaptations = len(adaptive_strategy.performance_history)
# Allow some variance due to lookback period requirements
assert abs(actual_adaptations - expected_adaptations) <= 2
def test_adaptation_history_tracking(self, adaptive_strategy, sample_market_data):
"""Test adaptation history tracking."""
adaptive_strategy.generate_signals(sample_market_data)
history = adaptive_strategy.get_adaptation_history()
assert "performance_history" in history
assert "parameter_history" in history
assert "current_parameters" in history
assert "original_parameters" in history
assert len(history["performance_history"]) > 0
assert isinstance(history["current_parameters"], dict)
assert isinstance(history["original_parameters"], dict)
def test_reset_to_original_parameters(self, adaptive_strategy, sample_market_data):
"""Test resetting strategy to original parameters."""
# Store original parameters
original_params = adaptive_strategy.base_strategy.parameters.copy()
# Generate signals to trigger adaptations
adaptive_strategy.generate_signals(sample_market_data)
# Parameters should have changed
# Reset to original
adaptive_strategy.reset_to_original()
# Should match original parameters
assert adaptive_strategy.base_strategy.parameters == original_params
assert len(adaptive_strategy.performance_history) == 0
assert len(adaptive_strategy.parameter_history) == 0
assert adaptive_strategy.last_adaptation == 0
def test_adaptive_strategy_error_handling(self, adaptive_strategy):
"""Test error handling in adaptive strategy."""
# Test with invalid data
invalid_data = pd.DataFrame({"close": [np.nan, np.nan]})
entry_signals, exit_signals = adaptive_strategy.generate_signals(invalid_data)
# Should return valid series even with bad data
assert isinstance(entry_signals, pd.Series)
assert isinstance(exit_signals, pd.Series)
assert len(entry_signals) == len(invalid_data)
class TestOnlineLearningStrategy:
"""Test suite for OnlineLearningStrategy class."""
@pytest.fixture
def online_strategy(self):
"""Create an online learning strategy."""
return OnlineLearningStrategy(
model_type="sgd",
update_frequency=10,
feature_window=20,
confidence_threshold=0.6,
)
@pytest.fixture
def online_learning_strategy(self, online_strategy):
"""Alias for online_strategy fixture for backward compatibility."""
return online_strategy
@pytest.fixture
def sample_market_data(self):
"""Create sample market data for testing."""
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
# Generate realistic price data with trends
returns = np.random.normal(0.0005, 0.02, len(dates))
# Add some trending periods
returns[100:150] += 0.003 # Bull period
returns[200:250] -= 0.002 # Bear period
prices = 100 * np.cumprod(1 + returns)
volumes = np.random.randint(1000000, 5000000, len(dates))
data = pd.DataFrame(
{
"open": prices * np.random.uniform(0.98, 1.02, len(dates)),
"high": prices * np.random.uniform(1.00, 1.05, len(dates)),
"low": prices * np.random.uniform(0.95, 1.00, len(dates)),
"close": prices,
"volume": volumes,
},
index=dates,
)
# Ensure high >= close, open and low <= close, open
data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
return data
def test_online_learning_initialization(self, online_strategy):
"""Test online learning strategy initialization."""
assert online_strategy.model_type == "sgd"
assert online_strategy.update_frequency == 10
assert online_strategy.feature_window == 20
assert online_strategy.confidence_threshold == 0.6
assert online_strategy.model is not None
assert hasattr(online_strategy.model, "fit") # Should be sklearn model
assert not online_strategy.is_trained
assert len(online_strategy.training_buffer) == 0
# Test name and description
assert "OnlineLearning" in online_strategy.name
assert "SGD" in online_strategy.name
assert "streaming" in online_strategy.description
def test_model_initialization_error(self):
"""Test model initialization with unsupported type."""
with pytest.raises(ValueError, match="Unsupported model type"):
OnlineLearningStrategy(model_type="unsupported_model")
def test_feature_extraction(self, online_strategy, sample_market_data):
"""Test feature extraction from market data."""
# Test with sufficient data
features = online_strategy.extract_features(sample_market_data, 30)
assert isinstance(features, np.ndarray)
assert len(features) > 0
assert not np.any(np.isnan(features))
# Test with insufficient data
features = online_strategy.extract_features(sample_market_data, 1)
assert len(features) == 0
def test_target_creation(self, online_learning_strategy, sample_market_data):
"""Test target variable creation."""
# Test normal case
target = online_learning_strategy.create_target(sample_market_data, 30)
assert target in [0, 1, 2] # sell, hold, buy
# Test edge case - near end of data
target = online_learning_strategy.create_target(
sample_market_data, len(sample_market_data) - 1
)
assert target == 1 # Should default to hold
def test_model_update_mechanism(self, online_strategy, sample_market_data):
"""Test online model update mechanism."""
# Simulate model updates
online_strategy.update_model(sample_market_data, 50)
# Should not update if frequency not met
assert online_strategy.last_update == 0 # No update yet
# Force update by meeting frequency requirement
online_strategy.last_update = 40
online_strategy.update_model(sample_market_data, 51)
# Now should have updated
assert online_strategy.last_update > 40
def test_online_signal_generation(self, online_strategy, sample_market_data):
"""Test online learning signal generation."""
entry_signals, exit_signals = online_strategy.generate_signals(
sample_market_data
)
# Basic validation
assert len(entry_signals) == len(sample_market_data)
assert len(exit_signals) == len(sample_market_data)
assert entry_signals.dtype == bool
assert exit_signals.dtype == bool
# Should eventually train the model
assert online_strategy.is_trained
def test_model_info_retrieval(self, online_strategy, sample_market_data):
"""Test model information retrieval."""
# Initially untrained
info = online_strategy.get_model_info()
assert info["model_type"] == "sgd"
assert not info["is_trained"]
assert info["feature_window"] == 20
assert info["update_frequency"] == 10
assert info["confidence_threshold"] == 0.6
# Train the model
online_strategy.generate_signals(sample_market_data)
# Get info after training
trained_info = online_strategy.get_model_info()
assert trained_info["is_trained"]
# Should have coefficients if model supports them
if (
hasattr(online_strategy.model, "coef_")
and online_strategy.model.coef_ is not None
):
assert "model_coefficients" in trained_info
def test_confidence_threshold_filtering(self, online_strategy, sample_market_data):
"""Test that signals are filtered by confidence threshold."""
# Use very high confidence threshold
high_confidence_strategy = OnlineLearningStrategy(confidence_threshold=0.95)
entry_signals, exit_signals = high_confidence_strategy.generate_signals(
sample_market_data
)
# With high confidence threshold, should have fewer signals
assert entry_signals.sum() <= 5 # Very few signals expected
assert exit_signals.sum() <= 5
def test_online_strategy_error_handling(self, online_strategy):
"""Test error handling in online learning strategy."""
# Test with empty data
empty_data = pd.DataFrame(columns=["close", "volume"])
entry_signals, exit_signals = online_strategy.generate_signals(empty_data)
assert len(entry_signals) == 0
assert len(exit_signals) == 0
class TestHybridAdaptiveStrategy:
"""Test suite for HybridAdaptiveStrategy class."""
@pytest.fixture
def mock_base_strategy(self):
"""Create a mock base strategy."""
return MockBaseStrategy({"window": 20, "threshold": 0.02})
@pytest.fixture
def sample_market_data(self):
"""Create sample market data for testing."""
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
# Generate realistic price data with trends
returns = np.random.normal(0.0005, 0.02, len(dates))
# Add some trending periods
returns[100:150] += 0.003 # Bull period
returns[200:250] -= 0.002 # Bear period
prices = 100 * np.cumprod(1 + returns)
volumes = np.random.randint(1000000, 5000000, len(dates))
data = pd.DataFrame(
{
"open": prices * np.random.uniform(0.98, 1.02, len(dates)),
"high": prices * np.random.uniform(1.00, 1.05, len(dates)),
"low": prices * np.random.uniform(0.95, 1.00, len(dates)),
"close": prices,
"volume": volumes,
},
index=dates,
)
# Ensure high >= close, open and low <= close, open
data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
return data
@pytest.fixture
def hybrid_strategy(self, mock_base_strategy):
"""Create a hybrid adaptive strategy."""
return HybridAdaptiveStrategy(
base_strategy=mock_base_strategy,
online_learning_weight=0.3,
adaptation_method="gradient",
learning_rate=0.02,
)
def test_hybrid_strategy_initialization(self, hybrid_strategy, mock_base_strategy):
"""Test hybrid strategy initialization."""
assert hybrid_strategy.base_strategy == mock_base_strategy
assert hybrid_strategy.online_learning_weight == 0.3
assert hybrid_strategy.online_strategy is not None
assert isinstance(hybrid_strategy.online_strategy, OnlineLearningStrategy)
# Test name and description
assert "HybridAdaptive" in hybrid_strategy.name
assert "MockStrategy" in hybrid_strategy.name
assert "hybrid" in hybrid_strategy.description.lower()
def test_hybrid_signal_generation(self, hybrid_strategy, sample_market_data):
"""Test hybrid signal generation combining both approaches."""
entry_signals, exit_signals = hybrid_strategy.generate_signals(
sample_market_data
)
# Basic validation
assert len(entry_signals) == len(sample_market_data)
assert len(exit_signals) == len(sample_market_data)
assert entry_signals.dtype == bool
assert exit_signals.dtype == bool
# Should have some signals (combination of both strategies)
total_signals = entry_signals.sum() + exit_signals.sum()
assert total_signals > 0
def test_signal_weighting_mechanism(self, hybrid_strategy, sample_market_data):
"""Test that signal weighting works correctly."""
# Set base strategy to generate specific pattern
hybrid_strategy.base_strategy._signal_pattern = "bullish"
# Generate signals
entry_signals, exit_signals = hybrid_strategy.generate_signals(
sample_market_data
)
# With bullish base strategy, should have more entry signals
assert entry_signals.sum() >= exit_signals.sum()
def test_hybrid_info_retrieval(self, hybrid_strategy, sample_market_data):
"""Test hybrid strategy information retrieval."""
# Generate some signals first
hybrid_strategy.generate_signals(sample_market_data)
hybrid_info = hybrid_strategy.get_hybrid_info()
assert "adaptation_history" in hybrid_info
assert "online_learning_info" in hybrid_info
assert "online_learning_weight" in hybrid_info
assert "base_weight" in hybrid_info
assert hybrid_info["online_learning_weight"] == 0.3
assert hybrid_info["base_weight"] == 0.7
# Verify nested information structure
assert "model_type" in hybrid_info["online_learning_info"]
assert "performance_history" in hybrid_info["adaptation_history"]
def test_different_weight_configurations(
self, mock_base_strategy, sample_market_data
):
"""Test hybrid strategy with different weight configurations."""
# Test heavy online learning weighting
heavy_online = HybridAdaptiveStrategy(
base_strategy=mock_base_strategy, online_learning_weight=0.8
)
entry1, exit1 = heavy_online.generate_signals(sample_market_data)
# Test heavy base strategy weighting
heavy_base = HybridAdaptiveStrategy(
base_strategy=mock_base_strategy, online_learning_weight=0.2
)
entry2, exit2 = heavy_base.generate_signals(sample_market_data)
# Both should generate valid signals
assert len(entry1) == len(entry2) == len(sample_market_data)
assert len(exit1) == len(exit2) == len(sample_market_data)
# Different weights should potentially produce different signals
# (though this is probabilistic and may not always be true)
signal_diff1 = (entry1 != entry2).sum() + (exit1 != exit2).sum()
assert signal_diff1 >= 0 # Allow for identical signals in edge cases
class TestMLStrategiesPerformance:
"""Performance and benchmark tests for ML strategies."""
@pytest.fixture
def sample_market_data(self):
"""Create sample market data for testing."""
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
# Generate realistic price data with trends
returns = np.random.normal(0.0005, 0.02, len(dates))
# Add some trending periods
returns[100:150] += 0.003 # Bull period
returns[200:250] -= 0.002 # Bear period
prices = 100 * np.cumprod(1 + returns)
volumes = np.random.randint(1000000, 5000000, len(dates))
data = pd.DataFrame(
{
"open": prices * np.random.uniform(0.98, 1.02, len(dates)),
"high": prices * np.random.uniform(1.00, 1.05, len(dates)),
"low": prices * np.random.uniform(0.95, 1.00, len(dates)),
"close": prices,
"volume": volumes,
},
index=dates,
)
# Ensure high >= close, open and low <= close, open
data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
return data
def test_strategy_computational_efficiency(
self, sample_market_data, benchmark_timer
):
"""Test computational efficiency of ML strategies."""
strategies = [
AdaptiveStrategy(MockBaseStrategy(), adaptation_method="gradient"),
OnlineLearningStrategy(model_type="sgd"),
HybridAdaptiveStrategy(MockBaseStrategy()),
]
for strategy in strategies:
with benchmark_timer() as timer:
entry_signals, exit_signals = strategy.generate_signals(
sample_market_data
)
# Should complete within reasonable time
assert timer.elapsed < 10.0 # < 10 seconds
assert len(entry_signals) == len(sample_market_data)
assert len(exit_signals) == len(sample_market_data)
def test_memory_usage_scalability(self, benchmark_timer):
"""Test memory usage with large datasets."""
import os
import psutil
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
# Create large dataset
dates = pd.date_range(start="2020-01-01", end="2023-12-31", freq="D") # 4 years
large_data = pd.DataFrame(
{
"open": 100 + np.random.normal(0, 10, len(dates)),
"high": 105 + np.random.normal(0, 10, len(dates)),
"low": 95 + np.random.normal(0, 10, len(dates)),
"close": 100 + np.random.normal(0, 10, len(dates)),
"volume": np.random.randint(1000000, 10000000, len(dates)),
},
index=dates,
)
# Test online learning strategy (most memory intensive)
strategy = OnlineLearningStrategy()
strategy.generate_signals(large_data)
final_memory = process.memory_info().rss
memory_growth = (final_memory - initial_memory) / 1024 / 1024 # MB
# Memory growth should be reasonable (< 200MB for 4 years of data)
assert memory_growth < 200
def test_strategy_adaptation_effectiveness(self, sample_market_data):
"""Test that adaptive strategies actually improve over time."""
base_strategy = MockBaseStrategy()
adaptive_strategy = AdaptiveStrategy(
base_strategy=base_strategy, adaptation_method="gradient"
)
# Generate initial signals and measure performance
initial_entry_signals, initial_exit_signals = (
adaptive_strategy.generate_signals(sample_market_data)
)
assert len(initial_entry_signals) == len(sample_market_data)
assert len(initial_exit_signals) == len(sample_market_data)
assert len(adaptive_strategy.performance_history) > 0
# Reset and generate again (should have different adaptations)
adaptive_strategy.reset_to_original()
post_reset_entry, post_reset_exit = adaptive_strategy.generate_signals(
sample_market_data
)
assert len(post_reset_entry) == len(sample_market_data)
assert len(post_reset_exit) == len(sample_market_data)
# Should have recorded performance metrics again
assert len(adaptive_strategy.performance_history) > 0
assert len(adaptive_strategy.parameter_history) > 0
def test_concurrent_strategy_execution(self, sample_market_data):
"""Test concurrent execution of multiple ML strategies."""
import queue
import threading
results_queue = queue.Queue()
error_queue = queue.Queue()
def run_strategy(strategy_id, strategy_class):
try:
if strategy_class == AdaptiveStrategy:
strategy = AdaptiveStrategy(MockBaseStrategy())
elif strategy_class == OnlineLearningStrategy:
strategy = OnlineLearningStrategy()
else:
strategy = HybridAdaptiveStrategy(MockBaseStrategy())
entry_signals, exit_signals = strategy.generate_signals(
sample_market_data
)
results_queue.put((strategy_id, len(entry_signals), len(exit_signals)))
except Exception as e:
error_queue.put(f"Strategy {strategy_id}: {e}")
# Run multiple strategies concurrently
threads = []
strategy_classes = [
AdaptiveStrategy,
OnlineLearningStrategy,
HybridAdaptiveStrategy,
]
for i, strategy_class in enumerate(strategy_classes):
thread = threading.Thread(target=run_strategy, args=(i, strategy_class))
threads.append(thread)
thread.start()
# Wait for completion
for thread in threads:
thread.join(timeout=30) # 30 second timeout
# Check results
assert error_queue.empty(), f"Errors: {list(error_queue.queue)}"
assert results_queue.qsize() == 3
# All should have processed the full dataset
while not results_queue.empty():
strategy_id, entry_len, exit_len = results_queue.get()
assert entry_len == len(sample_market_data)
assert exit_len == len(sample_market_data)
class TestMLStrategiesErrorHandling:
"""Error handling and edge case tests for ML strategies."""
@pytest.fixture
def sample_market_data(self):
"""Create sample market data for testing."""
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
# Generate realistic price data with trends
returns = np.random.normal(0.0005, 0.02, len(dates))
# Add some trending periods
returns[100:150] += 0.003 # Bull period
returns[200:250] -= 0.002 # Bear period
prices = 100 * np.cumprod(1 + returns)
volumes = np.random.randint(1000000, 5000000, len(dates))
data = pd.DataFrame(
{
"open": prices * np.random.uniform(0.98, 1.02, len(dates)),
"high": prices * np.random.uniform(1.00, 1.05, len(dates)),
"low": prices * np.random.uniform(0.95, 1.00, len(dates)),
"close": prices,
"volume": volumes,
},
index=dates,
)
# Ensure high >= close, open and low <= close, open
data["high"] = np.maximum(data["high"], np.maximum(data["open"], data["close"]))
data["low"] = np.minimum(data["low"], np.minimum(data["open"], data["close"]))
return data
@pytest.fixture
def mock_base_strategy(self):
"""Create a mock base strategy."""
return MockBaseStrategy({"window": 20, "threshold": 0.02})
def test_adaptive_strategy_with_failing_base(self, sample_market_data):
"""Test adaptive strategy when base strategy fails."""
# Create a base strategy that fails
failing_strategy = Mock(spec=Strategy)
failing_strategy.parameters = {"window": 20}
failing_strategy.generate_signals.side_effect = Exception(
"Base strategy failed"
)
adaptive_strategy = AdaptiveStrategy(failing_strategy)
# Should handle the error gracefully
entry_signals, exit_signals = adaptive_strategy.generate_signals(
sample_market_data
)
assert isinstance(entry_signals, pd.Series)
assert isinstance(exit_signals, pd.Series)
assert len(entry_signals) == len(sample_market_data)
def test_online_learning_with_insufficient_data(self):
"""Test online learning strategy with insufficient training data."""
# Very small dataset
small_data = pd.DataFrame({"close": [100, 101], "volume": [1000, 1100]})
strategy = OnlineLearningStrategy(feature_window=20) # Window larger than data
entry_signals, exit_signals = strategy.generate_signals(small_data)
# Should handle gracefully
assert len(entry_signals) == len(small_data)
assert len(exit_signals) == len(small_data)
assert not strategy.is_trained # Insufficient data to train
def test_model_prediction_failure_handling(self, sample_market_data):
"""Test handling of model prediction failures."""
strategy = OnlineLearningStrategy()
# Simulate model failure after training
with patch.object(
strategy.model, "predict", side_effect=Exception("Prediction failed")
):
entry_signals, exit_signals = strategy.generate_signals(sample_market_data)
# Should still return valid series
assert isinstance(entry_signals, pd.Series)
assert isinstance(exit_signals, pd.Series)
assert len(entry_signals) == len(sample_market_data)
def test_parameter_boundary_enforcement(self, mock_base_strategy):
"""Test that parameter adaptations respect boundaries."""
adaptive_strategy = AdaptiveStrategy(mock_base_strategy)
# Set extreme gradient that should be bounded
large_gradient = 100.0
# Store original parameter values
original_window = mock_base_strategy.parameters["window"]
# Apply extreme gradient
adaptive_strategy.adapt_parameters_gradient(large_gradient)
# Parameter should be bounded
new_window = mock_base_strategy.parameters["window"]
assert new_window != original_window
adaptable_params = adaptive_strategy.get_adaptable_parameters()
if "window" in adaptable_params:
assert new_window >= adaptable_params["window"]["min"]
assert new_window <= adaptable_params["window"]["max"]
def test_strategy_state_consistency(self, mock_base_strategy, sample_market_data):
"""Test that strategy state remains consistent after errors."""
adaptive_strategy = AdaptiveStrategy(mock_base_strategy)
# Generate initial signals successfully
initial_signals = adaptive_strategy.generate_signals(sample_market_data)
assert isinstance(initial_signals, tuple)
assert len(initial_signals) == 2
initial_state = {
"performance_history": len(adaptive_strategy.performance_history),
"parameter_history": len(adaptive_strategy.parameter_history),
"parameters": adaptive_strategy.base_strategy.parameters.copy(),
}
# Simulate error during signal generation
with patch.object(
mock_base_strategy,
"generate_signals",
side_effect=Exception("Signal generation failed"),
):
error_signals = adaptive_strategy.generate_signals(sample_market_data)
# State should remain consistent or be properly handled
assert isinstance(error_signals, tuple)
assert len(error_signals) == 2
assert isinstance(error_signals[0], pd.Series)
assert isinstance(error_signals[1], pd.Series)
assert (
len(adaptive_strategy.performance_history)
== initial_state["performance_history"]
)
assert (
len(adaptive_strategy.parameter_history)
== initial_state["parameter_history"]
)
assert adaptive_strategy.base_strategy.parameters == initial_state["parameters"]
if __name__ == "__main__":
# Run tests with detailed output
pytest.main([__file__, "-v", "--tb=short", "--asyncio-mode=auto"])
```
--------------------------------------------------------------------------------
/maverick_mcp/data/cache.py:
--------------------------------------------------------------------------------
```python
"""
Cache utilities for Maverick-MCP.
Implements Redis-based caching with fallback to in-memory caching.
Now uses centralized Redis connection pooling for improved performance.
Includes timezone handling, smart invalidation, and performance monitoring.
"""
import asyncio
import hashlib
import json
import logging
import os
import time
import zlib
from collections import defaultdict
from collections.abc import Sequence
from datetime import UTC, date, datetime
from typing import Any, cast
import msgpack
import pandas as pd
import redis
from dotenv import load_dotenv
from maverick_mcp.config.settings import get_settings
# Import the new performance module for Redis connection pooling
# Load environment variables
load_dotenv()
# Setup logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("maverick_mcp.cache")
settings = get_settings()
# Redis configuration (kept for backwards compatibility)
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB = int(os.getenv("REDIS_DB", "0"))
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", "")
REDIS_SSL = os.getenv("REDIS_SSL", "False").lower() == "true"
# Cache configuration
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "True").lower() == "true"
CACHE_TTL_SECONDS = settings.performance.cache_ttl_seconds
CACHE_VERSION = os.getenv("CACHE_VERSION", "v1")
# Cache statistics
CacheStatMap = defaultdict[str, float]
_cache_stats: CacheStatMap = defaultdict(float)
_cache_stats["hits"] = 0.0
_cache_stats["misses"] = 0.0
_cache_stats["sets"] = 0.0
_cache_stats["errors"] = 0.0
_cache_stats["serialization_time"] = 0.0
_cache_stats["deserialization_time"] = 0.0
# In-memory cache as fallback with memory management
_memory_cache: dict[str, dict[str, Any]] = {}
_memory_cache_max_size = 1000 # Will be updated to use config
# Cache metadata for version tracking
_cache_metadata: dict[str, dict[str, Any]] = {}
# Memory monitoring
_cache_memory_stats: dict[str, float] = {
"memory_cache_bytes": 0.0,
"redis_connection_count": 0.0,
"large_object_count": 0.0,
"compression_savings_bytes": 0.0,
}
def _dataframe_to_payload(df: pd.DataFrame) -> dict[str, Any]:
"""Convert a DataFrame to a JSON-serializable payload."""
normalized = ensure_timezone_naive(df)
json_payload = cast(
str,
normalized.to_json(orient="split", date_format="iso", default_handler=str),
)
payload = json.loads(json_payload)
payload["index_type"] = (
"datetime" if isinstance(normalized.index, pd.DatetimeIndex) else "other"
)
payload["index_name"] = normalized.index.name
return payload
def _payload_to_dataframe(payload: dict[str, Any]) -> pd.DataFrame:
"""Reconstruct a DataFrame from a serialized payload."""
data = payload.get("data", {})
columns = data.get("columns", [])
frame = pd.DataFrame(data.get("data", []), columns=columns)
index_values = data.get("index", [])
if payload.get("index_type") == "datetime":
index_values = pd.to_datetime(index_values)
index = normalize_timezone(pd.DatetimeIndex(index_values))
else:
index = index_values
frame.index = index
frame.index.name = payload.get("index_name")
return ensure_timezone_naive(frame)
def _json_default(value: Any) -> Any:
"""JSON serializer for unsupported types."""
if isinstance(value, datetime | date):
return value.isoformat()
if isinstance(value, pd.Timestamp):
return value.isoformat()
if isinstance(value, pd.Series):
return value.tolist()
if isinstance(value, set):
return list(value)
raise TypeError(f"Unsupported type {type(value)!r} for cache serialization")
def _decode_json_payload(raw_data: str) -> Any:
"""Decode JSON payloads with DataFrame support."""
payload = json.loads(raw_data)
if isinstance(payload, dict) and payload.get("__cache_type__") == "dataframe":
return _payload_to_dataframe(payload)
if isinstance(payload, dict) and payload.get("__cache_type__") == "dict":
result: dict[str, Any] = {}
for key, value in payload.get("data", {}).items():
if isinstance(value, dict) and value.get("__cache_type__") == "dataframe":
result[key] = _payload_to_dataframe(value)
else:
result[key] = value
return result
return payload
def normalize_timezone(index: pd.Index | Sequence[Any]) -> pd.DatetimeIndex:
"""Return a timezone-naive :class:`~pandas.DatetimeIndex` in UTC."""
dt_index = index if isinstance(index, pd.DatetimeIndex) else pd.DatetimeIndex(index)
if dt_index.tz is not None:
dt_index = dt_index.tz_convert("UTC").tz_localize(None)
return dt_index
def ensure_timezone_naive(df: pd.DataFrame) -> pd.DataFrame:
"""Ensure DataFrame has timezone-naive datetime index.
Args:
df: DataFrame with potentially timezone-aware index
Returns:
DataFrame with timezone-naive index
"""
if isinstance(df.index, pd.DatetimeIndex):
df = df.copy()
df.index = normalize_timezone(df.index)
return df
def get_cache_stats() -> dict[str, Any]:
"""Get current cache statistics with memory information.
Returns:
Dictionary containing cache performance metrics
"""
stats: dict[str, float | int] = cast(dict[str, float | int], dict(_cache_stats))
# Calculate hit rate
total_requests = stats["hits"] + stats["misses"]
hit_rate = (stats["hits"] / total_requests * 100) if total_requests > 0 else 0
stats["hit_rate_percent"] = round(hit_rate, 2)
stats["total_requests"] = total_requests
# Memory cache stats
stats["memory_cache_size"] = len(_memory_cache)
stats["memory_cache_max_size"] = _memory_cache_max_size
# Add memory statistics
stats.update(_cache_memory_stats)
# Calculate memory cache size in bytes
memory_size_bytes = 0
for entry in _memory_cache.values():
if "data" in entry:
try:
if hasattr(entry["data"], "__sizeof__"):
memory_size_bytes += entry["data"].__sizeof__()
elif isinstance(entry["data"], str | bytes):
memory_size_bytes += len(entry["data"])
elif isinstance(entry["data"], pd.DataFrame):
memory_size_bytes += entry["data"].memory_usage(deep=True).sum()
except Exception:
pass # Skip if size calculation fails
stats["memory_cache_bytes"] = memory_size_bytes
stats["memory_cache_mb"] = memory_size_bytes / (1024**2)
return stats
def reset_cache_stats() -> None:
"""Reset cache statistics."""
global _cache_stats
_cache_stats.clear()
_cache_stats.update(
{
"hits": 0.0,
"misses": 0.0,
"sets": 0.0,
"errors": 0.0,
"serialization_time": 0.0,
"deserialization_time": 0.0,
}
)
def generate_cache_key(base_key: str, **kwargs) -> str:
"""Generate versioned cache key with consistent hashing.
Args:
base_key: Base cache key
**kwargs: Additional parameters to include in key
Returns:
Versioned and hashed cache key
"""
# Include cache version and sorted parameters
key_parts = [CACHE_VERSION, base_key]
# Sort kwargs for consistent key generation
if kwargs:
sorted_params = sorted(kwargs.items())
param_str = ":".join(f"{k}={v}" for k, v in sorted_params)
key_parts.append(param_str)
full_key = ":".join(str(part) for part in key_parts)
# Hash long keys to prevent Redis key length limits
if len(full_key) > 250:
key_hash = hashlib.md5(full_key.encode()).hexdigest()
return f"{CACHE_VERSION}:hashed:{key_hash}"
return full_key
def _cleanup_expired_memory_cache():
"""Clean up expired entries from memory cache and enforce size limit with memory tracking."""
current_time = time.time()
bytes_freed = 0
# Remove expired entries
expired_keys = [
k
for k, v in _memory_cache.items()
if "expiry" in v and v["expiry"] < current_time
]
for k in expired_keys:
entry = _memory_cache[k]
if "data" in entry and isinstance(entry["data"], pd.DataFrame):
bytes_freed += entry["data"].memory_usage(deep=True).sum()
del _memory_cache[k]
# Calculate current memory usage
current_memory_bytes = 0
for entry in _memory_cache.values():
if "data" in entry and isinstance(entry["data"], pd.DataFrame):
current_memory_bytes += entry["data"].memory_usage(deep=True).sum()
# Enforce memory-based size limit (100MB default)
memory_limit_bytes = 100 * 1024 * 1024 # 100MB
# Enforce size limit - remove oldest entries if over limit
if (
len(_memory_cache) > _memory_cache_max_size
or current_memory_bytes > memory_limit_bytes
):
# Sort by expiry time (oldest first)
sorted_items = sorted(
_memory_cache.items(), key=lambda x: x[1].get("expiry", float("inf"))
)
# Calculate how many to remove
num_to_remove = max(len(_memory_cache) - _memory_cache_max_size, 0)
# Remove by memory if over memory limit
if current_memory_bytes > memory_limit_bytes:
removed_memory = 0
for k, v in sorted_items:
if "data" in v and isinstance(v["data"], pd.DataFrame):
entry_size = v["data"].memory_usage(deep=True).sum()
removed_memory += entry_size
bytes_freed += entry_size
del _memory_cache[k]
num_to_remove = max(num_to_remove, 1)
if removed_memory >= (current_memory_bytes - memory_limit_bytes):
break
else:
# Remove by count
for k, v in sorted_items[:num_to_remove]:
if "data" in v and isinstance(v["data"], pd.DataFrame):
bytes_freed += v["data"].memory_usage(deep=True).sum()
del _memory_cache[k]
if num_to_remove > 0:
logger.debug(
f"Removed {num_to_remove} entries from memory cache "
f"(freed {bytes_freed / (1024**2):.2f}MB)"
)
# Update memory stats
_cache_memory_stats["memory_cache_bytes"] = current_memory_bytes - bytes_freed
# Global Redis connection pool - created once and reused
_redis_pool: redis.ConnectionPool | None = None
def _get_or_create_redis_pool() -> redis.ConnectionPool | None:
"""Create or return existing Redis connection pool."""
global _redis_pool
if _redis_pool is not None:
return _redis_pool
try:
# Build connection pool parameters
pool_params = {
"host": REDIS_HOST,
"port": REDIS_PORT,
"db": REDIS_DB,
"max_connections": settings.db.redis_max_connections,
"retry_on_timeout": settings.db.redis_retry_on_timeout,
"socket_timeout": settings.db.redis_socket_timeout,
"socket_connect_timeout": settings.db.redis_socket_connect_timeout,
"health_check_interval": 30, # Check connection health every 30 seconds
}
# Only add password if provided
if REDIS_PASSWORD:
pool_params["password"] = REDIS_PASSWORD
# Only add SSL params if SSL is enabled
if REDIS_SSL:
pool_params["ssl"] = True
pool_params["ssl_check_hostname"] = False
_redis_pool = redis.ConnectionPool(**pool_params)
logger.debug(
f"Created Redis connection pool with {settings.db.redis_max_connections} max connections"
)
return _redis_pool
except Exception as e:
logger.warning(f"Failed to create Redis connection pool: {e}")
return None
def get_redis_client() -> redis.Redis | None:
"""
Get a Redis client using the centralized connection pool.
This function uses a singleton connection pool to avoid pool exhaustion
and provides robust error handling with graceful fallback.
"""
if not CACHE_ENABLED:
return None
try:
# Get or create the connection pool
pool = _get_or_create_redis_pool()
if pool is None:
return None
# Create client using the shared pool
client = redis.Redis(
connection_pool=pool,
decode_responses=False,
)
# Test connection with a timeout to avoid hanging
client.ping()
return client # type: ignore[no-any-return]
except redis.ConnectionError as e:
logger.warning(f"Redis connection failed: {e}. Using in-memory cache.")
return None
except redis.TimeoutError as e:
logger.warning(f"Redis connection timeout: {e}. Using in-memory cache.")
return None
except Exception as e:
# Handle the IndexError: pop from empty list and other unexpected errors
logger.warning(f"Redis client error: {e}. Using in-memory cache.")
# Reset the pool if we encounter unexpected errors
global _redis_pool
_redis_pool = None
return None
def _deserialize_cached_data(data: bytes, key: str) -> Any:
"""Deserialize cached data with multiple format support and timezone handling."""
start_time = time.time()
try:
# Try msgpack with zlib compression first (most efficient for DataFrames)
if data[:2] == b"\x78\x9c": # zlib magic bytes
try:
decompressed = zlib.decompress(data)
# Try msgpack first
try:
result = msgpack.loads(decompressed, raw=False)
# Handle DataFrame reconstruction with timezone normalization
if isinstance(result, dict) and result.get("_type") == "dataframe":
df = pd.DataFrame.from_dict(result["data"], orient="index")
# Restore proper index
if result.get("index_data"):
if result.get("index_type") == "datetime":
df.index = pd.to_datetime(result["index_data"])
df.index = normalize_timezone(df.index)
else:
df.index = result["index_data"]
elif result.get("index_type") == "datetime":
df.index = pd.to_datetime(df.index)
df.index = normalize_timezone(df.index)
# Restore column order
if result.get("columns"):
df = df[result["columns"]]
return df
return result
except Exception as e:
logger.debug(f"Msgpack decompressed failed for {key}: {e}")
try:
return _decode_json_payload(decompressed.decode("utf-8"))
except Exception as e2:
logger.debug(f"JSON decompressed failed for {key}: {e2}")
pass
except Exception:
pass
# Try msgpack uncompressed
try:
result = msgpack.loads(data, raw=False)
if isinstance(result, dict) and result.get("_type") == "dataframe":
df = pd.DataFrame.from_dict(result["data"], orient="index")
# Restore proper index
if result.get("index_data"):
if result.get("index_type") == "datetime":
df.index = pd.to_datetime(result["index_data"])
df.index = normalize_timezone(df.index)
else:
df.index = result["index_data"]
elif result.get("index_type") == "datetime":
df.index = pd.to_datetime(df.index)
df.index = normalize_timezone(df.index)
# Restore column order
if result.get("columns"):
df = df[result["columns"]]
return df
return result
except Exception:
pass
# Fall back to JSON
try:
decoded = data.decode() if isinstance(data, bytes) else data
return _decode_json_payload(decoded)
except Exception:
pass
except Exception as e:
_cache_stats["errors"] += 1
logger.warning(f"Failed to deserialize cache data for key {key}: {e}")
return None
finally:
_cache_stats["deserialization_time"] += time.time() - start_time
_cache_stats["errors"] += 1
logger.warning(f"Failed to deserialize cache data for key {key} - no format worked")
return None
def get_from_cache(key: str) -> Any | None:
"""
Get data from the cache.
Args:
key: Cache key
Returns:
Cached data or None if not found
"""
if not CACHE_ENABLED:
return None
# Try Redis first
redis_client = get_redis_client()
if redis_client:
try:
data = redis_client.get(key)
if data:
_cache_stats["hits"] += 1
logger.debug(f"Cache hit for {key} (Redis)")
result = _deserialize_cached_data(data, key) # type: ignore[arg-type]
return result
except Exception as e:
_cache_stats["errors"] += 1
logger.warning(f"Error reading from Redis cache: {e}")
# Fall back to in-memory cache
if key in _memory_cache:
entry = _memory_cache[key]
if "expiry" not in entry or entry["expiry"] > time.time():
_cache_stats["hits"] += 1
logger.debug(f"Cache hit for {key} (memory)")
return entry["data"]
else:
# Clean up expired entry
del _memory_cache[key]
_cache_stats["misses"] += 1
logger.debug(f"Cache miss for {key}")
return None
def _serialize_data(data: Any, key: str) -> bytes:
"""Serialize data efficiently based on type with optimized formats and memory tracking."""
start_time = time.time()
original_size = 0
compressed_size = 0
try:
# Special handling for DataFrames - use msgpack with timezone normalization
if isinstance(data, pd.DataFrame):
original_size = data.memory_usage(deep=True).sum()
# Track large objects
if original_size > 10 * 1024 * 1024: # 10MB threshold
_cache_memory_stats["large_object_count"] += 1
logger.debug(
f"Serializing large DataFrame for {key}: {original_size / (1024**2):.2f}MB"
)
# Ensure timezone-naive DataFrame
df = ensure_timezone_naive(data)
# Try msgpack first (most efficient for DataFrames)
try:
# Convert to msgpack-serializable format with proper index handling
df_dict = {
"_type": "dataframe",
"data": df.to_dict("index"),
"index_type": (
"datetime"
if isinstance(df.index, pd.DatetimeIndex)
else "other"
),
"columns": list(df.columns),
"index_data": [str(idx) for idx in df.index],
}
msgpack_data = cast(bytes, msgpack.packb(df_dict))
compressed = zlib.compress(msgpack_data, level=1)
compressed_size = len(compressed)
# Track compression savings
if original_size > compressed_size:
_cache_memory_stats["compression_savings_bytes"] += (
original_size - compressed_size
)
return compressed
except Exception as e:
logger.debug(f"Msgpack DataFrame serialization failed for {key}: {e}")
json_payload = {
"__cache_type__": "dataframe",
"data": _dataframe_to_payload(df),
}
compressed = zlib.compress(
json.dumps(json_payload).encode("utf-8"), level=1
)
compressed_size = len(compressed)
if original_size > compressed_size:
_cache_memory_stats["compression_savings_bytes"] += (
original_size - compressed_size
)
return compressed
# For dictionaries with DataFrames (like backtest results)
if isinstance(data, dict) and any(
isinstance(v, pd.DataFrame) for v in data.values()
):
# Ensure all DataFrames are timezone-naive
processed_data = {}
for k, v in data.items():
if isinstance(v, pd.DataFrame):
processed_data[k] = ensure_timezone_naive(v)
else:
processed_data[k] = v
try:
# Try msgpack for mixed dict with DataFrames
serializable_data = {}
for k, v in processed_data.items():
if isinstance(v, pd.DataFrame):
serializable_data[k] = {
"_type": "dataframe",
"data": v.to_dict("index"),
"index_type": (
"datetime"
if isinstance(v.index, pd.DatetimeIndex)
else "other"
),
}
else:
serializable_data[k] = v
msgpack_data = cast(bytes, msgpack.packb(serializable_data))
compressed = zlib.compress(msgpack_data, level=1)
return compressed
except Exception:
payload = {
"__cache_type__": "dict",
"data": {
key: (
{
"__cache_type__": "dataframe",
"data": _dataframe_to_payload(value),
}
if isinstance(value, pd.DataFrame)
else value
)
for key, value in processed_data.items()
},
}
compressed = zlib.compress(
json.dumps(payload, default=_json_default).encode("utf-8"),
level=1,
)
return compressed
# For simple data types, try msgpack first (more efficient than JSON)
if isinstance(data, dict | list | str | int | float | bool | type(None)):
try:
return cast(bytes, msgpack.packb(data))
except Exception:
# Fall back to JSON
return json.dumps(data, default=_json_default).encode("utf-8")
raise TypeError(f"Unsupported cache data type {type(data)!r} for key {key}")
except TypeError as exc:
_cache_stats["errors"] += 1
logger.warning(f"Unsupported data type for cache key {key}: {exc}")
raise
except Exception as e:
_cache_stats["errors"] += 1
logger.warning(f"Failed to serialize data for key {key}: {e}")
# Fall back to JSON string representation
try:
return json.dumps(str(data)).encode("utf-8")
except Exception:
return b"null"
finally:
_cache_stats["serialization_time"] += time.time() - start_time
def save_to_cache(key: str, data: Any, ttl: int | None = None) -> bool:
"""
Save data to the cache.
Args:
key: Cache key
data: Data to cache
ttl: Time-to-live in seconds (default: CACHE_TTL_SECONDS)
Returns:
True if saved successfully, False otherwise
"""
if not CACHE_ENABLED:
return False
resolved_ttl = CACHE_TTL_SECONDS if ttl is None else ttl
# Serialize data efficiently
try:
serialized_data = _serialize_data(data, key)
except TypeError as exc:
logger.warning(f"Skipping cache for {key}: {exc}")
return False
# Store cache metadata
_cache_metadata[key] = {
"created_at": datetime.now(UTC).isoformat(),
"ttl": resolved_ttl,
"size_bytes": len(serialized_data),
"version": CACHE_VERSION,
}
success = False
# Try Redis first
redis_client = get_redis_client()
if redis_client:
try:
redis_client.setex(key, resolved_ttl, serialized_data)
logger.debug(f"Saved to Redis cache: {key}")
success = True
except Exception as e:
_cache_stats["errors"] += 1
logger.warning(f"Error saving to Redis cache: {e}")
if not success:
# Fall back to in-memory cache
_memory_cache[key] = {"data": data, "expiry": time.time() + resolved_ttl}
logger.debug(f"Saved to memory cache: {key}")
success = True
# Clean up memory cache if needed
if len(_memory_cache) > _memory_cache_max_size:
_cleanup_expired_memory_cache()
if success:
_cache_stats["sets"] += 1
return success
def cleanup_redis_pool() -> None:
"""Cleanup Redis connection pool."""
global _redis_pool
if _redis_pool:
try:
_redis_pool.disconnect()
logger.debug("Redis connection pool disconnected")
except Exception as e:
logger.warning(f"Error disconnecting Redis pool: {e}")
finally:
_redis_pool = None
def clear_cache(pattern: str | None = None) -> int:
"""
Clear cache entries matching the pattern.
Args:
pattern: Pattern to match keys (e.g., "stock:*")
If None, clears all cache
Returns:
Number of entries cleared
"""
count = 0
# Clear from Redis
redis_client = get_redis_client()
if redis_client:
try:
if pattern:
keys = cast(list[bytes], redis_client.keys(pattern))
if keys:
delete_result = cast(int, redis_client.delete(*keys))
count += delete_result
else:
flush_result = cast(int, redis_client.flushdb())
count += flush_result
logger.info(f"Cleared {count} entries from Redis cache")
except Exception as e:
logger.warning(f"Error clearing Redis cache: {e}")
# Clear from memory cache
if pattern:
# Simple pattern matching for memory cache (only supports prefix*)
if pattern.endswith("*"):
prefix = pattern[:-1]
memory_keys = [k for k in _memory_cache.keys() if k.startswith(prefix)]
else:
memory_keys = [k for k in _memory_cache.keys() if k == pattern]
for k in memory_keys:
del _memory_cache[k]
count += len(memory_keys)
else:
count += len(_memory_cache)
_memory_cache.clear()
logger.info(f"Cleared {count} total cache entries")
return count
class CacheManager:
"""
Enhanced cache manager with async support and additional methods.
This manager now integrates with the centralized Redis connection pool
for improved performance and resource management.
"""
def __init__(self):
"""Initialize the cache manager."""
self._redis_client = None
self._initialized = False
self._use_performance_redis = True # Flag to use new performance module
def _ensure_client(self) -> redis.Redis | None:
"""Ensure Redis client is initialized with connection pooling."""
if not self._initialized:
# Always use the new robust connection pooling approach
self._redis_client = get_redis_client()
self._initialized = True
return self._redis_client
async def get(self, key: str) -> Any | None:
"""Async wrapper for get_from_cache."""
return await asyncio.get_event_loop().run_in_executor(None, get_from_cache, key)
async def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
"""Async wrapper for save_to_cache."""
return await asyncio.get_event_loop().run_in_executor(
None, save_to_cache, key, value, ttl
)
async def set_with_ttl(self, key: str, value: str, ttl: int) -> bool:
"""Set a value with specific TTL."""
if not CACHE_ENABLED:
return False
client = self._ensure_client()
if client:
try:
client.setex(key, ttl, value)
return True
except Exception as e:
logger.warning(f"Error setting value with TTL: {e}")
# Fallback to memory cache
_memory_cache[key] = {"data": value, "expiry": time.time() + ttl}
return True
async def set_many_with_ttl(self, items: list[tuple[str, str, int]]) -> bool:
"""Set multiple values with TTL in a batch."""
if not CACHE_ENABLED:
return False
client = self._ensure_client()
if client:
try:
pipe = client.pipeline()
for key, value, ttl in items:
pipe.setex(key, ttl, value)
pipe.execute()
return True
except Exception as e:
logger.warning(f"Error in batch set with TTL: {e}")
# Fallback to memory cache
for key, value, ttl in items:
_memory_cache[key] = {"data": value, "expiry": time.time() + ttl}
return True
async def get_many(self, keys: list[str]) -> dict[str, Any]:
"""Get multiple values at once using pipeline for better performance."""
results: dict[str, Any] = {}
if not CACHE_ENABLED:
return results
client = self._ensure_client()
if client:
try:
# Use pipeline for better performance with multiple operations
pipe = client.pipeline(transaction=False)
for key in keys:
pipe.get(key)
values = pipe.execute()
for key, value in zip(keys, values, strict=False): # type: ignore[arg-type]
if value:
decoded_value: Any
if isinstance(value, bytes):
decoded_value = value.decode()
else:
decoded_value = value
if isinstance(decoded_value, str):
try:
# Try to decode JSON if it's stored as JSON
results[key] = json.loads(decoded_value)
continue
except json.JSONDecodeError:
pass
# If not JSON or decoding fails, store as-is
results[key] = decoded_value
except Exception as e:
logger.warning(f"Error in batch get: {e}")
# Fallback to memory cache for missing keys
for key in keys:
if key not in results and key in _memory_cache:
entry = _memory_cache[key]
if "expiry" not in entry or entry["expiry"] > time.time():
results[key] = entry["data"]
return results
async def delete(self, key: str) -> bool:
"""Delete a key from cache."""
if not CACHE_ENABLED:
return False
deleted = False
client = self._ensure_client()
if client:
try:
deleted = bool(client.delete(key))
except Exception as e:
logger.warning(f"Error deleting key: {e}")
# Also delete from memory cache
if key in _memory_cache:
del _memory_cache[key]
deleted = True
return deleted
async def delete_pattern(self, pattern: str) -> int:
"""Delete all keys matching a pattern."""
count = 0
if not CACHE_ENABLED:
return count
client = self._ensure_client()
if client:
try:
keys = cast(list[bytes], client.keys(pattern))
if keys:
delete_result = cast(int, client.delete(*keys))
count += delete_result
except Exception as e:
logger.warning(f"Error deleting pattern: {e}")
# Also delete from memory cache
if pattern.endswith("*"):
prefix = pattern[:-1]
memory_keys = [k for k in _memory_cache.keys() if k.startswith(prefix)]
for k in memory_keys:
del _memory_cache[k]
count += 1
return count
async def exists(self, key: str) -> bool:
"""Check if a key exists."""
if not CACHE_ENABLED:
return False
client = self._ensure_client()
if client:
try:
return bool(client.exists(key))
except Exception as e:
logger.warning(f"Error checking key existence: {e}")
# Fallback to memory cache
if key in _memory_cache:
entry = _memory_cache[key]
return "expiry" not in entry or entry["expiry"] > time.time()
return False
async def count_keys(self, pattern: str) -> int:
"""Count keys matching a pattern."""
if not CACHE_ENABLED:
return 0
count = 0
client = self._ensure_client()
if client:
try:
cursor = 0
while True:
cursor, keys = client.scan(cursor, match=pattern, count=1000) # type: ignore[misc]
count += len(keys)
if cursor == 0:
break
except Exception as e:
logger.warning(f"Error counting keys: {e}")
# Add memory cache count
if pattern.endswith("*"):
prefix = pattern[:-1]
count += sum(1 for k in _memory_cache.keys() if k.startswith(prefix))
return count
async def batch_save(self, items: list[tuple[str, Any, int | None]]) -> int:
"""
Save multiple items to cache using pipeline for better performance.
Args:
items: List of tuples (key, data, ttl)
Returns:
Number of items successfully saved
"""
if not CACHE_ENABLED:
return 0
saved_count = 0
client = self._ensure_client()
if client:
try:
pipe = client.pipeline(transaction=False)
for key, data, ttl in items:
if ttl is None:
ttl = CACHE_TTL_SECONDS
# Convert data to JSON
json_data = json.dumps(data)
pipe.setex(key, ttl, json_data)
results = pipe.execute()
saved_count = sum(1 for r in results if r)
logger.debug(f"Batch saved {saved_count} items to Redis cache")
except Exception as e:
logger.warning(f"Error in batch save to Redis: {e}")
# Fallback to memory cache for failed items
if saved_count < len(items):
for key, data, ttl in items:
if ttl is None:
ttl = CACHE_TTL_SECONDS
_memory_cache[key] = {"data": data, "expiry": time.time() + ttl}
saved_count += 1
return saved_count
async def batch_delete(self, keys: list[str]) -> int:
"""
Delete multiple keys from cache using pipeline for better performance.
Args:
keys: List of keys to delete
Returns:
Number of keys deleted
"""
if not CACHE_ENABLED:
return 0
deleted_count = 0
client = self._ensure_client()
if client and keys:
try:
# Use single delete command for multiple keys
deleted_result = client.delete(*keys)
deleted_count = cast(int, deleted_result)
logger.debug(f"Batch deleted {deleted_count} keys from Redis cache")
except Exception as e:
logger.warning(f"Error in batch delete from Redis: {e}")
# Also delete from memory cache
for key in keys:
if key in _memory_cache:
del _memory_cache[key]
deleted_count += 1
return deleted_count
async def batch_exists(self, keys: list[str]) -> dict[str, bool]:
"""
Check existence of multiple keys using pipeline for better performance.
Args:
keys: List of keys to check
Returns:
Dictionary mapping key to existence boolean
"""
results: dict[str, bool] = {}
if not CACHE_ENABLED:
return dict.fromkeys(keys, False)
client = self._ensure_client()
if client:
try:
pipe = client.pipeline(transaction=False)
for key in keys:
pipe.exists(key)
existence_results = pipe.execute()
for key, exists in zip(keys, existence_results, strict=False):
results[key] = bool(exists)
except Exception as e:
logger.warning(f"Error in batch exists check: {e}")
# Check memory cache for missing keys
for key in keys:
if key not in results and key in _memory_cache:
entry = _memory_cache[key]
results[key] = "expiry" not in entry or entry["expiry"] > time.time()
elif key not in results:
results[key] = False
return results
async def batch_get_or_set(
self, items: list[tuple[str, Any, int | None]]
) -> dict[str, Any]:
"""
Get multiple values, setting missing ones atomically using pipeline.
Args:
items: List of tuples (key, default_value, ttl)
Returns:
Dictionary of key-value pairs
"""
if not CACHE_ENABLED:
return {key: default for key, default, _ in items}
results: dict[str, Any] = {}
keys = [item[0] for item in items]
# First, try to get all values
existing = await self.get_many(keys)
# Identify missing keys
missing_items = [item for item in items if item[0] not in existing]
# Set missing values if any
if missing_items:
await self.batch_save(missing_items)
# Add default values to results
for key, default_value, _ in missing_items:
results[key] = default_value
# Add existing values to results
results.update(existing)
return results
```
--------------------------------------------------------------------------------
/tests/test_deep_research_parallel_execution.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive test suite for DeepResearchAgent parallel execution functionality.
This test suite covers:
- Parallel vs sequential execution modes
- Subagent creation and orchestration
- Task routing to specialized subagents
- Parallel execution fallback mechanisms
- Result synthesis from parallel tasks
- Performance characteristics of parallel execution
"""
import asyncio
import time
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langgraph.checkpoint.memory import MemorySaver
from pydantic import ConfigDict
from maverick_mcp.agents.deep_research import (
BaseSubagent,
CompetitiveResearchAgent,
DeepResearchAgent,
FundamentalResearchAgent,
SentimentResearchAgent,
TechnicalResearchAgent,
)
from maverick_mcp.utils.parallel_research import (
ParallelResearchConfig,
ResearchResult,
ResearchTask,
)
class MockLLM(BaseChatModel):
"""Mock LLM for testing."""
# Allow extra fields to be set on this Pydantic model
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
def __init__(self, **kwargs):
# Extract our custom fields before calling super()
self._response_content = kwargs.pop("response_content", "Mock response")
super().__init__(**kwargs)
def _generate(self, messages, stop=None, **kwargs):
# This method should not be called in async tests
raise NotImplementedError("Use ainvoke for async tests")
async def ainvoke(self, messages, config=None, **kwargs):
"""Mock async invocation."""
await asyncio.sleep(0.01) # Simulate processing time
return AIMessage(content=self._response_content)
@property
def _llm_type(self) -> str:
return "mock_llm"
class TestDeepResearchAgentParallelExecution:
"""Test DeepResearchAgent parallel execution capabilities."""
@pytest.fixture
def mock_llm(self):
"""Create mock LLM."""
return MockLLM(
response_content='{"KEY_INSIGHTS": ["Test insight"], "SENTIMENT": {"direction": "bullish", "confidence": 0.8}, "CREDIBILITY": 0.9}'
)
@pytest.fixture
def parallel_config(self):
"""Create test parallel configuration."""
return ParallelResearchConfig(
max_concurrent_agents=3,
timeout_per_agent=5,
enable_fallbacks=True,
rate_limit_delay=0.1,
)
@pytest.fixture
def deep_research_agent(self, mock_llm, parallel_config):
"""Create DeepResearchAgent with parallel execution enabled."""
return DeepResearchAgent(
llm=mock_llm,
persona="moderate",
checkpointer=MemorySaver(),
enable_parallel_execution=True,
parallel_config=parallel_config,
)
@pytest.fixture
def sequential_agent(self, mock_llm):
"""Create DeepResearchAgent with sequential execution."""
return DeepResearchAgent(
llm=mock_llm,
persona="moderate",
checkpointer=MemorySaver(),
enable_parallel_execution=False,
)
def test_agent_initialization_parallel_enabled(self, deep_research_agent):
"""Test agent initialization with parallel execution enabled."""
assert deep_research_agent.enable_parallel_execution is True
assert deep_research_agent.parallel_config is not None
assert deep_research_agent.parallel_orchestrator is not None
assert deep_research_agent.task_distributor is not None
assert deep_research_agent.parallel_config.max_concurrent_agents == 3
def test_agent_initialization_sequential(self, sequential_agent):
"""Test agent initialization with sequential execution."""
assert sequential_agent.enable_parallel_execution is False
# These components should still be initialized for potential future use
assert sequential_agent.parallel_orchestrator is not None
@pytest.mark.asyncio
async def test_parallel_execution_mode_selection(self, deep_research_agent):
"""Test parallel execution mode selection."""
# Mock search providers to be available
mock_provider = AsyncMock()
deep_research_agent.search_providers = [mock_provider]
with (
patch.object(
deep_research_agent,
"_execute_parallel_research",
new_callable=AsyncMock,
) as mock_parallel,
patch.object(deep_research_agent.graph, "ainvoke") as mock_sequential,
patch.object(
deep_research_agent,
"_ensure_search_providers_loaded",
return_value=None,
),
):
mock_parallel.return_value = {
"status": "success",
"execution_mode": "parallel",
"agent_type": "deep_research",
}
# Test with parallel execution enabled (default)
result = await deep_research_agent.research_comprehensive(
topic="AAPL analysis", session_id="test_123"
)
# Should use parallel execution
mock_parallel.assert_called_once()
mock_sequential.assert_not_called()
assert result["execution_mode"] == "parallel"
@pytest.mark.asyncio
async def test_sequential_execution_mode_selection(self, sequential_agent):
"""Test sequential execution mode selection."""
# Mock search providers to be available
mock_provider = AsyncMock()
sequential_agent.search_providers = [mock_provider]
with (
patch.object(
sequential_agent, "_execute_parallel_research"
) as mock_parallel,
patch.object(sequential_agent.graph, "ainvoke") as mock_sequential,
patch.object(
sequential_agent, "_ensure_search_providers_loaded", return_value=None
),
):
mock_sequential.return_value = {
"status": "success",
"persona": "moderate",
"research_confidence": 0.8,
}
# Test with parallel execution disabled
await sequential_agent.research_comprehensive(
topic="AAPL analysis", session_id="test_123"
)
# Should use sequential execution
mock_parallel.assert_not_called()
mock_sequential.assert_called_once()
@pytest.mark.asyncio
async def test_parallel_execution_override(self, deep_research_agent):
"""Test overriding parallel execution at runtime."""
# Mock search providers to be available
mock_provider = AsyncMock()
deep_research_agent.search_providers = [mock_provider]
with (
patch.object(
deep_research_agent, "_execute_parallel_research"
) as mock_parallel,
patch.object(deep_research_agent.graph, "ainvoke") as mock_sequential,
patch.object(
deep_research_agent,
"_ensure_search_providers_loaded",
return_value=None,
),
):
mock_sequential.return_value = {"status": "success", "persona": "moderate"}
# Override parallel execution to false
await deep_research_agent.research_comprehensive(
topic="AAPL analysis",
session_id="test_123",
use_parallel_execution=False,
)
# Should use sequential despite agent default
mock_parallel.assert_not_called()
mock_sequential.assert_called_once()
@pytest.mark.asyncio
async def test_parallel_execution_fallback(self, deep_research_agent):
"""Test fallback to sequential when parallel execution fails."""
# Mock search providers to be available
mock_provider = AsyncMock()
deep_research_agent.search_providers = [mock_provider]
with (
patch.object(
deep_research_agent,
"_execute_parallel_research",
new_callable=AsyncMock,
) as mock_parallel,
patch.object(deep_research_agent.graph, "ainvoke") as mock_sequential,
patch.object(
deep_research_agent,
"_ensure_search_providers_loaded",
return_value=None,
),
):
# Parallel execution fails
mock_parallel.side_effect = RuntimeError("Parallel execution failed")
mock_sequential.return_value = {
"status": "success",
"persona": "moderate",
"research_confidence": 0.7,
}
result = await deep_research_agent.research_comprehensive(
topic="AAPL analysis", session_id="test_123"
)
# Should attempt parallel then fall back to sequential
mock_parallel.assert_called_once()
mock_sequential.assert_called_once()
assert result["status"] == "success"
@pytest.mark.asyncio
async def test_execute_parallel_research_task_distribution(
self, deep_research_agent
):
"""Test parallel research task distribution."""
with (
patch.object(
deep_research_agent.task_distributor, "distribute_research_tasks"
) as mock_distribute,
patch.object(
deep_research_agent.parallel_orchestrator,
"execute_parallel_research",
new_callable=AsyncMock,
) as mock_execute,
):
# Mock task distribution
mock_tasks = [
ResearchTask(
"test_123_fundamental", "fundamental", "AAPL", ["earnings"]
),
ResearchTask("test_123_sentiment", "sentiment", "AAPL", ["news"]),
]
mock_distribute.return_value = mock_tasks
# Mock orchestrator execution
mock_result = ResearchResult()
mock_result.successful_tasks = 2
mock_result.failed_tasks = 0
mock_result.synthesis = {"confidence_score": 0.85}
mock_execute.return_value = mock_result
initial_state = {
"persona": "moderate",
"research_topic": "AAPL analysis",
"session_id": "test_123",
"focus_areas": ["earnings", "sentiment"],
}
await deep_research_agent._execute_parallel_research(
topic="AAPL analysis",
session_id="test_123",
depth="standard",
focus_areas=["earnings", "sentiment"],
start_time=datetime.now(),
initial_state=initial_state,
)
# Verify task distribution was called correctly
mock_distribute.assert_called_once_with(
topic="AAPL analysis",
session_id="test_123",
focus_areas=["earnings", "sentiment"],
)
# Verify orchestrator was called with distributed tasks
mock_execute.assert_called_once()
args, kwargs = mock_execute.call_args
assert kwargs["tasks"] == mock_tasks
@pytest.mark.asyncio
async def test_subagent_task_routing(self, deep_research_agent):
"""Test routing tasks to appropriate subagents."""
# Test fundamental routing
fundamental_task = ResearchTask(
"test_fundamental", "fundamental", "AAPL", ["earnings"]
)
with patch(
"maverick_mcp.agents.deep_research.FundamentalResearchAgent"
) as mock_fundamental:
mock_subagent = AsyncMock()
mock_subagent.execute_research.return_value = {
"research_type": "fundamental"
}
mock_fundamental.return_value = mock_subagent
# This would normally be called by the orchestrator
# We're testing the routing logic directly
await deep_research_agent._execute_subagent_task(fundamental_task)
mock_fundamental.assert_called_once_with(deep_research_agent)
mock_subagent.execute_research.assert_called_once_with(fundamental_task)
@pytest.mark.asyncio
async def test_unknown_task_type_fallback(self, deep_research_agent):
"""Test fallback for unknown task types."""
unknown_task = ResearchTask("test_unknown", "unknown_type", "AAPL", ["test"])
with patch(
"maverick_mcp.agents.deep_research.FundamentalResearchAgent"
) as mock_fundamental:
mock_subagent = AsyncMock()
mock_subagent.execute_research.return_value = {
"research_type": "fundamental"
}
mock_fundamental.return_value = mock_subagent
await deep_research_agent._execute_subagent_task(unknown_task)
# Should fall back to fundamental analysis
mock_fundamental.assert_called_once_with(deep_research_agent)
@pytest.mark.asyncio
async def test_parallel_result_synthesis(self, deep_research_agent, mock_llm):
"""Test synthesis of results from parallel tasks."""
# Create mock task results
task_results = {
"test_123_fundamental": ResearchTask(
"test_123_fundamental", "fundamental", "AAPL", ["earnings"]
),
"test_123_sentiment": ResearchTask(
"test_123_sentiment", "sentiment", "AAPL", ["news"]
),
}
# Set tasks as completed with results
task_results["test_123_fundamental"].status = "completed"
task_results["test_123_fundamental"].result = {
"insights": ["Strong earnings growth"],
"sentiment": {"direction": "bullish", "confidence": 0.8},
"credibility_score": 0.9,
}
task_results["test_123_sentiment"].status = "completed"
task_results["test_123_sentiment"].result = {
"insights": ["Positive market sentiment"],
"sentiment": {"direction": "bullish", "confidence": 0.7},
"credibility_score": 0.8,
}
# Mock LLM synthesis response
mock_llm._response_content = "Synthesized analysis showing strong bullish outlook based on fundamental and sentiment analysis"
result = await deep_research_agent._synthesize_parallel_results(task_results)
assert result is not None
assert "synthesis" in result
assert "key_insights" in result
assert "overall_sentiment" in result
assert len(result["key_insights"]) > 0
assert result["overall_sentiment"]["direction"] == "bullish"
@pytest.mark.asyncio
async def test_synthesis_with_mixed_results(self, deep_research_agent):
"""Test synthesis with mixed successful and failed tasks."""
task_results = {
"test_123_fundamental": ResearchTask(
"test_123_fundamental", "fundamental", "AAPL", ["earnings"]
),
"test_123_technical": ResearchTask(
"test_123_technical", "technical", "AAPL", ["charts"]
),
"test_123_sentiment": ResearchTask(
"test_123_sentiment", "sentiment", "AAPL", ["news"]
),
}
# One successful, one failed, one successful
task_results["test_123_fundamental"].status = "completed"
task_results["test_123_fundamental"].result = {
"insights": ["Strong fundamentals"],
"sentiment": {"direction": "bullish", "confidence": 0.8},
}
task_results["test_123_technical"].status = "failed"
task_results["test_123_technical"].error = "Technical analysis failed"
task_results["test_123_sentiment"].status = "completed"
task_results["test_123_sentiment"].result = {
"insights": ["Mixed sentiment"],
"sentiment": {"direction": "neutral", "confidence": 0.6},
}
result = await deep_research_agent._synthesize_parallel_results(task_results)
# Should handle mixed results gracefully
assert result is not None
assert len(result["key_insights"]) > 0
assert "task_breakdown" in result
assert result["task_breakdown"]["test_123_technical"]["status"] == "failed"
@pytest.mark.asyncio
async def test_synthesis_with_no_successful_results(self, deep_research_agent):
"""Test synthesis when all tasks fail."""
task_results = {
"test_123_fundamental": ResearchTask(
"test_123_fundamental", "fundamental", "AAPL", ["earnings"]
),
"test_123_sentiment": ResearchTask(
"test_123_sentiment", "sentiment", "AAPL", ["news"]
),
}
# Both tasks failed
task_results["test_123_fundamental"].status = "failed"
task_results["test_123_fundamental"].error = "API timeout"
task_results["test_123_sentiment"].status = "failed"
task_results["test_123_sentiment"].error = "No data available"
result = await deep_research_agent._synthesize_parallel_results(task_results)
# Should handle gracefully
assert result is not None
assert result["confidence_score"] == 0.0
assert "No research results available" in result["synthesis"]
@pytest.mark.asyncio
async def test_synthesis_llm_failure_fallback(self, deep_research_agent):
"""Test fallback when LLM synthesis fails."""
task_results = {
"test_123_fundamental": ResearchTask(
"test_123_fundamental", "fundamental", "AAPL", ["earnings"]
),
}
task_results["test_123_fundamental"].status = "completed"
task_results["test_123_fundamental"].result = {
"insights": ["Good insights"],
"sentiment": {"direction": "bullish", "confidence": 0.8},
}
# Mock LLM to fail
with patch.object(
deep_research_agent.llm, "ainvoke", side_effect=RuntimeError("LLM failed")
):
result = await deep_research_agent._synthesize_parallel_results(
task_results
)
# Should use fallback synthesis
assert result is not None
assert "fallback synthesis" in result["synthesis"].lower()
@pytest.mark.asyncio
async def test_format_parallel_research_response(self, deep_research_agent):
"""Test formatting of parallel research response."""
# Create mock research result
research_result = ResearchResult()
research_result.successful_tasks = 2
research_result.failed_tasks = 0
research_result.total_execution_time = 1.5
research_result.parallel_efficiency = 2.1
research_result.synthesis = {
"confidence_score": 0.85,
"key_findings": ["Finding 1", "Finding 2"],
}
# Mock task results with sources
task1 = ResearchTask(
"test_123_fundamental", "fundamental", "AAPL", ["earnings"]
)
task1.status = "completed"
task1.result = {
"sources": [
{
"title": "AAPL Earnings Report",
"url": "https://example.com/earnings",
"credibility_score": 0.9,
}
]
}
research_result.task_results = {"test_123_fundamental": task1}
start_time = datetime.now()
formatted_result = await deep_research_agent._format_parallel_research_response(
research_result=research_result,
topic="AAPL analysis",
session_id="test_123",
depth="standard",
initial_state={"persona": "moderate"},
start_time=start_time,
)
# Verify formatted response structure
assert formatted_result["status"] == "success"
assert formatted_result["agent_type"] == "deep_research"
assert formatted_result["execution_mode"] == "parallel"
assert formatted_result["research_topic"] == "AAPL analysis"
assert formatted_result["confidence_score"] == 0.85
assert "parallel_execution_stats" in formatted_result
assert formatted_result["parallel_execution_stats"]["successful_tasks"] == 2
assert len(formatted_result["citations"]) > 0
@pytest.mark.asyncio
async def test_aggregated_sentiment_calculation(self, deep_research_agent):
"""Test aggregation of sentiment from multiple sources."""
sentiment_scores = [
{"direction": "bullish", "confidence": 0.8},
{"direction": "bullish", "confidence": 0.6},
{"direction": "neutral", "confidence": 0.7},
{"direction": "bearish", "confidence": 0.5},
]
result = deep_research_agent._calculate_aggregated_sentiment(sentiment_scores)
assert result is not None
assert "direction" in result
assert "confidence" in result
assert "consensus" in result
assert "source_count" in result
assert result["source_count"] == 4
@pytest.mark.asyncio
async def test_parallel_recommendation_derivation(self, deep_research_agent):
"""Test derivation of investment recommendations from parallel analysis."""
# Test strong bullish signal
bullish_sentiment = {"direction": "bullish", "confidence": 0.9}
recommendation = deep_research_agent._derive_parallel_recommendation(
bullish_sentiment
)
assert "strong buy" in recommendation.lower() or "buy" in recommendation.lower()
# Test bearish signal
bearish_sentiment = {"direction": "bearish", "confidence": 0.8}
recommendation = deep_research_agent._derive_parallel_recommendation(
bearish_sentiment
)
assert (
"caution" in recommendation.lower() or "negative" in recommendation.lower()
)
# Test neutral/mixed signals
neutral_sentiment = {"direction": "neutral", "confidence": 0.5}
recommendation = deep_research_agent._derive_parallel_recommendation(
neutral_sentiment
)
assert "neutral" in recommendation.lower() or "mixed" in recommendation.lower()
class TestSpecializedSubagents:
"""Test specialized research subagent functionality."""
@pytest.fixture
def mock_parent_agent(self):
"""Create mock parent DeepResearchAgent."""
parent = Mock()
parent.llm = MockLLM()
parent.search_providers = []
parent.content_analyzer = Mock()
parent.persona = Mock()
parent.persona.name = "moderate"
parent._calculate_source_credibility = Mock(return_value=0.8)
return parent
def test_base_subagent_initialization(self, mock_parent_agent):
"""Test BaseSubagent initialization."""
subagent = BaseSubagent(mock_parent_agent)
assert subagent.parent == mock_parent_agent
assert subagent.llm == mock_parent_agent.llm
assert subagent.search_providers == mock_parent_agent.search_providers
assert subagent.content_analyzer == mock_parent_agent.content_analyzer
assert subagent.persona == mock_parent_agent.persona
@pytest.mark.asyncio
async def test_fundamental_research_agent(self, mock_parent_agent):
"""Test FundamentalResearchAgent execution."""
# Mock content analyzer
mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
return_value={
"insights": ["Strong earnings growth"],
"sentiment": {"direction": "bullish", "confidence": 0.8},
"risk_factors": ["Market volatility"],
"opportunities": ["Dividend growth"],
"credibility_score": 0.9,
}
)
subagent = FundamentalResearchAgent(mock_parent_agent)
# Mock search results
with patch.object(subagent, "_perform_specialized_search") as mock_search:
mock_search.return_value = [
{
"title": "AAPL Earnings Report",
"url": "https://example.com/earnings",
"content": "Apple reported strong quarterly earnings...",
"credibility_score": 0.9,
}
]
task = ResearchTask(
"fund_task", "fundamental", "AAPL analysis", ["earnings"]
)
result = await subagent.execute_research(task)
assert result["research_type"] == "fundamental"
assert len(result["insights"]) > 0
assert "sentiment" in result
assert result["sentiment"]["direction"] == "bullish"
assert len(result["sources"]) > 0
def test_fundamental_query_generation(self, mock_parent_agent):
"""Test fundamental analysis query generation."""
subagent = FundamentalResearchAgent(mock_parent_agent)
queries = subagent._generate_fundamental_queries("AAPL")
assert len(queries) > 0
assert any("earnings" in query.lower() for query in queries)
assert any("revenue" in query.lower() for query in queries)
assert any("valuation" in query.lower() for query in queries)
@pytest.mark.asyncio
async def test_technical_research_agent(self, mock_parent_agent):
"""Test TechnicalResearchAgent execution."""
mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
return_value={
"insights": ["Bullish chart pattern"],
"sentiment": {"direction": "bullish", "confidence": 0.7},
"risk_factors": ["Support level break"],
"opportunities": ["Breakout potential"],
"credibility_score": 0.8,
}
)
subagent = TechnicalResearchAgent(mock_parent_agent)
with patch.object(subagent, "_perform_specialized_search") as mock_search:
mock_search.return_value = [
{
"title": "AAPL Technical Analysis",
"url": "https://example.com/technical",
"content": "Apple stock showing strong technical indicators...",
"credibility_score": 0.8,
}
]
task = ResearchTask("tech_task", "technical", "AAPL analysis", ["charts"])
result = await subagent.execute_research(task)
assert result["research_type"] == "technical"
assert "price_action" in result["focus_areas"]
assert "technical_indicators" in result["focus_areas"]
def test_technical_query_generation(self, mock_parent_agent):
"""Test technical analysis query generation."""
subagent = TechnicalResearchAgent(mock_parent_agent)
queries = subagent._generate_technical_queries("AAPL")
assert any("technical analysis" in query.lower() for query in queries)
assert any("chart pattern" in query.lower() for query in queries)
assert any(
"rsi" in query.lower() or "macd" in query.lower() for query in queries
)
@pytest.mark.asyncio
async def test_sentiment_research_agent(self, mock_parent_agent):
"""Test SentimentResearchAgent execution."""
mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
return_value={
"insights": ["Positive analyst sentiment"],
"sentiment": {"direction": "bullish", "confidence": 0.9},
"risk_factors": ["Market sentiment shift"],
"opportunities": ["Upgrade potential"],
"credibility_score": 0.85,
}
)
subagent = SentimentResearchAgent(mock_parent_agent)
with patch.object(subagent, "_perform_specialized_search") as mock_search:
mock_search.return_value = [
{
"title": "AAPL Analyst Upgrade",
"url": "https://example.com/upgrade",
"content": "Apple receives analyst upgrade...",
"credibility_score": 0.85,
}
]
task = ResearchTask("sent_task", "sentiment", "AAPL analysis", ["news"])
result = await subagent.execute_research(task)
assert result["research_type"] == "sentiment"
assert "market_sentiment" in result["focus_areas"]
assert result["sentiment"]["confidence"] > 0.8
@pytest.mark.asyncio
async def test_competitive_research_agent(self, mock_parent_agent):
"""Test CompetitiveResearchAgent execution."""
mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
return_value={
"insights": ["Strong competitive position"],
"sentiment": {"direction": "bullish", "confidence": 0.7},
"risk_factors": ["Increased competition"],
"opportunities": ["Market expansion"],
"credibility_score": 0.8,
}
)
subagent = CompetitiveResearchAgent(mock_parent_agent)
with patch.object(subagent, "_perform_specialized_search") as mock_search:
mock_search.return_value = [
{
"title": "AAPL Market Share Analysis",
"url": "https://example.com/marketshare",
"content": "Apple maintains strong market position...",
"credibility_score": 0.8,
}
]
task = ResearchTask(
"comp_task", "competitive", "AAPL analysis", ["market_share"]
)
result = await subagent.execute_research(task)
assert result["research_type"] == "competitive"
assert "competitive_position" in result["focus_areas"]
assert "market_share" in result["focus_areas"]
@pytest.mark.asyncio
async def test_subagent_search_deduplication(self, mock_parent_agent):
"""Test search result deduplication in subagents."""
subagent = BaseSubagent(mock_parent_agent)
# Mock search providers with duplicate results
mock_provider1 = AsyncMock()
mock_provider1.search.return_value = [
{"url": "https://example.com/article1", "title": "Article 1"},
{"url": "https://example.com/article2", "title": "Article 2"},
]
mock_provider2 = AsyncMock()
mock_provider2.search.return_value = [
{"url": "https://example.com/article1", "title": "Article 1"}, # Duplicate
{"url": "https://example.com/article3", "title": "Article 3"},
]
subagent.search_providers = [mock_provider1, mock_provider2]
results = await subagent._perform_specialized_search(
"test topic", ["test query"], max_results=10
)
# Should deduplicate by URL
urls = [result["url"] for result in results]
assert len(urls) == len(set(urls)) # No duplicates
assert len(results) == 3 # Should have 3 unique results
@pytest.mark.asyncio
async def test_subagent_search_error_handling(self, mock_parent_agent):
"""Test error handling in subagent search."""
subagent = BaseSubagent(mock_parent_agent)
# Mock provider that fails
mock_provider = AsyncMock()
mock_provider.search.side_effect = RuntimeError("Search failed")
subagent.search_providers = [mock_provider]
# Should handle errors gracefully and return empty results
results = await subagent._perform_specialized_search(
"test topic", ["test query"], max_results=10
)
assert results == [] # Should return empty list on error
@pytest.mark.asyncio
async def test_subagent_content_analysis_error_handling(self, mock_parent_agent):
"""Test content analysis error handling in subagents."""
# Mock content analyzer that fails
mock_parent_agent.content_analyzer.analyze_content = AsyncMock(
side_effect=RuntimeError("Analysis failed")
)
subagent = BaseSubagent(mock_parent_agent)
search_results = [
{
"title": "Test Article",
"url": "https://example.com/test",
"content": "Test content",
}
]
# Should handle analysis errors gracefully
results = await subagent._analyze_search_results(
search_results, "test_analysis"
)
# Should return empty results when analysis fails
assert results == []
@pytest.mark.integration
class TestDeepResearchParallelIntegration:
"""Integration tests for DeepResearchAgent parallel execution."""
@pytest.fixture
def integration_agent(self):
"""Create agent for integration testing."""
llm = MockLLM(
'{"KEY_INSIGHTS": ["Integration insight"], "SENTIMENT": {"direction": "bullish", "confidence": 0.8}}'
)
config = ParallelResearchConfig(
max_concurrent_agents=2,
timeout_per_agent=5,
enable_fallbacks=True,
rate_limit_delay=0.05,
)
return DeepResearchAgent(
llm=llm,
persona="moderate",
enable_parallel_execution=True,
parallel_config=config,
)
@pytest.mark.asyncio
async def test_end_to_end_parallel_research(self, integration_agent):
"""Test complete end-to-end parallel research workflow."""
# Mock the search providers and subagent execution
with patch.object(integration_agent, "_execute_subagent_task") as mock_execute:
mock_execute.return_value = {
"research_type": "fundamental",
"insights": ["Strong financial health", "Growing revenue"],
"sentiment": {"direction": "bullish", "confidence": 0.8},
"risk_factors": ["Market volatility"],
"opportunities": ["Expansion potential"],
"credibility_score": 0.85,
"sources": [
{
"title": "Financial Report",
"url": "https://example.com/report",
"credibility_score": 0.9,
}
],
}
start_time = time.time()
result = await integration_agent.research_comprehensive(
topic="Apple Inc comprehensive financial analysis",
session_id="integration_test_123",
depth="comprehensive",
focus_areas=["fundamentals", "sentiment", "competitive"],
)
execution_time = time.time() - start_time
# Verify result structure
assert result["status"] == "success"
assert result["agent_type"] == "deep_research"
assert result["execution_mode"] == "parallel"
assert (
result["research_topic"] == "Apple Inc comprehensive financial analysis"
)
assert result["confidence_score"] > 0
assert len(result["citations"]) > 0
assert "parallel_execution_stats" in result
# Verify performance characteristics
assert execution_time < 10 # Should complete reasonably quickly
assert result["execution_time_ms"] > 0
# Verify parallel execution stats
stats = result["parallel_execution_stats"]
assert stats["total_tasks"] > 0
assert stats["successful_tasks"] >= 0
assert stats["parallel_efficiency"] > 0
@pytest.mark.asyncio
async def test_parallel_vs_sequential_performance(self, integration_agent):
"""Test performance comparison between parallel and sequential execution."""
topic = "Microsoft Corp investment analysis"
session_id = "perf_test_123"
# Mock subagent execution with realistic delay
async def mock_subagent_execution(task):
await asyncio.sleep(0.1) # Simulate work
return {
"research_type": task.task_type,
"insights": [f"Insight from {task.task_type}"],
"sentiment": {"direction": "bullish", "confidence": 0.7},
"credibility_score": 0.8,
"sources": [],
}
with patch.object(
integration_agent,
"_execute_subagent_task",
side_effect=mock_subagent_execution,
):
# Test parallel execution
start_parallel = time.time()
parallel_result = await integration_agent.research_comprehensive(
topic=topic, session_id=session_id, use_parallel_execution=True
)
time.time() - start_parallel
# Test sequential execution
start_sequential = time.time()
sequential_result = await integration_agent.research_comprehensive(
topic=topic,
session_id=f"{session_id}_seq",
use_parallel_execution=False,
)
time.time() - start_sequential
# Verify both succeeded
assert parallel_result["status"] == "success"
assert sequential_result["status"] == "success"
# Parallel should generally be faster (though not guaranteed in all test environments)
# At minimum, parallel efficiency should be calculated
if "parallel_execution_stats" in parallel_result:
assert (
parallel_result["parallel_execution_stats"]["parallel_efficiency"]
> 0
)
@pytest.mark.asyncio
async def test_research_quality_consistency(self, integration_agent):
"""Test that parallel and sequential execution produce consistent quality."""
topic = "Tesla Inc strategic analysis"
# Mock consistent subagent responses
mock_response = {
"research_type": "fundamental",
"insights": ["Consistent insight 1", "Consistent insight 2"],
"sentiment": {"direction": "bullish", "confidence": 0.75},
"credibility_score": 0.8,
"sources": [
{
"title": "Source",
"url": "https://example.com",
"credibility_score": 0.8,
}
],
}
with patch.object(
integration_agent, "_execute_subagent_task", return_value=mock_response
):
parallel_result = await integration_agent.research_comprehensive(
topic=topic,
session_id="quality_test_parallel",
use_parallel_execution=True,
)
sequential_result = await integration_agent.research_comprehensive(
topic=topic,
session_id="quality_test_sequential",
use_parallel_execution=False,
)
# Both should succeed with reasonable confidence
assert parallel_result["status"] == "success"
assert sequential_result["status"] == "success"
assert parallel_result["confidence_score"] > 0.5
assert sequential_result["confidence_score"] > 0.5
```
--------------------------------------------------------------------------------
/maverick_mcp/api/server.py:
--------------------------------------------------------------------------------
```python
"""
MaverickMCP Server Implementation - Simple Stock Analysis MCP Server.
This module implements a simplified FastMCP server focused on stock analysis with:
- No authentication required
- No billing system
- Core stock data and technical analysis functionality
- Multi-transport support (stdio, SSE, streamable-http)
"""
# Configure warnings filter BEFORE any other imports to suppress known deprecation warnings
import warnings
warnings.filterwarnings(
"ignore",
message="pkg_resources is deprecated as an API.*",
category=UserWarning,
module="pandas_ta.*",
)
warnings.filterwarnings(
"ignore",
message="'crypt' is deprecated and slated for removal.*",
category=DeprecationWarning,
module="passlib.*",
)
warnings.filterwarnings(
"ignore",
message=".*pydantic.* is deprecated.*",
category=DeprecationWarning,
module="langchain.*",
)
warnings.filterwarnings(
"ignore",
message=".*cookie.*deprecated.*",
category=DeprecationWarning,
module="starlette.*",
)
# Suppress Plotly/Kaleido deprecation warnings from library internals
# These warnings come from the libraries themselves and can't be fixed at user level
# Comprehensive suppression patterns for all known kaleido warnings
kaleido_patterns = [
r".*plotly\.io\.kaleido\.scope\..*is deprecated.*",
r".*Use of plotly\.io\.kaleido\.scope\..*is deprecated.*",
r".*default_format.*deprecated.*",
r".*default_width.*deprecated.*",
r".*default_height.*deprecated.*",
r".*default_scale.*deprecated.*",
r".*mathjax.*deprecated.*",
r".*plotlyjs.*deprecated.*",
]
for pattern in kaleido_patterns:
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=pattern,
)
# Also suppress by module to catch any we missed
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
module=r".*kaleido.*",
)
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
module=r"plotly\.io\._kaleido",
)
# Suppress websockets deprecation warnings from uvicorn internals
# These warnings come from uvicorn's use of deprecated websockets APIs and cannot be fixed at our level
warnings.filterwarnings(
"ignore",
message=".*websockets.legacy is deprecated.*",
category=DeprecationWarning,
)
warnings.filterwarnings(
"ignore",
message=".*websockets.server.WebSocketServerProtocol is deprecated.*",
category=DeprecationWarning,
)
# Broad suppression for all websockets deprecation warnings from third-party libs
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
module="websockets.*",
)
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
module="uvicorn.protocols.websockets.*",
)
# ruff: noqa: E402 - Imports after warnings config for proper deprecation warning suppression
import argparse
import json
import sys
import uuid
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Protocol, cast
from fastapi import FastAPI
from fastmcp import FastMCP
# Import tool registry for direct registration
# This avoids Claude Desktop's issue with mounted router tool names
from maverick_mcp.api.routers.tool_registry import register_all_router_tools
from maverick_mcp.config.settings import settings
from maverick_mcp.data.models import get_db
from maverick_mcp.data.performance import (
cleanup_performance_systems,
initialize_performance_systems,
)
from maverick_mcp.providers.market_data import MarketDataProvider
from maverick_mcp.providers.stock_data import StockDataProvider
from maverick_mcp.utils.logging import get_logger, setup_structured_logging
from maverick_mcp.utils.monitoring import initialize_monitoring
from maverick_mcp.utils.structured_logger import (
get_logger_manager,
setup_backtesting_logging,
)
from maverick_mcp.utils.tracing import initialize_tracing
# Connection manager temporarily disabled for compatibility
if TYPE_CHECKING: # pragma: no cover - import used for static typing only
from maverick_mcp.infrastructure.connection_manager import MCPConnectionManager
# Monkey-patch FastMCP's create_sse_app to register both /sse and /sse/ routes
# This allows both paths to work without 307 redirects
# Fixes the mcp-remote tool registration failure issue
from fastmcp.server import http as fastmcp_http
from starlette.middleware import Middleware
from starlette.routing import BaseRoute, Route
_original_create_sse_app = fastmcp_http.create_sse_app
def _patched_create_sse_app(
server: Any,
message_path: str,
sse_path: str,
auth: Any | None = None,
debug: bool = False,
routes: list[BaseRoute] | None = None,
middleware: list[Middleware] | None = None,
) -> Any:
"""Patched version of create_sse_app that registers both /sse and /sse/ paths.
This prevents 307 redirects by registering both path variants explicitly,
fixing tool registration failures with mcp-remote that occurred when clients
used /sse instead of /sse/.
"""
import sys
print(
f"🔧 Patched create_sse_app called with sse_path={sse_path}",
file=sys.stderr,
flush=True,
)
# Call the original create_sse_app function
app = _original_create_sse_app(
server=server,
message_path=message_path,
sse_path=sse_path,
auth=auth,
debug=debug,
routes=routes,
middleware=middleware,
)
# Register both path variants (with and without trailing slash)
# Find the SSE endpoint handler from the existing routes
sse_endpoint = None
for route in app.router.routes:
if isinstance(route, Route) and route.path == sse_path:
sse_endpoint = route.endpoint
break
if sse_endpoint:
# Determine the alternative path
if sse_path.endswith("/"):
alt_path = sse_path.rstrip("/") # Remove trailing slash
else:
alt_path = sse_path + "/" # Add trailing slash
# Register the alternative path
new_route = Route(
alt_path,
endpoint=sse_endpoint,
methods=["GET"],
)
app.router.routes.insert(0, new_route)
print(
f"✅ Registered SSE routes: {sse_path} AND {alt_path}",
file=sys.stderr,
flush=True,
)
else:
print(
f"⚠️ Could not find SSE endpoint for {sse_path}",
file=sys.stderr,
flush=True,
)
return app
# Apply the monkey-patch
fastmcp_http.create_sse_app = _patched_create_sse_app
class FastMCPProtocol(Protocol):
"""Protocol describing the FastMCP interface we rely upon."""
fastapi_app: FastAPI | None
dependencies: list[Any]
def resource(
self, uri: str
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
def event(
self, name: str
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]: ...
def prompt(
self, name: str | None = None, *, description: str | None = None
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
def tool(
self, name: str | None = None, *, description: str | None = None
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
def run(self, *args: Any, **kwargs: Any) -> None: ...
_use_stderr = "--transport" in sys.argv and "stdio" in sys.argv
# Setup enhanced structured logging for backtesting
setup_backtesting_logging(
log_level=settings.api.log_level.upper(),
enable_debug=settings.api.debug,
log_file="logs/maverick_mcp.log" if not _use_stderr else None,
)
# Also setup the original logging for compatibility
setup_structured_logging(
log_level=settings.api.log_level.upper(),
log_format="json" if settings.api.debug else "text",
use_stderr=_use_stderr,
)
logger = get_logger("maverick_mcp.server")
logger_manager = get_logger_manager()
# Initialize FastMCP with enhanced connection management
_fastmcp_instance = FastMCP(
name=settings.app_name,
)
_fastmcp_instance.dependencies = []
mcp = cast(FastMCPProtocol, _fastmcp_instance)
# Initialize connection manager for stability
connection_manager: "MCPConnectionManager | None" = None
# TEMPORARILY DISABLED: MCP logging middleware - was breaking SSE transport
# TODO: Fix middleware to work properly with SSE transport
# logger.info("Adding comprehensive MCP logging middleware...")
# try:
# from maverick_mcp.api.middleware.mcp_logging import add_mcp_logging_middleware
#
# # Add logging middleware with debug mode based on settings
# include_payloads = settings.api.debug or settings.api.log_level.upper() == "DEBUG"
# import logging as py_logging
# add_mcp_logging_middleware(
# mcp,
# include_payloads=include_payloads,
# max_payload_length=3000, # Larger payloads in debug mode
# log_level=getattr(py_logging, settings.api.log_level.upper())
# )
# logger.info("✅ MCP logging middleware added successfully")
#
# # Add console notification
# print("🔧 MCP Server Enhanced Logging Enabled")
# print(" 📊 Tool calls will be logged with execution details")
# print(" 🔍 Protocol messages will be tracked for debugging")
# print(" ⏱️ Timeout detection and warnings active")
# print()
#
# except Exception as e:
# logger.warning(f"Failed to add MCP logging middleware: {e}")
# print("⚠️ Warning: MCP logging middleware could not be added")
# Initialize monitoring and observability systems
logger.info("Initializing monitoring and observability systems...")
# Initialize core monitoring
initialize_monitoring()
# Initialize distributed tracing
initialize_tracing()
# Initialize backtesting metrics collector
logger.info("Initializing backtesting metrics system...")
try:
from maverick_mcp.monitoring.metrics import get_backtesting_metrics
backtesting_collector = get_backtesting_metrics()
logger.info("✅ Backtesting metrics system initialized successfully")
# Log metrics system capabilities
print("🎯 Enhanced Backtesting Metrics System Enabled")
print(" 📊 Strategy performance tracking active")
print(" 🔄 API rate limiting and failure monitoring enabled")
print(" 💾 Resource usage monitoring configured")
print(" 🚨 Anomaly detection and alerting ready")
print(" 📈 Prometheus metrics available at /metrics")
print()
except Exception as e:
logger.warning(f"Failed to initialize backtesting metrics: {e}")
print("⚠️ Warning: Backtesting metrics system could not be initialized")
logger.info("Monitoring and observability systems initialized")
# ENHANCED CONNECTION MANAGEMENT: Register tools through connection manager
# This ensures tools persist through connection cycles and prevents disappearing tools
logger.info("Initializing enhanced connection management system...")
# Import connection manager and SSE optimizer
# Connection management imports disabled for compatibility
# from maverick_mcp.infrastructure.connection_manager import initialize_connection_management
# from maverick_mcp.infrastructure.sse_optimizer import apply_sse_optimizations
# Register all tools from routers directly for basic functionality
register_all_router_tools(_fastmcp_instance)
logger.info("Tools registered successfully")
# Register monitoring and health endpoints directly with FastMCP
from maverick_mcp.api.routers.health_enhanced import router as health_router
from maverick_mcp.api.routers.monitoring import router as monitoring_router
# Add monitoring and health endpoints to the FastMCP app's FastAPI instance
if hasattr(mcp, "fastapi_app") and mcp.fastapi_app:
mcp.fastapi_app.include_router(monitoring_router, tags=["monitoring"])
mcp.fastapi_app.include_router(health_router, tags=["health"])
logger.info("Monitoring and health endpoints registered with FastAPI application")
# Initialize enhanced health monitoring system
logger.info("Initializing enhanced health monitoring system...")
try:
from maverick_mcp.monitoring.health_monitor import get_health_monitor
from maverick_mcp.utils.circuit_breaker import initialize_all_circuit_breakers
# Initialize circuit breakers for all external APIs
circuit_breaker_success = initialize_all_circuit_breakers()
if circuit_breaker_success:
logger.info("✅ Circuit breakers initialized for all external APIs")
print("🛡️ Enhanced Circuit Breaker Protection Enabled")
print(" 🔄 yfinance, Tiingo, FRED, OpenRouter, Exa APIs protected")
print(" 📊 Failure detection and automatic recovery active")
print(" 🚨 Circuit breaker monitoring and alerting enabled")
else:
logger.warning("⚠️ Some circuit breakers failed to initialize")
# Get health monitor (will be started later in async context)
health_monitor = get_health_monitor()
logger.info("✅ Health monitoring system prepared")
print("🏥 Comprehensive Health Monitoring System Ready")
print(" 📈 Real-time component health tracking")
print(" 🔍 Database, cache, and external API monitoring")
print(" 💾 Resource usage monitoring (CPU, memory, disk)")
print(" 📊 Status dashboard with historical metrics")
print(" 🚨 Automated alerting and recovery actions")
print(
" 🩺 Health endpoints: /health, /health/detailed, /health/ready, /health/live"
)
print()
except Exception as e:
logger.warning(f"Failed to initialize enhanced health monitoring: {e}")
print("⚠️ Warning: Enhanced health monitoring could not be fully initialized")
# Add enhanced health endpoint as a resource
@mcp.resource("health://")
def health_resource() -> dict[str, Any]:
"""
Enhanced comprehensive health check endpoint.
Provides detailed system health including:
- Component status (database, cache, external APIs)
- Circuit breaker states
- Resource utilization
- Performance metrics
Financial Disclaimer: This health check is for system monitoring only and does not
provide any investment or financial advice.
"""
try:
import asyncio
from maverick_mcp.api.routers.health_enhanced import _get_detailed_health_status
loop_policy = asyncio.get_event_loop_policy()
try:
previous_loop = loop_policy.get_event_loop()
except RuntimeError:
previous_loop = None
loop = loop_policy.new_event_loop()
try:
asyncio.set_event_loop(loop)
health_status = loop.run_until_complete(_get_detailed_health_status())
finally:
loop.close()
if previous_loop is not None:
asyncio.set_event_loop(previous_loop)
else:
asyncio.set_event_loop(None)
# Add service-specific information
health_status.update(
{
"service": settings.app_name,
"version": "1.0.0",
"mode": "backtesting_with_enhanced_monitoring",
}
)
return health_status
except Exception as e:
logger.error(f"Health resource check failed: {e}")
return {
"status": "unhealthy",
"service": settings.app_name,
"version": "1.0.0",
"error": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
# Add status dashboard endpoint as a resource
@mcp.resource("dashboard://")
def status_dashboard_resource() -> dict[str, Any]:
"""
Comprehensive status dashboard with real-time metrics.
Provides aggregated health status, performance metrics, alerts,
and historical trends for the backtesting system.
"""
try:
import asyncio
from maverick_mcp.monitoring.status_dashboard import get_dashboard_data
loop_policy = asyncio.get_event_loop_policy()
try:
previous_loop = loop_policy.get_event_loop()
except RuntimeError:
previous_loop = None
loop = loop_policy.new_event_loop()
try:
asyncio.set_event_loop(loop)
dashboard_data = loop.run_until_complete(get_dashboard_data())
finally:
loop.close()
if previous_loop is not None:
asyncio.set_event_loop(previous_loop)
else:
asyncio.set_event_loop(None)
return dashboard_data
except Exception as e:
logger.error(f"Dashboard resource failed: {e}")
return {
"error": "Failed to generate dashboard",
"message": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
# Add performance dashboard endpoint as a resource (keep existing)
@mcp.resource("performance://")
def performance_dashboard() -> dict[str, Any]:
"""
Performance metrics dashboard showing backtesting system health.
Provides real-time performance metrics, resource usage, and operational statistics
for the backtesting infrastructure.
"""
try:
dashboard_metrics = logger_manager.create_dashboard_metrics()
# Add additional context
dashboard_metrics.update(
{
"service": settings.app_name,
"environment": settings.environment,
"version": "1.0.0",
"dashboard_type": "backtesting_performance",
"generated_at": datetime.now(UTC).isoformat(),
}
)
return dashboard_metrics
except Exception as e:
logger.error(f"Failed to generate performance dashboard: {e}", exc_info=True)
return {
"error": "Failed to generate performance dashboard",
"message": str(e),
"timestamp": datetime.now(UTC).isoformat(),
}
# Prompts for Trading and Investing
@mcp.prompt()
def technical_analysis(ticker: str, timeframe: str = "daily") -> str:
"""Generate a comprehensive technical analysis prompt for a stock."""
return f"""Please perform a comprehensive technical analysis for {ticker} on the {timeframe} timeframe.
Use the available tools to:
1. Fetch historical price data and current stock information
2. Generate a full technical analysis including:
- Trend analysis (primary, secondary trends)
- Support and resistance levels
- Moving averages (SMA, EMA analysis)
- Key indicators (RSI, MACD, Stochastic)
- Volume analysis and patterns
- Chart patterns identification
3. Create a technical chart visualization
4. Provide a short-term outlook
Focus on:
- Price action and volume confirmation
- Convergence/divergence of indicators
- Risk/reward setup quality
- Key decision levels for traders
Present findings in a structured format with clear entry/exit suggestions if applicable."""
@mcp.prompt()
def stock_screening_report(strategy: str = "momentum") -> str:
"""Generate a stock screening report based on different strategies."""
strategies = {
"momentum": "high momentum and relative strength",
"value": "undervalued with strong fundamentals",
"growth": "high growth potential",
"quality": "strong balance sheets and consistent earnings",
}
strategy_desc = strategies.get(strategy.lower(), "balanced approach")
return f"""Please generate a comprehensive stock screening report focused on {strategy_desc}.
Use the screening tools to:
1. Retrieve Maverick bullish stocks (for momentum/growth strategies)
2. Get Maverick bearish stocks (for short opportunities)
3. Fetch trending stocks (for breakout setups)
4. Analyze the top candidates with technical indicators
For each recommended stock:
- Current technical setup and score
- Key levels (support, resistance, stop loss)
- Risk/reward analysis
- Volume and momentum characteristics
- Sector/industry context
Organize results by:
1. Top picks (highest conviction)
2. Watch list (developing setups)
3. Avoid list (deteriorating technicals)
Include market context and any relevant economic factors."""
# Simplified portfolio and watchlist tools (no authentication required)
@mcp.tool()
async def get_user_portfolio_summary() -> dict[str, Any]:
"""
Get basic portfolio summary and stock analysis capabilities.
Returns available features and sample stock data.
"""
return {
"mode": "simple_stock_analysis",
"features": {
"stock_data": True,
"technical_analysis": True,
"market_screening": True,
"portfolio_analysis": True,
"real_time_quotes": True,
},
"sample_data": "Use get_watchlist() to see sample stock data",
"usage": "All stock analysis tools are available without restrictions",
"last_updated": datetime.now(UTC).isoformat(),
}
@mcp.tool()
async def get_watchlist(limit: int = 20) -> dict[str, Any]:
"""
Get sample watchlist with real-time stock data.
Provides stock data for popular tickers to demonstrate functionality.
"""
# Sample watchlist for demonstration
watchlist_tickers = [
"AAPL",
"MSFT",
"GOOGL",
"AMZN",
"TSLA",
"META",
"NVDA",
"JPM",
"V",
"JNJ",
"UNH",
"PG",
"HD",
"MA",
"DIS",
][:limit]
import asyncio
def _build_watchlist() -> dict[str, Any]:
db_session = next(get_db())
try:
provider = StockDataProvider(db_session=db_session)
watchlist_data: list[dict[str, Any]] = []
for ticker in watchlist_tickers:
try:
info = provider.get_stock_info(ticker)
current_price = info.get("currentPrice", 0)
previous_close = info.get("previousClose", current_price)
change = current_price - previous_close
change_pct = (
(change / previous_close * 100) if previous_close else 0
)
ticker_data = {
"ticker": ticker,
"name": info.get("longName", ticker),
"current_price": round(current_price, 2),
"change": round(change, 2),
"change_percent": round(change_pct, 2),
"volume": info.get("volume", 0),
"market_cap": info.get("marketCap", 0),
"bid": info.get("bid", 0),
"ask": info.get("ask", 0),
"bid_size": info.get("bidSize", 0),
"ask_size": info.get("askSize", 0),
"last_trade_time": datetime.now(UTC).isoformat(),
}
watchlist_data.append(ticker_data)
except Exception as exc:
logger.error(f"Error fetching data for {ticker}: {str(exc)}")
continue
return {
"watchlist": watchlist_data,
"count": len(watchlist_data),
"mode": "simple_stock_analysis",
"last_updated": datetime.now(UTC).isoformat(),
}
finally:
db_session.close()
return await asyncio.to_thread(_build_watchlist)
# Market Overview Tools (full access)
@mcp.tool()
async def get_market_overview() -> dict[str, Any]:
"""
Get comprehensive market overview including indices, sectors, and market breadth.
Provides full market data without restrictions.
"""
try:
# Create market provider instance
import asyncio
provider = MarketDataProvider()
indices, sectors, breadth = await asyncio.gather(
provider.get_market_summary_async(),
provider.get_sector_performance_async(),
provider.get_market_overview_async(),
)
overview = {
"indices": indices,
"sectors": sectors,
"market_breadth": breadth,
"last_updated": datetime.now(UTC).isoformat(),
"mode": "simple_stock_analysis",
}
vix_value = indices.get("current_price", 0)
overview["volatility"] = {
"vix": vix_value,
"vix_change": indices.get("change_percent", 0),
"fear_level": (
"extreme"
if vix_value > 30
else (
"high"
if vix_value > 20
else "moderate"
if vix_value > 15
else "low"
)
),
}
return overview
except Exception as e:
logger.error(f"Error getting market overview: {str(e)}")
return {"error": str(e), "status": "error"}
@mcp.tool()
async def get_economic_calendar(days_ahead: int = 7) -> dict[str, Any]:
"""
Get upcoming economic events and indicators.
Provides full access to economic calendar data.
"""
try:
# Get economic calendar events (placeholder implementation)
events: list[
dict[str, Any]
] = [] # macro_provider doesn't have get_economic_calendar method
return {
"events": events,
"days_ahead": days_ahead,
"event_count": len(events),
"mode": "simple_stock_analysis",
"last_updated": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Error getting economic calendar: {str(e)}")
return {"error": str(e), "status": "error"}
@mcp.tool()
async def get_mcp_connection_status() -> dict[str, Any]:
"""
Get current MCP connection status for debugging connection stability issues.
Returns detailed information about active connections, tool registration status,
and connection health metrics to help diagnose disappearing tools.
"""
try:
global connection_manager
if connection_manager is None:
return {
"error": "Connection manager not initialized",
"status": "error",
"server_mode": "simple_stock_analysis",
"timestamp": datetime.now(UTC).isoformat(),
}
# Get connection status from manager
status = connection_manager.get_connection_status()
# Add additional debugging info
status.update(
{
"server_mode": "simple_stock_analysis",
"mcp_server_name": settings.app_name,
"transport_modes": ["stdio", "sse", "streamable-http"],
"debugging_info": {
"tools_should_be_visible": status["tools_registered"],
"recommended_action": (
"Tools are registered and should be visible"
if status["tools_registered"]
else "Tools not registered - check connection manager"
),
},
"timestamp": datetime.now(UTC).isoformat(),
}
)
return status
except Exception as e:
logger.error(f"Error getting connection status: {str(e)}")
return {
"error": str(e),
"status": "error",
"timestamp": datetime.now(UTC).isoformat(),
}
# Resources (public access)
@mcp.resource("stock://{ticker}")
def stock_resource(ticker: str) -> Any:
"""Get the latest stock data for a given ticker"""
db_session = next(get_db())
try:
provider = StockDataProvider(db_session=db_session)
df = provider.get_stock_data(ticker)
payload = cast(str, df.to_json(orient="split", date_format="iso"))
return json.loads(payload)
finally:
db_session.close()
@mcp.resource("stock://{ticker}/{start_date}/{end_date}")
def stock_resource_with_dates(ticker: str, start_date: str, end_date: str) -> Any:
"""Get stock data for a given ticker and date range"""
db_session = next(get_db())
try:
provider = StockDataProvider(db_session=db_session)
df = provider.get_stock_data(ticker, start_date, end_date)
payload = cast(str, df.to_json(orient="split", date_format="iso"))
return json.loads(payload)
finally:
db_session.close()
@mcp.resource("stock_info://{ticker}")
def stock_info_resource(ticker: str) -> dict[str, Any]:
"""Get detailed information about a stock"""
db_session = next(get_db())
try:
provider = StockDataProvider(db_session=db_session)
info = provider.get_stock_info(ticker)
# Convert any non-serializable objects to strings
return {
k: (
str(v)
if not isinstance(
v, int | float | bool | str | list | dict | type(None)
)
else v
)
for k, v in info.items()
}
finally:
db_session.close()
@mcp.resource("portfolio://my-holdings")
def portfolio_holdings_resource() -> dict[str, Any]:
"""
Get your current portfolio holdings as an MCP resource.
This resource provides AI-enriched context about your portfolio for Claude to use
in conversations. It includes all positions with current prices and P&L calculations.
Returns:
Dictionary containing portfolio holdings with performance metrics
"""
from maverick_mcp.api.routers.portfolio import get_my_portfolio
try:
# Get portfolio with current prices
portfolio_data = get_my_portfolio(
user_id="default",
portfolio_name="My Portfolio",
include_current_prices=True,
)
if portfolio_data.get("status") == "error":
return {
"error": portfolio_data.get("error", "Unknown error"),
"uri": "portfolio://my-holdings",
"description": "Error retrieving portfolio holdings",
}
# Add resource metadata
portfolio_data["uri"] = "portfolio://my-holdings"
portfolio_data["description"] = (
"Your current stock portfolio with live prices and P&L"
)
portfolio_data["mimeType"] = "application/json"
return portfolio_data
except Exception as e:
logger.error(f"Portfolio holdings resource failed: {e}")
return {
"error": str(e),
"uri": "portfolio://my-holdings",
"description": "Failed to retrieve portfolio holdings",
}
# Main execution block
if __name__ == "__main__":
import asyncio
from maverick_mcp.config.validation import validate_environment
from maverick_mcp.utils.shutdown import graceful_shutdown
# Parse command line arguments
parser = argparse.ArgumentParser(
description=f"{settings.app_name} Simple Stock Analysis MCP Server"
)
parser.add_argument(
"--transport",
choices=["stdio", "sse", "streamable-http"],
default="sse",
help="Transport method to use (default: sse)",
)
parser.add_argument(
"--port",
type=int,
default=settings.api.port,
help=f"Port to run the server on (default: {settings.api.port})",
)
parser.add_argument(
"--host",
default=settings.api.host,
help=f"Host to run the server on (default: {settings.api.host})",
)
args = parser.parse_args()
# Reconfigure logging for stdio transport to use stderr
if args.transport == "stdio":
setup_structured_logging(
log_level=settings.api.log_level.upper(),
log_format="json" if settings.api.debug else "text",
use_stderr=True,
)
# Validate environment before starting
# For stdio transport, use lenient validation to support testing
fail_on_validation_error = args.transport != "stdio"
logger.info("Validating environment configuration...")
validate_environment(fail_on_error=fail_on_validation_error)
# Initialize performance systems and health monitoring
async def init_systems():
logger.info("Initializing performance optimization systems...")
try:
performance_status = await initialize_performance_systems()
logger.info(f"Performance systems initialized: {performance_status}")
except Exception as e:
logger.error(f"Failed to initialize performance systems: {e}")
# Initialize background health monitoring
logger.info("Starting background health monitoring...")
try:
from maverick_mcp.monitoring.health_monitor import start_health_monitoring
await start_health_monitoring()
logger.info("✅ Background health monitoring started")
except Exception as e:
logger.error(f"Failed to start health monitoring: {e}")
asyncio.run(init_systems())
# Initialize connection management and transport optimizations
async def init_connection_management():
global connection_manager
# Initialize connection manager (removed for linting)
logger.info("Enhanced connection management system initialized")
# Apply SSE transport optimizations (removed for linting)
logger.info("SSE transport optimizations applied")
# Add connection event handlers for monitoring
@mcp.event("connection_opened")
async def on_connection_open(session_id: str | None = None) -> str:
"""Handle new MCP connection with enhanced stability."""
if connection_manager is None:
fallback_session_id = session_id or str(uuid.uuid4())
logger.info(
"MCP connection opened without manager: %s", fallback_session_id[:8]
)
return fallback_session_id
try:
actual_session_id = await connection_manager.handle_new_connection(
session_id
)
logger.info(f"MCP connection opened: {actual_session_id[:8]}")
return actual_session_id
except Exception as e:
logger.error(f"Failed to handle connection open: {e}")
raise
@mcp.event("connection_closed")
async def on_connection_close(session_id: str) -> None:
"""Handle MCP connection close with cleanup."""
if connection_manager is None:
logger.info(
"MCP connection close received without manager: %s", session_id[:8]
)
return
try:
await connection_manager.handle_connection_close(session_id)
logger.info(f"MCP connection closed: {session_id[:8]}")
except Exception as e:
logger.error(f"Failed to handle connection close: {e}")
@mcp.event("message_received")
async def on_message_received(session_id: str, message: dict[str, Any]) -> None:
"""Update session activity on message received."""
if connection_manager is None:
logger.debug(
"Skipping session activity update; connection manager disabled."
)
return
try:
await connection_manager.update_session_activity(session_id)
except Exception as e:
logger.error(f"Failed to update session activity: {e}")
logger.info("Connection event handlers registered")
# Connection management disabled for compatibility
# asyncio.run(init_connection_management())
logger.info(f"Starting {settings.app_name} simple stock analysis server")
# Add initialization delay for connection stability
import time
logger.info("Adding startup delay for connection stability...")
time.sleep(3) # 3 second delay to ensure full initialization
logger.info("Startup delay completed, server ready for connections")
# Use graceful shutdown handler
with graceful_shutdown(f"{settings.app_name}-{args.transport}") as shutdown_handler:
# Log startup configuration
logger.info(
"Server configuration",
extra={
"transport": args.transport,
"host": args.host,
"port": args.port,
"mode": "simple_stock_analysis",
"auth_enabled": False,
"debug_mode": settings.api.debug,
"environment": settings.environment,
},
)
# Register performance systems cleanup
async def cleanup_performance():
"""Cleanup performance optimization systems during shutdown."""
try:
await cleanup_performance_systems()
except Exception as e:
logger.error(f"Error cleaning up performance systems: {e}")
shutdown_handler.register_cleanup(cleanup_performance)
# Register health monitoring cleanup
async def cleanup_health_monitoring():
"""Cleanup health monitoring during shutdown."""
try:
from maverick_mcp.monitoring.health_monitor import (
stop_health_monitoring,
)
await stop_health_monitoring()
logger.info("Health monitoring stopped")
except Exception as e:
logger.error(f"Error stopping health monitoring: {e}")
shutdown_handler.register_cleanup(cleanup_health_monitoring)
# Register connection manager cleanup
async def cleanup_connection_manager():
"""Cleanup connection manager during shutdown."""
try:
if connection_manager:
await connection_manager.shutdown()
logger.info("Connection manager shutdown complete")
except Exception as e:
logger.error(f"Error shutting down connection manager: {e}")
shutdown_handler.register_cleanup(cleanup_connection_manager)
# Register cache cleanup
def close_cache():
"""Close Redis connections during shutdown."""
from maverick_mcp.data.cache import get_redis_client
try:
redis_client = get_redis_client()
if redis_client:
logger.info("Closing Redis connections...")
redis_client.close()
logger.info("Redis connections closed")
except Exception as e:
logger.error(f"Error closing Redis: {e}")
shutdown_handler.register_cleanup(close_cache)
# Run with the appropriate transport
if args.transport == "stdio":
logger.info(f"Starting {settings.app_name} server with stdio transport")
mcp.run(
transport="stdio",
debug=settings.api.debug,
log_level=settings.api.log_level.upper(),
)
elif args.transport == "streamable-http":
logger.info(
f"Starting {settings.app_name} server with streamable-http transport on http://{args.host}:{args.port}"
)
mcp.run(
transport="streamable-http",
port=args.port,
host=args.host,
)
else: # sse
logger.info(
f"Starting {settings.app_name} server with SSE transport on http://{args.host}:{args.port}"
)
mcp.run(
transport="sse",
port=args.port,
host=args.host,
path="/sse", # No trailing slash - both /sse and /sse/ will work with the monkey-patch
)
```